1use std::collections::HashMap;
2
3use anyhow::{bail, Context, Result};
4
5const MAX_COLLECTION_LEN: u64 = 1_000_000;
8
9#[derive(Debug, Clone)]
11pub enum BrickResult {
12 Success {
13 output: CborValue,
14 },
15 LowConfidence {
16 output: CborValue,
17 error: ErrorObject,
18 },
19 Failure {
20 error: ErrorObject,
21 },
22}
23
24#[derive(Debug, Clone)]
26pub struct ErrorObject {
27 pub error_class: String,
28 pub message: String,
29 #[allow(dead_code)]
30 pub retry_advice: Option<String>,
31 #[allow(dead_code)]
32 pub severity: Option<String>,
33}
34
35#[derive(Debug, Clone)]
37pub enum CborValue {
38 Null,
39 Bool(bool),
40 Integer(i64),
41 Float(f64),
42 Text(String),
43 Bytes(Vec<u8>),
44 Array(Vec<CborValue>),
45 Map(Vec<(CborValue, CborValue)>),
46}
47
48impl BrickResult {
49 pub fn result_type(&self) -> &str {
51 match self {
52 Self::Success { .. } => "Success",
53 Self::LowConfidence { .. } => "LowConfidence",
54 Self::Failure { .. } => "Failure",
55 }
56 }
57
58 pub fn output(&self) -> Option<&CborValue> {
60 match self {
61 Self::Success { output } | Self::LowConfidence { output, .. } => Some(output),
62 Self::Failure { .. } => None,
63 }
64 }
65
66 pub fn error(&self) -> Option<&ErrorObject> {
68 match self {
69 Self::LowConfidence { error, .. } | Self::Failure { error } => Some(error),
70 Self::Success { .. } => None,
71 }
72 }
73}
74
75pub fn decode_result(cbor_bytes: &[u8]) -> Result<BrickResult> {
77 let mut decoder = minicbor::Decoder::new(cbor_bytes);
78
79 let map_len = match decoder.map() {
81 Ok(Some(len)) => len,
82 Ok(None) => bail!("result is an indefinite-length map (must be definite)"),
83 Err(e) => bail!("result is not a valid CBOR map: {e}"),
84 };
85
86 if map_len > MAX_COLLECTION_LEN {
87 bail!("result map has {map_len} entries (max {MAX_COLLECTION_LEN})");
88 }
89
90 let mut fields: HashMap<String, CborValue> = HashMap::new();
92 for _ in 0..map_len {
93 let key = decode_text(&mut decoder).context("result map key must be a text string")?;
94 let value = decode_value(&mut decoder).context("decoding result map value")?;
95 if fields.insert(key.clone(), value).is_some() {
96 bail!("duplicate top-level key in result map: '{key}'");
97 }
98 }
99
100 let type_val = fields
102 .get("type")
103 .ok_or_else(|| anyhow::anyhow!("result missing 'type' discriminant field"))?;
104 let type_str = match type_val {
105 CborValue::Text(s) => s.as_str(),
106 _ => bail!("result 'type' field must be a text string"),
107 };
108
109 match type_str {
110 "Success" => validate_success(&fields),
111 "LowConfidence" => validate_low_confidence(&fields),
112 "Failure" => validate_failure(&fields),
113 other => {
114 bail!("unknown result type '{other}' (expected Success, LowConfidence, or Failure)")
115 }
116 }
117}
118
119fn validate_success(fields: &HashMap<String, CborValue>) -> Result<BrickResult> {
120 let output = fields
122 .get("output")
123 .ok_or_else(|| anyhow::anyhow!("Success result missing 'output' field"))?
124 .clone();
125
126 if fields.contains_key("error") {
128 bail!("Success result MUST NOT have 'error' field");
129 }
130
131 if fields.contains_key("carry_state_side_effects") {
133 bail!("Success result MUST NOT have 'carry_state_side_effects' field");
134 }
135
136 if let Some(v) = fields.get("carry_state_next") {
138 if !matches!(v, CborValue::Null) {
139 bail!("carry_state_next must be null/absent in Phase 2 (carry_state_class=none)");
140 }
141 }
142
143 Ok(BrickResult::Success { output })
144}
145
146fn validate_low_confidence(fields: &HashMap<String, CborValue>) -> Result<BrickResult> {
147 let output = fields
149 .get("output")
150 .ok_or_else(|| anyhow::anyhow!("LowConfidence result missing 'output' field"))?
151 .clone();
152
153 let error_val = fields
155 .get("error")
156 .ok_or_else(|| anyhow::anyhow!("LowConfidence result missing 'error' field"))?;
157 let error = parse_error_object(error_val).context("parsing LowConfidence error object")?;
158
159 if error.error_class != "LOW_CONFIDENCE" {
161 bail!(
162 "LowConfidence result error.error_class must be 'LOW_CONFIDENCE', got '{}'",
163 error.error_class
164 );
165 }
166
167 if fields.contains_key("carry_state_side_effects") {
169 bail!("LowConfidence result MUST NOT have 'carry_state_side_effects' field");
170 }
171
172 if let Some(v) = fields.get("carry_state_next") {
174 if !matches!(v, CborValue::Null) {
175 bail!("carry_state_next must be null/absent in Phase 2 (carry_state_class=none)");
176 }
177 }
178
179 Ok(BrickResult::LowConfidence { output, error })
180}
181
182fn validate_failure(fields: &HashMap<String, CborValue>) -> Result<BrickResult> {
183 let error_val = fields
185 .get("error")
186 .ok_or_else(|| anyhow::anyhow!("Failure result missing 'error' field"))?;
187 let error = parse_error_object(error_val).context("parsing Failure error object")?;
188
189 if error.error_class == "LOW_CONFIDENCE" {
191 bail!("Failure result error.error_class MUST NOT be 'LOW_CONFIDENCE'");
192 }
193
194 if fields.contains_key("output") {
196 bail!("Failure result MUST NOT have 'output' field");
197 }
198
199 if fields.contains_key("carry_state_next") {
201 bail!("Failure result MUST NOT have 'carry_state_next' field");
202 }
203
204 Ok(BrickResult::Failure { error })
207}
208
209fn parse_error_object(val: &CborValue) -> Result<ErrorObject> {
211 let map = match val {
212 CborValue::Map(pairs) => pairs,
213 _ => bail!("error field must be a CBOR map"),
214 };
215
216 let mut error_class: Option<String> = None;
217 let mut message: Option<String> = None;
218 let mut retry_advice: Option<String> = None;
219 let mut severity: Option<String> = None;
220
221 for (k, v) in map {
222 let key = match k {
223 CborValue::Text(s) => s.as_str(),
224 _ => bail!("error map key must be a text string"),
225 };
226 match key {
227 "error_class" => {
228 if error_class.is_some() {
229 bail!("duplicate key 'error_class' in error object");
230 }
231 error_class = Some(extract_text(v).context("error.error_class must be text")?);
232 }
233 "message" => {
234 if message.is_some() {
235 bail!("duplicate key 'message' in error object");
236 }
237 message = Some(extract_text(v).context("error.message must be text")?);
238 }
239 "retry_advice" => {
240 if retry_advice.is_some() {
241 bail!("duplicate key 'retry_advice' in error object");
242 }
243 retry_advice = Some(extract_text(v).context("error.retry_advice must be text")?);
244 }
245 "severity" => {
246 if severity.is_some() {
247 bail!("duplicate key 'severity' in error object");
248 }
249 severity = Some(extract_text(v).context("error.severity must be text")?);
250 }
251 _ => {} }
253 }
254
255 let error_class =
256 error_class.ok_or_else(|| anyhow::anyhow!("error object missing 'error_class' field"))?;
257 let message = message.ok_or_else(|| anyhow::anyhow!("error object missing 'message' field"))?;
258
259 Ok(ErrorObject {
260 error_class,
261 message,
262 retry_advice,
263 severity,
264 })
265}
266
267fn extract_text(val: &CborValue) -> Result<String> {
268 match val {
269 CborValue::Text(s) => Ok(s.clone()),
270 _ => bail!("expected text string"),
271 }
272}
273
274fn decode_text(d: &mut minicbor::Decoder<'_>) -> Result<String> {
277 d.str()
278 .map(|s| s.to_string())
279 .map_err(|e| anyhow::anyhow!("expected CBOR text string: {e}"))
280}
281
282fn decode_value(d: &mut minicbor::Decoder<'_>) -> Result<CborValue> {
283 use minicbor::data::Type;
284
285 match d
286 .datatype()
287 .map_err(|e| anyhow::anyhow!("cannot peek CBOR type: {e}"))?
288 {
289 Type::Null => {
290 d.null()
291 .map_err(|e| anyhow::anyhow!("decoding null: {e}"))?;
292 Ok(CborValue::Null)
293 }
294 Type::Undefined => {
295 d.undefined()
296 .map_err(|e| anyhow::anyhow!("consuming undefined: {e}"))?;
297 bail!("CBOR undefined is not allowed in NCP results");
298 }
299 Type::Bool => {
300 let b = d
301 .bool()
302 .map_err(|e| anyhow::anyhow!("decoding bool: {e}"))?;
303 Ok(CborValue::Bool(b))
304 }
305 Type::U8 | Type::U16 | Type::U32 | Type::U64 => {
306 let n = d.u64().map_err(|e| anyhow::anyhow!("decoding uint: {e}"))?;
307 if n > i64::MAX as u64 {
308 bail!("CBOR uint too large for i64: {n}");
309 }
310 Ok(CborValue::Integer(n as i64))
311 }
312 Type::I8 | Type::I16 | Type::I32 | Type::I64 => {
313 let n = d.i64().map_err(|e| anyhow::anyhow!("decoding int: {e}"))?;
314 Ok(CborValue::Integer(n))
315 }
316 Type::F16 | Type::F32 | Type::F64 => {
317 let f = d
318 .f64()
319 .map_err(|e| anyhow::anyhow!("decoding float: {e}"))?;
320 Ok(CborValue::Float(f))
321 }
322 Type::String => {
323 let s = decode_text(d)?;
324 Ok(CborValue::Text(s))
325 }
326 Type::Bytes => {
327 let b = d
328 .bytes()
329 .map_err(|e| anyhow::anyhow!("decoding bytes: {e}"))?
330 .to_vec();
331 Ok(CborValue::Bytes(b))
332 }
333 Type::Array => {
334 let len = d
335 .array()
336 .map_err(|e| anyhow::anyhow!("decoding array: {e}"))?
337 .ok_or_else(|| anyhow::anyhow!("indefinite-length arrays not supported"))?;
338 if len > MAX_COLLECTION_LEN {
339 bail!("CBOR array has {len} elements (max {MAX_COLLECTION_LEN})");
340 }
341 let mut items = Vec::with_capacity(len as usize);
342 for _ in 0..len {
343 items.push(decode_value(d)?);
344 }
345 Ok(CborValue::Array(items))
346 }
347 Type::Map => {
348 let len = d
349 .map()
350 .map_err(|e| anyhow::anyhow!("decoding map: {e}"))?
351 .ok_or_else(|| anyhow::anyhow!("indefinite-length maps not supported"))?;
352 if len > MAX_COLLECTION_LEN {
353 bail!("CBOR map has {len} entries (max {MAX_COLLECTION_LEN})");
354 }
355 let mut pairs = Vec::with_capacity(len as usize);
356 for _ in 0..len {
357 let k = decode_value(d)?;
358 let v = decode_value(d)?;
359 pairs.push((k, v));
360 }
361 Ok(CborValue::Map(pairs))
362 }
363 Type::Tag => {
364 let tag = d.tag().map_err(|e| anyhow::anyhow!("decoding tag: {e}"))?;
365 bail!("CBOR tags are not supported in Phase 2 results (tag={tag:?})");
366 }
367 other => bail!("unsupported CBOR type: {other:?}"),
368 }
369}
370
371pub fn trap_failure(error_class: &str, message: String) -> BrickResult {
375 BrickResult::Failure {
376 error: ErrorObject {
377 error_class: error_class.to_string(),
378 message,
379 retry_advice: None,
380 severity: None,
381 },
382 }
383}
384
385#[cfg(test)]
386mod tests {
387 use super::*;
388 use minicbor::encode::Encoder;
389
390 fn encode_result(fields: &[(&str, EncodableValue)]) -> Vec<u8> {
392 let mut buf = Vec::new();
393 let mut enc = Encoder::new(&mut buf);
394 enc.map(fields.len() as u64).unwrap();
395 for (key, val) in fields {
396 enc.str(key).unwrap();
397 encode_test_value(&mut enc, val);
398 }
399 buf
400 }
401
402 #[allow(dead_code)]
403 enum EncodableValue {
404 Text(String),
405 Int(i64),
406 Float(f64),
407 Null,
408 Map(Vec<(String, EncodableValue)>),
409 }
410
411 fn encode_test_value(enc: &mut Encoder<&mut Vec<u8>>, val: &EncodableValue) {
412 match val {
413 EncodableValue::Text(s) => {
414 enc.str(s).unwrap();
415 }
416 EncodableValue::Int(n) => {
417 enc.i64(*n).unwrap();
418 }
419 EncodableValue::Float(f) => {
420 enc.f64(*f).unwrap();
421 }
422 EncodableValue::Null => {
423 enc.null().unwrap();
424 }
425 EncodableValue::Map(pairs) => {
426 enc.map(pairs.len() as u64).unwrap();
427 for (k, v) in pairs {
428 enc.str(k).unwrap();
429 encode_test_value(enc, v);
430 }
431 }
432 }
433 }
434
435 fn text(s: &str) -> EncodableValue {
436 EncodableValue::Text(s.to_string())
437 }
438 fn output_map() -> EncodableValue {
439 EncodableValue::Map(vec![
440 ("label".into(), text("positive")),
441 ("confidence".into(), EncodableValue::Float(0.95)),
442 ])
443 }
444 fn error_obj(class: &str) -> EncodableValue {
445 EncodableValue::Map(vec![
446 ("error_class".into(), text(class)),
447 ("message".into(), text("something went wrong")),
448 ])
449 }
450
451 #[test]
454 fn valid_success() {
455 let bytes = encode_result(&[("type", text("Success")), ("output", output_map())]);
456 let result = decode_result(&bytes).unwrap();
457 assert_eq!(result.result_type(), "Success");
458 assert!(result.output().is_some());
459 assert!(result.error().is_none());
460 }
461
462 #[test]
463 fn valid_low_confidence() {
464 let bytes = encode_result(&[
465 ("type", text("LowConfidence")),
466 ("output", output_map()),
467 ("error", error_obj("LOW_CONFIDENCE")),
468 ]);
469 let result = decode_result(&bytes).unwrap();
470 assert_eq!(result.result_type(), "LowConfidence");
471 assert!(result.output().is_some());
472 assert_eq!(result.error().unwrap().error_class, "LOW_CONFIDENCE");
473 }
474
475 #[test]
476 fn valid_failure() {
477 let bytes = encode_result(&[
478 ("type", text("Failure")),
479 ("error", error_obj("COMPUTATION_ERROR")),
480 ]);
481 let result = decode_result(&bytes).unwrap();
482 assert_eq!(result.result_type(), "Failure");
483 assert!(result.output().is_none());
484 assert_eq!(result.error().unwrap().error_class, "COMPUTATION_ERROR");
485 }
486
487 #[test]
488 fn valid_success_with_null_carry_state_next() {
489 let bytes = encode_result(&[
490 ("type", text("Success")),
491 ("output", output_map()),
492 ("carry_state_next", EncodableValue::Null),
493 ]);
494 let result = decode_result(&bytes).unwrap();
495 assert_eq!(result.result_type(), "Success");
496 }
497
498 #[test]
501 fn invalid_success_with_error() {
502 let bytes = encode_result(&[
503 ("type", text("Success")),
504 ("output", output_map()),
505 ("error", error_obj("LOW_CONFIDENCE")),
506 ]);
507 let err = decode_result(&bytes).unwrap_err();
508 assert!(err.to_string().contains("MUST NOT have 'error'"));
509 }
510
511 #[test]
512 fn invalid_failure_with_output() {
513 let bytes = encode_result(&[
514 ("type", text("Failure")),
515 ("error", error_obj("COMPUTATION_ERROR")),
516 ("output", output_map()),
517 ]);
518 let err = decode_result(&bytes).unwrap_err();
519 assert!(err.to_string().contains("MUST NOT have 'output'"));
520 }
521
522 #[test]
523 fn invalid_low_confidence_without_error() {
524 let bytes = encode_result(&[("type", text("LowConfidence")), ("output", output_map())]);
525 let err = decode_result(&bytes).unwrap_err();
526 assert!(err.to_string().contains("missing 'error'"));
527 }
528
529 #[test]
530 fn invalid_low_confidence_wrong_error_class() {
531 let bytes = encode_result(&[
532 ("type", text("LowConfidence")),
533 ("output", output_map()),
534 ("error", error_obj("COMPUTATION_ERROR")),
535 ]);
536 let err = decode_result(&bytes).unwrap_err();
537 assert!(err.to_string().contains("must be 'LOW_CONFIDENCE'"));
538 }
539
540 #[test]
541 fn invalid_missing_type() {
542 let bytes = encode_result(&[("output", output_map())]);
543 let err = decode_result(&bytes).unwrap_err();
544 assert!(err.to_string().contains("missing 'type'"));
545 }
546
547 #[test]
548 fn invalid_unknown_type() {
549 let bytes = encode_result(&[("type", text("Unknown")), ("output", output_map())]);
550 let err = decode_result(&bytes).unwrap_err();
551 assert!(err.to_string().contains("unknown result type"));
552 }
553
554 #[test]
555 fn invalid_error_missing_message() {
556 let error_no_msg =
557 EncodableValue::Map(vec![("error_class".into(), text("COMPUTATION_ERROR"))]);
558 let bytes = encode_result(&[("type", text("Failure")), ("error", error_no_msg)]);
559 let err = decode_result(&bytes).unwrap_err();
560 assert!(
561 err.chain()
562 .any(|c| c.to_string().contains("missing 'message'")),
563 "expected cause not found in error chain: {err:?}"
564 );
565 }
566
567 #[test]
568 fn invalid_failure_with_low_confidence_class() {
569 let bytes = encode_result(&[
570 ("type", text("Failure")),
571 ("error", error_obj("LOW_CONFIDENCE")),
572 ]);
573 let err = decode_result(&bytes).unwrap_err();
574 assert!(err.to_string().contains("MUST NOT be 'LOW_CONFIDENCE'"));
575 }
576
577 #[test]
578 fn invalid_duplicate_top_level_key() {
579 let mut buf = Vec::new();
580 let mut enc = Encoder::new(&mut buf);
581 enc.map(3).unwrap();
582 enc.str("type").unwrap();
583 enc.str("Success").unwrap();
584 enc.str("output").unwrap();
585 enc.str("hello").unwrap();
586 enc.str("type").unwrap();
587 enc.str("Failure").unwrap();
588 let err = decode_result(&buf).unwrap_err();
589 assert!(err.to_string().contains("duplicate top-level key"));
590 }
591
592 #[test]
593 fn invalid_failure_with_carry_state_next() {
594 let bytes = encode_result(&[
595 ("type", text("Failure")),
596 ("error", error_obj("COMPUTATION_ERROR")),
597 ("carry_state_next", EncodableValue::Null),
598 ]);
599 let err = decode_result(&bytes).unwrap_err();
600 assert!(err.to_string().contains("MUST NOT have 'carry_state_next'"));
601 }
602
603 #[test]
604 fn invalid_success_with_non_null_carry_state_next() {
605 let bytes = encode_result(&[
606 ("type", text("Success")),
607 ("output", output_map()),
608 ("carry_state_next", text("some_state")),
609 ]);
610 let err = decode_result(&bytes).unwrap_err();
611 assert!(err.to_string().contains("carry_state_next must be null"));
612 }
613
614 #[test]
615 fn invalid_error_duplicate_key() {
616 let mut buf = Vec::new();
617 let mut enc = Encoder::new(&mut buf);
618 enc.map(2).unwrap();
619 enc.str("type").unwrap();
620 enc.str("Failure").unwrap();
621 enc.str("error").unwrap();
622 enc.map(3).unwrap();
623 enc.str("error_class").unwrap();
624 enc.str("COMPUTATION_ERROR").unwrap();
625 enc.str("error_class").unwrap();
626 enc.str("LOW_CONFIDENCE").unwrap();
627 enc.str("message").unwrap();
628 enc.str("oops").unwrap();
629 let err = decode_result(&buf).unwrap_err();
630 assert!(
631 err.chain()
632 .any(|c| c.to_string().contains("duplicate key 'error_class'")),
633 "expected cause not found in error chain: {err:?}"
634 );
635 }
636}