a3s_code_core/context/
embedding.rs1use anyhow::Result;
7use async_trait::async_trait;
8
9pub type Embedding = Vec<f32>;
11
12#[async_trait]
14pub trait EmbeddingProvider: Send + Sync {
15 fn name(&self) -> &str;
17
18 fn dimension(&self) -> usize;
20
21 async fn embed(&self, text: &str) -> Result<Embedding>;
23
24 async fn embed_batch(&self, texts: &[&str]) -> Result<Vec<Embedding>> {
28 let mut results = Vec::with_capacity(texts.len());
29 for text in texts {
30 results.push(self.embed(text).await?);
31 }
32 Ok(results)
33 }
34}
35
36pub struct OpenAiEmbeddingProvider {
41 client: reqwest::Client,
42 base_url: String,
43 model: String,
44 dimension: usize,
45}
46
47impl OpenAiEmbeddingProvider {
48 pub fn new(
54 api_key: impl Into<String>,
55 model: impl Into<String>,
56 dimension: usize,
57 ) -> Result<Self> {
58 Self::with_base_url(api_key, model, dimension, "https://api.openai.com/v1")
59 }
60
61 pub fn with_base_url(
63 api_key: impl Into<String>,
64 model: impl Into<String>,
65 dimension: usize,
66 base_url: impl Into<String>,
67 ) -> Result<Self> {
68 let api_key = api_key.into();
69 let mut headers = reqwest::header::HeaderMap::new();
70 headers.insert(
71 reqwest::header::AUTHORIZATION,
72 format!("Bearer {}", api_key)
73 .parse()
74 .map_err(|e| anyhow::anyhow!("Invalid API key header: {}", e))?,
75 );
76 headers.insert(
77 reqwest::header::CONTENT_TYPE,
78 "application/json".parse().unwrap(),
79 );
80
81 let client = reqwest::Client::builder()
82 .default_headers(headers)
83 .timeout(std::time::Duration::from_secs(30))
84 .build()?;
85
86 Ok(Self {
87 client,
88 base_url: base_url.into().trim_end_matches('/').to_string(),
89 model: model.into(),
90 dimension,
91 })
92 }
93}
94
95#[async_trait]
96impl EmbeddingProvider for OpenAiEmbeddingProvider {
97 fn name(&self) -> &str {
98 "openai-embedding"
99 }
100
101 fn dimension(&self) -> usize {
102 self.dimension
103 }
104
105 async fn embed(&self, text: &str) -> Result<Embedding> {
106 let mut results = self.embed_batch(&[text]).await?;
107 results
108 .pop()
109 .ok_or_else(|| anyhow::anyhow!("Empty embedding response"))
110 }
111
112 async fn embed_batch(&self, texts: &[&str]) -> Result<Vec<Embedding>> {
113 if texts.is_empty() {
114 return Ok(Vec::new());
115 }
116
117 let url = format!("{}/embeddings", self.base_url);
118 let body = serde_json::json!({
119 "model": self.model,
120 "input": texts,
121 });
122
123 let response = self
124 .client
125 .post(&url)
126 .json(&body)
127 .send()
128 .await
129 .map_err(|e| anyhow::anyhow!("Embedding API request failed: {}", e))?;
130
131 if !response.status().is_success() {
132 let status = response.status();
133 let body = response.text().await.unwrap_or_default();
134 return Err(anyhow::anyhow!(
135 "Embedding API returned HTTP {}: {}",
136 status,
137 body
138 ));
139 }
140
141 let json: serde_json::Value = response.json().await?;
142 let data = json["data"]
143 .as_array()
144 .ok_or_else(|| anyhow::anyhow!("Invalid embedding response: missing 'data' array"))?;
145
146 let mut embeddings = Vec::with_capacity(data.len());
147 for item in data {
148 let embedding: Vec<f32> = item["embedding"]
149 .as_array()
150 .ok_or_else(|| anyhow::anyhow!("Invalid embedding item: missing 'embedding'"))?
151 .iter()
152 .filter_map(|v| v.as_f64().map(|f| f as f32))
153 .collect();
154
155 if embedding.len() != self.dimension {
156 return Err(anyhow::anyhow!(
157 "Embedding dimension mismatch: expected {}, got {}",
158 self.dimension,
159 embedding.len()
160 ));
161 }
162
163 embeddings.push(embedding);
164 }
165
166 Ok(embeddings)
167 }
168}
169
170impl std::fmt::Debug for OpenAiEmbeddingProvider {
171 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
172 f.debug_struct("OpenAiEmbeddingProvider")
173 .field("base_url", &self.base_url)
174 .field("model", &self.model)
175 .field("dimension", &self.dimension)
176 .finish()
177 }
178}
179
180#[cfg(test)]
185mod tests {
186 use super::*;
187
188 pub(crate) struct MockEmbeddingProvider {
192 dim: usize,
193 }
194
195 impl MockEmbeddingProvider {
196 pub fn new(dim: usize) -> Self {
197 Self { dim }
198 }
199 }
200
201 #[async_trait]
202 impl EmbeddingProvider for MockEmbeddingProvider {
203 fn name(&self) -> &str {
204 "mock-embedding"
205 }
206
207 fn dimension(&self) -> usize {
208 self.dim
209 }
210
211 async fn embed(&self, text: &str) -> Result<Embedding> {
212 let mut embedding = vec![0.0f32; self.dim];
214 for (i, byte) in text.bytes().enumerate() {
215 embedding[i % self.dim] += (byte as f32) / 255.0;
216 }
217 let norm: f32 = embedding.iter().map(|x| x * x).sum::<f32>().sqrt();
219 if norm > 0.0 {
220 for v in &mut embedding {
221 *v /= norm;
222 }
223 }
224 Ok(embedding)
225 }
226 }
227
228 #[test]
229 fn test_embedding_type() {
230 let emb: Embedding = vec![0.1, 0.2, 0.3];
231 assert_eq!(emb.len(), 3);
232 }
233
234 #[tokio::test]
235 async fn test_mock_embedding_provider() {
236 let provider = MockEmbeddingProvider::new(8);
237 assert_eq!(provider.name(), "mock-embedding");
238 assert_eq!(provider.dimension(), 8);
239
240 let emb = provider.embed("hello world").await.unwrap();
241 assert_eq!(emb.len(), 8);
242
243 let norm: f32 = emb.iter().map(|x| x * x).sum::<f32>().sqrt();
245 assert!((norm - 1.0).abs() < 0.01);
246 }
247
248 #[tokio::test]
249 async fn test_mock_embedding_deterministic() {
250 let provider = MockEmbeddingProvider::new(8);
251 let emb1 = provider.embed("test input").await.unwrap();
252 let emb2 = provider.embed("test input").await.unwrap();
253 assert_eq!(emb1, emb2);
254 }
255
256 #[tokio::test]
257 async fn test_mock_embedding_different_texts() {
258 let provider = MockEmbeddingProvider::new(8);
259 let emb1 = provider.embed("hello").await.unwrap();
260 let emb2 = provider.embed("world").await.unwrap();
261 assert_ne!(emb1, emb2);
262 }
263
264 #[tokio::test]
265 async fn test_embed_batch_default() {
266 let provider = MockEmbeddingProvider::new(4);
267 let results = provider
268 .embed_batch(&["hello", "world", "test"])
269 .await
270 .unwrap();
271 assert_eq!(results.len(), 3);
272 for emb in &results {
273 assert_eq!(emb.len(), 4);
274 }
275 }
276
277 #[tokio::test]
278 async fn test_embed_batch_empty() {
279 let provider = MockEmbeddingProvider::new(4);
280 let results = provider.embed_batch(&[]).await.unwrap();
281 assert!(results.is_empty());
282 }
283
284 #[test]
285 fn test_openai_embedding_provider_debug() {
286 let provider = OpenAiEmbeddingProvider {
287 client: reqwest::Client::new(),
288 base_url: "https://api.openai.com/v1".to_string(),
289 model: "text-embedding-3-small".to_string(),
290 dimension: 1536,
291 };
292 let debug = format!("{:?}", provider);
293 assert!(debug.contains("OpenAiEmbeddingProvider"));
294 assert!(debug.contains("text-embedding-3-small"));
295 assert!(debug.contains("1536"));
296 }
297}