inferd_proto/embed/
request.rs1use crate::error::ProtoError;
9use serde::{Deserialize, Serialize};
10
11#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
20#[serde(rename_all = "snake_case")]
21pub enum EmbedTask {
22 RetrievalQuery,
24 RetrievalDocument,
26 Similarity,
28 Classification,
30 Clustering,
32 QuestionAnswering,
34 FactVerification,
36 CodeRetrievalQuery,
38 #[serde(other)]
40 Other,
41}
42
43#[derive(Debug, Clone, PartialEq, Default, Serialize, Deserialize)]
48pub struct EmbedRequest {
49 #[serde(default, skip_serializing_if = "String::is_empty")]
51 pub id: String,
52
53 pub input: Vec<String>,
57
58 #[serde(default, skip_serializing_if = "Option::is_none")]
63 pub dimensions: Option<u32>,
64
65 #[serde(default, skip_serializing_if = "Option::is_none")]
68 pub task: Option<EmbedTask>,
69}
70
71#[derive(Debug, Clone, PartialEq)]
77pub struct EmbedResolved {
78 pub id: String,
80 pub input: Vec<String>,
82 pub dimensions: Option<u32>,
84 pub task: Option<EmbedTask>,
86}
87
88impl EmbedRequest {
89 pub fn resolve(self) -> Result<EmbedResolved, ProtoError> {
93 if self.input.is_empty() {
94 return Err(ProtoError::InvalidRequest("input must not be empty".into()));
95 }
96 for (i, s) in self.input.iter().enumerate() {
97 if s.is_empty() {
98 return Err(ProtoError::InvalidRequest(format!(
99 "input[{i}] must not be empty"
100 )));
101 }
102 }
103 if matches!(self.task, Some(EmbedTask::Other)) {
104 return Err(ProtoError::InvalidRequest(
105 "task uses an unknown variant".into(),
106 ));
107 }
108 Ok(EmbedResolved {
109 id: self.id,
110 input: self.input,
111 dimensions: self.dimensions,
112 task: self.task,
113 })
114 }
115}
116
117#[cfg(test)]
118mod tests {
119 use super::*;
120
121 #[test]
122 fn rejects_empty_input() {
123 let req = EmbedRequest {
124 id: "r1".into(),
125 input: vec![],
126 ..Default::default()
127 };
128 let err = req.resolve().unwrap_err();
129 assert!(matches!(err, ProtoError::InvalidRequest(_)));
130 }
131
132 #[test]
133 fn rejects_empty_inner_string() {
134 let req = EmbedRequest {
135 id: "r1".into(),
136 input: vec!["hello".into(), String::new()],
137 ..Default::default()
138 };
139 let err = req.resolve().unwrap_err();
140 assert!(matches!(err, ProtoError::InvalidRequest(_)));
141 }
142
143 #[test]
144 fn accepts_minimal_request() {
145 let req = EmbedRequest {
146 id: "r1".into(),
147 input: vec!["hello".into()],
148 ..Default::default()
149 };
150 let resolved = req.resolve().unwrap();
151 assert_eq!(resolved.input.len(), 1);
152 assert!(resolved.dimensions.is_none());
153 assert!(resolved.task.is_none());
154 }
155
156 #[test]
157 fn parses_full_request_json() {
158 let s = r#"{
159 "id": "r1",
160 "input": ["a", "b"],
161 "dimensions": 256,
162 "task": "retrieval_document"
163 }"#;
164 let req: EmbedRequest = serde_json::from_str(s).unwrap();
165 assert_eq!(req.input.len(), 2);
166 assert_eq!(req.dimensions, Some(256));
167 assert_eq!(req.task, Some(EmbedTask::RetrievalDocument));
168 }
169
170 #[test]
171 fn unknown_task_round_trips_as_other() {
172 let s = r#"{"id":"r1","input":["x"],"task":"some_future_task"}"#;
173 let req: EmbedRequest = serde_json::from_str(s).unwrap();
174 assert_eq!(req.task, Some(EmbedTask::Other));
175 assert!(req.resolve().is_err());
178 }
179
180 #[test]
181 fn skips_serializing_optional_fields_when_unset() {
182 let req = EmbedRequest {
183 id: "r1".into(),
184 input: vec!["hello".into()],
185 dimensions: None,
186 task: None,
187 };
188 let s = serde_json::to_string(&req).unwrap();
189 assert!(!s.contains("dimensions"));
190 assert!(!s.contains("task"));
191 }
192}