1use std::io::{Read, Write};
4use std::path::Path;
5
6use crate::types::{VisionError, VisionResult, VisualMemoryStore, VisualObservation};
7
8const AVIS_MAGIC: u32 = 0x41564953;
10
11const FORMAT_VERSION: u16 = 1;
13
14const HEADER_SIZE: usize = 64;
16
17pub struct AvisWriter;
19
20pub struct AvisReader;
22
23impl AvisWriter {
24 pub fn write_to_file(store: &VisualMemoryStore, path: &Path) -> VisionResult<()> {
26 if let Some(parent) = path.parent() {
27 std::fs::create_dir_all(parent)?;
28 }
29
30 let mut file = std::fs::File::create(path)?;
31 Self::write_to(store, &mut file)
32 }
33
34 pub fn write_to<W: Write>(store: &VisualMemoryStore, writer: &mut W) -> VisionResult<()> {
36 let payload = serde_json::to_vec(&SerializedStore {
38 observations: &store.observations,
39 embedding_dim: store.embedding_dim,
40 next_id: store.next_id,
41 session_count: store.session_count,
42 created_at: store.created_at,
43 updated_at: store.updated_at,
44 })
45 .map_err(|e| VisionError::Storage(format!("Serialization failed: {e}")))?;
46
47 let mut header = [0u8; HEADER_SIZE];
49 write_u32(&mut header[0..4], AVIS_MAGIC);
50 write_u16(&mut header[4..6], FORMAT_VERSION);
51 write_u16(&mut header[6..8], 0); write_u64(&mut header[8..16], store.observations.len() as u64);
53 write_u32(&mut header[16..20], store.embedding_dim);
54 write_u32(&mut header[20..24], store.session_count);
55 write_u64(&mut header[24..32], store.created_at);
56 write_u64(&mut header[32..40], store.updated_at);
57 write_u64(&mut header[40..48], payload.len() as u64); writer.write_all(&header)?;
60 writer.write_all(&payload)?;
61
62 Ok(())
63 }
64}
65
66impl AvisReader {
67 pub fn read_from_file(path: &Path) -> VisionResult<VisualMemoryStore> {
69 let mut file = std::fs::File::open(path)?;
70 Self::read_from(&mut file)
71 }
72
73 pub fn read_from<R: Read>(reader: &mut R) -> VisionResult<VisualMemoryStore> {
75 let mut header = [0u8; HEADER_SIZE];
77 reader.read_exact(&mut header)?;
78
79 let magic = read_u32(&header[0..4]);
80 if magic != AVIS_MAGIC {
81 return Err(VisionError::Storage(format!(
82 "Invalid magic: expected 0x{AVIS_MAGIC:08X}, got 0x{magic:08X}"
83 )));
84 }
85
86 let version = read_u16(&header[4..6]);
87 if version != FORMAT_VERSION {
88 return Err(VisionError::Storage(format!(
89 "Unsupported version: {version}"
90 )));
91 }
92
93 let _observation_count = read_u64(&header[8..16]);
94 let embedding_dim = read_u32(&header[16..20]);
95 let session_count = read_u32(&header[20..24]);
96 let created_at = read_u64(&header[24..32]);
97 let updated_at = read_u64(&header[32..40]);
98 let payload_len = read_u64(&header[40..48]) as usize;
99
100 let mut payload = vec![0u8; payload_len];
102 reader.read_exact(&mut payload)?;
103
104 let serialized: DeserializedStore = serde_json::from_slice(&payload)
105 .map_err(|e| VisionError::Storage(format!("Deserialization failed: {e}")))?;
106
107 let next_id = serialized.next_id;
108
109 Ok(VisualMemoryStore {
110 observations: serialized.observations,
111 embedding_dim,
112 next_id,
113 session_count,
114 created_at,
115 updated_at,
116 })
117 }
118}
119
120#[derive(serde::Serialize)]
121struct SerializedStore<'a> {
122 observations: &'a [VisualObservation],
123 embedding_dim: u32,
124 next_id: u64,
125 session_count: u32,
126 created_at: u64,
127 updated_at: u64,
128}
129
130#[derive(serde::Deserialize)]
131struct DeserializedStore {
132 observations: Vec<VisualObservation>,
133 #[allow(dead_code)]
134 embedding_dim: u32,
135 next_id: u64,
136 #[allow(dead_code)]
137 session_count: u32,
138 #[allow(dead_code)]
139 created_at: u64,
140 #[allow(dead_code)]
141 updated_at: u64,
142}
143
144fn write_u16(buf: &mut [u8], val: u16) {
146 buf[..2].copy_from_slice(&val.to_le_bytes());
147}
148fn write_u32(buf: &mut [u8], val: u32) {
149 buf[..4].copy_from_slice(&val.to_le_bytes());
150}
151fn write_u64(buf: &mut [u8], val: u64) {
152 buf[..8].copy_from_slice(&val.to_le_bytes());
153}
154fn read_u16(buf: &[u8]) -> u16 {
155 u16::from_le_bytes([buf[0], buf[1]])
156}
157fn read_u32(buf: &[u8]) -> u32 {
158 u32::from_le_bytes([buf[0], buf[1], buf[2], buf[3]])
159}
160fn read_u64(buf: &[u8]) -> u64 {
161 u64::from_le_bytes([
162 buf[0], buf[1], buf[2], buf[3], buf[4], buf[5], buf[6], buf[7],
163 ])
164}
165
166#[cfg(test)]
167mod tests {
168 use super::*;
169 use crate::types::{CaptureSource, ObservationMeta};
170
171 fn make_test_observation(id: u64) -> VisualObservation {
172 VisualObservation {
173 id,
174 timestamp: 1708345678,
175 session_id: 1,
176 source: CaptureSource::File {
177 path: "/test/image.png".to_string(),
178 },
179 embedding: vec![0.1, 0.2, 0.3],
180 thumbnail: vec![0xFF, 0xD8, 0xFF],
181 metadata: ObservationMeta {
182 width: 512,
183 height: 512,
184 original_width: 1920,
185 original_height: 1080,
186 labels: vec!["test".to_string()],
187 description: Some("Test observation".to_string()),
188 },
189 memory_link: None,
190 }
191 }
192
193 #[test]
194 fn test_roundtrip_empty() {
195 let store = VisualMemoryStore::new(512);
196 let mut buf = Vec::new();
197 AvisWriter::write_to(&store, &mut buf).unwrap();
198
199 let loaded = AvisReader::read_from(&mut &buf[..]).unwrap();
200 assert_eq!(loaded.count(), 0);
201 assert_eq!(loaded.embedding_dim, 512);
202 }
203
204 #[test]
205 fn test_roundtrip_with_observations() {
206 let mut store = VisualMemoryStore::new(512);
207 store.add(make_test_observation(0));
208 store.add(make_test_observation(0));
209
210 let mut buf = Vec::new();
211 AvisWriter::write_to(&store, &mut buf).unwrap();
212
213 let loaded = AvisReader::read_from(&mut &buf[..]).unwrap();
214 assert_eq!(loaded.count(), 2);
215 assert_eq!(loaded.observations[0].id, 1);
216 assert_eq!(loaded.observations[1].id, 2);
217 }
218
219 #[test]
220 fn test_invalid_magic() {
221 let mut buf = [0u8; HEADER_SIZE + 10];
222 buf[0..4].copy_from_slice(&[0x00, 0x00, 0x00, 0x00]);
223 let result = AvisReader::read_from(&mut &buf[..]);
224 assert!(result.is_err());
225 }
226
227 #[test]
228 fn test_file_roundtrip() {
229 let dir = tempfile::tempdir().unwrap();
230 let path = dir.path().join("test.avis");
231
232 let mut store = VisualMemoryStore::new(512);
233 store.add(make_test_observation(0));
234
235 AvisWriter::write_to_file(&store, &path).unwrap();
236 let loaded = AvisReader::read_from_file(&path).unwrap();
237 assert_eq!(loaded.count(), 1);
238 }
239}