Skip to main content

vexil_runtime/
handshake.rs

1//! Schema handshake helpers for connection-time identity checking.
2
3use crate::{BitReader, BitWriter, DecodeError};
4
5/// Schema identity for connection-time negotiation.
6#[derive(Debug, Clone, PartialEq)]
7pub struct SchemaHandshake {
8    pub hash: [u8; 32],
9    pub version: String,
10}
11
12/// Result of comparing two handshakes.
13#[derive(Debug, Clone, PartialEq)]
14pub enum HandshakeResult {
15    Match,
16    VersionMismatch {
17        local_version: String,
18        remote_version: String,
19        local_hash: [u8; 32],
20        remote_hash: [u8; 32],
21    },
22}
23
24impl SchemaHandshake {
25    pub fn new(hash: [u8; 32], version: &str) -> Self {
26        Self {
27            hash,
28            version: version.to_string(),
29        }
30    }
31
32    pub fn encode(&self) -> Vec<u8> {
33        let mut w = BitWriter::new();
34        w.write_raw_bytes(&self.hash);
35        w.write_string(&self.version);
36        w.finish()
37    }
38
39    pub fn decode(bytes: &[u8]) -> Result<Self, DecodeError> {
40        let mut r = BitReader::new(bytes);
41        let hash_bytes = r.read_raw_bytes(32)?;
42        let hash: [u8; 32] = hash_bytes
43            .try_into()
44            .map_err(|_| DecodeError::UnexpectedEof)?;
45        let version = r.read_string()?;
46        Ok(Self { hash, version })
47    }
48
49    pub fn check(&self, remote: &SchemaHandshake) -> HandshakeResult {
50        if self.hash == remote.hash {
51            HandshakeResult::Match
52        } else {
53            HandshakeResult::VersionMismatch {
54                local_version: self.version.clone(),
55                remote_version: remote.version.clone(),
56                local_hash: self.hash,
57                remote_hash: remote.hash,
58            }
59        }
60    }
61}
62
63#[cfg(test)]
64mod tests {
65    use super::*;
66
67    #[test]
68    fn encode_decode_roundtrip() {
69        let hash = [0xABu8; 32];
70        let hs = SchemaHandshake::new(hash, "1.2.3");
71        let bytes = hs.encode();
72        let decoded = SchemaHandshake::decode(&bytes).unwrap();
73        assert_eq!(decoded.hash, hash);
74        assert_eq!(decoded.version, "1.2.3");
75    }
76
77    #[test]
78    fn check_matching_hashes() {
79        let hash = [0x42u8; 32];
80        let local = SchemaHandshake::new(hash, "1.0.0");
81        let remote = SchemaHandshake::new(hash, "1.0.0");
82        assert_eq!(local.check(&remote), HandshakeResult::Match);
83    }
84
85    #[test]
86    fn check_different_hashes() {
87        let local = SchemaHandshake::new([0x01; 32], "1.0.0");
88        let remote = SchemaHandshake::new([0x02; 32], "1.1.0");
89        match local.check(&remote) {
90            HandshakeResult::VersionMismatch {
91                local_version,
92                remote_version,
93                ..
94            } => {
95                assert_eq!(local_version, "1.0.0");
96                assert_eq!(remote_version, "1.1.0");
97            }
98            _ => panic!("expected VersionMismatch"),
99        }
100    }
101
102    #[test]
103    fn wire_size_is_compact() {
104        let hs = SchemaHandshake::new([0; 32], "1.0.0");
105        let bytes = hs.encode();
106        // 32 (hash) + 1 (LEB128 len) + 5 (version) = 38
107        assert_eq!(bytes.len(), 38);
108    }
109}