use crate::net::atp::protocol::frames::Frame;
use sha2::{Digest, Sha256};
use std::collections::BTreeMap;
#[derive(Debug, Clone)]
pub struct TranscriptHasher {
hasher: Sha256,
frame_count: u64,
}
impl TranscriptHasher {
pub fn new() -> Self {
let mut hasher = Sha256::new();
hasher.update(b"ATP-TRANSCRIPT-V1\x00");
Self {
hasher,
frame_count: 0,
}
}
pub fn update_frame(&mut self, frame: &Frame) {
self.hasher.update(self.frame_count.to_le_bytes());
self.frame_count += 1;
self.hasher.update(frame.header.version.0.to_le_bytes());
self.hasher
.update((frame.header.frame_type as u16).to_le_bytes());
self.hasher
.update(frame.header.payload_length.value().to_le_bytes());
let mut sorted_extensions: BTreeMap<u16, &Vec<u8>> = BTreeMap::new();
for (id, data) in &frame.header.extensions {
sorted_extensions.insert(*id, data);
}
self.hasher
.update((sorted_extensions.len() as u32).to_le_bytes());
for (ext_id, ext_data) in sorted_extensions {
self.hasher.update(ext_id.to_le_bytes());
self.hasher.update((ext_data.len() as u32).to_le_bytes());
self.hasher.update(ext_data);
}
let payload_hash = Sha256::digest(&frame.payload);
self.hasher.update(payload_hash);
}
pub fn current_hash(&self) -> TranscriptHash {
TranscriptHash(self.hasher.clone().finalize().into())
}
pub fn finalize(self) -> TranscriptHash {
TranscriptHash(self.hasher.finalize().into())
}
pub fn frame_count(&self) -> u64 {
self.frame_count
}
pub fn checkpoint(&self) -> TranscriptCheckpoint {
TranscriptCheckpoint {
hash: self.current_hash(),
frame_count: self.frame_count,
}
}
}
impl Default for TranscriptHasher {
fn default() -> Self {
Self::new()
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub struct TranscriptHash(pub [u8; 32]);
impl TranscriptHash {
pub fn to_hex(&self) -> String {
hex::encode(self.0)
}
pub fn from_hex(hex: &str) -> Result<Self, TranscriptError> {
let bytes = hex::decode(hex)
.map_err(|_| TranscriptError::InvalidHash("invalid hex encoding".to_string()))?;
if bytes.len() != 32 {
return Err(TranscriptError::InvalidHash(format!(
"expected 32 bytes, got {}",
bytes.len()
)));
}
let mut hash = [0u8; 32];
hash.copy_from_slice(&bytes);
Ok(TranscriptHash(hash))
}
pub fn as_bytes(&self) -> &[u8; 32] {
&self.0
}
}
impl std::fmt::Display for TranscriptHash {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{}", self.to_hex())
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct TranscriptCheckpoint {
pub hash: TranscriptHash,
pub frame_count: u64,
}
#[derive(Debug, Clone)]
pub struct SessionTranscript {
hasher: TranscriptHasher,
checkpoints: Vec<TranscriptCheckpoint>,
}
impl SessionTranscript {
pub fn new() -> Self {
Self {
hasher: TranscriptHasher::new(),
checkpoints: Vec::new(),
}
}
pub fn add_frame(&mut self, frame: &Frame) {
self.hasher.update_frame(frame);
}
pub fn checkpoint(&mut self) -> TranscriptCheckpoint {
let checkpoint = self.hasher.checkpoint();
self.checkpoints.push(checkpoint.clone());
checkpoint
}
pub fn verify_hash(&self, expected: &TranscriptHash) -> bool {
&self.hasher.current_hash() == expected
}
pub fn checkpoints(&self) -> &[TranscriptCheckpoint] {
&self.checkpoints
}
pub fn current_hash(&self) -> TranscriptHash {
self.hasher.current_hash()
}
pub fn frame_count(&self) -> u64 {
self.hasher.frame_count()
}
pub fn finalize(self) -> TranscriptHash {
self.hasher.finalize()
}
}
impl Default for SessionTranscript {
fn default() -> Self {
Self::new()
}
}
#[derive(Debug, thiserror::Error)]
pub enum TranscriptError {
#[error("invalid transcript hash: {0}")]
InvalidHash(String),
#[error("transcript verification failed: expected {expected}, got {actual}")]
VerificationFailed {
expected: TranscriptHash,
actual: TranscriptHash,
},
#[error("frame sequence error: expected frame {expected}, got {actual}")]
SequenceError {
expected: u64,
actual: u64,
},
}
pub struct TranscriptVerifier {
expected_checkpoints: Vec<TranscriptCheckpoint>,
current_index: usize,
}
impl TranscriptVerifier {
pub fn new(expected_checkpoints: Vec<TranscriptCheckpoint>) -> Self {
Self {
expected_checkpoints,
current_index: 0,
}
}
pub fn verify_checkpoint(
&mut self,
checkpoint: &TranscriptCheckpoint,
) -> Result<(), TranscriptError> {
if self.current_index >= self.expected_checkpoints.len() {
return Err(TranscriptError::SequenceError {
expected: self.expected_checkpoints.len() as u64,
actual: self.current_index as u64,
});
}
let expected = &self.expected_checkpoints[self.current_index];
if checkpoint.frame_count != expected.frame_count {
return Err(TranscriptError::SequenceError {
expected: expected.frame_count,
actual: checkpoint.frame_count,
});
}
if checkpoint.hash != expected.hash {
return Err(TranscriptError::VerificationFailed {
expected: expected.hash,
actual: checkpoint.hash,
});
}
self.current_index += 1;
Ok(())
}
pub fn is_complete(&self) -> bool {
self.current_index == self.expected_checkpoints.len()
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::net::atp::protocol::frames::{FrameType, ProtocolVersion};
#[test]
fn test_transcript_deterministic() {
let frame1 =
Frame::new(ProtocolVersion::V0, FrameType::Handshake, b"hello".to_vec()).unwrap();
let frame2 = Frame::new(
ProtocolVersion::V0,
FrameType::HandshakeAck,
b"world".to_vec(),
)
.unwrap();
let mut hasher1 = TranscriptHasher::new();
hasher1.update_frame(&frame1);
hasher1.update_frame(&frame2);
let mut hasher2 = TranscriptHasher::new();
hasher2.update_frame(&frame1);
hasher2.update_frame(&frame2);
assert_eq!(hasher1.finalize(), hasher2.finalize());
}
#[test]
fn test_transcript_order_sensitive() {
let frame1 =
Frame::new(ProtocolVersion::V0, FrameType::Handshake, b"hello".to_vec()).unwrap();
let frame2 = Frame::new(
ProtocolVersion::V0,
FrameType::HandshakeAck,
b"world".to_vec(),
)
.unwrap();
let mut hasher1 = TranscriptHasher::new();
hasher1.update_frame(&frame1);
hasher1.update_frame(&frame2);
let mut hasher2 = TranscriptHasher::new();
hasher2.update_frame(&frame2);
hasher2.update_frame(&frame1);
assert_ne!(hasher1.finalize(), hasher2.finalize());
}
#[test]
fn test_transcript_with_extensions() {
let mut frame = Frame::new(
ProtocolVersion::V0,
FrameType::Capabilities,
b"test".to_vec(),
)
.unwrap();
frame.header.extensions.insert(3, b"ext3".to_vec());
frame.header.extensions.insert(1, b"ext1".to_vec());
frame.header.extensions.insert(2, b"ext2".to_vec());
let mut hasher1 = TranscriptHasher::new();
hasher1.update_frame(&frame);
let mut frame2 = Frame::new(
ProtocolVersion::V0,
FrameType::Capabilities,
b"test".to_vec(),
)
.unwrap();
frame2.header.extensions.insert(2, b"ext2".to_vec());
frame2.header.extensions.insert(3, b"ext3".to_vec());
frame2.header.extensions.insert(1, b"ext1".to_vec());
let mut hasher2 = TranscriptHasher::new();
hasher2.update_frame(&frame2);
assert_eq!(hasher1.finalize(), hasher2.finalize());
}
#[test]
fn test_session_transcript() {
let mut transcript = SessionTranscript::new();
let frame1 =
Frame::new(ProtocolVersion::V0, FrameType::Handshake, b"hello".to_vec()).unwrap();
transcript.add_frame(&frame1);
let checkpoint1 = transcript.checkpoint();
let frame2 = Frame::new(
ProtocolVersion::V0,
FrameType::HandshakeAck,
b"world".to_vec(),
)
.unwrap();
transcript.add_frame(&frame2);
let checkpoint2 = transcript.checkpoint();
assert_eq!(checkpoint1.frame_count, 1);
assert_eq!(checkpoint2.frame_count, 2);
assert_ne!(checkpoint1.hash, checkpoint2.hash);
}
#[test]
fn test_transcript_hash_hex() {
let hash = TranscriptHash([0xab; 32]);
let hex = hash.to_hex();
assert_eq!(hex.len(), 64);
assert!(hex.chars().all(|c| "0123456789abcdef".contains(c)));
let parsed = TranscriptHash::from_hex(&hex).unwrap();
assert_eq!(parsed, hash);
}
#[test]
fn test_transcript_verifier() {
let checkpoints = vec![
TranscriptCheckpoint {
hash: TranscriptHash([1; 32]),
frame_count: 1,
},
TranscriptCheckpoint {
hash: TranscriptHash([2; 32]),
frame_count: 2,
},
];
let mut verifier = TranscriptVerifier::new(checkpoints.clone());
verifier.verify_checkpoint(&checkpoints[0]).unwrap();
verifier.verify_checkpoint(&checkpoints[1]).unwrap();
assert!(verifier.is_complete());
let mut bad_verifier = TranscriptVerifier::new(checkpoints.clone());
let bad_checkpoint = TranscriptCheckpoint {
hash: TranscriptHash([99; 32]),
frame_count: 1,
};
assert!(bad_verifier.verify_checkpoint(&bad_checkpoint).is_err());
}
}