Skip to main content

mockforge_intelligence/intelligent_behavior/
embedding_client.rs

1//! Embedding client for vector memory
2//!
3//! This module provides a client for generating embeddings for semantic search.
4
5use mockforge_foundation::Result;
6
7/// Embedding client for generating vector embeddings
8pub struct EmbeddingClient {
9    /// Provider type
10    provider: String,
11    /// Model name
12    model: String,
13    /// API key (optional)
14    api_key: Option<String>,
15    /// API endpoint
16    endpoint: String,
17    /// HTTP client
18    client: reqwest::Client,
19}
20
21impl EmbeddingClient {
22    /// Create a new embedding client
23    pub fn new(
24        provider: impl Into<String>,
25        model: impl Into<String>,
26        api_key: Option<String>,
27        endpoint: Option<String>,
28    ) -> Self {
29        let provider = provider.into();
30        let endpoint = endpoint.unwrap_or_else(|| match provider.as_str() {
31            "openai" => "https://api.openai.com/v1/embeddings".to_string(),
32            _ => "http://localhost:8080/v1/embeddings".to_string(),
33        });
34
35        Self {
36            provider,
37            model: model.into(),
38            api_key,
39            endpoint,
40            client: reqwest::Client::new(),
41        }
42    }
43
44    /// Generate an embedding for text
45    pub async fn generate_embedding(&self, text: &str) -> Result<Vec<f32>> {
46        match self.provider.as_str() {
47            "openai" | "openai-compatible" => self.generate_openai_embedding(text).await,
48            _ => Err(mockforge_foundation::Error::internal(format!(
49                "Unsupported embedding provider: {}",
50                self.provider
51            ))),
52        }
53    }
54
55    /// Generate embeddings for multiple texts
56    pub async fn generate_embeddings(&self, texts: Vec<String>) -> Result<Vec<Vec<f32>>> {
57        let mut embeddings = Vec::new();
58        for text in texts {
59            let embedding = self.generate_embedding(&text).await?;
60            embeddings.push(embedding);
61        }
62        Ok(embeddings)
63    }
64
65    /// Generate embedding using OpenAI API
66    async fn generate_openai_embedding(&self, text: &str) -> Result<Vec<f32>> {
67        let api_key = self
68            .api_key
69            .clone()
70            .or_else(|| std::env::var("OPENAI_API_KEY").ok())
71            .ok_or_else(|| mockforge_foundation::Error::internal("OpenAI API key not found"))?;
72
73        let request_body = serde_json::json!({
74            "model": self.model,
75            "input": text,
76        });
77
78        let mut request =
79            self.client.post(&self.endpoint).header("Content-Type", "application/json");
80
81        if !api_key.is_empty() {
82            request = request.header("Authorization", format!("Bearer {}", api_key));
83        }
84
85        let response = request.json(&request_body).send().await.map_err(|e| {
86            mockforge_foundation::Error::internal(format!("Embedding API request failed: {}", e))
87        })?;
88
89        if !response.status().is_success() {
90            let error_text = response.text().await.unwrap_or_default();
91            return Err(mockforge_foundation::Error::internal(format!(
92                "Embedding API error: {}",
93                error_text
94            )));
95        }
96
97        let response_json: serde_json::Value = response.json().await.map_err(|e| {
98            mockforge_foundation::Error::config(format!(
99                "Failed to parse embedding response: {}",
100                e
101            ))
102        })?;
103
104        // Extract embedding vector
105        let embedding: Vec<f32> = response_json["data"][0]["embedding"]
106            .as_array()
107            .ok_or_else(|| {
108                mockforge_foundation::Error::internal("Invalid embedding response format")
109            })?
110            .iter()
111            .filter_map(|v| v.as_f64().map(|f| f as f32))
112            .collect();
113
114        if embedding.is_empty() {
115            return Err(mockforge_foundation::Error::internal("Empty embedding returned"));
116        }
117
118        Ok(embedding)
119    }
120}
121
122/// Calculate cosine similarity between two vectors
123pub fn cosine_similarity(a: &[f32], b: &[f32]) -> f32 {
124    if a.len() != b.len() {
125        return 0.0;
126    }
127
128    let dot_product: f32 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
129
130    let magnitude_a: f32 = a.iter().map(|x| x * x).sum::<f32>().sqrt();
131    let magnitude_b: f32 = b.iter().map(|x| x * x).sum::<f32>().sqrt();
132
133    if magnitude_a == 0.0 || magnitude_b == 0.0 {
134        return 0.0;
135    }
136
137    dot_product / (magnitude_a * magnitude_b)
138}
139
140#[cfg(test)]
141mod tests {
142    use super::*;
143
144    #[test]
145    fn test_cosine_similarity() {
146        let a = vec![1.0, 0.0, 0.0];
147        let b = vec![1.0, 0.0, 0.0];
148        assert!((cosine_similarity(&a, &b) - 1.0).abs() < 0.001);
149
150        let c = vec![1.0, 0.0];
151        let d = vec![0.0, 1.0];
152        assert!((cosine_similarity(&c, &d) - 0.0).abs() < 0.001);
153    }
154
155    #[test]
156    fn test_embedding_client_creation() {
157        let client = EmbeddingClient::new(
158            "openai",
159            "text-embedding-ada-002",
160            Some("test_key".to_string()),
161            None,
162        );
163        assert_eq!(client.provider, "openai");
164        assert_eq!(client.model, "text-embedding-ada-002");
165    }
166}