agent_io/memory/
embeddings.rs1use async_trait::async_trait;
4use reqwest::Client;
5use serde::Deserialize;
6
7use crate::Result;
8
9#[async_trait]
11pub trait EmbeddingProvider: Send + Sync {
12 async fn embed(&self, text: &str) -> Result<Vec<f32>>;
14
15 async fn embed_batch(&self, texts: &[&str]) -> Result<Vec<Vec<f32>>>;
17
18 fn dimension(&self) -> usize;
20}
21
22pub struct OpenAIEmbedding {
24 client: Client,
25 api_key: String,
26 model: String,
27 dimension: usize,
28}
29
30impl OpenAIEmbedding {
31 pub fn new(api_key: impl Into<String>) -> Self {
33 Self {
34 client: Client::new(),
35 api_key: api_key.into(),
36 model: "text-embedding-3-small".to_string(),
37 dimension: 1536,
38 }
39 }
40
41 pub fn from_env() -> crate::Result<Self> {
43 let api_key = std::env::var("OPENAI_API_KEY")
44 .map_err(|_| crate::Error::Config("OPENAI_API_KEY not set".into()))?;
45 Ok(Self::new(api_key))
46 }
47
48 pub fn with_model(mut self, model: impl Into<String>, dimension: usize) -> Self {
50 self.model = model.into();
51 self.dimension = dimension;
52 self
53 }
54
55 pub fn large() -> crate::Result<Self> {
57 Ok(Self::from_env()?.with_model("text-embedding-3-large", 3072))
58 }
59
60 pub fn ada() -> crate::Result<Self> {
62 Ok(Self::from_env()?.with_model("text-embedding-ada-002", 1536))
63 }
64}
65
66#[derive(Deserialize)]
67struct EmbeddingResponse {
68 data: Vec<EmbeddingData>,
69}
70
71#[derive(Deserialize)]
72struct EmbeddingData {
73 embedding: Vec<f32>,
74}
75
76#[async_trait]
77impl EmbeddingProvider for OpenAIEmbedding {
78 async fn embed(&self, text: &str) -> Result<Vec<f32>> {
79 let embeddings = self.embed_batch(&[text]).await?;
80 embeddings
81 .into_iter()
82 .next()
83 .ok_or_else(|| crate::Error::Agent("No embedding returned".into()))
84 }
85
86 async fn embed_batch(&self, texts: &[&str]) -> Result<Vec<Vec<f32>>> {
87 let response = self
88 .client
89 .post("https://api.openai.com/v1/embeddings")
90 .header("Authorization", format!("Bearer {}", self.api_key))
91 .header("Content-Type", "application/json")
92 .json(&serde_json::json!({
93 "model": self.model,
94 "input": texts,
95 }))
96 .send()
97 .await?;
98
99 if !response.status().is_success() {
100 let status = response.status();
101 let body = response.text().await.unwrap_or_default();
102 return Err(crate::Error::Agent(format!(
103 "OpenAI embedding error ({}): {}",
104 status, body
105 )));
106 }
107
108 let data: EmbeddingResponse = response.json().await?;
109 Ok(data.data.into_iter().map(|e| e.embedding).collect())
110 }
111
112 fn dimension(&self) -> usize {
113 self.dimension
114 }
115}
116
117#[allow(dead_code)]
119pub struct MockEmbedding {
120 dimension: usize,
121}
122
123#[allow(dead_code)]
124impl MockEmbedding {
125 pub fn new(dimension: usize) -> Self {
127 Self { dimension }
128 }
129}
130
131impl Default for MockEmbedding {
132 fn default() -> Self {
133 Self::new(384)
134 }
135}
136
137#[async_trait]
138impl EmbeddingProvider for MockEmbedding {
139 async fn embed(&self, _text: &str) -> Result<Vec<f32>> {
140 Ok(vec![0.1; self.dimension])
142 }
143
144 async fn embed_batch(&self, texts: &[&str]) -> Result<Vec<Vec<f32>>> {
145 Ok(texts.iter().map(|_| vec![0.1; self.dimension]).collect())
146 }
147
148 fn dimension(&self) -> usize {
149 self.dimension
150 }
151}
152
153#[cfg(test)]
154mod tests {
155 use super::*;
156
157 #[tokio::test]
158 async fn test_mock_embedding() {
159 let embedder = MockEmbedding::new(128);
160
161 let embedding = embedder.embed("test").await.unwrap();
162 assert_eq!(embedding.len(), 128);
163
164 let batch = embedder.embed_batch(&["a", "b", "c"]).await.unwrap();
165 assert_eq!(batch.len(), 3);
166 assert_eq!(batch[0].len(), 128);
167 }
168}