1use async_trait::async_trait;
2use mem7_config::EmbeddingConfig;
3use mem7_error::{Mem7Error, Result};
4use reqwest::Client;
5use serde::{Deserialize, Serialize};
6use tracing::debug;
7
8use crate::EmbeddingClient;
9
10pub struct OpenAICompatibleEmbedding {
12 client: Client,
13 config: EmbeddingConfig,
14}
15
16impl OpenAICompatibleEmbedding {
17 pub fn new(config: EmbeddingConfig) -> Self {
18 let client = Client::new();
19 Self { client, config }
20 }
21}
22
23#[derive(Debug, Serialize)]
24struct EmbeddingRequest {
25 model: String,
26 input: Vec<String>,
27}
28
29#[derive(Debug, Deserialize)]
30struct EmbeddingResponse {
31 data: Vec<EmbeddingData>,
32}
33
34#[derive(Debug, Deserialize)]
35struct EmbeddingData {
36 embedding: Vec<f32>,
37}
38
39#[async_trait]
40impl EmbeddingClient for OpenAICompatibleEmbedding {
41 async fn embed(&self, texts: &[String]) -> Result<Vec<Vec<f32>>> {
42 let url = format!("{}/embeddings", self.config.base_url.trim_end_matches('/'));
43
44 let body = EmbeddingRequest {
45 model: self.config.model.clone(),
46 input: texts.to_vec(),
47 };
48
49 debug!(url = %url, model = %self.config.model, count = texts.len(), "sending embedding request");
50
51 let resp = self
52 .client
53 .post(&url)
54 .header("Authorization", format!("Bearer {}", self.config.api_key))
55 .header("Content-Type", "application/json")
56 .json(&body)
57 .send()
58 .await?;
59
60 if !resp.status().is_success() {
61 let status = resp.status();
62 let text = resp.text().await.unwrap_or_default();
63 return Err(Mem7Error::Embedding(format!("HTTP {status}: {text}")));
64 }
65
66 let data: EmbeddingResponse = resp.json().await?;
67 Ok(data.data.into_iter().map(|d| d.embedding).collect())
68 }
69}