1use mem_types::{Embedder, EmbedderError};
4use serde::Deserialize;
5
6#[derive(Debug, Deserialize)]
7struct EmbedResponse {
8 data: Option<Vec<EmbedItem>>,
9}
10
11#[derive(Debug, Deserialize)]
12struct EmbedItem {
13 embedding: Vec<f32>,
14}
15
16pub struct OpenAiEmbedder {
18 client: reqwest::Client,
19 url: String,
20 api_key: Option<String>,
21 model: String,
22}
23
24impl OpenAiEmbedder {
25 pub fn new(url: String, api_key: Option<String>, model: Option<&str>) -> Self {
26 Self {
27 client: reqwest::Client::new(),
28 url,
29 api_key,
30 model: model.unwrap_or("text-embedding-3-small").to_string(),
31 }
32 }
33
34 pub fn from_env() -> Self {
35 let url = std::env::var("EMBED_API_URL")
36 .unwrap_or_else(|_| "https://api.openai.com/v1/embeddings".to_string());
37 let api_key = std::env::var("EMBED_API_KEY").ok();
38 let model = std::env::var("EMBED_MODEL").ok();
39 Self::new(url, api_key, model.as_deref())
40 }
41}
42
43#[async_trait::async_trait]
44impl Embedder for OpenAiEmbedder {
45 async fn embed_batch(&self, texts: &[String]) -> Result<Vec<Vec<f32>>, EmbedderError> {
46 if texts.is_empty() {
47 return Ok(Vec::new());
48 }
49 let mut all = Vec::with_capacity(texts.len());
50 for text in texts {
51 let body = serde_json::json!({
52 "input": text,
53 "model": self.model
54 });
55 let mut req = self.client.post(&self.url).json(&body);
56 if let Some(ref key) = self.api_key {
57 req = req.bearer_auth(key);
58 }
59 let res = req
60 .send()
61 .await
62 .map_err(|e| EmbedderError::Other(e.to_string()))?;
63 let status = res.status();
64 let body = res
65 .text()
66 .await
67 .map_err(|e| EmbedderError::Other(e.to_string()))?;
68 if !status.is_success() {
69 return Err(EmbedderError::Other(format!(
70 "embed API error {}: {}",
71 status, body
72 )));
73 }
74 let parsed: EmbedResponse =
75 serde_json::from_str(&body).map_err(|e| EmbedderError::Other(e.to_string()))?;
76 let embedding = parsed
77 .data
78 .and_then(|d| d.into_iter().next())
79 .map(|i| i.embedding)
80 .ok_or(EmbedderError::EmptyResponse)?;
81 all.push(embedding);
82 }
83 Ok(all)
84 }
85}