1mod engine;
38mod types;
39
40pub use engine::{
41 ActionDiff, ActionReplayResult, ReplayEngine, ReplayExecutionConfig, ReplayResult as EngineReplayResult,
42 SessionComparison,
43};
44pub use types::{
45 ActionType, Checkpoint, ComparisonReport, ComparisonSummary, Divergence, DivergenceType,
46 ModelResult, ReplayAction, ReplayManifest, ReplayOptions, ReplayResult, ReplaySession,
47 SessionSummary, StateSnapshot, REPLAY_SEGMENT_MAGIC, REPLAY_SEGMENT_VERSION,
48};
49
50use crate::error::Result;
51use crate::MemvidError;
52use uuid::Uuid;
53
54#[derive(Clone, Debug, Default, serde::Serialize, serde::Deserialize)]
56pub struct ReplayConfig {
57 pub auto_checkpoint_interval: u64,
59 pub max_actions_per_session: Option<u64>,
61 pub auto_record: bool,
63}
64
65#[derive(Debug, serde::Serialize, serde::Deserialize)]
67pub struct ActiveSession {
68 pub session: ReplaySession,
70 pub next_checkpoint_id: u64,
72 pub actions_since_checkpoint: u64,
74 pub config: ReplayConfig,
76}
77
78impl ActiveSession {
79 #[must_use]
81 pub fn new(name: Option<String>, config: ReplayConfig) -> Self {
82 Self {
83 session: ReplaySession::new(name),
84 next_checkpoint_id: 0,
85 actions_since_checkpoint: 0,
86 config,
87 }
88 }
89
90 pub fn record_action(&mut self, action: ReplayAction) {
92 self.session.add_action(action);
93 self.actions_since_checkpoint += 1;
94 }
95
96 #[must_use]
98 pub fn should_checkpoint(&self) -> bool {
99 self.config.auto_checkpoint_interval > 0
100 && self.actions_since_checkpoint >= self.config.auto_checkpoint_interval
101 }
102
103 pub fn create_checkpoint(&mut self, snapshot: StateSnapshot) -> Checkpoint {
105 let checkpoint = Checkpoint::new(
106 self.next_checkpoint_id,
107 self.session.next_sequence().saturating_sub(1),
108 snapshot,
109 );
110 self.session.add_checkpoint(checkpoint.clone());
111 self.next_checkpoint_id += 1;
112 self.actions_since_checkpoint = 0;
113 checkpoint
114 }
115
116 pub fn end(mut self) -> ReplaySession {
118 self.session.end();
119 self.session
120 }
121
122 #[must_use]
124 pub fn session_id(&self) -> Uuid {
125 self.session.session_id
126 }
127}
128
129pub mod storage {
131 use super::*;
132 use bincode::config::{self, Config};
133 use std::io::{Read, Write};
134
135 fn bincode_config() -> impl Config {
136 config::standard()
137 .with_fixed_int_encoding()
138 .with_little_endian()
139 }
140
141 #[derive(Debug)]
143 pub struct ReplaySegmentHeader {
144 pub magic: [u8; 8],
145 pub version: u32,
146 pub session_count: u32,
147 pub total_size: u64,
148 }
149
150 impl ReplaySegmentHeader {
151 #[must_use]
153 pub fn new(session_count: u32, total_size: u64) -> Self {
154 Self {
155 magic: *REPLAY_SEGMENT_MAGIC,
156 version: REPLAY_SEGMENT_VERSION,
157 session_count,
158 total_size,
159 }
160 }
161
162 pub fn write<W: Write>(&self, writer: &mut W) -> Result<()> {
164 writer.write_all(&self.magic)?;
165 writer.write_all(&self.version.to_le_bytes())?;
166 writer.write_all(&self.session_count.to_le_bytes())?;
167 writer.write_all(&self.total_size.to_le_bytes())?;
168 Ok(())
169 }
170
171 pub fn read<R: Read>(reader: &mut R) -> Result<Self> {
173 let mut magic = [0u8; 8];
174 reader.read_exact(&mut magic)?;
175 if &magic != REPLAY_SEGMENT_MAGIC {
176 return Err(MemvidError::InvalidToc {
177 reason: "Invalid replay segment magic".into(),
178 });
179 }
180
181 let mut version_bytes = [0u8; 4];
182 reader.read_exact(&mut version_bytes)?;
183 let version = u32::from_le_bytes(version_bytes);
184
185 let mut session_count_bytes = [0u8; 4];
186 reader.read_exact(&mut session_count_bytes)?;
187 let session_count = u32::from_le_bytes(session_count_bytes);
188
189 let mut total_size_bytes = [0u8; 8];
190 reader.read_exact(&mut total_size_bytes)?;
191 let total_size = u64::from_le_bytes(total_size_bytes);
192
193 Ok(Self {
194 magic,
195 version,
196 session_count,
197 total_size,
198 })
199 }
200
201 pub const SIZE: usize = 8 + 4 + 4 + 8; }
204
205 pub fn serialize_session(session: &ReplaySession) -> Result<Vec<u8>> {
207 bincode::serde::encode_to_vec(session, bincode_config())
208 .map_err(|e| MemvidError::InvalidToc {
209 reason: format!("Failed to serialize replay session: {}", e).into(),
210 })
211 }
212
213 pub fn deserialize_session(data: &[u8]) -> Result<ReplaySession> {
215 bincode::serde::decode_from_slice(data, bincode_config())
216 .map(|(session, _)| session)
217 .map_err(|e| MemvidError::InvalidToc {
218 reason: format!("Failed to deserialize replay session: {}", e).into(),
219 })
220 }
221
222 pub fn build_segment(sessions: &[ReplaySession]) -> Result<Vec<u8>> {
224 let mut session_data: Vec<Vec<u8>> = Vec::with_capacity(sessions.len());
225 let mut total_session_bytes = 0u64;
226
227 for session in sessions {
228 let data = serialize_session(session)?;
229 total_session_bytes += data.len() as u64 + 8; session_data.push(data);
231 }
232
233 let header = ReplaySegmentHeader::new(
234 sessions.len() as u32,
235 ReplaySegmentHeader::SIZE as u64 + total_session_bytes,
236 );
237
238 let mut segment = Vec::with_capacity(header.total_size as usize);
239 header.write(&mut segment)?;
240
241 for data in session_data {
243 segment.extend_from_slice(&(data.len() as u64).to_le_bytes());
244 segment.extend_from_slice(&data);
245 }
246
247 Ok(segment)
248 }
249
250 pub fn read_segment(data: &[u8]) -> Result<Vec<ReplaySession>> {
252 let mut cursor = std::io::Cursor::new(data);
253 let header = ReplaySegmentHeader::read(&mut cursor)?;
254
255 let mut sessions = Vec::with_capacity(header.session_count as usize);
256 for _ in 0..header.session_count {
257 let mut len_bytes = [0u8; 8];
258 cursor.read_exact(&mut len_bytes)?;
259 let len = u64::from_le_bytes(len_bytes) as usize;
260
261 let mut session_data = vec![0u8; len];
262 cursor.read_exact(&mut session_data)?;
263
264 let session = deserialize_session(&session_data)?;
265 sessions.push(session);
266 }
267
268 Ok(sessions)
269 }
270
271 pub const ACTIVE_SESSION_MAGIC: &[u8; 8] = b"MV2ACTIV";
273
274 pub fn serialize_active_session(session: &super::ActiveSession) -> Result<Vec<u8>> {
276 let mut data = Vec::new();
277 data.extend_from_slice(ACTIVE_SESSION_MAGIC);
278 let session_bytes = bincode::serde::encode_to_vec(session, bincode_config())
279 .map_err(|e| MemvidError::InvalidToc {
280 reason: format!("Failed to serialize active session: {}", e).into(),
281 })?;
282 data.extend_from_slice(&(session_bytes.len() as u64).to_le_bytes());
283 data.extend_from_slice(&session_bytes);
284 Ok(data)
285 }
286
287 pub fn deserialize_active_session(data: &[u8]) -> Result<super::ActiveSession> {
289 if data.len() < 16 {
290 return Err(MemvidError::InvalidToc {
291 reason: "Active session data too short".into(),
292 });
293 }
294 if &data[0..8] != ACTIVE_SESSION_MAGIC {
295 return Err(MemvidError::InvalidToc {
296 reason: "Invalid active session magic".into(),
297 });
298 }
299 let len = u64::from_le_bytes(data[8..16].try_into().unwrap()) as usize;
300 if data.len() < 16 + len {
301 return Err(MemvidError::InvalidToc {
302 reason: "Active session data truncated".into(),
303 });
304 }
305 bincode::serde::decode_from_slice(&data[16..16 + len], bincode_config())
306 .map(|(session, _)| session)
307 .map_err(|e| MemvidError::InvalidToc {
308 reason: format!("Failed to deserialize active session: {}", e).into(),
309 })
310 }
311}
312
313#[cfg(test)]
314mod tests {
315 use super::*;
316
317 #[test]
318 fn test_active_session() {
319 let mut active = ActiveSession::new(
320 Some("Test".to_string()),
321 ReplayConfig {
322 auto_checkpoint_interval: 2,
323 ..Default::default()
324 },
325 );
326
327 assert!(!active.should_checkpoint());
328
329 active.record_action(ReplayAction::new(0, ActionType::Put { frame_id: 1 }));
330 assert!(!active.should_checkpoint());
331
332 active.record_action(ReplayAction::new(1, ActionType::Put { frame_id: 2 }));
333 assert!(active.should_checkpoint());
334
335 let checkpoint = active.create_checkpoint(StateSnapshot::default());
336 assert_eq!(checkpoint.id, 0);
337 assert!(!active.should_checkpoint());
338
339 let session = active.end();
340 assert!(!session.is_recording());
341 assert_eq!(session.actions.len(), 2);
342 assert_eq!(session.checkpoints.len(), 1);
343 }
344
345 #[test]
346 fn test_segment_roundtrip() {
347 let mut session1 = ReplaySession::new(Some("Session 1".to_string()));
348 session1.add_action(ReplayAction::new(0, ActionType::Put { frame_id: 1 }));
349 session1.end();
350
351 let mut session2 = ReplaySession::new(Some("Session 2".to_string()));
352 session2.add_action(ReplayAction::new(0, ActionType::Find {
353 query: "test".into(),
354 mode: "lexical".into(),
355 result_count: 5,
356 }));
357 session2.end();
358
359 let segment = storage::build_segment(&[session1.clone(), session2.clone()]).unwrap();
360 let restored = storage::read_segment(&segment).unwrap();
361
362 assert_eq!(restored.len(), 2);
363 assert_eq!(restored[0].session_id, session1.session_id);
364 assert_eq!(restored[1].session_id, session2.session_id);
365 }
366}