1#![allow(clippy::unwrap_used)]
3mod engine;
40mod types;
41
42pub use engine::{
43 ActionDiff, ActionReplayResult, ReplayEngine, ReplayExecutionConfig,
44 ReplayResult as EngineReplayResult, SessionComparison,
45};
46pub use types::{
47 ActionType, Checkpoint, ComparisonReport, ComparisonSummary, Divergence, DivergenceType,
48 ModelResult, REPLAY_SEGMENT_MAGIC, REPLAY_SEGMENT_VERSION, ReplayAction, ReplayManifest,
49 ReplayOptions, ReplayResult, ReplaySession, SessionSummary, StateSnapshot,
50};
51
52use crate::MemvidError;
53use crate::error::Result;
54use uuid::Uuid;
55
56#[derive(Clone, Debug, Default, serde::Serialize, serde::Deserialize)]
58pub struct ReplayConfig {
59 pub auto_checkpoint_interval: u64,
61 pub max_actions_per_session: Option<u64>,
63 pub auto_record: bool,
65}
66
67#[derive(Debug, serde::Serialize, serde::Deserialize)]
69pub struct ActiveSession {
70 pub session: ReplaySession,
72 pub next_checkpoint_id: u64,
74 pub actions_since_checkpoint: u64,
76 pub config: ReplayConfig,
78}
79
80impl ActiveSession {
81 #[must_use]
83 pub fn new(name: Option<String>, config: ReplayConfig) -> Self {
84 Self {
85 session: ReplaySession::new(name),
86 next_checkpoint_id: 0,
87 actions_since_checkpoint: 0,
88 config,
89 }
90 }
91
92 pub fn record_action(&mut self, action: ReplayAction) {
94 self.session.add_action(action);
95 self.actions_since_checkpoint += 1;
96 }
97
98 #[must_use]
100 pub fn should_checkpoint(&self) -> bool {
101 self.config.auto_checkpoint_interval > 0
102 && self.actions_since_checkpoint >= self.config.auto_checkpoint_interval
103 }
104
105 pub fn create_checkpoint(&mut self, snapshot: StateSnapshot) -> Checkpoint {
107 let checkpoint = Checkpoint::new(
108 self.next_checkpoint_id,
109 self.session.next_sequence().saturating_sub(1),
110 snapshot,
111 );
112 self.session.add_checkpoint(checkpoint.clone());
113 self.next_checkpoint_id += 1;
114 self.actions_since_checkpoint = 0;
115 checkpoint
116 }
117
118 #[must_use]
120 pub fn end(mut self) -> ReplaySession {
121 self.session.end();
122 self.session
123 }
124
125 #[must_use]
127 pub fn session_id(&self) -> Uuid {
128 self.session.session_id
129 }
130}
131
132pub mod storage {
134 use super::{MemvidError, REPLAY_SEGMENT_MAGIC, REPLAY_SEGMENT_VERSION, ReplaySession, Result};
135 use bincode::config::{self, Config};
136 use std::io::{Read, Write};
137
138 fn bincode_config() -> impl Config {
139 config::standard()
140 .with_fixed_int_encoding()
141 .with_little_endian()
142 }
143
144 #[derive(Debug)]
146 pub struct ReplaySegmentHeader {
147 pub magic: [u8; 8],
148 pub version: u32,
149 pub session_count: u32,
150 pub total_size: u64,
151 }
152
153 impl ReplaySegmentHeader {
154 #[must_use]
156 pub fn new(session_count: u32, total_size: u64) -> Self {
157 Self {
158 magic: *REPLAY_SEGMENT_MAGIC,
159 version: REPLAY_SEGMENT_VERSION,
160 session_count,
161 total_size,
162 }
163 }
164
165 pub fn write<W: Write>(&self, writer: &mut W) -> Result<()> {
167 writer.write_all(&self.magic)?;
168 writer.write_all(&self.version.to_le_bytes())?;
169 writer.write_all(&self.session_count.to_le_bytes())?;
170 writer.write_all(&self.total_size.to_le_bytes())?;
171 Ok(())
172 }
173
174 pub fn read<R: Read>(reader: &mut R) -> Result<Self> {
176 let mut magic = [0u8; 8];
177 reader.read_exact(&mut magic)?;
178 if &magic != REPLAY_SEGMENT_MAGIC {
179 return Err(MemvidError::InvalidToc {
180 reason: "Invalid replay segment magic".into(),
181 });
182 }
183
184 let mut version_bytes = [0u8; 4];
185 reader.read_exact(&mut version_bytes)?;
186 let version = u32::from_le_bytes(version_bytes);
187
188 let mut session_count_bytes = [0u8; 4];
189 reader.read_exact(&mut session_count_bytes)?;
190 let session_count = u32::from_le_bytes(session_count_bytes);
191
192 let mut total_size_bytes = [0u8; 8];
193 reader.read_exact(&mut total_size_bytes)?;
194 let total_size = u64::from_le_bytes(total_size_bytes);
195
196 Ok(Self {
197 magic,
198 version,
199 session_count,
200 total_size,
201 })
202 }
203
204 pub const SIZE: usize = 8 + 4 + 4 + 8; }
207
208 pub fn serialize_session(session: &ReplaySession) -> Result<Vec<u8>> {
210 bincode::serde::encode_to_vec(session, bincode_config()).map_err(|e| {
211 MemvidError::InvalidToc {
212 reason: format!("Failed to serialize replay session: {e}").into(),
213 }
214 })
215 }
216
217 pub fn deserialize_session(data: &[u8]) -> Result<ReplaySession> {
219 bincode::serde::decode_from_slice(data, bincode_config())
220 .map(|(session, _)| session)
221 .map_err(|e| MemvidError::InvalidToc {
222 reason: format!("Failed to deserialize replay session: {e}").into(),
223 })
224 }
225
226 pub fn build_segment(sessions: &[ReplaySession]) -> Result<Vec<u8>> {
228 let mut session_data: Vec<Vec<u8>> = Vec::with_capacity(sessions.len());
229 let mut total_session_bytes = 0u64;
230
231 for session in sessions {
232 let data = serialize_session(session)?;
233 total_session_bytes += data.len() as u64 + 8; session_data.push(data);
235 }
236
237 let header = ReplaySegmentHeader::new(
238 u32::try_from(sessions.len()).unwrap_or(u32::MAX),
239 ReplaySegmentHeader::SIZE as u64 + total_session_bytes,
240 );
241
242 let mut segment = Vec::with_capacity(usize::try_from(header.total_size).unwrap_or(0));
243 header.write(&mut segment)?;
244
245 for data in session_data {
247 segment.extend_from_slice(&(data.len() as u64).to_le_bytes());
248 segment.extend_from_slice(&data);
249 }
250
251 Ok(segment)
252 }
253
254 pub fn read_segment(data: &[u8]) -> Result<Vec<ReplaySession>> {
256 let mut cursor = std::io::Cursor::new(data);
257 let header = ReplaySegmentHeader::read(&mut cursor)?;
258
259 let mut sessions = Vec::with_capacity(header.session_count as usize);
260 for _ in 0..header.session_count {
261 let mut len_bytes = [0u8; 8];
262 cursor.read_exact(&mut len_bytes)?;
263 let len = usize::try_from(u64::from_le_bytes(len_bytes)).unwrap_or(0);
264
265 let mut session_data = vec![0u8; len];
266 cursor.read_exact(&mut session_data)?;
267
268 let session = deserialize_session(&session_data)?;
269 sessions.push(session);
270 }
271
272 Ok(sessions)
273 }
274
275 pub const ACTIVE_SESSION_MAGIC: &[u8; 8] = b"MV2ACTIV";
277
278 pub fn serialize_active_session(session: &super::ActiveSession) -> Result<Vec<u8>> {
280 let mut data = Vec::new();
281 data.extend_from_slice(ACTIVE_SESSION_MAGIC);
282 let session_bytes =
283 bincode::serde::encode_to_vec(session, bincode_config()).map_err(|e| {
284 MemvidError::InvalidToc {
285 reason: format!("Failed to serialize active session: {e}").into(),
286 }
287 })?;
288 data.extend_from_slice(&(session_bytes.len() as u64).to_le_bytes());
289 data.extend_from_slice(&session_bytes);
290 Ok(data)
291 }
292
293 pub fn deserialize_active_session(data: &[u8]) -> Result<super::ActiveSession> {
295 if data.len() < 16 {
296 return Err(MemvidError::InvalidToc {
297 reason: "Active session data too short".into(),
298 });
299 }
300 if &data[0..8] != ACTIVE_SESSION_MAGIC {
301 return Err(MemvidError::InvalidToc {
302 reason: "Invalid active session magic".into(),
303 });
304 }
305 let len = usize::try_from(u64::from_le_bytes(data[8..16].try_into().unwrap())).unwrap_or(0);
306 if data.len() < 16 + len {
307 return Err(MemvidError::InvalidToc {
308 reason: "Active session data truncated".into(),
309 });
310 }
311 bincode::serde::decode_from_slice(&data[16..16 + len], bincode_config())
312 .map(|(session, _)| session)
313 .map_err(|e| MemvidError::InvalidToc {
314 reason: format!("Failed to deserialize active session: {e}").into(),
315 })
316 }
317}
318
319#[cfg(test)]
320mod tests {
321 use super::*;
322
323 #[test]
324 fn test_active_session() {
325 let mut active = ActiveSession::new(
326 Some("Test".to_string()),
327 ReplayConfig {
328 auto_checkpoint_interval: 2,
329 ..Default::default()
330 },
331 );
332
333 assert!(!active.should_checkpoint());
334
335 active.record_action(ReplayAction::new(0, ActionType::Put { frame_id: 1 }));
336 assert!(!active.should_checkpoint());
337
338 active.record_action(ReplayAction::new(1, ActionType::Put { frame_id: 2 }));
339 assert!(active.should_checkpoint());
340
341 let checkpoint = active.create_checkpoint(StateSnapshot::default());
342 assert_eq!(checkpoint.id, 0);
343 assert!(!active.should_checkpoint());
344
345 let session = active.end();
346 assert!(!session.is_recording());
347 assert_eq!(session.actions.len(), 2);
348 assert_eq!(session.checkpoints.len(), 1);
349 }
350
351 #[test]
352 fn test_segment_roundtrip() {
353 let mut session1 = ReplaySession::new(Some("Session 1".to_string()));
354 session1.add_action(ReplayAction::new(0, ActionType::Put { frame_id: 1 }));
355 session1.end();
356
357 let mut session2 = ReplaySession::new(Some("Session 2".to_string()));
358 session2.add_action(ReplayAction::new(
359 0,
360 ActionType::Find {
361 query: "test".into(),
362 mode: "lexical".into(),
363 result_count: 5,
364 },
365 ));
366 session2.end();
367
368 let segment = storage::build_segment(&[session1.clone(), session2.clone()]).unwrap();
369 let restored = storage::read_segment(&segment).unwrap();
370
371 assert_eq!(restored.len(), 2);
372 assert_eq!(restored[0].session_id, session1.session_id);
373 assert_eq!(restored[1].session_id, session2.session_id);
374 }
375}