use serde::{Deserialize, Serialize};
use std::collections::HashMap;
#[derive(Debug, Clone, Serialize, Deserialize, Default)]
pub struct EmbeddingOptions {
#[serde(skip_serializing_if = "Option::is_none")]
pub user: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub encoding_format: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub dimensions: Option<u32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub api_key: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub api_base: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub timeout: Option<u64>,
#[serde(skip_serializing_if = "Option::is_none")]
pub headers: Option<HashMap<String, String>>,
#[serde(skip_serializing_if = "Option::is_none")]
pub task_type: Option<String>,
#[serde(default)]
pub extra_params: HashMap<String, serde_json::Value>,
}
impl EmbeddingOptions {
pub fn new() -> Self {
Self::default()
}
pub fn with_user(mut self, user: impl Into<String>) -> Self {
self.user = Some(user.into());
self
}
pub fn with_encoding_format(mut self, format: impl Into<String>) -> Self {
self.encoding_format = Some(format.into());
self
}
pub fn with_dimensions(mut self, dimensions: u32) -> Self {
self.dimensions = Some(dimensions);
self
}
pub fn with_api_key(mut self, api_key: impl Into<String>) -> Self {
self.api_key = Some(api_key.into());
self
}
pub fn with_api_base(mut self, api_base: impl Into<String>) -> Self {
self.api_base = Some(api_base.into());
self
}
pub fn with_timeout(mut self, timeout: u64) -> Self {
self.timeout = Some(timeout);
self
}
pub fn with_headers(mut self, headers: HashMap<String, String>) -> Self {
self.headers = Some(headers);
self
}
pub fn with_header(mut self, key: impl Into<String>, value: impl Into<String>) -> Self {
self.headers
.get_or_insert_with(HashMap::new)
.insert(key.into(), value.into());
self
}
pub fn with_task_type(mut self, task_type: impl Into<String>) -> Self {
self.task_type = Some(task_type.into());
self
}
pub fn with_extra_param(
mut self,
key: impl Into<String>,
value: impl Into<serde_json::Value>,
) -> Self {
self.extra_params.insert(key.into(), value.into());
self
}
pub fn with_extra_params(mut self, params: HashMap<String, serde_json::Value>) -> Self {
self.extra_params.extend(params);
self
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_embedding_options_default() {
let opts = EmbeddingOptions::default();
assert!(opts.user.is_none());
assert!(opts.encoding_format.is_none());
assert!(opts.dimensions.is_none());
assert!(opts.api_key.is_none());
assert!(opts.api_base.is_none());
assert!(opts.timeout.is_none());
assert!(opts.headers.is_none());
assert!(opts.task_type.is_none());
assert!(opts.extra_params.is_empty());
}
#[test]
fn test_embedding_options_builder() {
let opts = EmbeddingOptions::new()
.with_user("user-123")
.with_encoding_format("float")
.with_dimensions(1536)
.with_api_key("sk-test")
.with_api_base("https://api.example.com")
.with_timeout(30)
.with_task_type("RETRIEVAL_QUERY");
assert_eq!(opts.user, Some("user-123".to_string()));
assert_eq!(opts.encoding_format, Some("float".to_string()));
assert_eq!(opts.dimensions, Some(1536));
assert_eq!(opts.api_key, Some("sk-test".to_string()));
assert_eq!(opts.api_base, Some("https://api.example.com".to_string()));
assert_eq!(opts.timeout, Some(30));
assert_eq!(opts.task_type, Some("RETRIEVAL_QUERY".to_string()));
}
#[test]
fn test_embedding_options_headers() {
let opts = EmbeddingOptions::new()
.with_header("X-Custom-Header", "value1")
.with_header("X-Another-Header", "value2");
let headers = opts.headers.unwrap();
assert_eq!(headers.get("X-Custom-Header"), Some(&"value1".to_string()));
assert_eq!(headers.get("X-Another-Header"), Some(&"value2".to_string()));
}
#[test]
fn test_embedding_options_bulk_headers() {
let mut headers = HashMap::new();
headers.insert("Header1".to_string(), "Value1".to_string());
headers.insert("Header2".to_string(), "Value2".to_string());
let opts = EmbeddingOptions::new().with_headers(headers.clone());
assert_eq!(opts.headers, Some(headers));
}
#[test]
fn test_embedding_options_extra_params() {
let opts = EmbeddingOptions::new()
.with_extra_param("custom_field", serde_json::json!("value"))
.with_extra_param("numeric_field", serde_json::json!(42));
assert_eq!(
opts.extra_params.get("custom_field"),
Some(&serde_json::json!("value"))
);
assert_eq!(
opts.extra_params.get("numeric_field"),
Some(&serde_json::json!(42))
);
}
#[test]
fn test_embedding_options_serialization() {
let opts = EmbeddingOptions::new()
.with_dimensions(256)
.with_encoding_format("base64");
let json = serde_json::to_value(&opts).unwrap();
assert_eq!(json["dimensions"], 256);
assert_eq!(json["encoding_format"], "base64");
assert!(!json.as_object().unwrap().contains_key("user"));
assert!(!json.as_object().unwrap().contains_key("api_key"));
}
#[test]
fn test_embedding_options_deserialization() {
let json = r#"{
"user": "test-user",
"dimensions": 512,
"encoding_format": "float"
}"#;
let opts: EmbeddingOptions = serde_json::from_str(json).unwrap();
assert_eq!(opts.user, Some("test-user".to_string()));
assert_eq!(opts.dimensions, Some(512));
assert_eq!(opts.encoding_format, Some("float".to_string()));
}
#[test]
fn test_embedding_options_clone() {
let opts = EmbeddingOptions::new()
.with_api_key("key")
.with_dimensions(768);
let cloned = opts.clone();
assert_eq!(opts.api_key, cloned.api_key);
assert_eq!(opts.dimensions, cloned.dimensions);
}
}