Skip to main content

a3s_code_core/context/
embedding.rs

1//! Embedding Provider Extension Point
2//!
3//! Defines the trait for generating vector embeddings from text.
4//! Implementations can use LLM APIs (OpenAI, Anthropic) or local models.
5
6use anyhow::Result;
7use async_trait::async_trait;
8
9/// Vector embedding (dense float array)
10pub type Embedding = Vec<f32>;
11
12/// Trait for generating text embeddings
13#[async_trait]
14pub trait EmbeddingProvider: Send + Sync {
15    /// Provider name for logging
16    fn name(&self) -> &str;
17
18    /// Embedding dimension (e.g., 1536 for OpenAI text-embedding-3-small)
19    fn dimension(&self) -> usize;
20
21    /// Generate embedding for a single text
22    async fn embed(&self, text: &str) -> Result<Embedding>;
23
24    /// Generate embeddings for multiple texts (batch)
25    ///
26    /// Default implementation calls `embed()` sequentially.
27    async fn embed_batch(&self, texts: &[&str]) -> Result<Vec<Embedding>> {
28        let mut results = Vec::with_capacity(texts.len());
29        for text in texts {
30            results.push(self.embed(text).await?);
31        }
32        Ok(results)
33    }
34}
35
36/// OpenAI-compatible embedding provider
37///
38/// Works with OpenAI, Azure OpenAI, and any API that implements
39/// the `/v1/embeddings` endpoint.
40pub struct OpenAiEmbeddingProvider {
41    client: reqwest::Client,
42    base_url: String,
43    model: String,
44    dimension: usize,
45}
46
47impl OpenAiEmbeddingProvider {
48    /// Create a new OpenAI embedding provider
49    ///
50    /// - `api_key`: API key for authentication
51    /// - `model`: Model name (e.g., "text-embedding-3-small")
52    /// - `dimension`: Embedding dimension (e.g., 1536)
53    pub fn new(
54        api_key: impl Into<String>,
55        model: impl Into<String>,
56        dimension: usize,
57    ) -> Result<Self> {
58        Self::with_base_url(api_key, model, dimension, "https://api.openai.com/v1")
59    }
60
61    /// Create with a custom base URL (for Azure, local proxies, etc.)
62    pub fn with_base_url(
63        api_key: impl Into<String>,
64        model: impl Into<String>,
65        dimension: usize,
66        base_url: impl Into<String>,
67    ) -> Result<Self> {
68        let api_key = api_key.into();
69        let mut headers = reqwest::header::HeaderMap::new();
70        headers.insert(
71            reqwest::header::AUTHORIZATION,
72            format!("Bearer {}", api_key)
73                .parse()
74                .map_err(|e| anyhow::anyhow!("Invalid API key header: {}", e))?,
75        );
76        headers.insert(
77            reqwest::header::CONTENT_TYPE,
78            "application/json".parse().unwrap(),
79        );
80
81        let client = reqwest::Client::builder()
82            .default_headers(headers)
83            .timeout(std::time::Duration::from_secs(30))
84            .build()?;
85
86        Ok(Self {
87            client,
88            base_url: base_url.into().trim_end_matches('/').to_string(),
89            model: model.into(),
90            dimension,
91        })
92    }
93}
94
95#[async_trait]
96impl EmbeddingProvider for OpenAiEmbeddingProvider {
97    fn name(&self) -> &str {
98        "openai-embedding"
99    }
100
101    fn dimension(&self) -> usize {
102        self.dimension
103    }
104
105    async fn embed(&self, text: &str) -> Result<Embedding> {
106        let mut results = self.embed_batch(&[text]).await?;
107        results
108            .pop()
109            .ok_or_else(|| anyhow::anyhow!("Empty embedding response"))
110    }
111
112    async fn embed_batch(&self, texts: &[&str]) -> Result<Vec<Embedding>> {
113        if texts.is_empty() {
114            return Ok(Vec::new());
115        }
116
117        let url = format!("{}/embeddings", self.base_url);
118        let body = serde_json::json!({
119            "model": self.model,
120            "input": texts,
121        });
122
123        let response = self
124            .client
125            .post(&url)
126            .json(&body)
127            .send()
128            .await
129            .map_err(|e| anyhow::anyhow!("Embedding API request failed: {}", e))?;
130
131        if !response.status().is_success() {
132            let status = response.status();
133            let body = response.text().await.unwrap_or_default();
134            return Err(anyhow::anyhow!(
135                "Embedding API returned HTTP {}: {}",
136                status,
137                body
138            ));
139        }
140
141        let json: serde_json::Value = response.json().await?;
142        let data = json["data"]
143            .as_array()
144            .ok_or_else(|| anyhow::anyhow!("Invalid embedding response: missing 'data' array"))?;
145
146        let mut embeddings = Vec::with_capacity(data.len());
147        for item in data {
148            let embedding: Vec<f32> = item["embedding"]
149                .as_array()
150                .ok_or_else(|| anyhow::anyhow!("Invalid embedding item: missing 'embedding'"))?
151                .iter()
152                .filter_map(|v| v.as_f64().map(|f| f as f32))
153                .collect();
154
155            if embedding.len() != self.dimension {
156                return Err(anyhow::anyhow!(
157                    "Embedding dimension mismatch: expected {}, got {}",
158                    self.dimension,
159                    embedding.len()
160                ));
161            }
162
163            embeddings.push(embedding);
164        }
165
166        Ok(embeddings)
167    }
168}
169
170impl std::fmt::Debug for OpenAiEmbeddingProvider {
171    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
172        f.debug_struct("OpenAiEmbeddingProvider")
173            .field("base_url", &self.base_url)
174            .field("model", &self.model)
175            .field("dimension", &self.dimension)
176            .finish()
177    }
178}
179
180// ============================================================================
181// Tests
182// ============================================================================
183
184#[cfg(test)]
185mod tests {
186    use super::*;
187
188    // -- Mock embedding provider for testing --
189
190    /// Simple mock that returns deterministic embeddings based on text hash
191    pub(crate) struct MockEmbeddingProvider {
192        dim: usize,
193    }
194
195    impl MockEmbeddingProvider {
196        pub fn new(dim: usize) -> Self {
197            Self { dim }
198        }
199    }
200
201    #[async_trait]
202    impl EmbeddingProvider for MockEmbeddingProvider {
203        fn name(&self) -> &str {
204            "mock-embedding"
205        }
206
207        fn dimension(&self) -> usize {
208            self.dim
209        }
210
211        async fn embed(&self, text: &str) -> Result<Embedding> {
212            // Deterministic pseudo-embedding from text bytes
213            let mut embedding = vec![0.0f32; self.dim];
214            for (i, byte) in text.bytes().enumerate() {
215                embedding[i % self.dim] += (byte as f32) / 255.0;
216            }
217            // Normalize
218            let norm: f32 = embedding.iter().map(|x| x * x).sum::<f32>().sqrt();
219            if norm > 0.0 {
220                for v in &mut embedding {
221                    *v /= norm;
222                }
223            }
224            Ok(embedding)
225        }
226    }
227
228    #[test]
229    fn test_embedding_type() {
230        let emb: Embedding = vec![0.1, 0.2, 0.3];
231        assert_eq!(emb.len(), 3);
232    }
233
234    #[tokio::test]
235    async fn test_mock_embedding_provider() {
236        let provider = MockEmbeddingProvider::new(8);
237        assert_eq!(provider.name(), "mock-embedding");
238        assert_eq!(provider.dimension(), 8);
239
240        let emb = provider.embed("hello world").await.unwrap();
241        assert_eq!(emb.len(), 8);
242
243        // Verify normalization (unit vector)
244        let norm: f32 = emb.iter().map(|x| x * x).sum::<f32>().sqrt();
245        assert!((norm - 1.0).abs() < 0.01);
246    }
247
248    #[tokio::test]
249    async fn test_mock_embedding_deterministic() {
250        let provider = MockEmbeddingProvider::new(8);
251        let emb1 = provider.embed("test input").await.unwrap();
252        let emb2 = provider.embed("test input").await.unwrap();
253        assert_eq!(emb1, emb2);
254    }
255
256    #[tokio::test]
257    async fn test_mock_embedding_different_texts() {
258        let provider = MockEmbeddingProvider::new(8);
259        let emb1 = provider.embed("hello").await.unwrap();
260        let emb2 = provider.embed("world").await.unwrap();
261        assert_ne!(emb1, emb2);
262    }
263
264    #[tokio::test]
265    async fn test_embed_batch_default() {
266        let provider = MockEmbeddingProvider::new(4);
267        let results = provider
268            .embed_batch(&["hello", "world", "test"])
269            .await
270            .unwrap();
271        assert_eq!(results.len(), 3);
272        for emb in &results {
273            assert_eq!(emb.len(), 4);
274        }
275    }
276
277    #[tokio::test]
278    async fn test_embed_batch_empty() {
279        let provider = MockEmbeddingProvider::new(4);
280        let results = provider.embed_batch(&[]).await.unwrap();
281        assert!(results.is_empty());
282    }
283
284    #[test]
285    fn test_openai_embedding_provider_debug() {
286        let provider = OpenAiEmbeddingProvider {
287            client: reqwest::Client::new(),
288            base_url: "https://api.openai.com/v1".to_string(),
289            model: "text-embedding-3-small".to_string(),
290            dimension: 1536,
291        };
292        let debug = format!("{:?}", provider);
293        assert!(debug.contains("OpenAiEmbeddingProvider"));
294        assert!(debug.contains("text-embedding-3-small"));
295        assert!(debug.contains("1536"));
296    }
297}