mockforge_intelligence/intelligent_behavior/
embedding_client.rs1use mockforge_foundation::Result;
6
7pub struct EmbeddingClient {
9 provider: String,
11 model: String,
13 api_key: Option<String>,
15 endpoint: String,
17 client: reqwest::Client,
19}
20
21impl EmbeddingClient {
22 pub fn new(
24 provider: impl Into<String>,
25 model: impl Into<String>,
26 api_key: Option<String>,
27 endpoint: Option<String>,
28 ) -> Self {
29 let provider = provider.into();
30 let endpoint = endpoint.unwrap_or_else(|| match provider.as_str() {
31 "openai" => "https://api.openai.com/v1/embeddings".to_string(),
32 _ => "http://localhost:8080/v1/embeddings".to_string(),
33 });
34
35 Self {
36 provider,
37 model: model.into(),
38 api_key,
39 endpoint,
40 client: reqwest::Client::new(),
41 }
42 }
43
44 pub async fn generate_embedding(&self, text: &str) -> Result<Vec<f32>> {
46 match self.provider.as_str() {
47 "openai" | "openai-compatible" => self.generate_openai_embedding(text).await,
48 _ => Err(mockforge_foundation::Error::internal(format!(
49 "Unsupported embedding provider: {}",
50 self.provider
51 ))),
52 }
53 }
54
55 pub async fn generate_embeddings(&self, texts: Vec<String>) -> Result<Vec<Vec<f32>>> {
57 let mut embeddings = Vec::new();
58 for text in texts {
59 let embedding = self.generate_embedding(&text).await?;
60 embeddings.push(embedding);
61 }
62 Ok(embeddings)
63 }
64
65 async fn generate_openai_embedding(&self, text: &str) -> Result<Vec<f32>> {
67 let api_key = self
68 .api_key
69 .clone()
70 .or_else(|| std::env::var("OPENAI_API_KEY").ok())
71 .ok_or_else(|| mockforge_foundation::Error::internal("OpenAI API key not found"))?;
72
73 let request_body = serde_json::json!({
74 "model": self.model,
75 "input": text,
76 });
77
78 let mut request =
79 self.client.post(&self.endpoint).header("Content-Type", "application/json");
80
81 if !api_key.is_empty() {
82 request = request.header("Authorization", format!("Bearer {}", api_key));
83 }
84
85 let response = request.json(&request_body).send().await.map_err(|e| {
86 mockforge_foundation::Error::internal(format!("Embedding API request failed: {}", e))
87 })?;
88
89 if !response.status().is_success() {
90 let error_text = response.text().await.unwrap_or_default();
91 return Err(mockforge_foundation::Error::internal(format!(
92 "Embedding API error: {}",
93 error_text
94 )));
95 }
96
97 let response_json: serde_json::Value = response.json().await.map_err(|e| {
98 mockforge_foundation::Error::config(format!(
99 "Failed to parse embedding response: {}",
100 e
101 ))
102 })?;
103
104 let embedding: Vec<f32> = response_json["data"][0]["embedding"]
106 .as_array()
107 .ok_or_else(|| {
108 mockforge_foundation::Error::internal("Invalid embedding response format")
109 })?
110 .iter()
111 .filter_map(|v| v.as_f64().map(|f| f as f32))
112 .collect();
113
114 if embedding.is_empty() {
115 return Err(mockforge_foundation::Error::internal("Empty embedding returned"));
116 }
117
118 Ok(embedding)
119 }
120}
121
122pub fn cosine_similarity(a: &[f32], b: &[f32]) -> f32 {
124 if a.len() != b.len() {
125 return 0.0;
126 }
127
128 let dot_product: f32 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
129
130 let magnitude_a: f32 = a.iter().map(|x| x * x).sum::<f32>().sqrt();
131 let magnitude_b: f32 = b.iter().map(|x| x * x).sum::<f32>().sqrt();
132
133 if magnitude_a == 0.0 || magnitude_b == 0.0 {
134 return 0.0;
135 }
136
137 dot_product / (magnitude_a * magnitude_b)
138}
139
140#[cfg(test)]
141mod tests {
142 use super::*;
143
144 #[test]
145 fn test_cosine_similarity() {
146 let a = vec![1.0, 0.0, 0.0];
147 let b = vec![1.0, 0.0, 0.0];
148 assert!((cosine_similarity(&a, &b) - 1.0).abs() < 0.001);
149
150 let c = vec![1.0, 0.0];
151 let d = vec![0.0, 1.0];
152 assert!((cosine_similarity(&c, &d) - 0.0).abs() < 0.001);
153 }
154
155 #[test]
156 fn test_embedding_client_creation() {
157 let client = EmbeddingClient::new(
158 "openai",
159 "text-embedding-ada-002",
160 Some("test_key".to_string()),
161 None,
162 );
163 assert_eq!(client.provider, "openai");
164 assert_eq!(client.model, "text-embedding-ada-002");
165 }
166}