Skip to main content

mem_embed/
openai.rs

1//! HTTP client for OpenAI-compatible embedding API.
2
3use mem_types::{Embedder, EmbedderError};
4use serde::Deserialize;
5
6#[derive(Debug, Deserialize)]
7struct EmbedResponse {
8    data: Option<Vec<EmbedItem>>,
9}
10
11#[derive(Debug, Deserialize)]
12struct EmbedItem {
13    embedding: Vec<f32>,
14}
15
16/// Embedder that calls an OpenAI-compatible embedding endpoint (e.g. POST /embeddings).
17pub struct OpenAiEmbedder {
18    client: reqwest::Client,
19    url: String,
20    api_key: Option<String>,
21    model: String,
22}
23
24impl OpenAiEmbedder {
25    pub fn new(url: String, api_key: Option<String>, model: Option<&str>) -> Self {
26        Self {
27            client: reqwest::Client::new(),
28            url,
29            api_key,
30            model: model.unwrap_or("text-embedding-3-small").to_string(),
31        }
32    }
33
34    pub fn from_env() -> Self {
35        let url = std::env::var("EMBED_API_URL")
36            .unwrap_or_else(|_| "https://api.openai.com/v1/embeddings".to_string());
37        let api_key = std::env::var("EMBED_API_KEY").ok();
38        let model = std::env::var("EMBED_MODEL").ok();
39        Self::new(url, api_key, model.as_deref())
40    }
41}
42
43#[async_trait::async_trait]
44impl Embedder for OpenAiEmbedder {
45    async fn embed_batch(&self, texts: &[String]) -> Result<Vec<Vec<f32>>, EmbedderError> {
46        if texts.is_empty() {
47            return Ok(Vec::new());
48        }
49        let mut all = Vec::with_capacity(texts.len());
50        for text in texts {
51            let body = serde_json::json!({
52                "input": text,
53                "model": self.model
54            });
55            let mut req = self.client.post(&self.url).json(&body);
56            if let Some(ref key) = self.api_key {
57                req = req.bearer_auth(key);
58            }
59            let res = req
60                .send()
61                .await
62                .map_err(|e| EmbedderError::Other(e.to_string()))?;
63            let status = res.status();
64            let body = res
65                .text()
66                .await
67                .map_err(|e| EmbedderError::Other(e.to_string()))?;
68            if !status.is_success() {
69                return Err(EmbedderError::Other(format!(
70                    "embed API error {}: {}",
71                    status, body
72                )));
73            }
74            let parsed: EmbedResponse =
75                serde_json::from_str(&body).map_err(|e| EmbedderError::Other(e.to_string()))?;
76            let embedding = parsed
77                .data
78                .and_then(|d| d.into_iter().next())
79                .map(|i| i.embedding)
80                .ok_or(EmbedderError::EmptyResponse)?;
81            all.push(embedding);
82        }
83        Ok(all)
84    }
85}