Skip to main content

reddb_server/runtime/ai/
grpc_ask_message.rs

1//! `GrpcAskMessage` — pure builder pinning the typed gRPC `AskReply`
2//! shape (issue #407, PRD #391).
3//!
4//! Deep module: no I/O, no transport, no codegen dependency. Defines
5//! the typed gRPC message shape and a pure converter from the
6//! canonical [`super::ask_response_envelope::AskResult`] to the typed
7//! reply. The proto edit + service-impl wiring slice is deferred —
8//! pinning the shape here means the wiring slice cannot quietly drop
9//! `citations`, reorder field tags, or re-shape `validation` without
10//! the tests in this file failing first.
11//!
12//! ## Why a separate module
13//!
14//! The current gRPC `Ask` RPC (`service_impl.rs::ask`) returns a
15//! generic `PayloadReply { payload_json }` — the legacy bucketed-only
16//! shape. PRD #391's AC for #407 requires the full ASK schema as a
17//! typed gRPC message so existing gRPC clients (Go's `pgx`-style
18//! drivers, JVM `io.grpc.*`, dotnet `Grpc.Net.Client`) get field-typed
19//! deserialisation rather than a JSON blob to parse twice.
20//!
21//! Two things must stay aligned across the slice:
22//!
23//! 1. The proto field numbers are an external API — once a tag is
24//!    shipped, it cannot change without breaking compiled clients.
25//!    Pinned by [`PROTO_TAGS`] + tests.
26//! 2. Field set must match [`super::ask_response_envelope`] one-to-one
27//!    so the JSON envelope (used by JSON-RPC #406, MCP non-stream
28//!    #409, PG-wire #408) and the gRPC reply describe the same data.
29//!    Pinned by `field_set_matches_json_envelope`.
30//!
31//! ## Field tags (proto3)
32//!
33//! `AskReply` (top-level):
34//! - 1 `string answer`
35//! - 2 `string sources_flat_json`   (JSON-encoded array, same bytes as the envelope's `sources_flat`)
36//! - 3 `repeated Citation citations`
37//! - 4 `Validation validation`
38//! - 5 `string provider`
39//! - 6 `string model`
40//! - 7 `uint32 prompt_tokens`
41//! - 8 `uint32 completion_tokens`
42//! - 9 `double cost_usd`
43//! - 10 `bool cache_hit`
44//! - 11 `string mode`               ("strict" | "lenient" — effective)
45//! - 12 `uint32 retry_count`
46//!
47//! `Citation`:
48//! - 1 `uint32 marker`
49//! - 2 `string urn`
50//!
51//! `Validation`:
52//! - 1 `bool ok`
53//! - 2 `repeated ValidationItem warnings`
54//! - 3 `repeated ValidationItem errors`
55//!
56//! `ValidationItem`:
57//! - 1 `string kind`     ("malformed" | "out_of_range")
58//! - 2 `string detail`
59//!
60//! `sources_flat` is carried as a single JSON string (`sources_flat_json`)
61//! rather than a `repeated SourceRow` to keep parity with the envelope
62//! shape and avoid forcing per-row payload re-encoding. Clients that
63//! want structured rows parse the JSON; the same bytes already flow on
64//! JSON-RPC #406, MCP #409, and PG-wire #408.
65//!
66//! Determinism = seed (#400) is *not* in the reply. Mirrors the JSON
67//! envelope's omission — see `ask_response_envelope` rationale.
68
69use super::ask_response_envelope::{
70    AskResult, Citation as EnvCitation, Mode, SourceRow, Validation as EnvValidation,
71    ValidationError, ValidationWarning,
72};
73
74/// One citation row in the typed gRPC reply.
75#[derive(Debug, Clone, PartialEq)]
76pub struct GrpcCitation {
77    pub marker: u32,
78    pub urn: String,
79}
80
81/// One validation item (warning or error).
82#[derive(Debug, Clone, PartialEq)]
83pub struct GrpcValidationItem {
84    pub kind: String,
85    pub detail: String,
86}
87
88/// Validation block.
89#[derive(Debug, Clone, PartialEq)]
90pub struct GrpcValidation {
91    pub ok: bool,
92    pub warnings: Vec<GrpcValidationItem>,
93    pub errors: Vec<GrpcValidationItem>,
94}
95
96/// Typed gRPC `AskReply` body.
97#[derive(Debug, Clone, PartialEq)]
98pub struct GrpcAskReply {
99    pub answer: String,
100    pub sources_flat_json: String,
101    pub citations: Vec<GrpcCitation>,
102    pub validation: GrpcValidation,
103    pub provider: String,
104    pub model: String,
105    pub prompt_tokens: u32,
106    pub completion_tokens: u32,
107    pub cost_usd: f64,
108    pub cache_hit: bool,
109    pub mode: String,
110    pub retry_count: u32,
111}
112
113/// Proto field tags for `AskReply` — pinned constants. Editing any of
114/// these is a wire-breaking change and the tests in this module will
115/// catch it.
116pub mod proto_tags {
117    pub mod ask_reply {
118        pub const ANSWER: u32 = 1;
119        pub const SOURCES_FLAT_JSON: u32 = 2;
120        pub const CITATIONS: u32 = 3;
121        pub const VALIDATION: u32 = 4;
122        pub const PROVIDER: u32 = 5;
123        pub const MODEL: u32 = 6;
124        pub const PROMPT_TOKENS: u32 = 7;
125        pub const COMPLETION_TOKENS: u32 = 8;
126        pub const COST_USD: u32 = 9;
127        pub const CACHE_HIT: u32 = 10;
128        pub const MODE: u32 = 11;
129        pub const RETRY_COUNT: u32 = 12;
130    }
131    pub mod citation {
132        pub const MARKER: u32 = 1;
133        pub const URN: u32 = 2;
134    }
135    pub mod validation {
136        pub const OK: u32 = 1;
137        pub const WARNINGS: u32 = 2;
138        pub const ERRORS: u32 = 3;
139    }
140    pub mod validation_item {
141        pub const KIND: u32 = 1;
142        pub const DETAIL: u32 = 2;
143    }
144}
145
146/// Build the typed gRPC reply from the canonical `AskResult`.
147///
148/// Citation ordering, `sources_flat` ordering, and field semantics
149/// match [`super::ask_response_envelope::build`] one-to-one. Running
150/// this on byte-equal input is byte-equal output (pinned by
151/// `build_is_deterministic_across_calls`) — required by the ASK
152/// determinism contract (#400).
153pub fn build(result: &AskResult) -> GrpcAskReply {
154    let mut citations: Vec<GrpcCitation> = result
155        .citations
156        .iter()
157        .map(|c: &EnvCitation| GrpcCitation {
158            marker: c.marker,
159            urn: c.urn.clone(),
160        })
161        .collect();
162    citations.sort_by_key(|c| c.marker);
163
164    GrpcAskReply {
165        answer: result.answer.clone(),
166        sources_flat_json: sources_flat_json(&result.sources_flat),
167        citations,
168        validation: validation_from(&result.validation),
169        provider: result.provider.clone(),
170        model: result.model.clone(),
171        prompt_tokens: result.prompt_tokens,
172        completion_tokens: result.completion_tokens,
173        cost_usd: result.cost_usd,
174        cache_hit: result.cache_hit,
175        mode: mode_str(result.effective_mode).to_string(),
176        retry_count: result.retry_count,
177    }
178}
179
180fn mode_str(mode: Mode) -> &'static str {
181    match mode {
182        Mode::Strict => "strict",
183        Mode::Lenient => "lenient",
184    }
185}
186
187fn validation_from(v: &EnvValidation) -> GrpcValidation {
188    GrpcValidation {
189        ok: v.ok,
190        warnings: v.warnings.iter().map(warning_item).collect(),
191        errors: v.errors.iter().map(error_item).collect(),
192    }
193}
194
195fn warning_item(w: &ValidationWarning) -> GrpcValidationItem {
196    GrpcValidationItem {
197        kind: w.kind.clone(),
198        detail: w.detail.clone(),
199    }
200}
201
202fn error_item(e: &ValidationError) -> GrpcValidationItem {
203    GrpcValidationItem {
204        kind: e.kind.clone(),
205        detail: e.detail.clone(),
206    }
207}
208
209fn sources_flat_json(rows: &[SourceRow]) -> String {
210    // Order preserved verbatim — post-RRF rank is the contract since
211    // citation `[^N]` indexes into the array, and reordering would
212    // silently break grounding. Keys alphabetised (`payload`, `urn`)
213    // to match the envelope's `BTreeMap`-backed encoder.
214    let mut out = String::from("[");
215    for (i, r) in rows.iter().enumerate() {
216        if i > 0 {
217            out.push(',');
218        }
219        out.push('{');
220        out.push_str("\"payload\":");
221        push_json_string(&mut out, &r.payload);
222        out.push(',');
223        out.push_str("\"urn\":");
224        push_json_string(&mut out, &r.urn);
225        out.push('}');
226    }
227    out.push(']');
228    out
229}
230
231fn push_json_string(out: &mut String, s: &str) {
232    out.push('"');
233    for ch in s.chars() {
234        match ch {
235            '"' => out.push_str("\\\""),
236            '\\' => out.push_str("\\\\"),
237            '\n' => out.push_str("\\n"),
238            '\r' => out.push_str("\\r"),
239            '\t' => out.push_str("\\t"),
240            c if (c as u32) < 0x20 => {
241                out.push_str(&format!("\\u{:04x}", c as u32));
242            }
243            c => out.push(c),
244        }
245    }
246    out.push('"');
247}
248
249#[cfg(test)]
250mod tests {
251    use super::proto_tags::*;
252    use super::*;
253    use crate::runtime::ai::ask_response_envelope::{
254        AskResult, Citation as EnvCitation, Mode, SourceRow, Validation as EnvValidation,
255        ValidationError, ValidationWarning,
256    };
257
258    fn sample_result() -> AskResult {
259        AskResult {
260            answer: "The capital is Lisbon [^1].".to_string(),
261            sources_flat: vec![SourceRow {
262                urn: "urn:reddb:row:cities/42".to_string(),
263                payload: "{\"name\":\"Lisbon\"}".to_string(),
264            }],
265            citations: vec![EnvCitation {
266                marker: 1,
267                urn: "urn:reddb:row:cities/42".to_string(),
268            }],
269            validation: EnvValidation {
270                ok: true,
271                warnings: vec![],
272                errors: vec![],
273            },
274            cache_hit: false,
275            provider: "openai".to_string(),
276            model: "gpt-4o-mini".to_string(),
277            prompt_tokens: 123,
278            completion_tokens: 17,
279            cost_usd: 0.0042,
280            effective_mode: Mode::Strict,
281            retry_count: 0,
282        }
283    }
284
285    #[test]
286    fn build_emits_every_top_level_field() {
287        let r = sample_result();
288        let reply = build(&r);
289        assert_eq!(reply.answer, r.answer);
290        assert_eq!(reply.provider, r.provider);
291        assert_eq!(reply.model, r.model);
292        assert_eq!(reply.prompt_tokens, r.prompt_tokens);
293        assert_eq!(reply.completion_tokens, r.completion_tokens);
294        assert_eq!(reply.cost_usd, r.cost_usd);
295        assert_eq!(reply.cache_hit, r.cache_hit);
296        assert_eq!(reply.retry_count, r.retry_count);
297        assert_eq!(reply.mode, "strict");
298        assert!(reply.validation.ok);
299        assert_eq!(reply.citations.len(), 1);
300        assert_eq!(reply.citations[0].marker, 1);
301        assert!(reply.sources_flat_json.starts_with('['));
302        assert!(reply.sources_flat_json.ends_with(']'));
303    }
304
305    #[test]
306    fn mode_strict_serialises_as_strict() {
307        let mut r = sample_result();
308        r.effective_mode = Mode::Strict;
309        assert_eq!(build(&r).mode, "strict");
310    }
311
312    #[test]
313    fn mode_lenient_serialises_as_lenient() {
314        let mut r = sample_result();
315        r.effective_mode = Mode::Lenient;
316        assert_eq!(build(&r).mode, "lenient");
317    }
318
319    #[test]
320    fn citations_sorted_by_marker_ascending() {
321        let mut r = sample_result();
322        r.citations = vec![
323            EnvCitation {
324                marker: 3,
325                urn: "urn:c".to_string(),
326            },
327            EnvCitation {
328                marker: 1,
329                urn: "urn:a".to_string(),
330            },
331            EnvCitation {
332                marker: 2,
333                urn: "urn:b".to_string(),
334            },
335        ];
336        let reply = build(&r);
337        assert_eq!(
338            reply.citations.iter().map(|c| c.marker).collect::<Vec<_>>(),
339            vec![1, 2, 3]
340        );
341    }
342
343    #[test]
344    fn citation_same_marker_is_stable() {
345        let mut r = sample_result();
346        r.citations = vec![
347            EnvCitation {
348                marker: 1,
349                urn: "urn:first".to_string(),
350            },
351            EnvCitation {
352                marker: 1,
353                urn: "urn:second".to_string(),
354            },
355        ];
356        let reply = build(&r);
357        assert_eq!(reply.citations[0].urn, "urn:first");
358        assert_eq!(reply.citations[1].urn, "urn:second");
359    }
360
361    #[test]
362    fn sources_flat_preserves_order_verbatim() {
363        let mut r = sample_result();
364        r.sources_flat = vec![
365            SourceRow {
366                urn: "urn:b".to_string(),
367                payload: "{}".to_string(),
368            },
369            SourceRow {
370                urn: "urn:a".to_string(),
371                payload: "{}".to_string(),
372            },
373        ];
374        let reply = build(&r);
375        let pos_b = reply.sources_flat_json.find("urn:b").unwrap();
376        let pos_a = reply.sources_flat_json.find("urn:a").unwrap();
377        assert!(pos_b < pos_a, "RRF order must be preserved");
378    }
379
380    #[test]
381    fn empty_sources_serialises_as_empty_array() {
382        let mut r = sample_result();
383        r.sources_flat = vec![];
384        assert_eq!(build(&r).sources_flat_json, "[]");
385    }
386
387    #[test]
388    fn empty_citations_yields_empty_vec_not_panic() {
389        let mut r = sample_result();
390        r.citations = vec![];
391        assert!(build(&r).citations.is_empty());
392    }
393
394    #[test]
395    fn sources_flat_json_keys_alphabetised() {
396        let mut r = sample_result();
397        r.sources_flat = vec![SourceRow {
398            urn: "urn:x".to_string(),
399            payload: "p".to_string(),
400        }];
401        let json = build(&r).sources_flat_json;
402        let pos_payload = json.find("\"payload\"").unwrap();
403        let pos_urn = json.find("\"urn\"").unwrap();
404        assert!(pos_payload < pos_urn, "envelope parity: payload before urn");
405    }
406
407    #[test]
408    fn sources_flat_json_escapes_quotes_and_backslashes() {
409        let mut r = sample_result();
410        r.sources_flat = vec![SourceRow {
411            urn: "urn:row".to_string(),
412            payload: "{\"k\":\"v\\\"\"}".to_string(),
413        }];
414        let json = build(&r).sources_flat_json;
415        // Round-trip via serde_json: must parse to a JSON array of one object.
416        let parsed: crate::serde_json::Value = crate::serde_json::from_str(&json).unwrap();
417        let arr = parsed.as_array().unwrap();
418        assert_eq!(arr.len(), 1);
419    }
420
421    #[test]
422    fn sources_flat_json_escapes_control_chars() {
423        let mut r = sample_result();
424        r.sources_flat = vec![SourceRow {
425            urn: "urn:row".to_string(),
426            payload: "line1\nline2\ttab\u{0001}ctrl".to_string(),
427        }];
428        let json = build(&r).sources_flat_json;
429        let parsed: crate::serde_json::Value = crate::serde_json::from_str(&json).unwrap();
430        let arr = parsed.as_array().unwrap();
431        let payload = arr[0]["payload"].as_str().unwrap();
432        assert!(payload.contains('\n'));
433        assert!(payload.contains('\t'));
434        assert!(payload.contains('\u{0001}'));
435    }
436
437    #[test]
438    fn validation_warnings_and_errors_roundtrip() {
439        let mut r = sample_result();
440        r.validation = EnvValidation {
441            ok: false,
442            warnings: vec![ValidationWarning {
443                kind: "malformed".to_string(),
444                detail: "missing marker".to_string(),
445            }],
446            errors: vec![ValidationError {
447                kind: "out_of_range".to_string(),
448                detail: "marker > sources".to_string(),
449            }],
450        };
451        let reply = build(&r);
452        assert!(!reply.validation.ok);
453        assert_eq!(reply.validation.warnings.len(), 1);
454        assert_eq!(reply.validation.warnings[0].kind, "malformed");
455        assert_eq!(reply.validation.warnings[0].detail, "missing marker");
456        assert_eq!(reply.validation.errors.len(), 1);
457        assert_eq!(reply.validation.errors[0].kind, "out_of_range");
458    }
459
460    #[test]
461    fn cache_hit_records_zero_cost_and_tokens_when_zero() {
462        let mut r = sample_result();
463        r.cache_hit = true;
464        r.prompt_tokens = 0;
465        r.completion_tokens = 0;
466        r.cost_usd = 0.0;
467        let reply = build(&r);
468        assert!(reply.cache_hit);
469        assert_eq!(reply.prompt_tokens, 0);
470        assert_eq!(reply.completion_tokens, 0);
471        assert_eq!(reply.cost_usd, 0.0);
472    }
473
474    #[test]
475    fn build_is_deterministic_across_calls() {
476        let r = sample_result();
477        assert_eq!(build(&r), build(&r));
478    }
479
480    #[test]
481    fn does_not_expose_seed_or_temperature() {
482        // Compile-time pin: `GrpcAskReply` has no seed/temperature fields.
483        // Adding one would break this destructuring.
484        let r = sample_result();
485        let GrpcAskReply {
486            answer: _,
487            sources_flat_json: _,
488            citations: _,
489            validation: _,
490            provider: _,
491            model: _,
492            prompt_tokens: _,
493            completion_tokens: _,
494            cost_usd: _,
495            cache_hit: _,
496            mode: _,
497            retry_count: _,
498        } = build(&r);
499    }
500
501    #[test]
502    fn ask_reply_proto_tags_pinned() {
503        assert_eq!(ask_reply::ANSWER, 1);
504        assert_eq!(ask_reply::SOURCES_FLAT_JSON, 2);
505        assert_eq!(ask_reply::CITATIONS, 3);
506        assert_eq!(ask_reply::VALIDATION, 4);
507        assert_eq!(ask_reply::PROVIDER, 5);
508        assert_eq!(ask_reply::MODEL, 6);
509        assert_eq!(ask_reply::PROMPT_TOKENS, 7);
510        assert_eq!(ask_reply::COMPLETION_TOKENS, 8);
511        assert_eq!(ask_reply::COST_USD, 9);
512        assert_eq!(ask_reply::CACHE_HIT, 10);
513        assert_eq!(ask_reply::MODE, 11);
514        assert_eq!(ask_reply::RETRY_COUNT, 12);
515    }
516
517    #[test]
518    fn ask_reply_proto_tags_are_unique_and_contiguous() {
519        let tags = [
520            ask_reply::ANSWER,
521            ask_reply::SOURCES_FLAT_JSON,
522            ask_reply::CITATIONS,
523            ask_reply::VALIDATION,
524            ask_reply::PROVIDER,
525            ask_reply::MODEL,
526            ask_reply::PROMPT_TOKENS,
527            ask_reply::COMPLETION_TOKENS,
528            ask_reply::COST_USD,
529            ask_reply::CACHE_HIT,
530            ask_reply::MODE,
531            ask_reply::RETRY_COUNT,
532        ];
533        let mut sorted = tags.to_vec();
534        sorted.sort();
535        sorted.dedup();
536        assert_eq!(sorted.len(), tags.len(), "duplicate proto field tag");
537        assert_eq!(sorted, (1u32..=tags.len() as u32).collect::<Vec<_>>());
538    }
539
540    #[test]
541    fn nested_message_proto_tags_pinned() {
542        assert_eq!(citation::MARKER, 1);
543        assert_eq!(citation::URN, 2);
544        assert_eq!(validation::OK, 1);
545        assert_eq!(validation::WARNINGS, 2);
546        assert_eq!(validation::ERRORS, 3);
547        assert_eq!(validation_item::KIND, 1);
548        assert_eq!(validation_item::DETAIL, 2);
549    }
550
551    #[test]
552    fn field_set_matches_json_envelope() {
553        // Parity check: every top-level key in the JSON envelope must
554        // map to a GrpcAskReply field. If the envelope grows, this
555        // test forces a matching GrpcAskReply field + proto tag.
556        let r = sample_result();
557        let envelope = crate::runtime::ai::ask_response_envelope::build(&r);
558        let keys: Vec<&str> = envelope
559            .as_object()
560            .unwrap()
561            .keys()
562            .map(|s| s.as_str())
563            .collect();
564        let expected = [
565            "answer",
566            "cache_hit",
567            "citations",
568            "completion_tokens",
569            "cost_usd",
570            "mode",
571            "model",
572            "prompt_tokens",
573            "provider",
574            "retry_count",
575            "sources_flat",
576            "validation",
577        ];
578        for k in &expected {
579            assert!(keys.contains(k), "envelope missing key {k}");
580        }
581        assert_eq!(keys.len(), expected.len(), "envelope keys drift detected");
582    }
583}