Skip to main content

mentedb_embedding/
http_provider.rs

1//! Generic HTTP-based embedding provider for OpenAI, Cohere, Voyage, and other APIs.
2
3use std::collections::HashMap;
4
5use mentedb_core::MenteError;
6use mentedb_core::error::MenteResult;
7use serde::{Deserialize, Serialize};
8
9use crate::provider::AsyncEmbeddingProvider;
10
11/// Configuration for an HTTP-based embedding API.
12#[derive(Debug, Clone, Serialize, Deserialize)]
13pub struct HttpEmbeddingConfig {
14    /// The API endpoint URL.
15    pub api_url: String,
16    /// The API key for authentication.
17    pub api_key: String,
18    /// The model name to request.
19    pub model_name: String,
20    /// The dimensionality of the returned embeddings.
21    pub dimensions: usize,
22    /// Additional headers to include in requests.
23    pub headers: HashMap<String, String>,
24}
25
26impl HttpEmbeddingConfig {
27    /// Create a configuration for OpenAI's embedding API.
28    ///
29    /// Default dimensions: 1536 for text-embedding-ada-002, 3072 for text-embedding-3-large.
30    pub fn openai(api_key: impl Into<String>, model: impl Into<String>) -> Self {
31        let model = model.into();
32        let dimensions = match model.as_str() {
33            "text-embedding-3-small" => 1536,
34            "text-embedding-3-large" => 3072,
35            "text-embedding-ada-002" => 1536,
36            _ => 1536,
37        };
38
39        let mut headers = HashMap::new();
40        headers.insert("Content-Type".to_string(), "application/json".to_string());
41
42        Self {
43            api_url: "https://api.openai.com/v1/embeddings".to_string(),
44            api_key: api_key.into(),
45            model_name: model,
46            dimensions,
47            headers,
48        }
49    }
50
51    /// Create a configuration for Cohere's embedding API.
52    ///
53    /// Default dimensions: 1024 for embed-english-v3.0.
54    pub fn cohere(api_key: impl Into<String>, model: impl Into<String>) -> Self {
55        let model = model.into();
56        let dimensions = match model.as_str() {
57            "embed-english-v3.0" => 1024,
58            "embed-multilingual-v3.0" => 1024,
59            "embed-english-light-v3.0" => 384,
60            "embed-multilingual-light-v3.0" => 384,
61            _ => 1024,
62        };
63
64        let mut headers = HashMap::new();
65        headers.insert("Content-Type".to_string(), "application/json".to_string());
66
67        Self {
68            api_url: "https://api.cohere.ai/v1/embed".to_string(),
69            api_key: api_key.into(),
70            model_name: model,
71            dimensions,
72            headers,
73        }
74    }
75
76    /// Create a configuration for Voyage AI's embedding API.
77    ///
78    /// Default dimensions: 1024 for voyage-2.
79    pub fn voyage(api_key: impl Into<String>, model: impl Into<String>) -> Self {
80        let model = model.into();
81        let dimensions = match model.as_str() {
82            "voyage-2" => 1024,
83            "voyage-large-2" => 1536,
84            "voyage-code-2" => 1536,
85            "voyage-lite-02-instruct" => 1024,
86            _ => 1024,
87        };
88
89        let mut headers = HashMap::new();
90        headers.insert("Content-Type".to_string(), "application/json".to_string());
91
92        Self {
93            api_url: "https://api.voyageai.com/v1/embeddings".to_string(),
94            api_key: api_key.into(),
95            model_name: model,
96            dimensions,
97            headers,
98        }
99    }
100
101    /// Override the embedding dimensions.
102    pub fn with_dimensions(mut self, dimensions: usize) -> Self {
103        self.dimensions = dimensions;
104        self
105    }
106
107    /// Add a custom header.
108    pub fn with_header(mut self, key: impl Into<String>, value: impl Into<String>) -> Self {
109        self.headers.insert(key.into(), value.into());
110        self
111    }
112}
113
114/// HTTP-based embedding provider.
115///
116/// Currently requires an external HTTP client feature to function.
117/// The structure and configuration are fully usable for setup and validation.
118pub struct HttpEmbeddingProvider {
119    config: HttpEmbeddingConfig,
120}
121
122impl HttpEmbeddingProvider {
123    /// Create a new HTTP embedding provider with the given configuration.
124    pub fn new(config: HttpEmbeddingConfig) -> Self {
125        Self { config }
126    }
127
128    /// Get a reference to the provider's configuration.
129    pub fn config(&self) -> &HttpEmbeddingConfig {
130        &self.config
131    }
132}
133
134impl AsyncEmbeddingProvider for HttpEmbeddingProvider {
135    async fn embed(&self, _text: &str) -> MenteResult<Vec<f32>> {
136        Err(MenteError::Storage(
137            "HTTP embedding requires the 'reqwest' feature".to_string(),
138        ))
139    }
140
141    async fn embed_batch(&self, _texts: &[&str]) -> MenteResult<Vec<Vec<f32>>> {
142        Err(MenteError::Storage(
143            "HTTP embedding requires the 'reqwest' feature".to_string(),
144        ))
145    }
146
147    fn dimensions(&self) -> usize {
148        self.config.dimensions
149    }
150
151    fn model_name(&self) -> &str {
152        &self.config.model_name
153    }
154}
155
156#[cfg(test)]
157mod tests {
158    use super::*;
159
160    #[test]
161    fn test_openai_config() {
162        let config = HttpEmbeddingConfig::openai("sk-test", "text-embedding-3-small");
163        assert_eq!(config.api_url, "https://api.openai.com/v1/embeddings");
164        assert_eq!(config.dimensions, 1536);
165        assert_eq!(config.model_name, "text-embedding-3-small");
166    }
167
168    #[test]
169    fn test_cohere_config() {
170        let config = HttpEmbeddingConfig::cohere("key", "embed-english-v3.0");
171        assert_eq!(config.api_url, "https://api.cohere.ai/v1/embed");
172        assert_eq!(config.dimensions, 1024);
173    }
174
175    #[test]
176    fn test_voyage_config() {
177        let config = HttpEmbeddingConfig::voyage("key", "voyage-2");
178        assert_eq!(config.api_url, "https://api.voyageai.com/v1/embeddings");
179        assert_eq!(config.dimensions, 1024);
180    }
181
182    #[test]
183    fn test_with_dimensions_override() {
184        let config =
185            HttpEmbeddingConfig::openai("key", "text-embedding-3-small").with_dimensions(256);
186        assert_eq!(config.dimensions, 256);
187    }
188}