1use std::fs::File;
22use std::io::{self, Read, Seek, SeekFrom, Write};
23
24pub const SHM_MAGIC: &[u8; 8] = b"RDBSHM01";
25pub const SHM_VERSION: u32 = 1;
26pub const SHM_HEADER_SIZE: usize = 64;
27pub const SHM_FILE_SIZE: u64 = 4096;
28
29#[derive(Debug, Clone)]
30pub struct ShmHeader {
31 pub version: u32,
32 pub owner_pid: u32,
33 pub generation: u64,
34 pub reader_count: u64,
35 pub last_heartbeat_ms: u64,
36}
37
38impl ShmHeader {
39 pub fn new(owner_pid: u32, generation: u64, reader_count: u64, last_heartbeat_ms: u64) -> Self {
40 Self {
41 version: SHM_VERSION,
42 owner_pid,
43 generation,
44 reader_count,
45 last_heartbeat_ms,
46 }
47 }
48
49 pub fn encode(&self) -> [u8; SHM_HEADER_SIZE] {
50 let mut buf = [0u8; SHM_HEADER_SIZE];
51 buf[0..8].copy_from_slice(SHM_MAGIC);
52 buf[8..12].copy_from_slice(&self.version.to_le_bytes());
53 buf[12..16].copy_from_slice(&self.owner_pid.to_le_bytes());
54 buf[16..24].copy_from_slice(&self.generation.to_le_bytes());
55 buf[24..32].copy_from_slice(&self.reader_count.to_le_bytes());
56 buf[32..40].copy_from_slice(&self.last_heartbeat_ms.to_le_bytes());
57 let checksum = fold_checksum(&buf[..56]);
58 buf[56..64].copy_from_slice(&checksum.to_le_bytes());
59 buf
60 }
61
62 pub fn decode(buf: &[u8; SHM_HEADER_SIZE]) -> io::Result<Self> {
63 if &buf[0..8] != SHM_MAGIC {
64 return Err(io::Error::new(
65 io::ErrorKind::InvalidData,
66 "shm magic mismatch",
67 ));
68 }
69 let stored_checksum = u64::from_le_bytes(buf[56..64].try_into().unwrap());
70 let computed = fold_checksum(&buf[..56]);
71 if stored_checksum != computed {
72 return Err(io::Error::new(
73 io::ErrorKind::InvalidData,
74 "shm checksum mismatch",
75 ));
76 }
77 Ok(Self {
78 version: u32::from_le_bytes(buf[8..12].try_into().unwrap()),
79 owner_pid: u32::from_le_bytes(buf[12..16].try_into().unwrap()),
80 generation: u64::from_le_bytes(buf[16..24].try_into().unwrap()),
81 reader_count: u64::from_le_bytes(buf[24..32].try_into().unwrap()),
82 last_heartbeat_ms: u64::from_le_bytes(buf[32..40].try_into().unwrap()),
83 })
84 }
85}
86
87pub fn initialize_shm_file(file: &mut File, header: &ShmHeader) -> io::Result<()> {
88 file.set_len(SHM_FILE_SIZE)?;
89 write_shm_header_to_file(file, header)
90}
91
92pub fn read_shm_header_from_file(file: &mut File) -> io::Result<ShmHeader> {
93 let mut buf = [0u8; SHM_HEADER_SIZE];
94 file.seek(SeekFrom::Start(0))?;
95 file.read_exact(&mut buf)?;
96 ShmHeader::decode(&buf)
97}
98
99pub fn write_shm_header_to_file(file: &mut File, header: &ShmHeader) -> io::Result<()> {
100 let buf = header.encode();
101 file.seek(SeekFrom::Start(0))?;
102 file.write_all(&buf)?;
103 file.sync_data()?;
104 Ok(())
105}
106
107fn fold_checksum(bytes: &[u8]) -> u64 {
108 let mut acc: u64 = 0xcbf29ce484222325;
109 for &byte in bytes {
110 acc ^= byte as u64;
111 acc = acc.wrapping_mul(0x100000001b3);
112 }
113 acc
114}
115
116#[cfg(test)]
117mod tests {
118 use super::*;
119
120 #[test]
121 fn shm_header_round_trips() {
122 let header = ShmHeader::new(42, 7, 3, 99);
123
124 let encoded = header.encode();
125 assert_eq!(&encoded[0..8], SHM_MAGIC);
126 assert_eq!(encoded.len(), SHM_HEADER_SIZE);
127
128 let decoded = ShmHeader::decode(&encoded).expect("decode");
129 assert_eq!(decoded.version, header.version);
130 assert_eq!(decoded.owner_pid, header.owner_pid);
131 assert_eq!(decoded.generation, header.generation);
132 assert_eq!(decoded.reader_count, header.reader_count);
133 assert_eq!(decoded.last_heartbeat_ms, header.last_heartbeat_ms);
134 }
135
136 #[test]
137 fn shm_header_rejects_checksum_mismatch() {
138 let header = ShmHeader::new(1, 1, 0, 1);
139 let mut encoded = header.encode();
140 encoded[20] ^= 0xff;
141
142 let err = ShmHeader::decode(&encoded).expect_err("checksum must fail");
143 assert_eq!(err.kind(), io::ErrorKind::InvalidData);
144 }
145
146 #[test]
147 fn shm_file_helpers_initialize_and_rewrite_header() {
148 let path = std::env::temp_dir().join(format!(
149 "reddb-shm-file-helper-{}-{}.shm",
150 std::process::id(),
151 unique_test_suffix()
152 ));
153 let mut file = File::options()
154 .read(true)
155 .write(true)
156 .create(true)
157 .truncate(true)
158 .open(&path)
159 .expect("create shm file");
160
161 let header = ShmHeader::new(11, 2, 3, 4);
162 initialize_shm_file(&mut file, &header).expect("initialize");
163 assert_eq!(
164 file.metadata().expect("metadata").len(),
165 SHM_FILE_SIZE,
166 "helper owns the fixed shm file size"
167 );
168
169 let decoded = read_shm_header_from_file(&mut file).expect("read initialized header");
170 assert_eq!(decoded.owner_pid, 11);
171 assert_eq!(decoded.generation, 2);
172 assert_eq!(decoded.reader_count, 3);
173 assert_eq!(decoded.last_heartbeat_ms, 4);
174
175 let next = ShmHeader::new(12, 3, 0, 9);
176 write_shm_header_to_file(&mut file, &next).expect("rewrite");
177 let decoded = read_shm_header_from_file(&mut file).expect("read rewritten header");
178 assert_eq!(decoded.owner_pid, 12);
179 assert_eq!(decoded.generation, 3);
180 assert_eq!(decoded.reader_count, 0);
181 assert_eq!(decoded.last_heartbeat_ms, 9);
182
183 drop(file);
184 let _ = std::fs::remove_file(path);
185 }
186
187 fn unique_test_suffix() -> u128 {
188 std::time::SystemTime::now()
189 .duration_since(std::time::UNIX_EPOCH)
190 .expect("clock")
191 .as_nanos()
192 }
193}