Skip to main content

chronicle/restore/
mod.rs

1use anyhow::{Context, Result};
2use rusqlite::Connection;
3use std::path::Path;
4
5use crate::db::{models::Snapshot, queries};
6
7#[derive(Debug, PartialEq)]
8pub enum RestoreAction {
9    Overwrite { path: String },
10    Create { path: String },
11    Delete { path: String },
12}
13
14pub fn restore_to_event(
15    conn: &Connection,
16    session_id: &str,
17    event_id: i64,
18) -> Result<Vec<RestoreAction>> {
19    // Get each file modified AFTER the target event, along with its desired
20    // state at the target event. content_after = desired state (None = should not exist).
21    let snapshots = queries::get_restore_targets(conn, session_id, event_id)?;
22    let mut actions = Vec::new();
23
24    for snap in &snapshots {
25        let path = Path::new(&snap.file_path);
26        let file_exists = path.exists();
27
28        match &snap.content_after {
29            Some(_) => {
30                // File should exist with this content at the target event
31                if file_exists {
32                    actions.push(RestoreAction::Overwrite {
33                        path: snap.file_path.clone(),
34                    });
35                } else {
36                    actions.push(RestoreAction::Create {
37                        path: snap.file_path.clone(),
38                    });
39                }
40            }
41            None => {
42                // File should not exist at the target event (created after it)
43                if file_exists {
44                    actions.push(RestoreAction::Delete {
45                        path: snap.file_path.clone(),
46                    });
47                }
48            }
49        }
50    }
51
52    Ok(actions)
53}
54
55pub fn execute_restore(conn: &Connection, session_id: &str, event_id: i64) -> Result<()> {
56    let snapshots = queries::get_restore_targets(conn, session_id, event_id)?;
57
58    let now = chrono::Utc::now().timestamp_millis();
59    let checkpoint_event = crate::db::models::Event {
60        id: 0,
61        session_id: session_id.to_string(),
62        timestamp: now,
63        event_type: "RestoreCheckpoint".to_string(),
64        tool_name: None,
65        tool_use_id: None,
66        agent_id: None,
67        agent_type: None,
68        input_json: Some(format!(r#"{{"restored_to_event_id":{event_id}}}"#).into_bytes()),
69        output_json: None,
70    };
71    let checkpoint_id = queries::insert_event(conn, &checkpoint_event)?;
72
73    for snap in &snapshots {
74        let current_content = std::fs::read(&snap.file_path).ok();
75        let checkpoint_snap = Snapshot {
76            id: 0,
77            event_id: checkpoint_id,
78            file_path: snap.file_path.clone(),
79            content_before: current_content
80                .as_ref()
81                .map(|c| zstd::encode_all(c.as_slice(), 3))
82                .transpose()?,
83            content_after: current_content
84                .map(|c| zstd::encode_all(c.as_slice(), 3))
85                .transpose()?,
86            diff_unified: "(checkpoint)".to_string(),
87        };
88        queries::insert_snapshot(conn, &checkpoint_snap)?;
89    }
90
91    // Phase 1: Write all restored files to temp paths
92    let mut temp_files: Vec<(String, std::path::PathBuf)> = Vec::new();
93    let mut files_to_delete: Vec<String> = Vec::new();
94
95    for snap in &snapshots {
96        if let Some(ref compressed) = snap.content_after {
97            let content = zstd::decode_all(compressed.as_slice())?;
98            let path = Path::new(&snap.file_path);
99            if let Some(parent) = path.parent() {
100                std::fs::create_dir_all(parent)?;
101            }
102            let temp_path = path.with_extension("chronicle_tmp");
103            std::fs::write(&temp_path, &content)?;
104            temp_files.push((snap.file_path.clone(), temp_path));
105        } else if Path::new(&snap.file_path).exists() {
106            files_to_delete.push(snap.file_path.clone());
107        }
108    }
109
110    // Phase 2: Rename all temp files into place
111    for (target, temp) in &temp_files {
112        std::fs::rename(temp, target)
113            .with_context(|| format!("Failed to rename {} to {}", temp.display(), target))?;
114    }
115
116    // Phase 3: Delete files that should no longer exist
117    for path in &files_to_delete {
118        std::fs::remove_file(path)
119            .with_context(|| format!("Failed to delete {path}"))?;
120    }
121
122    Ok(())
123}
124
125#[cfg(test)]
126mod tests {
127    use super::*;
128    use crate::db::models::{Event, Session};
129    use crate::db::{queries, schema};
130
131    fn setup_db() -> Connection {
132        let conn = Connection::open_in_memory().unwrap();
133        schema::initialize(&conn).unwrap();
134        queries::upsert_session(
135            &conn,
136            &Session {
137                id: "s1".into(),
138                started_at: 1000,
139                ended_at: None,
140                cwd: "/tmp".into(),
141                model: None,
142                permission_mode: None,
143            },
144        )
145        .unwrap();
146        conn
147    }
148
149    #[test]
150    fn test_restore_plan_deletes_file_created_after_target() {
151        let conn = setup_db();
152
153        // Event 1: some baseline event (no snapshots)
154        let eid1 = queries::insert_event(
155            &conn,
156            &Event {
157                id: 0,
158                session_id: "s1".into(),
159                timestamp: 1001,
160                event_type: "PostToolUse".into(),
161                tool_name: Some("Read".into()),
162                tool_use_id: Some("tu0".into()),
163                agent_id: None,
164                agent_type: None,
165                input_json: None,
166                output_json: None,
167            },
168        )
169        .unwrap();
170
171        // Event 2: creates a new file
172        let eid2 = queries::insert_event(
173            &conn,
174            &Event {
175                id: 0,
176                session_id: "s1".into(),
177                timestamp: 1002,
178                event_type: "PostToolUse".into(),
179                tool_name: Some("Write".into()),
180                tool_use_id: Some("tu1".into()),
181                agent_id: None,
182                agent_type: None,
183                input_json: None,
184                output_json: None,
185            },
186        )
187        .unwrap();
188
189        // Create a temp file to simulate it existing on disk
190        let tmp_path = std::env::temp_dir().join("chronicle_test_restore_delete.rs");
191        std::fs::write(&tmp_path, "fn main() {}").unwrap();
192        let tmp_path_str = tmp_path.to_str().unwrap().to_string();
193
194        queries::insert_snapshot(
195            &conn,
196            &Snapshot {
197                id: 0,
198                event_id: eid2,
199                file_path: tmp_path_str.clone(),
200                content_before: None,
201                content_after: Some(zstd::encode_all(b"fn main() {}".as_slice(), 3).unwrap()),
202                diff_unified: "+fn main() {}".into(),
203            },
204        )
205        .unwrap();
206
207        // Restoring to event 1 (before the file was created) should delete it
208        let actions = restore_to_event(&conn, "s1", eid1).unwrap();
209        assert_eq!(actions.len(), 1);
210        assert!(matches!(
211            &actions[0],
212            RestoreAction::Delete { path } if path == &tmp_path_str
213        ));
214
215        // Restoring to event 2 (when the file was created) should have no actions
216        let actions2 = restore_to_event(&conn, "s1", eid2).unwrap();
217        assert!(actions2.is_empty());
218
219        let _ = std::fs::remove_file(&tmp_path);
220    }
221}