Skip to main content

litellm_rs/core/embedding/
options.rs

1//! Embedding options - Python LiteLLM compatible
2//!
3//! This module provides embedding configuration options with a builder pattern.
4
5use serde::{Deserialize, Serialize};
6use std::collections::HashMap;
7
8/// Embedding options - Python LiteLLM compatible
9///
10/// These options control how embeddings are generated and allow customization
11/// of API behavior, authentication, and encoding format.
12///
13/// # Example
14///
15/// ```rust
16/// use litellm_rs::core::embedding::EmbeddingOptions;
17///
18/// let options = EmbeddingOptions::new()
19///     .with_api_key("sk-...")
20///     .with_dimensions(1536)
21///     .with_encoding_format("float");
22/// ```
23#[derive(Debug, Clone, Serialize, Deserialize, Default)]
24pub struct EmbeddingOptions {
25    /// User identifier for tracking
26    #[serde(skip_serializing_if = "Option::is_none")]
27    pub user: Option<String>,
28
29    /// Encoding format for the embeddings (e.g., "float", "base64")
30    #[serde(skip_serializing_if = "Option::is_none")]
31    pub encoding_format: Option<String>,
32
33    /// Number of dimensions for the embedding output
34    /// Only supported by models that allow dimension reduction
35    #[serde(skip_serializing_if = "Option::is_none")]
36    pub dimensions: Option<u32>,
37
38    /// API key to use for this request
39    /// Overrides environment variable configuration
40    #[serde(skip_serializing_if = "Option::is_none")]
41    pub api_key: Option<String>,
42
43    /// Custom API base URL
44    /// Allows using custom endpoints or proxies
45    #[serde(skip_serializing_if = "Option::is_none")]
46    pub api_base: Option<String>,
47
48    /// Request timeout in seconds
49    #[serde(skip_serializing_if = "Option::is_none")]
50    pub timeout: Option<u64>,
51
52    /// Custom headers to include in the request
53    #[serde(skip_serializing_if = "Option::is_none")]
54    pub headers: Option<HashMap<String, String>>,
55
56    /// Task type for specialized embeddings (e.g., for Vertex AI)
57    /// Values: "RETRIEVAL_QUERY", "RETRIEVAL_DOCUMENT", "SEMANTIC_SIMILARITY", etc.
58    #[serde(skip_serializing_if = "Option::is_none")]
59    pub task_type: Option<String>,
60
61    /// Extra provider-specific parameters
62    #[serde(default)]
63    pub extra_params: HashMap<String, serde_json::Value>,
64}
65
66impl EmbeddingOptions {
67    /// Create new empty options
68    pub fn new() -> Self {
69        Self::default()
70    }
71
72    /// Set user identifier
73    pub fn with_user(mut self, user: impl Into<String>) -> Self {
74        self.user = Some(user.into());
75        self
76    }
77
78    /// Set encoding format
79    pub fn with_encoding_format(mut self, format: impl Into<String>) -> Self {
80        self.encoding_format = Some(format.into());
81        self
82    }
83
84    /// Set output dimensions
85    pub fn with_dimensions(mut self, dimensions: u32) -> Self {
86        self.dimensions = Some(dimensions);
87        self
88    }
89
90    /// Set API key
91    pub fn with_api_key(mut self, api_key: impl Into<String>) -> Self {
92        self.api_key = Some(api_key.into());
93        self
94    }
95
96    /// Set API base URL
97    pub fn with_api_base(mut self, api_base: impl Into<String>) -> Self {
98        self.api_base = Some(api_base.into());
99        self
100    }
101
102    /// Set request timeout in seconds
103    pub fn with_timeout(mut self, timeout: u64) -> Self {
104        self.timeout = Some(timeout);
105        self
106    }
107
108    /// Set custom headers
109    pub fn with_headers(mut self, headers: HashMap<String, String>) -> Self {
110        self.headers = Some(headers);
111        self
112    }
113
114    /// Add a single custom header
115    pub fn with_header(mut self, key: impl Into<String>, value: impl Into<String>) -> Self {
116        self.headers
117            .get_or_insert_with(HashMap::new)
118            .insert(key.into(), value.into());
119        self
120    }
121
122    /// Set task type for specialized embeddings
123    pub fn with_task_type(mut self, task_type: impl Into<String>) -> Self {
124        self.task_type = Some(task_type.into());
125        self
126    }
127
128    /// Add an extra parameter
129    pub fn with_extra_param(
130        mut self,
131        key: impl Into<String>,
132        value: impl Into<serde_json::Value>,
133    ) -> Self {
134        self.extra_params.insert(key.into(), value.into());
135        self
136    }
137
138    /// Set multiple extra parameters
139    pub fn with_extra_params(mut self, params: HashMap<String, serde_json::Value>) -> Self {
140        self.extra_params.extend(params);
141        self
142    }
143}
144
145#[cfg(test)]
146mod tests {
147    use super::*;
148
149    #[test]
150    fn test_embedding_options_default() {
151        let opts = EmbeddingOptions::default();
152        assert!(opts.user.is_none());
153        assert!(opts.encoding_format.is_none());
154        assert!(opts.dimensions.is_none());
155        assert!(opts.api_key.is_none());
156        assert!(opts.api_base.is_none());
157        assert!(opts.timeout.is_none());
158        assert!(opts.headers.is_none());
159        assert!(opts.task_type.is_none());
160        assert!(opts.extra_params.is_empty());
161    }
162
163    #[test]
164    fn test_embedding_options_builder() {
165        let opts = EmbeddingOptions::new()
166            .with_user("user-123")
167            .with_encoding_format("float")
168            .with_dimensions(1536)
169            .with_api_key("sk-test")
170            .with_api_base("https://api.example.com")
171            .with_timeout(30)
172            .with_task_type("RETRIEVAL_QUERY");
173
174        assert_eq!(opts.user, Some("user-123".to_string()));
175        assert_eq!(opts.encoding_format, Some("float".to_string()));
176        assert_eq!(opts.dimensions, Some(1536));
177        assert_eq!(opts.api_key, Some("sk-test".to_string()));
178        assert_eq!(opts.api_base, Some("https://api.example.com".to_string()));
179        assert_eq!(opts.timeout, Some(30));
180        assert_eq!(opts.task_type, Some("RETRIEVAL_QUERY".to_string()));
181    }
182
183    #[test]
184    fn test_embedding_options_headers() {
185        let opts = EmbeddingOptions::new()
186            .with_header("X-Custom-Header", "value1")
187            .with_header("X-Another-Header", "value2");
188
189        let headers = opts.headers.unwrap();
190        assert_eq!(headers.get("X-Custom-Header"), Some(&"value1".to_string()));
191        assert_eq!(headers.get("X-Another-Header"), Some(&"value2".to_string()));
192    }
193
194    #[test]
195    fn test_embedding_options_bulk_headers() {
196        let mut headers = HashMap::new();
197        headers.insert("Header1".to_string(), "Value1".to_string());
198        headers.insert("Header2".to_string(), "Value2".to_string());
199
200        let opts = EmbeddingOptions::new().with_headers(headers.clone());
201        assert_eq!(opts.headers, Some(headers));
202    }
203
204    #[test]
205    fn test_embedding_options_extra_params() {
206        let opts = EmbeddingOptions::new()
207            .with_extra_param("custom_field", serde_json::json!("value"))
208            .with_extra_param("numeric_field", serde_json::json!(42));
209
210        assert_eq!(
211            opts.extra_params.get("custom_field"),
212            Some(&serde_json::json!("value"))
213        );
214        assert_eq!(
215            opts.extra_params.get("numeric_field"),
216            Some(&serde_json::json!(42))
217        );
218    }
219
220    #[test]
221    fn test_embedding_options_serialization() {
222        let opts = EmbeddingOptions::new()
223            .with_dimensions(256)
224            .with_encoding_format("base64");
225
226        let json = serde_json::to_value(&opts).unwrap();
227        assert_eq!(json["dimensions"], 256);
228        assert_eq!(json["encoding_format"], "base64");
229        // None values should be skipped
230        assert!(!json.as_object().unwrap().contains_key("user"));
231        assert!(!json.as_object().unwrap().contains_key("api_key"));
232    }
233
234    #[test]
235    fn test_embedding_options_deserialization() {
236        let json = r#"{
237            "user": "test-user",
238            "dimensions": 512,
239            "encoding_format": "float"
240        }"#;
241
242        let opts: EmbeddingOptions = serde_json::from_str(json).unwrap();
243        assert_eq!(opts.user, Some("test-user".to_string()));
244        assert_eq!(opts.dimensions, Some(512));
245        assert_eq!(opts.encoding_format, Some("float".to_string()));
246    }
247
248    #[test]
249    fn test_embedding_options_clone() {
250        let opts = EmbeddingOptions::new()
251            .with_api_key("key")
252            .with_dimensions(768);
253
254        let cloned = opts.clone();
255        assert_eq!(opts.api_key, cloned.api_key);
256        assert_eq!(opts.dimensions, cloned.dimensions);
257    }
258}