Skip to main content

phago_embeddings/
embedder.rs

1//! Core embedder trait and types.
2
3use thiserror::Error;
4
5/// Embedding error types.
6#[derive(Debug, Error)]
7pub enum EmbeddingError {
8    #[error("Model not loaded: {0}")]
9    ModelNotLoaded(String),
10
11    #[error("Tokenization failed: {0}")]
12    TokenizationFailed(String),
13
14    #[error("Inference failed: {0}")]
15    InferenceFailed(String),
16
17    #[error("API error: {0}")]
18    ApiError(String),
19
20    #[error("Invalid input: {0}")]
21    InvalidInput(String),
22
23    #[error("Dimension mismatch: expected {expected}, got {got}")]
24    DimensionMismatch { expected: usize, got: usize },
25
26    #[error("IO error: {0}")]
27    Io(#[from] std::io::Error),
28}
29
30/// Result type for embedding operations.
31pub type EmbeddingResult<T> = Result<T, EmbeddingError>;
32
33/// Core trait for embedding providers.
34///
35/// Implementors convert text to dense vectors for semantic similarity.
36pub trait Embedder: Send + Sync {
37    /// Embed a single text string.
38    fn embed(&self, text: &str) -> EmbeddingResult<Vec<f32>>;
39
40    /// Embed multiple texts in a batch (more efficient).
41    fn embed_batch(&self, texts: &[&str]) -> EmbeddingResult<Vec<Vec<f32>>> {
42        // Default implementation: embed one by one
43        texts.iter().map(|t| self.embed(t)).collect()
44    }
45
46    /// Get the embedding dimension.
47    fn dimension(&self) -> usize;
48
49    /// Get the model name/identifier.
50    fn model_name(&self) -> &str;
51
52    /// Compute cosine similarity between two vectors.
53    fn similarity(&self, a: &[f32], b: &[f32]) -> EmbeddingResult<f32> {
54        if a.len() != b.len() {
55            return Err(EmbeddingError::DimensionMismatch {
56                expected: a.len(),
57                got: b.len(),
58            });
59        }
60
61        let dot: f32 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
62        let norm_a: f32 = a.iter().map(|x| x * x).sum::<f32>().sqrt();
63        let norm_b: f32 = b.iter().map(|x| x * x).sum::<f32>().sqrt();
64
65        if norm_a == 0.0 || norm_b == 0.0 {
66            return Ok(0.0);
67        }
68
69        Ok(dot / (norm_a * norm_b))
70    }
71
72    /// Find the most similar text from a list.
73    fn most_similar<'a>(
74        &self,
75        query: &str,
76        candidates: &[&'a str],
77    ) -> EmbeddingResult<Option<(&'a str, f32)>> {
78        if candidates.is_empty() {
79            return Ok(None);
80        }
81
82        let query_vec = self.embed(query)?;
83        let candidate_vecs = self.embed_batch(candidates)?;
84
85        let mut best: Option<(usize, f32)> = None;
86        for (i, vec) in candidate_vecs.iter().enumerate() {
87            let sim = self.similarity(&query_vec, vec)?;
88            if best.is_none() || sim > best.unwrap().1 {
89                best = Some((i, sim));
90            }
91        }
92
93        Ok(best.map(|(i, sim)| (candidates[i], sim)))
94    }
95
96    /// Find top-k most similar texts.
97    fn top_k_similar<'a>(
98        &self,
99        query: &str,
100        candidates: &[&'a str],
101        k: usize,
102    ) -> EmbeddingResult<Vec<(&'a str, f32)>> {
103        if candidates.is_empty() || k == 0 {
104            return Ok(vec![]);
105        }
106
107        let query_vec = self.embed(query)?;
108        let candidate_vecs = self.embed_batch(candidates)?;
109
110        let mut scores: Vec<(usize, f32)> = candidate_vecs
111            .iter()
112            .enumerate()
113            .map(|(i, vec)| {
114                let sim = self.similarity(&query_vec, vec).unwrap_or(0.0);
115                (i, sim)
116            })
117            .collect();
118
119        // Sort by similarity descending
120        scores.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
121
122        Ok(scores
123            .into_iter()
124            .take(k)
125            .map(|(i, sim)| (candidates[i], sim))
126            .collect())
127    }
128}
129
130/// Embedding with metadata.
131#[derive(Debug, Clone)]
132pub struct Embedding {
133    /// The vector representation.
134    pub vector: Vec<f32>,
135    /// Original text (optional).
136    pub text: Option<String>,
137    /// Token count.
138    pub tokens: usize,
139}
140
141impl Embedding {
142    /// Create a new embedding.
143    pub fn new(vector: Vec<f32>) -> Self {
144        Self {
145            vector,
146            text: None,
147            tokens: 0,
148        }
149    }
150
151    /// Create with text.
152    pub fn with_text(vector: Vec<f32>, text: String, tokens: usize) -> Self {
153        Self {
154            vector,
155            text: Some(text),
156            tokens,
157        }
158    }
159
160    /// Get dimension.
161    pub fn dimension(&self) -> usize {
162        self.vector.len()
163    }
164}