litellm_rs/core/embedding/
options.rs1use serde::{Deserialize, Serialize};
6use std::collections::HashMap;
7
8#[derive(Debug, Clone, Serialize, Deserialize, Default)]
24pub struct EmbeddingOptions {
25 #[serde(skip_serializing_if = "Option::is_none")]
27 pub user: Option<String>,
28
29 #[serde(skip_serializing_if = "Option::is_none")]
31 pub encoding_format: Option<String>,
32
33 #[serde(skip_serializing_if = "Option::is_none")]
36 pub dimensions: Option<u32>,
37
38 #[serde(skip_serializing_if = "Option::is_none")]
41 pub api_key: Option<String>,
42
43 #[serde(skip_serializing_if = "Option::is_none")]
46 pub api_base: Option<String>,
47
48 #[serde(skip_serializing_if = "Option::is_none")]
50 pub timeout: Option<u64>,
51
52 #[serde(skip_serializing_if = "Option::is_none")]
54 pub headers: Option<HashMap<String, String>>,
55
56 #[serde(skip_serializing_if = "Option::is_none")]
59 pub task_type: Option<String>,
60
61 #[serde(default)]
63 pub extra_params: HashMap<String, serde_json::Value>,
64}
65
66impl EmbeddingOptions {
67 pub fn new() -> Self {
69 Self::default()
70 }
71
72 pub fn with_user(mut self, user: impl Into<String>) -> Self {
74 self.user = Some(user.into());
75 self
76 }
77
78 pub fn with_encoding_format(mut self, format: impl Into<String>) -> Self {
80 self.encoding_format = Some(format.into());
81 self
82 }
83
84 pub fn with_dimensions(mut self, dimensions: u32) -> Self {
86 self.dimensions = Some(dimensions);
87 self
88 }
89
90 pub fn with_api_key(mut self, api_key: impl Into<String>) -> Self {
92 self.api_key = Some(api_key.into());
93 self
94 }
95
96 pub fn with_api_base(mut self, api_base: impl Into<String>) -> Self {
98 self.api_base = Some(api_base.into());
99 self
100 }
101
102 pub fn with_timeout(mut self, timeout: u64) -> Self {
104 self.timeout = Some(timeout);
105 self
106 }
107
108 pub fn with_headers(mut self, headers: HashMap<String, String>) -> Self {
110 self.headers = Some(headers);
111 self
112 }
113
114 pub fn with_header(mut self, key: impl Into<String>, value: impl Into<String>) -> Self {
116 self.headers
117 .get_or_insert_with(HashMap::new)
118 .insert(key.into(), value.into());
119 self
120 }
121
122 pub fn with_task_type(mut self, task_type: impl Into<String>) -> Self {
124 self.task_type = Some(task_type.into());
125 self
126 }
127
128 pub fn with_extra_param(
130 mut self,
131 key: impl Into<String>,
132 value: impl Into<serde_json::Value>,
133 ) -> Self {
134 self.extra_params.insert(key.into(), value.into());
135 self
136 }
137
138 pub fn with_extra_params(mut self, params: HashMap<String, serde_json::Value>) -> Self {
140 self.extra_params.extend(params);
141 self
142 }
143}
144
145#[cfg(test)]
146mod tests {
147 use super::*;
148
149 #[test]
150 fn test_embedding_options_default() {
151 let opts = EmbeddingOptions::default();
152 assert!(opts.user.is_none());
153 assert!(opts.encoding_format.is_none());
154 assert!(opts.dimensions.is_none());
155 assert!(opts.api_key.is_none());
156 assert!(opts.api_base.is_none());
157 assert!(opts.timeout.is_none());
158 assert!(opts.headers.is_none());
159 assert!(opts.task_type.is_none());
160 assert!(opts.extra_params.is_empty());
161 }
162
163 #[test]
164 fn test_embedding_options_builder() {
165 let opts = EmbeddingOptions::new()
166 .with_user("user-123")
167 .with_encoding_format("float")
168 .with_dimensions(1536)
169 .with_api_key("sk-test")
170 .with_api_base("https://api.example.com")
171 .with_timeout(30)
172 .with_task_type("RETRIEVAL_QUERY");
173
174 assert_eq!(opts.user, Some("user-123".to_string()));
175 assert_eq!(opts.encoding_format, Some("float".to_string()));
176 assert_eq!(opts.dimensions, Some(1536));
177 assert_eq!(opts.api_key, Some("sk-test".to_string()));
178 assert_eq!(opts.api_base, Some("https://api.example.com".to_string()));
179 assert_eq!(opts.timeout, Some(30));
180 assert_eq!(opts.task_type, Some("RETRIEVAL_QUERY".to_string()));
181 }
182
183 #[test]
184 fn test_embedding_options_headers() {
185 let opts = EmbeddingOptions::new()
186 .with_header("X-Custom-Header", "value1")
187 .with_header("X-Another-Header", "value2");
188
189 let headers = opts.headers.unwrap();
190 assert_eq!(headers.get("X-Custom-Header"), Some(&"value1".to_string()));
191 assert_eq!(headers.get("X-Another-Header"), Some(&"value2".to_string()));
192 }
193
194 #[test]
195 fn test_embedding_options_bulk_headers() {
196 let mut headers = HashMap::new();
197 headers.insert("Header1".to_string(), "Value1".to_string());
198 headers.insert("Header2".to_string(), "Value2".to_string());
199
200 let opts = EmbeddingOptions::new().with_headers(headers.clone());
201 assert_eq!(opts.headers, Some(headers));
202 }
203
204 #[test]
205 fn test_embedding_options_extra_params() {
206 let opts = EmbeddingOptions::new()
207 .with_extra_param("custom_field", serde_json::json!("value"))
208 .with_extra_param("numeric_field", serde_json::json!(42));
209
210 assert_eq!(
211 opts.extra_params.get("custom_field"),
212 Some(&serde_json::json!("value"))
213 );
214 assert_eq!(
215 opts.extra_params.get("numeric_field"),
216 Some(&serde_json::json!(42))
217 );
218 }
219
220 #[test]
221 fn test_embedding_options_serialization() {
222 let opts = EmbeddingOptions::new()
223 .with_dimensions(256)
224 .with_encoding_format("base64");
225
226 let json = serde_json::to_value(&opts).unwrap();
227 assert_eq!(json["dimensions"], 256);
228 assert_eq!(json["encoding_format"], "base64");
229 assert!(!json.as_object().unwrap().contains_key("user"));
231 assert!(!json.as_object().unwrap().contains_key("api_key"));
232 }
233
234 #[test]
235 fn test_embedding_options_deserialization() {
236 let json = r#"{
237 "user": "test-user",
238 "dimensions": 512,
239 "encoding_format": "float"
240 }"#;
241
242 let opts: EmbeddingOptions = serde_json::from_str(json).unwrap();
243 assert_eq!(opts.user, Some("test-user".to_string()));
244 assert_eq!(opts.dimensions, Some(512));
245 assert_eq!(opts.encoding_format, Some("float".to_string()));
246 }
247
248 #[test]
249 fn test_embedding_options_clone() {
250 let opts = EmbeddingOptions::new()
251 .with_api_key("key")
252 .with_dimensions(768);
253
254 let cloned = opts.clone();
255 assert_eq!(opts.api_key, cloned.api_key);
256 assert_eq!(opts.dimensions, cloned.dimensions);
257 }
258}