1use std::io::{Read, Write};
4use std::path::Path;
5
6use crate::types::{VisionError, VisionResult, VisualMemoryStore, VisualObservation};
7
8const AVIS_MAGIC: u32 = 0x41564953;
10
11const FORMAT_VERSION: u16 = 1;
13
14const HEADER_SIZE: usize = 64;
16
17pub struct AvisWriter;
19
20pub struct AvisReader;
22
23impl AvisWriter {
24 pub fn write_to_file(store: &VisualMemoryStore, path: &Path) -> VisionResult<()> {
26 if let Some(parent) = path.parent() {
27 std::fs::create_dir_all(parent)?;
28 }
29
30 let mut file = std::fs::File::create(path)?;
31 Self::write_to(store, &mut file)
32 }
33
34 pub fn write_to<W: Write>(store: &VisualMemoryStore, writer: &mut W) -> VisionResult<()> {
36 let payload = serde_json::to_vec(&SerializedStore {
38 observations: &store.observations,
39 embedding_dim: store.embedding_dim,
40 next_id: store.next_id,
41 session_count: store.session_count,
42 created_at: store.created_at,
43 updated_at: store.updated_at,
44 })
45 .map_err(|e| VisionError::Storage(format!("Serialization failed: {e}")))?;
46
47 let mut header = [0u8; HEADER_SIZE];
49 write_u32(&mut header[0..4], AVIS_MAGIC);
50 write_u16(&mut header[4..6], FORMAT_VERSION);
51 write_u16(&mut header[6..8], 0); write_u64(&mut header[8..16], store.observations.len() as u64);
53 write_u32(&mut header[16..20], store.embedding_dim);
54 write_u32(&mut header[20..24], store.session_count);
55 write_u64(&mut header[24..32], store.created_at);
56 write_u64(&mut header[32..40], store.updated_at);
57 write_u64(&mut header[40..48], payload.len() as u64); writer.write_all(&header)?;
60 writer.write_all(&payload)?;
61
62 Ok(())
63 }
64}
65
66impl AvisReader {
67 pub fn read_from_file(path: &Path) -> VisionResult<VisualMemoryStore> {
69 let mut file = std::fs::File::open(path)?;
70 Self::read_from(&mut file)
71 }
72
73 pub fn read_from<R: Read>(reader: &mut R) -> VisionResult<VisualMemoryStore> {
75 let mut header = [0u8; HEADER_SIZE];
77 reader.read_exact(&mut header)?;
78
79 let magic = read_u32(&header[0..4]);
80 if magic != AVIS_MAGIC {
81 return Err(VisionError::Storage(format!(
82 "Invalid magic: expected 0x{AVIS_MAGIC:08X}, got 0x{magic:08X}"
83 )));
84 }
85
86 let version = read_u16(&header[4..6]);
87 if version != FORMAT_VERSION {
88 return Err(VisionError::Storage(format!(
89 "Unsupported version: {version}"
90 )));
91 }
92
93 let _observation_count = read_u64(&header[8..16]);
94 let embedding_dim = read_u32(&header[16..20]);
95 let session_count = read_u32(&header[20..24]);
96 let created_at = read_u64(&header[24..32]);
97 let updated_at = read_u64(&header[32..40]);
98 let payload_len = read_u64(&header[40..48]) as usize;
99
100 let mut payload = vec![0u8; payload_len];
102 reader.read_exact(&mut payload)?;
103
104 let serialized: DeserializedStore = serde_json::from_slice(&payload)
105 .map_err(|e| VisionError::Storage(format!("Deserialization failed: {e}")))?;
106
107 let next_id = serialized.next_id;
108
109 Ok(VisualMemoryStore {
110 observations: serialized.observations,
111 embedding_dim,
112 next_id,
113 session_count,
114 created_at,
115 updated_at,
116 })
117 }
118}
119
120#[derive(serde::Serialize)]
121struct SerializedStore<'a> {
122 observations: &'a [VisualObservation],
123 embedding_dim: u32,
124 next_id: u64,
125 session_count: u32,
126 created_at: u64,
127 updated_at: u64,
128}
129
130#[derive(serde::Deserialize)]
131struct DeserializedStore {
132 observations: Vec<VisualObservation>,
133 #[allow(dead_code)]
134 embedding_dim: u32,
135 next_id: u64,
136 #[allow(dead_code)]
137 session_count: u32,
138 #[allow(dead_code)]
139 created_at: u64,
140 #[allow(dead_code)]
141 updated_at: u64,
142}
143
144fn write_u16(buf: &mut [u8], val: u16) {
146 buf[..2].copy_from_slice(&val.to_le_bytes());
147}
148fn write_u32(buf: &mut [u8], val: u32) {
149 buf[..4].copy_from_slice(&val.to_le_bytes());
150}
151fn write_u64(buf: &mut [u8], val: u64) {
152 buf[..8].copy_from_slice(&val.to_le_bytes());
153}
154fn read_u16(buf: &[u8]) -> u16 {
155 u16::from_le_bytes([buf[0], buf[1]])
156}
157fn read_u32(buf: &[u8]) -> u32 {
158 u32::from_le_bytes([buf[0], buf[1], buf[2], buf[3]])
159}
160fn read_u64(buf: &[u8]) -> u64 {
161 u64::from_le_bytes([
162 buf[0], buf[1], buf[2], buf[3], buf[4], buf[5], buf[6], buf[7],
163 ])
164}
165
166#[cfg(test)]
167mod tests {
168 use super::*;
169 use crate::types::{CaptureSource, ObservationMeta};
170
171 fn make_test_observation(id: u64) -> VisualObservation {
172 VisualObservation {
173 id,
174 timestamp: 1708345678,
175 session_id: 1,
176 source: CaptureSource::File {
177 path: "/test/image.png".to_string(),
178 },
179 embedding: vec![0.1, 0.2, 0.3],
180 thumbnail: vec![0xFF, 0xD8, 0xFF],
181 metadata: ObservationMeta {
182 width: 512,
183 height: 512,
184 original_width: 1920,
185 original_height: 1080,
186 labels: vec!["test".to_string()],
187 description: Some("Test observation".to_string()),
188 quality_score: 0.85,
189 },
190 memory_link: None,
191 }
192 }
193
194 #[test]
195 fn test_roundtrip_empty() {
196 let store = VisualMemoryStore::new(512);
197 let mut buf = Vec::new();
198 AvisWriter::write_to(&store, &mut buf).unwrap();
199
200 let loaded = AvisReader::read_from(&mut &buf[..]).unwrap();
201 assert_eq!(loaded.count(), 0);
202 assert_eq!(loaded.embedding_dim, 512);
203 }
204
205 #[test]
206 fn test_roundtrip_with_observations() {
207 let mut store = VisualMemoryStore::new(512);
208 store.add(make_test_observation(0));
209 store.add(make_test_observation(0));
210
211 let mut buf = Vec::new();
212 AvisWriter::write_to(&store, &mut buf).unwrap();
213
214 let loaded = AvisReader::read_from(&mut &buf[..]).unwrap();
215 assert_eq!(loaded.count(), 2);
216 assert_eq!(loaded.observations[0].id, 1);
217 assert_eq!(loaded.observations[1].id, 2);
218 }
219
220 #[test]
221 fn test_invalid_magic() {
222 let mut buf = [0u8; HEADER_SIZE + 10];
223 buf[0..4].copy_from_slice(&[0x00, 0x00, 0x00, 0x00]);
224 let result = AvisReader::read_from(&mut &buf[..]);
225 assert!(result.is_err());
226 }
227
228 #[test]
229 fn test_file_roundtrip() {
230 let dir = tempfile::tempdir().unwrap();
231 let path = dir.path().join("test.avis");
232
233 let mut store = VisualMemoryStore::new(512);
234 store.add(make_test_observation(0));
235
236 AvisWriter::write_to_file(&store, &path).unwrap();
237 let loaded = AvisReader::read_from_file(&path).unwrap();
238 assert_eq!(loaded.count(), 1);
239 }
240}