mentedb_embedding/
http_provider.rs1use std::collections::HashMap;
4
5use mentedb_core::MenteError;
6use mentedb_core::error::MenteResult;
7use serde::{Deserialize, Serialize};
8
9use crate::provider::AsyncEmbeddingProvider;
10
11#[derive(Debug, Clone, Serialize, Deserialize)]
13pub struct HttpEmbeddingConfig {
14 pub api_url: String,
16 pub api_key: String,
18 pub model_name: String,
20 pub dimensions: usize,
22 pub headers: HashMap<String, String>,
24}
25
26impl HttpEmbeddingConfig {
27 pub fn openai(api_key: impl Into<String>, model: impl Into<String>) -> Self {
31 let model = model.into();
32 let dimensions = match model.as_str() {
33 "text-embedding-3-small" => 1536,
34 "text-embedding-3-large" => 3072,
35 "text-embedding-ada-002" => 1536,
36 _ => 1536,
37 };
38
39 let mut headers = HashMap::new();
40 headers.insert("Content-Type".to_string(), "application/json".to_string());
41
42 Self {
43 api_url: "https://api.openai.com/v1/embeddings".to_string(),
44 api_key: api_key.into(),
45 model_name: model,
46 dimensions,
47 headers,
48 }
49 }
50
51 pub fn cohere(api_key: impl Into<String>, model: impl Into<String>) -> Self {
55 let model = model.into();
56 let dimensions = match model.as_str() {
57 "embed-english-v3.0" => 1024,
58 "embed-multilingual-v3.0" => 1024,
59 "embed-english-light-v3.0" => 384,
60 "embed-multilingual-light-v3.0" => 384,
61 _ => 1024,
62 };
63
64 let mut headers = HashMap::new();
65 headers.insert("Content-Type".to_string(), "application/json".to_string());
66
67 Self {
68 api_url: "https://api.cohere.ai/v1/embed".to_string(),
69 api_key: api_key.into(),
70 model_name: model,
71 dimensions,
72 headers,
73 }
74 }
75
76 pub fn voyage(api_key: impl Into<String>, model: impl Into<String>) -> Self {
80 let model = model.into();
81 let dimensions = match model.as_str() {
82 "voyage-2" => 1024,
83 "voyage-large-2" => 1536,
84 "voyage-code-2" => 1536,
85 "voyage-lite-02-instruct" => 1024,
86 _ => 1024,
87 };
88
89 let mut headers = HashMap::new();
90 headers.insert("Content-Type".to_string(), "application/json".to_string());
91
92 Self {
93 api_url: "https://api.voyageai.com/v1/embeddings".to_string(),
94 api_key: api_key.into(),
95 model_name: model,
96 dimensions,
97 headers,
98 }
99 }
100
101 pub fn with_dimensions(mut self, dimensions: usize) -> Self {
103 self.dimensions = dimensions;
104 self
105 }
106
107 pub fn with_header(mut self, key: impl Into<String>, value: impl Into<String>) -> Self {
109 self.headers.insert(key.into(), value.into());
110 self
111 }
112}
113
114pub struct HttpEmbeddingProvider {
119 config: HttpEmbeddingConfig,
120}
121
122impl HttpEmbeddingProvider {
123 pub fn new(config: HttpEmbeddingConfig) -> Self {
125 Self { config }
126 }
127
128 pub fn config(&self) -> &HttpEmbeddingConfig {
130 &self.config
131 }
132}
133
134impl AsyncEmbeddingProvider for HttpEmbeddingProvider {
135 async fn embed(&self, _text: &str) -> MenteResult<Vec<f32>> {
136 Err(MenteError::Storage(
137 "HTTP embedding requires the 'reqwest' feature".to_string(),
138 ))
139 }
140
141 async fn embed_batch(&self, _texts: &[&str]) -> MenteResult<Vec<Vec<f32>>> {
142 Err(MenteError::Storage(
143 "HTTP embedding requires the 'reqwest' feature".to_string(),
144 ))
145 }
146
147 fn dimensions(&self) -> usize {
148 self.config.dimensions
149 }
150
151 fn model_name(&self) -> &str {
152 &self.config.model_name
153 }
154}
155
156#[cfg(test)]
157mod tests {
158 use super::*;
159
160 #[test]
161 fn test_openai_config() {
162 let config = HttpEmbeddingConfig::openai("sk-test", "text-embedding-3-small");
163 assert_eq!(config.api_url, "https://api.openai.com/v1/embeddings");
164 assert_eq!(config.dimensions, 1536);
165 assert_eq!(config.model_name, "text-embedding-3-small");
166 }
167
168 #[test]
169 fn test_cohere_config() {
170 let config = HttpEmbeddingConfig::cohere("key", "embed-english-v3.0");
171 assert_eq!(config.api_url, "https://api.cohere.ai/v1/embed");
172 assert_eq!(config.dimensions, 1024);
173 }
174
175 #[test]
176 fn test_voyage_config() {
177 let config = HttpEmbeddingConfig::voyage("key", "voyage-2");
178 assert_eq!(config.api_url, "https://api.voyageai.com/v1/embeddings");
179 assert_eq!(config.dimensions, 1024);
180 }
181
182 #[test]
183 fn test_with_dimensions_override() {
184 let config =
185 HttpEmbeddingConfig::openai("key", "text-embedding-3-small").with_dimensions(256);
186 assert_eq!(config.dimensions, 256);
187 }
188}