Skip to main content

serdes_ai_embeddings/
model.rs

1//! Embedding model trait and types.
2
3use crate::embedding::Embedding;
4use crate::error::{EmbeddingError, EmbeddingResult};
5use async_trait::async_trait;
6use serde::{Deserialize, Serialize};
7
8/// Result of an embedding operation.
9#[derive(Debug, Clone)]
10pub struct EmbeddingOutput {
11    /// The embedding vectors.
12    pub embeddings: Vec<Embedding>,
13    /// Total tokens used.
14    pub total_tokens: Option<u64>,
15    /// Model name used.
16    pub model: String,
17}
18
19impl EmbeddingOutput {
20    /// Create a new embedding output.
21    pub fn new(embeddings: Vec<Embedding>, model: impl Into<String>) -> Self {
22        Self {
23            embeddings,
24            total_tokens: None,
25            model: model.into(),
26        }
27    }
28
29    /// Set the token count.
30    pub fn with_tokens(mut self, tokens: u64) -> Self {
31        self.total_tokens = Some(tokens);
32        self
33    }
34
35    /// Get the first (single) embedding.
36    pub fn embedding(&self) -> Option<&Embedding> {
37        self.embeddings.first()
38    }
39
40    /// Get the dimensionality.
41    pub fn dimensions(&self) -> Option<usize> {
42        self.embeddings.first().map(|e| e.dimensions())
43    }
44
45    /// Check if empty.
46    pub fn is_empty(&self) -> bool {
47        self.embeddings.is_empty()
48    }
49
50    /// Get number of embeddings.
51    pub fn len(&self) -> usize {
52        self.embeddings.len()
53    }
54}
55
56/// Input type for embedding operations.
57#[derive(Debug, Clone)]
58pub enum EmbedInput {
59    /// Single text query.
60    Query(String),
61    /// Multiple documents.
62    Documents(Vec<String>),
63}
64
65impl EmbedInput {
66    /// Get the number of texts.
67    pub fn len(&self) -> usize {
68        match self {
69            Self::Query(_) => 1,
70            Self::Documents(docs) => docs.len(),
71        }
72    }
73
74    /// Check if empty.
75    pub fn is_empty(&self) -> bool {
76        match self {
77            Self::Query(q) => q.is_empty(),
78            Self::Documents(docs) => docs.is_empty(),
79        }
80    }
81
82    /// Convert to list of strings.
83    pub fn into_texts(self) -> Vec<String> {
84        match self {
85            Self::Query(q) => vec![q],
86            Self::Documents(docs) => docs,
87        }
88    }
89
90    /// Get as list of string references.
91    pub fn texts(&self) -> Vec<&str> {
92        match self {
93            Self::Query(q) => vec![q.as_str()],
94            Self::Documents(docs) => docs.iter().map(|s| s.as_str()).collect(),
95        }
96    }
97}
98
99impl From<&str> for EmbedInput {
100    fn from(s: &str) -> Self {
101        Self::Query(s.to_string())
102    }
103}
104
105impl From<String> for EmbedInput {
106    fn from(s: String) -> Self {
107        Self::Query(s)
108    }
109}
110
111impl From<Vec<String>> for EmbedInput {
112    fn from(docs: Vec<String>) -> Self {
113        Self::Documents(docs)
114    }
115}
116
117impl From<Vec<&str>> for EmbedInput {
118    fn from(docs: Vec<&str>) -> Self {
119        Self::Documents(docs.into_iter().map(String::from).collect())
120    }
121}
122
123/// Encoding format for embeddings.
124#[derive(Debug, Clone, Copy, Default, PartialEq, Eq, Serialize, Deserialize)]
125#[serde(rename_all = "lowercase")]
126pub enum EncodingFormat {
127    /// Standard floating point.
128    #[default]
129    Float,
130    /// Base64 encoded.
131    Base64,
132}
133
134/// Settings for embedding requests.
135#[derive(Debug, Clone, Default)]
136pub struct EmbeddingSettings {
137    /// Override dimensions (if model supports).
138    pub dimensions: Option<usize>,
139    /// Encoding format.
140    pub encoding_format: Option<EncodingFormat>,
141    /// User identifier for tracking.
142    pub user: Option<String>,
143    /// Input type hint (query vs document).
144    pub input_type: Option<InputType>,
145    /// Truncation mode.
146    pub truncation: Option<TruncationMode>,
147}
148
149impl EmbeddingSettings {
150    /// Create new default settings.
151    pub fn new() -> Self {
152        Self::default()
153    }
154
155    /// Set dimensions.
156    pub fn dimensions(mut self, dims: usize) -> Self {
157        self.dimensions = Some(dims);
158        self
159    }
160
161    /// Set encoding format.
162    pub fn encoding_format(mut self, format: EncodingFormat) -> Self {
163        self.encoding_format = Some(format);
164        self
165    }
166
167    /// Set user identifier.
168    pub fn user(mut self, user: impl Into<String>) -> Self {
169        self.user = Some(user.into());
170        self
171    }
172
173    /// Set input type.
174    pub fn input_type(mut self, input_type: InputType) -> Self {
175        self.input_type = Some(input_type);
176        self
177    }
178
179    /// Set truncation mode.
180    pub fn truncation(mut self, mode: TruncationMode) -> Self {
181        self.truncation = Some(mode);
182        self
183    }
184}
185
186/// Type of input for embeddings.
187#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
188#[serde(rename_all = "snake_case")]
189pub enum InputType {
190    /// Search query (optimized for retrieval).
191    SearchQuery,
192    /// Document to be indexed.
193    SearchDocument,
194    /// Classification input.
195    Classification,
196    /// Clustering input.
197    Clustering,
198}
199
200/// Truncation mode for long inputs.
201#[derive(Debug, Clone, Copy, Default, PartialEq, Eq)]
202pub enum TruncationMode {
203    /// Truncate to fit model's context.
204    #[default]
205    End,
206    /// Truncate from start.
207    Start,
208    /// No truncation (error if too long).
209    None,
210}
211
212/// Core trait for embedding models.
213#[async_trait]
214pub trait EmbeddingModel: Send + Sync {
215    /// Get the model name.
216    fn name(&self) -> &str;
217
218    /// Get the default embedding dimensions.
219    fn dimensions(&self) -> usize;
220
221    /// Get the maximum tokens per input.
222    fn max_tokens(&self) -> usize {
223        8192 // Default, override for specific models
224    }
225
226    /// Generate embeddings.
227    async fn embed(
228        &self,
229        input: EmbedInput,
230        settings: &EmbeddingSettings,
231    ) -> EmbeddingResult<EmbeddingOutput>;
232
233    /// Embed a single query.
234    async fn embed_query(&self, query: &str) -> EmbeddingResult<EmbeddingOutput> {
235        self.embed(
236            EmbedInput::Query(query.to_string()),
237            &EmbeddingSettings::default().input_type(InputType::SearchQuery),
238        )
239        .await
240    }
241
242    /// Embed multiple documents.
243    async fn embed_documents(&self, docs: Vec<String>) -> EmbeddingResult<EmbeddingOutput> {
244        self.embed(
245            EmbedInput::Documents(docs),
246            &EmbeddingSettings::default().input_type(InputType::SearchDocument),
247        )
248        .await
249    }
250
251    /// Count tokens in text (if supported).
252    async fn count_tokens(&self, _text: &str) -> EmbeddingResult<u64> {
253        Err(EmbeddingError::NotSupported("Token counting".into()))
254    }
255}
256
257/// Boxed embedding model for dynamic dispatch.
258pub type BoxedEmbeddingModel = Box<dyn EmbeddingModel>;
259
260#[cfg(test)]
261mod tests {
262    use super::*;
263
264    #[test]
265    fn test_embed_input_from_str() {
266        let input: EmbedInput = "hello".into();
267        assert!(matches!(input, EmbedInput::Query(_)));
268        assert_eq!(input.len(), 1);
269    }
270
271    #[test]
272    fn test_embed_input_from_vec() {
273        let input: EmbedInput = vec!["a", "b", "c"].into();
274        assert!(matches!(input, EmbedInput::Documents(_)));
275        assert_eq!(input.len(), 3);
276    }
277
278    #[test]
279    fn test_embedding_output() {
280        let embeddings = vec![
281            Embedding::new(vec![1.0, 2.0, 3.0]),
282            Embedding::new(vec![4.0, 5.0, 6.0]),
283        ];
284        let output = EmbeddingOutput::new(embeddings, "test-model").with_tokens(100);
285
286        assert_eq!(output.len(), 2);
287        assert_eq!(output.dimensions(), Some(3));
288        assert_eq!(output.total_tokens, Some(100));
289    }
290
291    #[test]
292    fn test_embedding_settings() {
293        let settings = EmbeddingSettings::new()
294            .dimensions(1536)
295            .input_type(InputType::SearchQuery)
296            .user("user-123");
297
298        assert_eq!(settings.dimensions, Some(1536));
299        assert_eq!(settings.input_type, Some(InputType::SearchQuery));
300    }
301}