Skip to main content

ncp_runtime/
result.rs

1use std::collections::HashMap;
2
3use anyhow::{bail, Context, Result};
4
5/// Maximum definite length for CBOR arrays/maps before we reject.
6/// Prevents huge allocation from a malicious/buggy brick before max_output_bytes kicks in.
7const MAX_COLLECTION_LEN: u64 = 1_000_000;
8
9/// Decoded brick result after boundary validation.
10#[derive(Debug, Clone)]
11pub enum BrickResult {
12    Success {
13        output: CborValue,
14    },
15    LowConfidence {
16        output: CborValue,
17        error: ErrorObject,
18    },
19    Failure {
20        error: ErrorObject,
21    },
22}
23
24/// Decoded error object from a brick result.
25#[derive(Debug, Clone)]
26pub struct ErrorObject {
27    pub error_class: String,
28    pub message: String,
29    #[allow(dead_code)]
30    pub retry_advice: Option<String>,
31    #[allow(dead_code)]
32    pub severity: Option<String>,
33}
34
35/// Minimal CBOR value representation for runtime use.
36#[derive(Debug, Clone)]
37pub enum CborValue {
38    Null,
39    Bool(bool),
40    Integer(i64),
41    Float(f64),
42    Text(String),
43    Bytes(Vec<u8>),
44    Array(Vec<CborValue>),
45    Map(Vec<(CborValue, CborValue)>),
46}
47
48impl BrickResult {
49    /// Get the result type as a string.
50    pub fn result_type(&self) -> &str {
51        match self {
52            Self::Success { .. } => "Success",
53            Self::LowConfidence { .. } => "LowConfidence",
54            Self::Failure { .. } => "Failure",
55        }
56    }
57
58    /// Get the output value, if present.
59    pub fn output(&self) -> Option<&CborValue> {
60        match self {
61            Self::Success { output } | Self::LowConfidence { output, .. } => Some(output),
62            Self::Failure { .. } => None,
63        }
64    }
65
66    /// Get the error object, if present.
67    pub fn error(&self) -> Option<&ErrorObject> {
68        match self {
69            Self::LowConfidence { error, .. } | Self::Failure { error } => Some(error),
70            Self::Success { .. } => None,
71        }
72    }
73}
74
75/// Decode raw CBOR result bytes and validate §9.2 structural boundaries.
76pub fn decode_result(cbor_bytes: &[u8]) -> Result<BrickResult> {
77    let mut decoder = minicbor::Decoder::new(cbor_bytes);
78
79    // Top-level must be a definite-length map
80    let map_len = match decoder.map() {
81        Ok(Some(len)) => len,
82        Ok(None) => bail!("result is an indefinite-length map (must be definite)"),
83        Err(e) => bail!("result is not a valid CBOR map: {e}"),
84    };
85
86    if map_len > MAX_COLLECTION_LEN {
87        bail!("result map has {map_len} entries (max {MAX_COLLECTION_LEN})");
88    }
89
90    // Collect all top-level key-value pairs, rejecting duplicates
91    let mut fields: HashMap<String, CborValue> = HashMap::new();
92    for _ in 0..map_len {
93        let key = decode_text(&mut decoder).context("result map key must be a text string")?;
94        let value = decode_value(&mut decoder).context("decoding result map value")?;
95        if fields.insert(key.clone(), value).is_some() {
96            bail!("duplicate top-level key in result map: '{key}'");
97        }
98    }
99
100    // Extract discriminant
101    let type_val = fields
102        .get("type")
103        .ok_or_else(|| anyhow::anyhow!("result missing 'type' discriminant field"))?;
104    let type_str = match type_val {
105        CborValue::Text(s) => s.as_str(),
106        _ => bail!("result 'type' field must be a text string"),
107    };
108
109    match type_str {
110        "Success" => validate_success(&fields),
111        "LowConfidence" => validate_low_confidence(&fields),
112        "Failure" => validate_failure(&fields),
113        other => {
114            bail!("unknown result type '{other}' (expected Success, LowConfidence, or Failure)")
115        }
116    }
117}
118
119fn validate_success(fields: &HashMap<String, CborValue>) -> Result<BrickResult> {
120    // MUST have output
121    let output = fields
122        .get("output")
123        .ok_or_else(|| anyhow::anyhow!("Success result missing 'output' field"))?
124        .clone();
125
126    // MUST NOT have error
127    if fields.contains_key("error") {
128        bail!("Success result MUST NOT have 'error' field");
129    }
130
131    // MUST NOT have carry_state_side_effects
132    if fields.contains_key("carry_state_side_effects") {
133        bail!("Success result MUST NOT have 'carry_state_side_effects' field");
134    }
135
136    // Phase 2: carry_state_next must be null or absent (carry_state_class=none)
137    if let Some(v) = fields.get("carry_state_next") {
138        if !matches!(v, CborValue::Null) {
139            bail!("carry_state_next must be null/absent in Phase 2 (carry_state_class=none)");
140        }
141    }
142
143    Ok(BrickResult::Success { output })
144}
145
146fn validate_low_confidence(fields: &HashMap<String, CborValue>) -> Result<BrickResult> {
147    // MUST have output
148    let output = fields
149        .get("output")
150        .ok_or_else(|| anyhow::anyhow!("LowConfidence result missing 'output' field"))?
151        .clone();
152
153    // MUST have error
154    let error_val = fields
155        .get("error")
156        .ok_or_else(|| anyhow::anyhow!("LowConfidence result missing 'error' field"))?;
157    let error = parse_error_object(error_val).context("parsing LowConfidence error object")?;
158
159    // error.error_class MUST be LOW_CONFIDENCE
160    if error.error_class != "LOW_CONFIDENCE" {
161        bail!(
162            "LowConfidence result error.error_class must be 'LOW_CONFIDENCE', got '{}'",
163            error.error_class
164        );
165    }
166
167    // MUST NOT have carry_state_side_effects
168    if fields.contains_key("carry_state_side_effects") {
169        bail!("LowConfidence result MUST NOT have 'carry_state_side_effects' field");
170    }
171
172    // Phase 2: carry_state_next must be null or absent (carry_state_class=none)
173    if let Some(v) = fields.get("carry_state_next") {
174        if !matches!(v, CborValue::Null) {
175            bail!("carry_state_next must be null/absent in Phase 2 (carry_state_class=none)");
176        }
177    }
178
179    Ok(BrickResult::LowConfidence { output, error })
180}
181
182fn validate_failure(fields: &HashMap<String, CborValue>) -> Result<BrickResult> {
183    // MUST have error
184    let error_val = fields
185        .get("error")
186        .ok_or_else(|| anyhow::anyhow!("Failure result missing 'error' field"))?;
187    let error = parse_error_object(error_val).context("parsing Failure error object")?;
188
189    // error.error_class MUST NOT be LOW_CONFIDENCE
190    if error.error_class == "LOW_CONFIDENCE" {
191        bail!("Failure result error.error_class MUST NOT be 'LOW_CONFIDENCE'");
192    }
193
194    // MUST NOT have output
195    if fields.contains_key("output") {
196        bail!("Failure result MUST NOT have 'output' field");
197    }
198
199    // MUST NOT have carry_state_next
200    if fields.contains_key("carry_state_next") {
201        bail!("Failure result MUST NOT have 'carry_state_next' field");
202    }
203
204    // carry_state_side_effects: manifest-dependent, enforced in orchestration layer.
205
206    Ok(BrickResult::Failure { error })
207}
208
209/// Parse an error object from a CborValue (must be a map with error_class + message).
210fn parse_error_object(val: &CborValue) -> Result<ErrorObject> {
211    let map = match val {
212        CborValue::Map(pairs) => pairs,
213        _ => bail!("error field must be a CBOR map"),
214    };
215
216    let mut error_class: Option<String> = None;
217    let mut message: Option<String> = None;
218    let mut retry_advice: Option<String> = None;
219    let mut severity: Option<String> = None;
220
221    for (k, v) in map {
222        let key = match k {
223            CborValue::Text(s) => s.as_str(),
224            _ => bail!("error map key must be a text string"),
225        };
226        match key {
227            "error_class" => {
228                if error_class.is_some() {
229                    bail!("duplicate key 'error_class' in error object");
230                }
231                error_class = Some(extract_text(v).context("error.error_class must be text")?);
232            }
233            "message" => {
234                if message.is_some() {
235                    bail!("duplicate key 'message' in error object");
236                }
237                message = Some(extract_text(v).context("error.message must be text")?);
238            }
239            "retry_advice" => {
240                if retry_advice.is_some() {
241                    bail!("duplicate key 'retry_advice' in error object");
242                }
243                retry_advice = Some(extract_text(v).context("error.retry_advice must be text")?);
244            }
245            "severity" => {
246                if severity.is_some() {
247                    bail!("duplicate key 'severity' in error object");
248                }
249                severity = Some(extract_text(v).context("error.severity must be text")?);
250            }
251            _ => {} // ignore unknown fields for forward compatibility
252        }
253    }
254
255    let error_class =
256        error_class.ok_or_else(|| anyhow::anyhow!("error object missing 'error_class' field"))?;
257    let message = message.ok_or_else(|| anyhow::anyhow!("error object missing 'message' field"))?;
258
259    Ok(ErrorObject {
260        error_class,
261        message,
262        retry_advice,
263        severity,
264    })
265}
266
267fn extract_text(val: &CborValue) -> Result<String> {
268    match val {
269        CborValue::Text(s) => Ok(s.clone()),
270        _ => bail!("expected text string"),
271    }
272}
273
274// ── CBOR Decoding ───────────────────────────────────────────────────
275
276fn decode_text(d: &mut minicbor::Decoder<'_>) -> Result<String> {
277    d.str()
278        .map(|s| s.to_string())
279        .map_err(|e| anyhow::anyhow!("expected CBOR text string: {e}"))
280}
281
282fn decode_value(d: &mut minicbor::Decoder<'_>) -> Result<CborValue> {
283    use minicbor::data::Type;
284
285    match d
286        .datatype()
287        .map_err(|e| anyhow::anyhow!("cannot peek CBOR type: {e}"))?
288    {
289        Type::Null => {
290            d.null()
291                .map_err(|e| anyhow::anyhow!("decoding null: {e}"))?;
292            Ok(CborValue::Null)
293        }
294        Type::Undefined => {
295            d.undefined()
296                .map_err(|e| anyhow::anyhow!("consuming undefined: {e}"))?;
297            bail!("CBOR undefined is not allowed in NCP results");
298        }
299        Type::Bool => {
300            let b = d
301                .bool()
302                .map_err(|e| anyhow::anyhow!("decoding bool: {e}"))?;
303            Ok(CborValue::Bool(b))
304        }
305        Type::U8 | Type::U16 | Type::U32 | Type::U64 => {
306            let n = d.u64().map_err(|e| anyhow::anyhow!("decoding uint: {e}"))?;
307            if n > i64::MAX as u64 {
308                bail!("CBOR uint too large for i64: {n}");
309            }
310            Ok(CborValue::Integer(n as i64))
311        }
312        Type::I8 | Type::I16 | Type::I32 | Type::I64 => {
313            let n = d.i64().map_err(|e| anyhow::anyhow!("decoding int: {e}"))?;
314            Ok(CborValue::Integer(n))
315        }
316        Type::F16 | Type::F32 | Type::F64 => {
317            let f = d
318                .f64()
319                .map_err(|e| anyhow::anyhow!("decoding float: {e}"))?;
320            Ok(CborValue::Float(f))
321        }
322        Type::String => {
323            let s = decode_text(d)?;
324            Ok(CborValue::Text(s))
325        }
326        Type::Bytes => {
327            let b = d
328                .bytes()
329                .map_err(|e| anyhow::anyhow!("decoding bytes: {e}"))?
330                .to_vec();
331            Ok(CborValue::Bytes(b))
332        }
333        Type::Array => {
334            let len = d
335                .array()
336                .map_err(|e| anyhow::anyhow!("decoding array: {e}"))?
337                .ok_or_else(|| anyhow::anyhow!("indefinite-length arrays not supported"))?;
338            if len > MAX_COLLECTION_LEN {
339                bail!("CBOR array has {len} elements (max {MAX_COLLECTION_LEN})");
340            }
341            let mut items = Vec::with_capacity(len as usize);
342            for _ in 0..len {
343                items.push(decode_value(d)?);
344            }
345            Ok(CborValue::Array(items))
346        }
347        Type::Map => {
348            let len = d
349                .map()
350                .map_err(|e| anyhow::anyhow!("decoding map: {e}"))?
351                .ok_or_else(|| anyhow::anyhow!("indefinite-length maps not supported"))?;
352            if len > MAX_COLLECTION_LEN {
353                bail!("CBOR map has {len} entries (max {MAX_COLLECTION_LEN})");
354            }
355            let mut pairs = Vec::with_capacity(len as usize);
356            for _ in 0..len {
357                let k = decode_value(d)?;
358                let v = decode_value(d)?;
359                pairs.push((k, v));
360            }
361            Ok(CborValue::Map(pairs))
362        }
363        Type::Tag => {
364            let tag = d.tag().map_err(|e| anyhow::anyhow!("decoding tag: {e}"))?;
365            bail!("CBOR tags are not supported in Phase 2 results (tag={tag:?})");
366        }
367        other => bail!("unsupported CBOR type: {other:?}"),
368    }
369}
370
371// ── Trap Detection ──────────────────────────────────────────────────
372
373/// Create a Failure result for a WASM trap (invoke error).
374pub fn trap_failure(error_class: &str, message: String) -> BrickResult {
375    BrickResult::Failure {
376        error: ErrorObject {
377            error_class: error_class.to_string(),
378            message,
379            retry_advice: None,
380            severity: None,
381        },
382    }
383}
384
385#[cfg(test)]
386mod tests {
387    use super::*;
388    use minicbor::encode::Encoder;
389
390    /// Helper: encode a result map to CBOR bytes.
391    fn encode_result(fields: &[(&str, EncodableValue)]) -> Vec<u8> {
392        let mut buf = Vec::new();
393        let mut enc = Encoder::new(&mut buf);
394        enc.map(fields.len() as u64).unwrap();
395        for (key, val) in fields {
396            enc.str(key).unwrap();
397            encode_test_value(&mut enc, val);
398        }
399        buf
400    }
401
402    #[allow(dead_code)]
403    enum EncodableValue {
404        Text(String),
405        Int(i64),
406        Float(f64),
407        Null,
408        Map(Vec<(String, EncodableValue)>),
409    }
410
411    fn encode_test_value(enc: &mut Encoder<&mut Vec<u8>>, val: &EncodableValue) {
412        match val {
413            EncodableValue::Text(s) => {
414                enc.str(s).unwrap();
415            }
416            EncodableValue::Int(n) => {
417                enc.i64(*n).unwrap();
418            }
419            EncodableValue::Float(f) => {
420                enc.f64(*f).unwrap();
421            }
422            EncodableValue::Null => {
423                enc.null().unwrap();
424            }
425            EncodableValue::Map(pairs) => {
426                enc.map(pairs.len() as u64).unwrap();
427                for (k, v) in pairs {
428                    enc.str(k).unwrap();
429                    encode_test_value(enc, v);
430                }
431            }
432        }
433    }
434
435    fn text(s: &str) -> EncodableValue {
436        EncodableValue::Text(s.to_string())
437    }
438    fn output_map() -> EncodableValue {
439        EncodableValue::Map(vec![
440            ("label".into(), text("positive")),
441            ("confidence".into(), EncodableValue::Float(0.95)),
442        ])
443    }
444    fn error_obj(class: &str) -> EncodableValue {
445        EncodableValue::Map(vec![
446            ("error_class".into(), text(class)),
447            ("message".into(), text("something went wrong")),
448        ])
449    }
450
451    // ── Valid variants ──────────────────────────────────────────────
452
453    #[test]
454    fn valid_success() {
455        let bytes = encode_result(&[("type", text("Success")), ("output", output_map())]);
456        let result = decode_result(&bytes).unwrap();
457        assert_eq!(result.result_type(), "Success");
458        assert!(result.output().is_some());
459        assert!(result.error().is_none());
460    }
461
462    #[test]
463    fn valid_low_confidence() {
464        let bytes = encode_result(&[
465            ("type", text("LowConfidence")),
466            ("output", output_map()),
467            ("error", error_obj("LOW_CONFIDENCE")),
468        ]);
469        let result = decode_result(&bytes).unwrap();
470        assert_eq!(result.result_type(), "LowConfidence");
471        assert!(result.output().is_some());
472        assert_eq!(result.error().unwrap().error_class, "LOW_CONFIDENCE");
473    }
474
475    #[test]
476    fn valid_failure() {
477        let bytes = encode_result(&[
478            ("type", text("Failure")),
479            ("error", error_obj("COMPUTATION_ERROR")),
480        ]);
481        let result = decode_result(&bytes).unwrap();
482        assert_eq!(result.result_type(), "Failure");
483        assert!(result.output().is_none());
484        assert_eq!(result.error().unwrap().error_class, "COMPUTATION_ERROR");
485    }
486
487    #[test]
488    fn valid_success_with_null_carry_state_next() {
489        let bytes = encode_result(&[
490            ("type", text("Success")),
491            ("output", output_map()),
492            ("carry_state_next", EncodableValue::Null),
493        ]);
494        let result = decode_result(&bytes).unwrap();
495        assert_eq!(result.result_type(), "Success");
496    }
497
498    // ── Invalid variants ────────────────────────────────────────────
499
500    #[test]
501    fn invalid_success_with_error() {
502        let bytes = encode_result(&[
503            ("type", text("Success")),
504            ("output", output_map()),
505            ("error", error_obj("LOW_CONFIDENCE")),
506        ]);
507        let err = decode_result(&bytes).unwrap_err();
508        assert!(err.to_string().contains("MUST NOT have 'error'"));
509    }
510
511    #[test]
512    fn invalid_failure_with_output() {
513        let bytes = encode_result(&[
514            ("type", text("Failure")),
515            ("error", error_obj("COMPUTATION_ERROR")),
516            ("output", output_map()),
517        ]);
518        let err = decode_result(&bytes).unwrap_err();
519        assert!(err.to_string().contains("MUST NOT have 'output'"));
520    }
521
522    #[test]
523    fn invalid_low_confidence_without_error() {
524        let bytes = encode_result(&[("type", text("LowConfidence")), ("output", output_map())]);
525        let err = decode_result(&bytes).unwrap_err();
526        assert!(err.to_string().contains("missing 'error'"));
527    }
528
529    #[test]
530    fn invalid_low_confidence_wrong_error_class() {
531        let bytes = encode_result(&[
532            ("type", text("LowConfidence")),
533            ("output", output_map()),
534            ("error", error_obj("COMPUTATION_ERROR")),
535        ]);
536        let err = decode_result(&bytes).unwrap_err();
537        assert!(err.to_string().contains("must be 'LOW_CONFIDENCE'"));
538    }
539
540    #[test]
541    fn invalid_missing_type() {
542        let bytes = encode_result(&[("output", output_map())]);
543        let err = decode_result(&bytes).unwrap_err();
544        assert!(err.to_string().contains("missing 'type'"));
545    }
546
547    #[test]
548    fn invalid_unknown_type() {
549        let bytes = encode_result(&[("type", text("Unknown")), ("output", output_map())]);
550        let err = decode_result(&bytes).unwrap_err();
551        assert!(err.to_string().contains("unknown result type"));
552    }
553
554    #[test]
555    fn invalid_error_missing_message() {
556        let error_no_msg =
557            EncodableValue::Map(vec![("error_class".into(), text("COMPUTATION_ERROR"))]);
558        let bytes = encode_result(&[("type", text("Failure")), ("error", error_no_msg)]);
559        let err = decode_result(&bytes).unwrap_err();
560        assert!(
561            err.chain()
562                .any(|c| c.to_string().contains("missing 'message'")),
563            "expected cause not found in error chain: {err:?}"
564        );
565    }
566
567    #[test]
568    fn invalid_failure_with_low_confidence_class() {
569        let bytes = encode_result(&[
570            ("type", text("Failure")),
571            ("error", error_obj("LOW_CONFIDENCE")),
572        ]);
573        let err = decode_result(&bytes).unwrap_err();
574        assert!(err.to_string().contains("MUST NOT be 'LOW_CONFIDENCE'"));
575    }
576
577    #[test]
578    fn invalid_duplicate_top_level_key() {
579        let mut buf = Vec::new();
580        let mut enc = Encoder::new(&mut buf);
581        enc.map(3).unwrap();
582        enc.str("type").unwrap();
583        enc.str("Success").unwrap();
584        enc.str("output").unwrap();
585        enc.str("hello").unwrap();
586        enc.str("type").unwrap();
587        enc.str("Failure").unwrap();
588        let err = decode_result(&buf).unwrap_err();
589        assert!(err.to_string().contains("duplicate top-level key"));
590    }
591
592    #[test]
593    fn invalid_failure_with_carry_state_next() {
594        let bytes = encode_result(&[
595            ("type", text("Failure")),
596            ("error", error_obj("COMPUTATION_ERROR")),
597            ("carry_state_next", EncodableValue::Null),
598        ]);
599        let err = decode_result(&bytes).unwrap_err();
600        assert!(err.to_string().contains("MUST NOT have 'carry_state_next'"));
601    }
602
603    #[test]
604    fn invalid_success_with_non_null_carry_state_next() {
605        let bytes = encode_result(&[
606            ("type", text("Success")),
607            ("output", output_map()),
608            ("carry_state_next", text("some_state")),
609        ]);
610        let err = decode_result(&bytes).unwrap_err();
611        assert!(err.to_string().contains("carry_state_next must be null"));
612    }
613
614    #[test]
615    fn invalid_error_duplicate_key() {
616        let mut buf = Vec::new();
617        let mut enc = Encoder::new(&mut buf);
618        enc.map(2).unwrap();
619        enc.str("type").unwrap();
620        enc.str("Failure").unwrap();
621        enc.str("error").unwrap();
622        enc.map(3).unwrap();
623        enc.str("error_class").unwrap();
624        enc.str("COMPUTATION_ERROR").unwrap();
625        enc.str("error_class").unwrap();
626        enc.str("LOW_CONFIDENCE").unwrap();
627        enc.str("message").unwrap();
628        enc.str("oops").unwrap();
629        let err = decode_result(&buf).unwrap_err();
630        assert!(
631            err.chain()
632                .any(|c| c.to_string().contains("duplicate key 'error_class'")),
633            "expected cause not found in error chain: {err:?}"
634        );
635    }
636}