Skip to main content

agent_io/memory/
embeddings.rs

1//! Embedding provider trait and implementations
2
3use async_trait::async_trait;
4use reqwest::Client;
5use serde::Deserialize;
6
7use crate::Result;
8
9/// Embedding provider trait
10#[async_trait]
11pub trait EmbeddingProvider: Send + Sync {
12    /// Generate embedding for a single text
13    async fn embed(&self, text: &str) -> Result<Vec<f32>>;
14
15    /// Generate embeddings for multiple texts
16    async fn embed_batch(&self, texts: &[&str]) -> Result<Vec<Vec<f32>>>;
17
18    /// Get the embedding dimension
19    fn dimension(&self) -> usize;
20}
21
22/// OpenAI embedding provider
23pub struct OpenAIEmbedding {
24    client: Client,
25    api_key: String,
26    model: String,
27    dimension: usize,
28}
29
30impl OpenAIEmbedding {
31    /// Create a new OpenAI embedding provider
32    pub fn new(api_key: impl Into<String>) -> Self {
33        Self {
34            client: Client::new(),
35            api_key: api_key.into(),
36            model: "text-embedding-3-small".to_string(),
37            dimension: 1536,
38        }
39    }
40
41    /// Create from environment variable
42    pub fn from_env() -> crate::Result<Self> {
43        let api_key = std::env::var("OPENAI_API_KEY")
44            .map_err(|_| crate::Error::Config("OPENAI_API_KEY not set".into()))?;
45        Ok(Self::new(api_key))
46    }
47
48    /// Use a specific model
49    pub fn with_model(mut self, model: impl Into<String>, dimension: usize) -> Self {
50        self.model = model.into();
51        self.dimension = dimension;
52        self
53    }
54
55    /// Use text-embedding-3-large model
56    pub fn large() -> crate::Result<Self> {
57        Ok(Self::from_env()?.with_model("text-embedding-3-large", 3072))
58    }
59
60    /// Use text-embedding-ada-002 model
61    pub fn ada() -> crate::Result<Self> {
62        Ok(Self::from_env()?.with_model("text-embedding-ada-002", 1536))
63    }
64}
65
66#[derive(Deserialize)]
67struct EmbeddingResponse {
68    data: Vec<EmbeddingData>,
69}
70
71#[derive(Deserialize)]
72struct EmbeddingData {
73    embedding: Vec<f32>,
74}
75
76#[async_trait]
77impl EmbeddingProvider for OpenAIEmbedding {
78    async fn embed(&self, text: &str) -> Result<Vec<f32>> {
79        let embeddings = self.embed_batch(&[text]).await?;
80        embeddings
81            .into_iter()
82            .next()
83            .ok_or_else(|| crate::Error::Agent("No embedding returned".into()))
84    }
85
86    async fn embed_batch(&self, texts: &[&str]) -> Result<Vec<Vec<f32>>> {
87        let response = self
88            .client
89            .post("https://api.openai.com/v1/embeddings")
90            .header("Authorization", format!("Bearer {}", self.api_key))
91            .header("Content-Type", "application/json")
92            .json(&serde_json::json!({
93                "model": self.model,
94                "input": texts,
95            }))
96            .send()
97            .await?;
98
99        if !response.status().is_success() {
100            let status = response.status();
101            let body = response.text().await.unwrap_or_default();
102            return Err(crate::Error::Agent(format!(
103                "OpenAI embedding error ({}): {}",
104                status, body
105            )));
106        }
107
108        let data: EmbeddingResponse = response.json().await?;
109        Ok(data.data.into_iter().map(|e| e.embedding).collect())
110    }
111
112    fn dimension(&self) -> usize {
113        self.dimension
114    }
115}
116
117/// Mock embedding provider for testing
118#[allow(dead_code)]
119pub struct MockEmbedding {
120    dimension: usize,
121}
122
123#[allow(dead_code)]
124impl MockEmbedding {
125    /// Create a new mock embedding provider
126    pub fn new(dimension: usize) -> Self {
127        Self { dimension }
128    }
129}
130
131impl Default for MockEmbedding {
132    fn default() -> Self {
133        Self::new(384)
134    }
135}
136
137#[async_trait]
138impl EmbeddingProvider for MockEmbedding {
139    async fn embed(&self, _text: &str) -> Result<Vec<f32>> {
140        // Return a deterministic embedding based on text length
141        Ok(vec![0.1; self.dimension])
142    }
143
144    async fn embed_batch(&self, texts: &[&str]) -> Result<Vec<Vec<f32>>> {
145        Ok(texts.iter().map(|_| vec![0.1; self.dimension]).collect())
146    }
147
148    fn dimension(&self) -> usize {
149        self.dimension
150    }
151}
152
153#[cfg(test)]
154mod tests {
155    use super::*;
156
157    #[tokio::test]
158    async fn test_mock_embedding() {
159        let embedder = MockEmbedding::new(128);
160
161        let embedding = embedder.embed("test").await.unwrap();
162        assert_eq!(embedding.len(), 128);
163
164        let batch = embedder.embed_batch(&["a", "b", "c"]).await.unwrap();
165        assert_eq!(batch.len(), 3);
166        assert_eq!(batch[0].len(), 128);
167    }
168}