1use blake3::Hasher;
2
3use crate::cert::CERT_VERSION;
4use crate::crypto::merkle_node;
5use crate::error::A1Error;
6use crate::registry::fresh_nonce;
7
8const DOMAIN_PROVENANCE_LEAF: &str = "a1::provenance::leaf::v1";
9const DOMAIN_PROVENANCE_ROOT: &str = "a1::provenance::root::v1";
10const DOMAIN_PROVENANCE_META: &str = "a1::provenance::meta::v1";
11
12#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
23#[repr(u8)]
24#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
25pub enum ReasoningStepKind {
26 Thought = 1,
28 ToolCall = 2,
30 Observation = 3,
32 Decision = 4,
34 PlanStep = 5,
36 FinalAction = 6,
38 Error = 7,
40 Retrieval = 8,
42}
43
44impl ReasoningStepKind {
45 pub fn as_u8(self) -> u8 {
46 self as u8
47 }
48
49 pub fn name(self) -> &'static str {
50 match self {
51 Self::Thought => "thought",
52 Self::ToolCall => "tool_call",
53 Self::Observation => "observation",
54 Self::Decision => "decision",
55 Self::PlanStep => "plan_step",
56 Self::FinalAction => "final_action",
57 Self::Error => "error",
58 Self::Retrieval => "retrieval",
59 }
60 }
61}
62
63impl std::fmt::Display for ReasoningStepKind {
64 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
65 f.write_str(self.name())
66 }
67}
68
69#[derive(Debug, Clone, PartialEq, Eq)]
97#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
98pub struct ReasoningStep {
99 pub index: u32,
101 pub kind: ReasoningStepKind,
103 #[cfg_attr(feature = "serde", serde(with = "hex_32"))]
105 pub content_hash: [u8; 32],
106 pub timestamp_unix: u64,
108 #[cfg_attr(feature = "serde", serde(with = "hex_32"))]
111 pub metadata_hash: [u8; 32],
112}
113
114impl ReasoningStep {
115 pub fn leaf_hash(&self) -> [u8; 32] {
121 let mut h = Hasher::new_derive_key(DOMAIN_PROVENANCE_LEAF);
122 h.update(&[CERT_VERSION]);
123 h.update(&self.index.to_le_bytes());
124 h.update(&[self.kind.as_u8()]);
125 h.update(&self.content_hash);
126 h.update(&self.timestamp_unix.to_be_bytes());
127 h.update(&self.metadata_hash);
128 h.finalize().into()
129 }
130}
131
132#[derive(Debug, Clone, PartialEq, Eq)]
152#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
153pub struct ProvenanceRoot {
154 pub step_count: u32,
156 #[cfg_attr(feature = "serde", serde(with = "hex_32"))]
158 pub merkle_root: [u8; 32],
159 #[cfg_attr(feature = "serde", serde(with = "hex_16"))]
164 pub trace_id: [u8; 16],
165 pub started_at_unix: u64,
167 pub finalized_at_unix: u64,
169 #[cfg_attr(feature = "serde", serde(with = "hex_32"))]
173 pub chain_binding: [u8; 32],
174}
175
176impl ProvenanceRoot {
177 pub fn verify_chain_binding(&self, chain_fingerprint: &[u8; 32]) -> bool {
182 let expected = compute_chain_binding(self, chain_fingerprint);
183 subtle::ConstantTimeEq::ct_eq(&expected[..], &self.chain_binding[..]).unwrap_u8() == 1
184 }
185
186 pub fn merkle_root_hex(&self) -> String {
188 hex::encode(self.merkle_root)
189 }
190
191 pub fn trace_id_hex(&self) -> String {
193 hex::encode(self.trace_id)
194 }
195}
196
197fn compute_chain_binding(root: &ProvenanceRoot, chain_fp: &[u8; 32]) -> [u8; 32] {
198 let mut h = Hasher::new_derive_key(DOMAIN_PROVENANCE_ROOT);
199 h.update(&root.step_count.to_le_bytes());
200 h.update(&root.merkle_root);
201 h.update(&root.trace_id);
202 h.update(&root.started_at_unix.to_be_bytes());
203 h.update(&root.finalized_at_unix.to_be_bytes());
204 h.update(chain_fp);
205 h.finalize().into()
206}
207
208#[derive(Debug, Clone)]
226#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
227pub struct ProvenanceStepProof {
228 pub step: ReasoningStep,
230 pub siblings: Vec<[u8; 32]>,
232 pub step_count: u32,
234}
235
236impl ProvenanceStepProof {
237 pub fn verify(&self, root: &ProvenanceRoot) -> bool {
239 if self.step.index >= self.step_count {
240 return false;
241 }
242 if self.step_count != root.step_count {
243 return false;
244 }
245
246 let leaf_count = next_power_of_two(self.step_count as usize);
247 let expected_depth = leaf_count.trailing_zeros() as usize;
248
249 if self.siblings.len() != expected_depth {
250 return false;
251 }
252
253 let mut current = self.step.leaf_hash();
254 let mut idx = self.step.index as usize;
255
256 for sibling in &self.siblings {
257 if idx.is_multiple_of(2) {
258 current = merkle_node(¤t, sibling);
259 } else {
260 current = merkle_node(sibling, ¤t);
261 }
262 idx >>= 1;
263 }
264
265 subtle::ConstantTimeEq::ct_eq(¤t[..], &root.merkle_root[..]).unwrap_u8() == 1
266 }
267}
268
269pub struct ReasoningTrace {
311 steps: Vec<ReasoningStep>,
312 trace_id: [u8; 16],
313 started_at_unix: u64,
314}
315
316impl ReasoningTrace {
317 pub fn new(started_at_unix: u64) -> Self {
319 Self {
320 steps: Vec::new(),
321 trace_id: fresh_nonce(),
322 started_at_unix,
323 }
324 }
325
326 pub fn record(
331 &mut self,
332 kind: ReasoningStepKind,
333 content: &[u8],
334 timestamp_unix: u64,
335 ) -> &ReasoningStep {
336 let content_hash = blake3::hash(content).into();
337 self.record_hashed(kind, content_hash, [0u8; 32], timestamp_unix)
338 }
339
340 pub fn record_tool_call(
342 &mut self,
343 tool_name: &str,
344 input: &[u8],
345 timestamp_unix: u64,
346 ) -> &ReasoningStep {
347 let content_hash = blake3::hash(input).into();
348 let metadata_hash = hash_metadata(&[("tool", tool_name)]);
349 self.record_hashed(
350 ReasoningStepKind::ToolCall,
351 content_hash,
352 metadata_hash,
353 timestamp_unix,
354 )
355 }
356
357 pub fn record_observation(
359 &mut self,
360 tool_name: &str,
361 output: &[u8],
362 timestamp_unix: u64,
363 ) -> &ReasoningStep {
364 let content_hash = blake3::hash(output).into();
365 let metadata_hash = hash_metadata(&[("tool", tool_name)]);
366 self.record_hashed(
367 ReasoningStepKind::Observation,
368 content_hash,
369 metadata_hash,
370 timestamp_unix,
371 )
372 }
373
374 pub fn record_hashed(
379 &mut self,
380 kind: ReasoningStepKind,
381 content_hash: [u8; 32],
382 metadata_hash: [u8; 32],
383 timestamp_unix: u64,
384 ) -> &ReasoningStep {
385 let index = self.steps.len() as u32;
386 self.steps.push(ReasoningStep {
387 index,
388 kind,
389 content_hash,
390 timestamp_unix,
391 metadata_hash,
392 });
393 self.steps.last().expect("just pushed")
394 }
395
396 pub fn len(&self) -> usize {
398 self.steps.len()
399 }
400
401 pub fn is_empty(&self) -> bool {
402 self.steps.is_empty()
403 }
404
405 pub fn finalize(
413 &self,
414 finalized_at_unix: u64,
415 chain_fingerprint: &[u8; 32],
416 ) -> Result<ProvenanceRoot, A1Error> {
417 if self.steps.is_empty() {
418 return Err(A1Error::EmptyTree);
419 }
420
421 let merkle_root = build_merkle_root(&self.steps);
422
423 let mut root = ProvenanceRoot {
424 step_count: self.steps.len() as u32,
425 merkle_root,
426 trace_id: self.trace_id,
427 started_at_unix: self.started_at_unix,
428 finalized_at_unix,
429 chain_binding: [0u8; 32],
430 };
431
432 root.chain_binding = compute_chain_binding(&root, chain_fingerprint);
433 Ok(root)
434 }
435
436 pub fn step_proof(&self, index: usize) -> Option<ProvenanceStepProof> {
440 if index >= self.steps.len() {
441 return None;
442 }
443
444 let leaf_count = next_power_of_two(self.steps.len());
445 let mut leaves: Vec<[u8; 32]> = self.steps.iter().map(|s| s.leaf_hash()).collect();
446 let last = *leaves.last().expect("non-empty");
447 leaves.resize(leaf_count, last);
448
449 let depth = leaf_count.trailing_zeros() as usize;
450 let mut siblings = Vec::with_capacity(depth);
451 let mut layer = leaves;
452 let mut idx = index;
453
454 for _ in 0..depth {
455 let sibling_idx = if idx.is_multiple_of(2) {
456 idx + 1
457 } else {
458 idx - 1
459 };
460 siblings.push(layer[sibling_idx]);
461 let next_len = layer.len() / 2;
462 let mut next = Vec::with_capacity(next_len);
463 for i in 0..next_len {
464 next.push(merkle_node(&layer[2 * i], &layer[2 * i + 1]));
465 }
466 layer = next;
467 idx >>= 1;
468 }
469
470 Some(ProvenanceStepProof {
471 step: self.steps[index].clone(),
472 siblings,
473 step_count: self.steps.len() as u32,
474 })
475 }
476
477 pub fn steps(&self) -> &[ReasoningStep] {
479 &self.steps
480 }
481}
482
483fn build_merkle_root(steps: &[ReasoningStep]) -> [u8; 32] {
486 assert!(!steps.is_empty());
487
488 let leaf_count = next_power_of_two(steps.len());
489 let mut layer: Vec<[u8; 32]> = steps.iter().map(|s| s.leaf_hash()).collect();
490 let last = *layer.last().expect("non-empty");
491 layer.resize(leaf_count, last);
492
493 while layer.len() > 1 {
494 let next_len = layer.len() / 2;
495 let mut next = Vec::with_capacity(next_len);
496 for i in 0..next_len {
497 next.push(merkle_node(&layer[2 * i], &layer[2 * i + 1]));
498 }
499 layer = next;
500 }
501
502 layer[0]
503}
504
505fn next_power_of_two(n: usize) -> usize {
506 if n <= 1 {
507 return 1;
508 }
509 let mut p = 1usize;
510 while p < n {
511 p <<= 1;
512 }
513 p
514}
515
516fn hash_metadata(pairs: &[(&str, &str)]) -> [u8; 32] {
517 let mut h = Hasher::new_derive_key(DOMAIN_PROVENANCE_META);
518 h.update(&(pairs.len() as u32).to_le_bytes());
519 for (k, v) in pairs {
520 h.update(&(k.len() as u32).to_le_bytes());
521 h.update(k.as_bytes());
522 h.update(&(v.len() as u32).to_le_bytes());
523 h.update(v.as_bytes());
524 }
525 h.finalize().into()
526}
527
528#[cfg(feature = "serde")]
531mod hex_32 {
532 use serde::{Deserialize, Deserializer, Serializer};
533
534 pub fn serialize<S: Serializer>(v: &[u8; 32], s: S) -> Result<S::Ok, S::Error> {
535 s.serialize_str(&hex::encode(v))
536 }
537
538 pub fn deserialize<'de, D: Deserializer<'de>>(d: D) -> Result<[u8; 32], D::Error> {
539 let raw = hex::decode(String::deserialize(d)?).map_err(serde::de::Error::custom)?;
540 raw.try_into()
541 .map_err(|_| serde::de::Error::custom("expected 32-byte hex"))
542 }
543}
544
545#[cfg(feature = "serde")]
546mod hex_16 {
547 use serde::{Deserialize, Deserializer, Serializer};
548
549 pub fn serialize<S: Serializer>(v: &[u8; 16], s: S) -> Result<S::Ok, S::Error> {
550 s.serialize_str(&hex::encode(v))
551 }
552
553 pub fn deserialize<'de, D: Deserializer<'de>>(d: D) -> Result<[u8; 16], D::Error> {
554 let raw = hex::decode(String::deserialize(d)?).map_err(serde::de::Error::custom)?;
555 raw.try_into()
556 .map_err(|_| serde::de::Error::custom("expected 16-byte hex"))
557 }
558}
559
560#[cfg(test)]
563mod tests {
564 use super::*;
565
566 fn fake_chain_fp() -> [u8; 32] {
567 let mut fp = [0u8; 32];
568 fp[0] = 0xAB;
569 fp[31] = 0xCD;
570 fp
571 }
572
573 fn build_trace(n: usize) -> ReasoningTrace {
574 let mut trace = ReasoningTrace::new(1_700_000_000);
575 for i in 0..n {
576 trace.record(
577 ReasoningStepKind::Thought,
578 format!("step {i}").as_bytes(),
579 1_700_000_000 + i as u64,
580 );
581 }
582 trace
583 }
584
585 #[test]
586 fn single_step_trace_finalizes() {
587 let trace = build_trace(1);
588 let fp = fake_chain_fp();
589 let root = trace.finalize(1_700_001_000, &fp).unwrap();
590 assert_eq!(root.step_count, 1);
591 assert!(root.verify_chain_binding(&fp));
592 }
593
594 #[test]
595 fn chain_binding_fails_wrong_fp() {
596 let trace = build_trace(3);
597 let fp = fake_chain_fp();
598 let root = trace.finalize(1_700_001_000, &fp).unwrap();
599 let mut wrong_fp = fp;
600 wrong_fp[0] ^= 0xFF;
601 assert!(!root.verify_chain_binding(&wrong_fp));
602 }
603
604 #[test]
605 fn empty_trace_returns_error() {
606 let trace = ReasoningTrace::new(1_700_000_000);
607 let fp = fake_chain_fp();
608 assert!(trace.finalize(1_700_001_000, &fp).is_err());
609 }
610
611 #[test]
612 fn merkle_proof_verifies_each_step() {
613 for n in [1usize, 2, 3, 4, 5, 7, 8, 9, 15, 16] {
614 let trace = build_trace(n);
615 let fp = fake_chain_fp();
616 let root = trace.finalize(1_700_001_000, &fp).unwrap();
617
618 for i in 0..n {
619 let proof = trace.step_proof(i).expect("step exists");
620 assert!(
621 proof.verify(&root),
622 "proof failed for step {i} in trace of {n}"
623 );
624 }
625 }
626 }
627
628 #[test]
629 fn step_proof_out_of_range_is_none() {
630 let trace = build_trace(3);
631 assert!(trace.step_proof(3).is_none());
632 assert!(trace.step_proof(100).is_none());
633 }
634
635 #[test]
636 fn tampered_step_content_fails_proof() {
637 let trace = build_trace(4);
638 let fp = fake_chain_fp();
639 let root = trace.finalize(1_700_001_000, &fp).unwrap();
640 let mut proof = trace.step_proof(2).unwrap();
641 proof.step.content_hash[0] ^= 0x01;
642 assert!(!proof.verify(&root));
643 }
644
645 #[test]
646 fn reordered_step_index_fails_proof() {
647 let trace = build_trace(4);
648 let fp = fake_chain_fp();
649 let root = trace.finalize(1_700_001_000, &fp).unwrap();
650 let mut proof = trace.step_proof(1).unwrap();
651 proof.step.index = 3;
652 assert!(!proof.verify(&root));
653 }
654
655 #[test]
656 fn different_traces_produce_different_roots() {
657 let fp = fake_chain_fp();
658 let t1 = build_trace(3);
659 let t2 = build_trace(3);
660 let r1 = t1.finalize(1_700_001_000, &fp).unwrap();
661 let r2 = t2.finalize(1_700_001_000, &fp).unwrap();
662 assert_ne!(r1.trace_id, r2.trace_id);
663 assert_ne!(r1.chain_binding, r2.chain_binding);
664 }
665
666 #[test]
667 fn tool_call_and_observation_record_metadata() {
668 let mut trace = ReasoningTrace::new(1_700_000_000);
669 let step = trace.record_tool_call("search", b"AAPL price", 1_700_000_001);
670 assert_eq!(step.kind, ReasoningStepKind::ToolCall);
671 assert_ne!(step.metadata_hash, [0u8; 32]);
672
673 let step = trace.record_observation("search", b"182.50", 1_700_000_002);
674 assert_eq!(step.kind, ReasoningStepKind::Observation);
675 assert_ne!(step.metadata_hash, [0u8; 32]);
676 }
677
678 #[test]
679 fn leaf_hash_is_index_sensitive() {
680 let mut s1 = ReasoningStep {
681 index: 0,
682 kind: ReasoningStepKind::Thought,
683 content_hash: [1u8; 32],
684 timestamp_unix: 1_700_000_000,
685 metadata_hash: [0u8; 32],
686 };
687 let hash_at_0 = s1.leaf_hash();
688 s1.index = 1;
689 let hash_at_1 = s1.leaf_hash();
690 assert_ne!(hash_at_0, hash_at_1);
691 }
692}