Skip to main content

mem7_embedding/
openai.rs

1use async_trait::async_trait;
2use mem7_config::EmbeddingConfig;
3use mem7_error::{Mem7Error, Result};
4use reqwest::Client;
5use serde::{Deserialize, Serialize};
6use tracing::debug;
7
8use crate::EmbeddingClient;
9
10/// An OpenAI-compatible embedding client that works with both OpenAI and vLLM.
11pub struct OpenAICompatibleEmbedding {
12    client: Client,
13    config: EmbeddingConfig,
14}
15
16impl OpenAICompatibleEmbedding {
17    pub fn new(config: EmbeddingConfig) -> Self {
18        let client = Client::new();
19        Self { client, config }
20    }
21}
22
23#[derive(Debug, Serialize)]
24struct EmbeddingRequest {
25    model: String,
26    input: Vec<String>,
27}
28
29#[derive(Debug, Deserialize)]
30struct EmbeddingResponse {
31    data: Vec<EmbeddingData>,
32}
33
34#[derive(Debug, Deserialize)]
35struct EmbeddingData {
36    embedding: Vec<f32>,
37}
38
39#[async_trait]
40impl EmbeddingClient for OpenAICompatibleEmbedding {
41    async fn embed(&self, texts: &[String]) -> Result<Vec<Vec<f32>>> {
42        let url = format!("{}/embeddings", self.config.base_url.trim_end_matches('/'));
43
44        let body = EmbeddingRequest {
45            model: self.config.model.clone(),
46            input: texts.to_vec(),
47        };
48
49        debug!(url = %url, model = %self.config.model, count = texts.len(), "sending embedding request");
50
51        let resp = self
52            .client
53            .post(&url)
54            .header("Authorization", format!("Bearer {}", self.config.api_key))
55            .header("Content-Type", "application/json")
56            .json(&body)
57            .send()
58            .await?;
59
60        if !resp.status().is_success() {
61            let status = resp.status();
62            let text = resp.text().await.unwrap_or_default();
63            return Err(Mem7Error::Embedding(format!("HTTP {status}: {text}")));
64        }
65
66        let data: EmbeddingResponse = resp.json().await?;
67        Ok(data.data.into_iter().map(|d| d.embedding).collect())
68    }
69}