Skip to main content

inferd_proto/embed/
response.rs

1//! Embed response frame schema.
2//!
3//! Per ADR 0017 §"Embed response". Single terminal frame per request
4//! — embeddings are not streamed. Two variants: `Embeddings` (success)
5//! and `Error` (failure).
6
7use serde::{Deserialize, Serialize};
8
9/// Token-count usage report carried on `embeddings` frames.
10///
11/// Embed requests have no output tokens (the output is a vector, not
12/// a generation), so only `input_tokens` is reported.
13#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
14pub struct EmbedUsage {
15    /// Tokens consumed by the input strings (sum across the batch).
16    pub input_tokens: u32,
17}
18
19/// Embed-specific error-code taxonomy.
20///
21/// Superset of v1's `ErrorCode` (kept independent so the v1 enum
22/// stays frozen per ADR 0008). The only embed-specific addition is
23/// `embed_unsupported`, returned in the belt-and-braces case where a
24/// daemon configured with a generation-only backend somehow receives
25/// an embed request (the embed socket should not have been bound in
26/// that configuration — the error is a fail-safe).
27#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
28#[serde(rename_all = "snake_case")]
29pub enum EmbedErrorCode {
30    /// Admission queue full at submit time.
31    QueueFull,
32    /// Selected backend errored before or during embedding.
33    BackendUnavailable,
34    /// Request failed validation (empty input, unsupported dimensions,
35    /// unknown task, etc.).
36    InvalidRequest,
37    /// Frame exceeded the 64 MiB cap.
38    FrameTooLarge,
39    /// Daemon-side bug or unexpected condition.
40    Internal,
41    /// The active backend doesn't support embeddings.
42    EmbedUnsupported,
43}
44
45/// One frame on the embed response stream.
46///
47/// Always terminal — there are exactly two outcomes, success
48/// (`Embeddings`) or failure (`Error`). The connection stays open for
49/// the next request.
50#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
51#[serde(tag = "type", rename_all = "snake_case")]
52pub enum EmbedResponse {
53    /// Successful embedding result.
54    Embeddings {
55        /// Request id.
56        id: String,
57        /// One vector per input string, in the same order as the
58        /// request's `input`. Inner vectors all share the same
59        /// `dimensions` length.
60        embeddings: Vec<Vec<f32>>,
61        /// Actual length of each inner vector after any MRL truncation.
62        dimensions: u32,
63        /// Backend-reported model name (e.g. `"embeddinggemma-300m"`).
64        model: String,
65        /// Token-count usage.
66        usage: EmbedUsage,
67        /// `Backend::name()` of the adapter that served this request.
68        ///
69        /// Diagnostic only — apps must not branch on this (ADR 0007).
70        backend: String,
71    },
72    /// Failure terminal frame.
73    Error {
74        /// Request id.
75        id: String,
76        /// Machine-readable classification.
77        code: EmbedErrorCode,
78        /// Human-readable description.
79        message: String,
80    },
81}
82
83impl EmbedResponse {
84    /// Correlation id of the frame regardless of variant.
85    pub fn id(&self) -> &str {
86        match self {
87            EmbedResponse::Embeddings { id, .. } | EmbedResponse::Error { id, .. } => id,
88        }
89    }
90
91    /// `true` if this frame represents a successful embedding result.
92    pub fn is_ok(&self) -> bool {
93        matches!(self, EmbedResponse::Embeddings { .. })
94    }
95}
96
97#[cfg(test)]
98mod tests {
99    use super::*;
100
101    #[test]
102    fn embeddings_variant_round_trips() {
103        let resp = EmbedResponse::Embeddings {
104            id: "r1".into(),
105            embeddings: vec![vec![0.1, 0.2, 0.3], vec![0.4, 0.5, 0.6]],
106            dimensions: 3,
107            model: "embeddinggemma-300m".into(),
108            usage: EmbedUsage { input_tokens: 12 },
109            backend: "llamacpp".into(),
110        };
111        let s = serde_json::to_string(&resp).unwrap();
112        let back: EmbedResponse = serde_json::from_str(&s).unwrap();
113        assert_eq!(resp, back);
114        assert!(resp.is_ok());
115        assert_eq!(resp.id(), "r1");
116    }
117
118    #[test]
119    fn error_variant_round_trips() {
120        let resp = EmbedResponse::Error {
121            id: "r1".into(),
122            code: EmbedErrorCode::InvalidRequest,
123            message: "dimensions must be one of [128, 256, 512, 768]".into(),
124        };
125        let s = serde_json::to_string(&resp).unwrap();
126        let back: EmbedResponse = serde_json::from_str(&s).unwrap();
127        assert_eq!(resp, back);
128        assert!(!resp.is_ok());
129    }
130
131    #[test]
132    fn embeddings_serializes_with_type_tag() {
133        let resp = EmbedResponse::Embeddings {
134            id: "r1".into(),
135            embeddings: vec![vec![0.1]],
136            dimensions: 1,
137            model: "m".into(),
138            usage: EmbedUsage { input_tokens: 1 },
139            backend: "llamacpp".into(),
140        };
141        let v: serde_json::Value = serde_json::to_value(&resp).unwrap();
142        assert_eq!(v["type"], "embeddings");
143        assert_eq!(v["dimensions"], 1);
144    }
145
146    #[test]
147    fn error_code_serializes_snake_case() {
148        let s = serde_json::to_string(&EmbedErrorCode::EmbedUnsupported).unwrap();
149        assert_eq!(s, "\"embed_unsupported\"");
150    }
151}