Skip to main content

ai_lib_rust/embeddings/
types.rs

1//! Embedding types and data structures.
2
3use serde::{Deserialize, Serialize};
4
5/// A single embedding vector with metadata.
6#[derive(Debug, Clone, Serialize, Deserialize)]
7pub struct Embedding {
8    pub index: usize,
9    pub vector: Vec<f32>,
10    #[serde(default = "default_object_type")]
11    pub object_type: String,
12}
13
14fn default_object_type() -> String {
15    "embedding".to_string()
16}
17
18impl Embedding {
19    pub fn new(index: usize, vector: Vec<f32>) -> Self {
20        Self { index, vector, object_type: "embedding".to_string() }
21    }
22
23    pub fn dimensions(&self) -> usize {
24        self.vector.len()
25    }
26}
27
28/// Request for generating embeddings.
29#[derive(Debug, Clone, Serialize, Deserialize)]
30pub struct EmbeddingRequest {
31    pub input: EmbeddingInput,
32    pub model: String,
33    #[serde(skip_serializing_if = "Option::is_none")]
34    pub dimensions: Option<usize>,
35    #[serde(skip_serializing_if = "Option::is_none")]
36    pub encoding_format: Option<String>,
37    #[serde(skip_serializing_if = "Option::is_none")]
38    pub user: Option<String>,
39}
40
41#[derive(Debug, Clone, Serialize, Deserialize)]
42#[serde(untagged)]
43pub enum EmbeddingInput {
44    Single(String),
45    Batch(Vec<String>),
46}
47
48impl EmbeddingRequest {
49    pub fn single(model: impl Into<String>, text: impl Into<String>) -> Self {
50        Self {
51            input: EmbeddingInput::Single(text.into()),
52            model: model.into(),
53            dimensions: None,
54            encoding_format: None,
55            user: None,
56        }
57    }
58
59    pub fn batch(model: impl Into<String>, texts: Vec<String>) -> Self {
60        Self {
61            input: EmbeddingInput::Batch(texts),
62            model: model.into(),
63            dimensions: None,
64            encoding_format: None,
65            user: None,
66        }
67    }
68
69    pub fn with_dimensions(mut self, dimensions: usize) -> Self {
70        self.dimensions = Some(dimensions);
71        self
72    }
73}
74
75#[derive(Debug, Clone, Default, Serialize, Deserialize)]
76pub struct EmbeddingUsage {
77    pub prompt_tokens: u32,
78    pub total_tokens: u32,
79}
80
81impl EmbeddingUsage {
82    pub fn new(prompt_tokens: u32) -> Self {
83        Self { prompt_tokens, total_tokens: prompt_tokens }
84    }
85
86    pub fn add(&mut self, other: &EmbeddingUsage) {
87        self.prompt_tokens += other.prompt_tokens;
88        self.total_tokens += other.total_tokens;
89    }
90}
91
92#[derive(Debug, Clone, Serialize, Deserialize)]
93pub struct EmbeddingResponse {
94    pub embeddings: Vec<Embedding>,
95    pub model: String,
96    pub usage: EmbeddingUsage,
97    #[serde(default = "default_list_type")]
98    pub object: String,
99}
100
101fn default_list_type() -> String { "list".to_string() }
102
103impl EmbeddingResponse {
104    pub fn new(embeddings: Vec<Embedding>, model: String, usage: EmbeddingUsage) -> Self {
105        Self { embeddings, model, usage, object: "list".to_string() }
106    }
107
108    pub fn first(&self) -> Option<&Embedding> { self.embeddings.first() }
109    pub fn len(&self) -> usize { self.embeddings.len() }
110    pub fn is_empty(&self) -> bool { self.embeddings.is_empty() }
111
112    pub fn from_openai_format(data: &serde_json::Value) -> crate::Result<Self> {
113        let embeddings = data["data"]
114            .as_array()
115            .ok_or_else(|| crate::Error::parsing("Missing 'data' array"))?
116            .iter()
117            .map(|item| {
118                let index = item["index"].as_u64().unwrap_or(0) as usize;
119                let vector: Vec<f32> = item["embedding"]
120                    .as_array()
121                    .map(|arr| arr.iter().filter_map(|v| v.as_f64().map(|f| f as f32)).collect())
122                    .unwrap_or_default();
123                Embedding::new(index, vector)
124            })
125            .collect();
126        let model = data["model"].as_str().unwrap_or("unknown").to_string();
127        let usage = EmbeddingUsage {
128            prompt_tokens: data["usage"]["prompt_tokens"].as_u64().unwrap_or(0) as u32,
129            total_tokens: data["usage"]["total_tokens"].as_u64().unwrap_or(0) as u32,
130        };
131        Ok(Self::new(embeddings, model, usage))
132    }
133}
134
135#[derive(Debug, Clone, Serialize, Deserialize)]
136pub struct EmbeddingModel {
137    pub id: String,
138    pub name: String,
139    pub max_input_tokens: u32,
140    pub dimensions: usize,
141    pub provider: String,
142}
143
144impl EmbeddingModel {
145    pub fn text_embedding_3_small() -> Self {
146        Self { id: "text-embedding-3-small".into(), name: "Text Embedding 3 Small".into(), max_input_tokens: 8191, dimensions: 1536, provider: "openai".into() }
147    }
148    pub fn text_embedding_3_large() -> Self {
149        Self { id: "text-embedding-3-large".into(), name: "Text Embedding 3 Large".into(), max_input_tokens: 8191, dimensions: 3072, provider: "openai".into() }
150    }
151}