1use anyhow::{Context, Result};
2use rusqlite::Connection;
3use std::path::Path;
4
5use crate::db::{models::Snapshot, queries};
6
7#[derive(Debug, PartialEq)]
8pub enum RestoreAction {
9 Overwrite { path: String },
10 Create { path: String },
11 Delete { path: String },
12}
13
14pub fn restore_to_event(
15 conn: &Connection,
16 session_id: &str,
17 event_id: i64,
18) -> Result<Vec<RestoreAction>> {
19 let snapshots = queries::get_restore_targets(conn, session_id, event_id)?;
22 let mut actions = Vec::new();
23
24 for snap in &snapshots {
25 let path = Path::new(&snap.file_path);
26 let file_exists = path.exists();
27
28 match &snap.content_after {
29 Some(_) => {
30 if file_exists {
32 actions.push(RestoreAction::Overwrite {
33 path: snap.file_path.clone(),
34 });
35 } else {
36 actions.push(RestoreAction::Create {
37 path: snap.file_path.clone(),
38 });
39 }
40 }
41 None => {
42 if file_exists {
44 actions.push(RestoreAction::Delete {
45 path: snap.file_path.clone(),
46 });
47 }
48 }
49 }
50 }
51
52 Ok(actions)
53}
54
55pub fn execute_restore(conn: &Connection, session_id: &str, event_id: i64) -> Result<()> {
56 let snapshots = queries::get_restore_targets(conn, session_id, event_id)?;
57
58 let now = chrono::Utc::now().timestamp_millis();
59 let checkpoint_event = crate::db::models::Event {
60 id: 0,
61 session_id: session_id.to_string(),
62 timestamp: now,
63 event_type: "RestoreCheckpoint".to_string(),
64 tool_name: None,
65 tool_use_id: None,
66 agent_id: None,
67 agent_type: None,
68 input_json: Some(format!(r#"{{"restored_to_event_id":{event_id}}}"#).into_bytes()),
69 output_json: None,
70 };
71 let checkpoint_id = queries::insert_event(conn, &checkpoint_event)?;
72
73 for snap in &snapshots {
74 let current_content = std::fs::read(&snap.file_path).ok();
75 let checkpoint_snap = Snapshot {
76 id: 0,
77 event_id: checkpoint_id,
78 file_path: snap.file_path.clone(),
79 content_before: current_content
80 .as_ref()
81 .map(|c| zstd::encode_all(c.as_slice(), 3))
82 .transpose()?,
83 content_after: current_content
84 .map(|c| zstd::encode_all(c.as_slice(), 3))
85 .transpose()?,
86 diff_unified: "(checkpoint)".to_string(),
87 };
88 queries::insert_snapshot(conn, &checkpoint_snap)?;
89 }
90
91 let mut temp_files: Vec<(String, std::path::PathBuf)> = Vec::new();
93 let mut files_to_delete: Vec<String> = Vec::new();
94
95 for snap in &snapshots {
96 if let Some(ref compressed) = snap.content_after {
97 let content = zstd::decode_all(compressed.as_slice())?;
98 let path = Path::new(&snap.file_path);
99 if let Some(parent) = path.parent() {
100 std::fs::create_dir_all(parent)?;
101 }
102 let temp_path = path.with_extension("chronicle_tmp");
103 std::fs::write(&temp_path, &content)?;
104 temp_files.push((snap.file_path.clone(), temp_path));
105 } else if Path::new(&snap.file_path).exists() {
106 files_to_delete.push(snap.file_path.clone());
107 }
108 }
109
110 for (target, temp) in &temp_files {
112 std::fs::rename(temp, target)
113 .with_context(|| format!("Failed to rename {} to {}", temp.display(), target))?;
114 }
115
116 for path in &files_to_delete {
118 std::fs::remove_file(path)
119 .with_context(|| format!("Failed to delete {path}"))?;
120 }
121
122 Ok(())
123}
124
125#[cfg(test)]
126mod tests {
127 use super::*;
128 use crate::db::models::{Event, Session};
129 use crate::db::{queries, schema};
130
131 fn setup_db() -> Connection {
132 let conn = Connection::open_in_memory().unwrap();
133 schema::initialize(&conn).unwrap();
134 queries::upsert_session(
135 &conn,
136 &Session {
137 id: "s1".into(),
138 started_at: 1000,
139 ended_at: None,
140 cwd: "/tmp".into(),
141 model: None,
142 permission_mode: None,
143 },
144 )
145 .unwrap();
146 conn
147 }
148
149 #[test]
150 fn test_restore_plan_deletes_file_created_after_target() {
151 let conn = setup_db();
152
153 let eid1 = queries::insert_event(
155 &conn,
156 &Event {
157 id: 0,
158 session_id: "s1".into(),
159 timestamp: 1001,
160 event_type: "PostToolUse".into(),
161 tool_name: Some("Read".into()),
162 tool_use_id: Some("tu0".into()),
163 agent_id: None,
164 agent_type: None,
165 input_json: None,
166 output_json: None,
167 },
168 )
169 .unwrap();
170
171 let eid2 = queries::insert_event(
173 &conn,
174 &Event {
175 id: 0,
176 session_id: "s1".into(),
177 timestamp: 1002,
178 event_type: "PostToolUse".into(),
179 tool_name: Some("Write".into()),
180 tool_use_id: Some("tu1".into()),
181 agent_id: None,
182 agent_type: None,
183 input_json: None,
184 output_json: None,
185 },
186 )
187 .unwrap();
188
189 let tmp_path = std::env::temp_dir().join("chronicle_test_restore_delete.rs");
191 std::fs::write(&tmp_path, "fn main() {}").unwrap();
192 let tmp_path_str = tmp_path.to_str().unwrap().to_string();
193
194 queries::insert_snapshot(
195 &conn,
196 &Snapshot {
197 id: 0,
198 event_id: eid2,
199 file_path: tmp_path_str.clone(),
200 content_before: None,
201 content_after: Some(zstd::encode_all(b"fn main() {}".as_slice(), 3).unwrap()),
202 diff_unified: "+fn main() {}".into(),
203 },
204 )
205 .unwrap();
206
207 let actions = restore_to_event(&conn, "s1", eid1).unwrap();
209 assert_eq!(actions.len(), 1);
210 assert!(matches!(
211 &actions[0],
212 RestoreAction::Delete { path } if path == &tmp_path_str
213 ));
214
215 let actions2 = restore_to_event(&conn, "s1", eid2).unwrap();
217 assert!(actions2.is_empty());
218
219 let _ = std::fs::remove_file(&tmp_path);
220 }
221}