Skip to main content

inferd_proto/embed/
request.rs

1//! Embed request envelope and validation.
2//!
3//! Per ADR 0017 §"Embed request". Required fields: `id`, `input`
4//! (non-empty, each entry non-empty). Optional: `dimensions` (Matryoshka
5//! truncation length, validated against the model's supported set at
6//! the backend layer), `task` (task-prefix hint).
7
8use crate::error::ProtoError;
9use serde::{Deserialize, Serialize};
10
11/// Task-prefix hint for embedding models trained with task-aware
12/// prefixes (e.g. EmbeddingGemma). Backends that don't recognise the
13/// task ignore the field; the daemon applies the engine-specific
14/// prefix on behalf of the consumer per ADR 0013.
15///
16/// Forward-compatibility: unknown task variants land in `Other` so
17/// older daemons / clients tolerate task hints added in later v0.x
18/// revisions without rejecting at parse time.
19#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
20#[serde(rename_all = "snake_case")]
21pub enum EmbedTask {
22    /// Encoding a query that will be matched against indexed documents.
23    RetrievalQuery,
24    /// Encoding a document that will be indexed for retrieval.
25    RetrievalDocument,
26    /// Encoding text for similarity scoring (symmetric).
27    Similarity,
28    /// Encoding text whose label will be predicted.
29    Classification,
30    /// Encoding text for unsupervised clustering.
31    Clustering,
32    /// Encoding for question-answering pair scoring.
33    QuestionAnswering,
34    /// Encoding text for fact-verification scoring.
35    FactVerification,
36    /// Encoding code for code-search retrieval.
37    CodeRetrievalQuery,
38    /// Forward-compatible escape hatch — unknown task strings deserialise here.
39    #[serde(other)]
40    Other,
41}
42
43/// The embed request envelope sent by clients.
44///
45/// `Default` is intentionally available for `..Default::default()`
46/// shorthand; callers must populate `id` and `input` before sending.
47#[derive(Debug, Clone, PartialEq, Default, Serialize, Deserialize)]
48pub struct EmbedRequest {
49    /// Caller-assigned correlation id; echoed on the response frame.
50    #[serde(default, skip_serializing_if = "String::is_empty")]
51    pub id: String,
52
53    /// One or more input strings to embed. Each is encoded
54    /// independently; the response's `embeddings[i]` corresponds to
55    /// `input[i]`.
56    pub input: Vec<String>,
57
58    /// Matryoshka truncation length. EmbeddingGemma supports
59    /// `768 | 512 | 256 | 128`; backends validate against their own
60    /// supported set and emit `invalid_request` if the value is
61    /// rejected. Omitted means "model default".
62    #[serde(default, skip_serializing_if = "Option::is_none")]
63    pub dimensions: Option<u32>,
64
65    /// Task-prefix hint. Backends that don't apply task prefixes
66    /// ignore this field.
67    #[serde(default, skip_serializing_if = "Option::is_none")]
68    pub task: Option<EmbedTask>,
69}
70
71/// `EmbedRequest` with semantic validation completed.
72///
73/// `dimensions` validation against the active backend's supported set
74/// happens at the backend layer (different models support different
75/// MRL widths), not here.
76#[derive(Debug, Clone, PartialEq)]
77pub struct EmbedResolved {
78    /// Caller-assigned correlation id.
79    pub id: String,
80    /// Validated input strings.
81    pub input: Vec<String>,
82    /// Truncation length, if set.
83    pub dimensions: Option<u32>,
84    /// Task hint, if set.
85    pub task: Option<EmbedTask>,
86}
87
88impl EmbedRequest {
89    /// Validate the request envelope. Rejects empty `input` and
90    /// empty inner strings. Does NOT validate `dimensions` against any
91    /// model-specific supported set — backends do that.
92    pub fn resolve(self) -> Result<EmbedResolved, ProtoError> {
93        if self.input.is_empty() {
94            return Err(ProtoError::InvalidRequest("input must not be empty".into()));
95        }
96        for (i, s) in self.input.iter().enumerate() {
97            if s.is_empty() {
98                return Err(ProtoError::InvalidRequest(format!(
99                    "input[{i}] must not be empty"
100                )));
101            }
102        }
103        if matches!(self.task, Some(EmbedTask::Other)) {
104            return Err(ProtoError::InvalidRequest(
105                "task uses an unknown variant".into(),
106            ));
107        }
108        Ok(EmbedResolved {
109            id: self.id,
110            input: self.input,
111            dimensions: self.dimensions,
112            task: self.task,
113        })
114    }
115}
116
117#[cfg(test)]
118mod tests {
119    use super::*;
120
121    #[test]
122    fn rejects_empty_input() {
123        let req = EmbedRequest {
124            id: "r1".into(),
125            input: vec![],
126            ..Default::default()
127        };
128        let err = req.resolve().unwrap_err();
129        assert!(matches!(err, ProtoError::InvalidRequest(_)));
130    }
131
132    #[test]
133    fn rejects_empty_inner_string() {
134        let req = EmbedRequest {
135            id: "r1".into(),
136            input: vec!["hello".into(), String::new()],
137            ..Default::default()
138        };
139        let err = req.resolve().unwrap_err();
140        assert!(matches!(err, ProtoError::InvalidRequest(_)));
141    }
142
143    #[test]
144    fn accepts_minimal_request() {
145        let req = EmbedRequest {
146            id: "r1".into(),
147            input: vec!["hello".into()],
148            ..Default::default()
149        };
150        let resolved = req.resolve().unwrap();
151        assert_eq!(resolved.input.len(), 1);
152        assert!(resolved.dimensions.is_none());
153        assert!(resolved.task.is_none());
154    }
155
156    #[test]
157    fn parses_full_request_json() {
158        let s = r#"{
159            "id": "r1",
160            "input": ["a", "b"],
161            "dimensions": 256,
162            "task": "retrieval_document"
163        }"#;
164        let req: EmbedRequest = serde_json::from_str(s).unwrap();
165        assert_eq!(req.input.len(), 2);
166        assert_eq!(req.dimensions, Some(256));
167        assert_eq!(req.task, Some(EmbedTask::RetrievalDocument));
168    }
169
170    #[test]
171    fn unknown_task_round_trips_as_other() {
172        let s = r#"{"id":"r1","input":["x"],"task":"some_future_task"}"#;
173        let req: EmbedRequest = serde_json::from_str(s).unwrap();
174        assert_eq!(req.task, Some(EmbedTask::Other));
175        // resolve rejects Other so the daemon doesn't silently apply
176        // a task it doesn't know.
177        assert!(req.resolve().is_err());
178    }
179
180    #[test]
181    fn skips_serializing_optional_fields_when_unset() {
182        let req = EmbedRequest {
183            id: "r1".into(),
184            input: vec!["hello".into()],
185            dimensions: None,
186            task: None,
187        };
188        let s = serde_json::to_string(&req).unwrap();
189        assert!(!s.contains("dimensions"));
190        assert!(!s.contains("task"));
191    }
192}