use serde::{Deserialize, Serialize};
use std::collections::HashMap;
#[derive(Debug, Clone, Default)]
pub struct EmbeddingRequest {
pub input: Vec<String>,
pub model: Option<String>,
pub dimensions: Option<u32>,
pub encoding_format: Option<EmbeddingFormat>,
pub user: Option<String>,
pub provider_params: HashMap<String, serde_json::Value>,
}
impl EmbeddingRequest {
pub fn new(input: Vec<String>) -> Self {
Self {
input,
..Default::default()
}
}
pub fn single(text: impl Into<String>) -> Self {
Self::new(vec![text.into()])
}
pub fn query(text: impl Into<String>) -> Self {
Self::single(text).with_task_type(EmbeddingTaskType::RetrievalQuery)
}
pub fn document(text: impl Into<String>) -> Self {
Self::single(text).with_task_type(EmbeddingTaskType::RetrievalDocument)
}
pub fn similarity(text: impl Into<String>) -> Self {
Self::single(text).with_task_type(EmbeddingTaskType::SemanticSimilarity)
}
pub fn classification(text: impl Into<String>) -> Self {
Self::single(text).with_task_type(EmbeddingTaskType::Classification)
}
pub fn clustering(text: impl Into<String>) -> Self {
Self::single(text).with_task_type(EmbeddingTaskType::Clustering)
}
pub fn question_answering(text: impl Into<String>) -> Self {
Self::single(text).with_task_type(EmbeddingTaskType::QuestionAnswering)
}
pub fn fact_verification(text: impl Into<String>) -> Self {
Self::single(text).with_task_type(EmbeddingTaskType::FactVerification)
}
pub fn with_model(mut self, model: impl Into<String>) -> Self {
self.model = Some(model.into());
self
}
pub fn with_dimensions(mut self, dimensions: u32) -> Self {
self.dimensions = Some(dimensions);
self
}
pub fn with_encoding_format(mut self, format: EmbeddingFormat) -> Self {
self.encoding_format = Some(format);
self
}
pub fn with_user(mut self, user: impl Into<String>) -> Self {
self.user = Some(user.into());
self
}
pub fn with_provider_param(mut self, key: impl Into<String>, value: serde_json::Value) -> Self {
self.provider_params.insert(key.into(), value);
self
}
pub fn with_task_type(mut self, task_type: EmbeddingTaskType) -> Self {
let task_str = match task_type {
EmbeddingTaskType::RetrievalQuery => "RETRIEVAL_QUERY",
EmbeddingTaskType::RetrievalDocument => "RETRIEVAL_DOCUMENT",
EmbeddingTaskType::SemanticSimilarity => "SEMANTIC_SIMILARITY",
EmbeddingTaskType::Classification => "CLASSIFICATION",
EmbeddingTaskType::Clustering => "CLUSTERING",
EmbeddingTaskType::QuestionAnswering => "QUESTION_ANSWERING",
EmbeddingTaskType::FactVerification => "FACT_VERIFICATION",
EmbeddingTaskType::Unspecified => "TASK_TYPE_UNSPECIFIED",
};
self.provider_params.insert(
"task_type".to_string(),
serde_json::Value::String(task_str.to_string()),
);
self
}
pub fn with_provider_params(mut self, params: HashMap<String, serde_json::Value>) -> Self {
self.provider_params.extend(params);
self
}
}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
#[serde(rename_all = "lowercase")]
pub enum EmbeddingFormat {
Float,
Base64,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct EmbeddingResponse {
pub embeddings: Vec<Vec<f32>>,
pub model: String,
pub usage: Option<EmbeddingUsage>,
#[serde(default)]
pub metadata: HashMap<String, serde_json::Value>,
}
impl EmbeddingResponse {
pub fn new(embeddings: Vec<Vec<f32>>, model: String) -> Self {
Self {
embeddings,
model,
usage: None,
metadata: HashMap::new(),
}
}
pub fn count(&self) -> usize {
self.embeddings.len()
}
pub fn dimension(&self) -> Option<usize> {
self.embeddings.first().map(|e| e.len())
}
pub fn is_empty(&self) -> bool {
self.embeddings.is_empty()
}
pub fn get(&self, index: usize) -> Option<&Vec<f32>> {
self.embeddings.get(index)
}
pub fn with_metadata(mut self, key: String, value: serde_json::Value) -> Self {
self.metadata.insert(key, value);
self
}
pub fn with_usage(mut self, usage: EmbeddingUsage) -> Self {
self.usage = Some(usage);
self
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct EmbeddingUsage {
pub prompt_tokens: u32,
pub total_tokens: u32,
}
impl EmbeddingUsage {
pub fn new(prompt_tokens: u32, total_tokens: u32) -> Self {
Self {
prompt_tokens,
total_tokens,
}
}
}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
#[serde(rename_all = "SCREAMING_SNAKE_CASE")]
pub enum EmbeddingTaskType {
RetrievalQuery,
RetrievalDocument,
SemanticSimilarity,
Classification,
Clustering,
QuestionAnswering,
FactVerification,
Unspecified,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct EmbeddingModelInfo {
pub id: String,
pub name: String,
pub dimension: usize,
pub max_input_tokens: usize,
pub supported_tasks: Vec<EmbeddingTaskType>,
pub supports_custom_dimensions: bool,
}
impl EmbeddingModelInfo {
pub fn new(id: String, name: String, dimension: usize, max_input_tokens: usize) -> Self {
Self {
id,
name,
dimension,
max_input_tokens,
supported_tasks: vec![EmbeddingTaskType::Unspecified],
supports_custom_dimensions: false,
}
}
pub fn with_task(mut self, task: EmbeddingTaskType) -> Self {
self.supported_tasks.push(task);
self
}
pub fn with_custom_dimensions(mut self) -> Self {
self.supports_custom_dimensions = true;
self
}
}
#[derive(Debug, Clone)]
pub struct BatchEmbeddingRequest {
pub requests: Vec<EmbeddingRequest>,
pub batch_options: BatchOptions,
}
#[derive(Debug, Clone, Default)]
pub struct BatchOptions {
pub max_concurrency: Option<usize>,
pub request_timeout: Option<std::time::Duration>,
pub fail_fast: bool,
}
#[derive(Debug, Clone)]
pub struct BatchEmbeddingResponse {
pub responses: Vec<Result<EmbeddingResponse, String>>,
pub metadata: HashMap<String, serde_json::Value>,
}