use crate::error::ProtoError;
use serde::{Deserialize, Serialize};
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum EmbedTask {
RetrievalQuery,
RetrievalDocument,
Similarity,
Classification,
Clustering,
QuestionAnswering,
FactVerification,
CodeRetrievalQuery,
#[serde(other)]
Other,
}
#[derive(Debug, Clone, PartialEq, Default, Serialize, Deserialize)]
pub struct EmbedRequest {
#[serde(default, skip_serializing_if = "String::is_empty")]
pub id: String,
pub input: Vec<String>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub dimensions: Option<u32>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub task: Option<EmbedTask>,
}
#[derive(Debug, Clone, PartialEq)]
pub struct EmbedResolved {
pub id: String,
pub input: Vec<String>,
pub dimensions: Option<u32>,
pub task: Option<EmbedTask>,
}
impl EmbedRequest {
pub fn resolve(self) -> Result<EmbedResolved, ProtoError> {
if self.input.is_empty() {
return Err(ProtoError::InvalidRequest("input must not be empty".into()));
}
for (i, s) in self.input.iter().enumerate() {
if s.is_empty() {
return Err(ProtoError::InvalidRequest(format!(
"input[{i}] must not be empty"
)));
}
}
if matches!(self.task, Some(EmbedTask::Other)) {
return Err(ProtoError::InvalidRequest(
"task uses an unknown variant".into(),
));
}
Ok(EmbedResolved {
id: self.id,
input: self.input,
dimensions: self.dimensions,
task: self.task,
})
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn rejects_empty_input() {
let req = EmbedRequest {
id: "r1".into(),
input: vec![],
..Default::default()
};
let err = req.resolve().unwrap_err();
assert!(matches!(err, ProtoError::InvalidRequest(_)));
}
#[test]
fn rejects_empty_inner_string() {
let req = EmbedRequest {
id: "r1".into(),
input: vec!["hello".into(), String::new()],
..Default::default()
};
let err = req.resolve().unwrap_err();
assert!(matches!(err, ProtoError::InvalidRequest(_)));
}
#[test]
fn accepts_minimal_request() {
let req = EmbedRequest {
id: "r1".into(),
input: vec!["hello".into()],
..Default::default()
};
let resolved = req.resolve().unwrap();
assert_eq!(resolved.input.len(), 1);
assert!(resolved.dimensions.is_none());
assert!(resolved.task.is_none());
}
#[test]
fn parses_full_request_json() {
let s = r#"{
"id": "r1",
"input": ["a", "b"],
"dimensions": 256,
"task": "retrieval_document"
}"#;
let req: EmbedRequest = serde_json::from_str(s).unwrap();
assert_eq!(req.input.len(), 2);
assert_eq!(req.dimensions, Some(256));
assert_eq!(req.task, Some(EmbedTask::RetrievalDocument));
}
#[test]
fn unknown_task_round_trips_as_other() {
let s = r#"{"id":"r1","input":["x"],"task":"some_future_task"}"#;
let req: EmbedRequest = serde_json::from_str(s).unwrap();
assert_eq!(req.task, Some(EmbedTask::Other));
assert!(req.resolve().is_err());
}
#[test]
fn skips_serializing_optional_fields_when_unset() {
let req = EmbedRequest {
id: "r1".into(),
input: vec!["hello".into()],
dimensions: None,
task: None,
};
let s = serde_json::to_string(&req).unwrap();
assert!(!s.contains("dimensions"));
assert!(!s.contains("task"));
}
}