vexil_runtime/
handshake.rs1use crate::{BitReader, BitWriter, DecodeError};
4
5#[derive(Debug, Clone, PartialEq)]
7pub struct SchemaHandshake {
8 pub hash: [u8; 32],
9 pub version: String,
10}
11
12#[derive(Debug, Clone, PartialEq)]
14pub enum HandshakeResult {
15 Match,
16 VersionMismatch {
17 local_version: String,
18 remote_version: String,
19 local_hash: [u8; 32],
20 remote_hash: [u8; 32],
21 },
22}
23
24impl SchemaHandshake {
25 pub fn new(hash: [u8; 32], version: &str) -> Self {
26 Self {
27 hash,
28 version: version.to_string(),
29 }
30 }
31
32 pub fn encode(&self) -> Vec<u8> {
33 let mut w = BitWriter::new();
34 w.write_raw_bytes(&self.hash);
35 w.write_string(&self.version);
36 w.finish()
37 }
38
39 pub fn decode(bytes: &[u8]) -> Result<Self, DecodeError> {
40 let mut r = BitReader::new(bytes);
41 let hash_bytes = r.read_raw_bytes(32)?;
42 let hash: [u8; 32] = hash_bytes
43 .try_into()
44 .map_err(|_| DecodeError::UnexpectedEof)?;
45 let version = r.read_string()?;
46 Ok(Self { hash, version })
47 }
48
49 pub fn check(&self, remote: &SchemaHandshake) -> HandshakeResult {
50 if self.hash == remote.hash {
51 HandshakeResult::Match
52 } else {
53 HandshakeResult::VersionMismatch {
54 local_version: self.version.clone(),
55 remote_version: remote.version.clone(),
56 local_hash: self.hash,
57 remote_hash: remote.hash,
58 }
59 }
60 }
61}
62
63#[cfg(test)]
64mod tests {
65 use super::*;
66
67 #[test]
68 fn encode_decode_roundtrip() {
69 let hash = [0xABu8; 32];
70 let hs = SchemaHandshake::new(hash, "1.2.3");
71 let bytes = hs.encode();
72 let decoded = SchemaHandshake::decode(&bytes).unwrap();
73 assert_eq!(decoded.hash, hash);
74 assert_eq!(decoded.version, "1.2.3");
75 }
76
77 #[test]
78 fn check_matching_hashes() {
79 let hash = [0x42u8; 32];
80 let local = SchemaHandshake::new(hash, "1.0.0");
81 let remote = SchemaHandshake::new(hash, "1.0.0");
82 assert_eq!(local.check(&remote), HandshakeResult::Match);
83 }
84
85 #[test]
86 fn check_different_hashes() {
87 let local = SchemaHandshake::new([0x01; 32], "1.0.0");
88 let remote = SchemaHandshake::new([0x02; 32], "1.1.0");
89 match local.check(&remote) {
90 HandshakeResult::VersionMismatch {
91 local_version,
92 remote_version,
93 ..
94 } => {
95 assert_eq!(local_version, "1.0.0");
96 assert_eq!(remote_version, "1.1.0");
97 }
98 _ => panic!("expected VersionMismatch"),
99 }
100 }
101
102 #[test]
103 fn wire_size_is_compact() {
104 let hs = SchemaHandshake::new([0; 32], "1.0.0");
105 let bytes = hs.encode();
106 assert_eq!(bytes.len(), 38);
108 }
109}