Skip to main content

rab/agent/session/
repo.rs

1use super::model::{
2    SessionInfo, delete_session as delete_session_file, fork_session, load_session_info,
3};
4use std::path::{Path, PathBuf};
5use std::sync::Arc;
6use std::sync::atomic::{AtomicUsize, Ordering};
7use std::sync::mpsc;
8
9/// Maximum number of concurrent session file loads (pi-compatible).
10const MAX_CONCURRENT_LOADS: usize = 10;
11
12/// Session lifecycle management: create, open, list, delete, fork.
13///
14/// Default implementation uses JSONL files on disk.
15pub trait SessionRepo {
16    /// List sessions in a directory, optionally filtered by cwd.
17    /// `progress` receives `(loaded_count, total_count)` for UI updates.
18    fn list(
19        &self,
20        session_dir: &Path,
21        filter_cwd: Option<&Path>,
22        progress: Option<&dyn Fn(usize, usize)>,
23    ) -> Vec<SessionInfo>;
24
25    /// List sessions across all project directories under `~/.rab/sessions/`.
26    fn list_all(&self, progress: Option<&dyn Fn(usize, usize)>) -> Vec<SessionInfo>;
27
28    /// Delete a session file.
29    fn delete(&self, path: &Path) -> std::io::Result<()>;
30
31    /// Fork a session: create a new session file containing entries up to (and including)
32    /// the given entry_id, or all entries if entry_id is None.
33    fn fork(
34        &self,
35        source_path: &Path,
36        target_dir: &Path,
37        entry_id: Option<&str>,
38        position: Option<&str>,
39    ) -> std::io::Result<String>;
40
41    /// Load metadata for a single session file.
42    fn load_info(&self, path: &Path) -> Option<SessionInfo>;
43}
44
45// ── Default JSONL-based repo ───────────────────────────────────────
46
47/// Default session repo backed by JSONL files.
48pub struct DefaultSessionRepo;
49
50impl Default for DefaultSessionRepo {
51    fn default() -> Self {
52        Self
53    }
54}
55
56impl DefaultSessionRepo {
57    pub fn new() -> Self {
58        Self
59    }
60}
61
62impl SessionRepo for DefaultSessionRepo {
63    fn list(
64        &self,
65        session_dir: &Path,
66        filter_cwd: Option<&Path>,
67        progress: Option<&dyn Fn(usize, usize)>,
68    ) -> Vec<SessionInfo> {
69        list_sessions(session_dir, filter_cwd, progress)
70    }
71
72    fn list_all(&self, progress: Option<&dyn Fn(usize, usize)>) -> Vec<SessionInfo> {
73        let dir = directories::BaseDirs::new()
74            .map(|d| d.home_dir().join(".rab").join("sessions"))
75            .unwrap_or_else(|| PathBuf::from("/tmp/.rab/sessions"));
76
77        let mut all_sessions: Vec<SessionInfo> = Vec::new();
78
79        // Collect all session dirs + root
80        let mut dirs = vec![dir.clone()];
81        if let Ok(read_dir) = std::fs::read_dir(&dir) {
82            for entry in read_dir.flatten() {
83                let path = entry.path();
84                if path.is_dir() {
85                    dirs.push(path);
86                }
87            }
88        }
89
90        let total_dirs = dirs.len();
91        let mut loaded = 0;
92
93        for session_dir in &dirs {
94            let sessions = list_sessions_concurrent(session_dir, None);
95            loaded += 1;
96            if let Some(ref cb) = progress {
97                cb(loaded, total_dirs);
98            }
99            all_sessions.extend(sessions);
100        }
101
102        all_sessions.sort_by_key(|b| std::cmp::Reverse(b.created));
103        all_sessions
104    }
105
106    fn delete(&self, path: &Path) -> std::io::Result<()> {
107        delete_session_file(path)
108    }
109
110    fn fork(
111        &self,
112        source_path: &Path,
113        target_dir: &Path,
114        entry_id: Option<&str>,
115        position: Option<&str>,
116    ) -> std::io::Result<String> {
117        fork_session(source_path, target_dir, entry_id, position)
118    }
119
120    fn load_info(&self, path: &Path) -> Option<SessionInfo> {
121        load_session_info(path)
122    }
123}
124
125// ── Sequential listing (used by `list`) ────────────────────────────
126
127/// List session files sequentially with optional cwd filtering and progress callback.
128/// Uses the public `list_sessions` from `session.rs` for the core listing.
129fn list_sessions(
130    session_dir: &Path,
131    filter_cwd: Option<&Path>,
132    progress: Option<&dyn Fn(usize, usize)>,
133) -> Vec<SessionInfo> {
134    let sessions = crate::agent::session::list_sessions(session_dir);
135    let total = sessions.len();
136    let mut loaded = 0;
137
138    let filtered: Vec<SessionInfo> = sessions
139        .into_iter()
140        .filter(|s| {
141            loaded += 1;
142            if let Some(ref cb) = progress {
143                cb(loaded, total);
144            }
145            if let Some(filter) = filter_cwd {
146                s.cwd == filter.to_string_lossy().as_ref()
147            } else {
148                true
149            }
150        })
151        .collect();
152
153    filtered
154}
155
156// ── Concurrent listing (used by `list_all`) ────────────────────────
157
158/// List session files with concurrent loading (pi-compatible: up to 10 workers).
159/// Uses a channel to collect results; the calling thread gathers them.
160fn list_sessions_concurrent(session_dir: &Path, filter_cwd: Option<&Path>) -> Vec<SessionInfo> {
161    let dir = match std::fs::read_dir(session_dir) {
162        Ok(d) => d,
163        Err(_) => return vec![],
164    };
165
166    let file_paths: Vec<PathBuf> = dir
167        .flatten()
168        .filter(|e| e.path().extension().is_some_and(|ext| ext == "jsonl"))
169        .map(|e| e.path())
170        .collect();
171
172    let total = file_paths.len();
173    if total == 0 {
174        return vec![];
175    }
176
177    // For a single file, avoid threading overhead
178    if total == 1 {
179        let mut sessions = Vec::new();
180        if let Some(info) = load_session_info(&file_paths[0]) {
181            sessions.push(info);
182        }
183        return sessions;
184    }
185
186    let (tx, rx) = mpsc::channel::<Option<SessionInfo>>();
187    let next_index = Arc::new(AtomicUsize::new(0));
188    let filter_cwd_owned = Arc::new(filter_cwd.map(|p| p.to_path_buf()));
189    let file_paths = Arc::new(file_paths);
190
191    let worker_count = MAX_CONCURRENT_LOADS.min(total);
192
193    std::thread::scope(|scope| {
194        for _ in 0..worker_count {
195            let tx = tx.clone();
196            let next_index = Arc::clone(&next_index);
197            let filter_cwd_owned = Arc::clone(&filter_cwd_owned);
198            let file_paths = Arc::clone(&file_paths);
199
200            scope.spawn(move || {
201                loop {
202                    let idx = next_index.fetch_add(1, Ordering::Relaxed);
203                    if idx >= total {
204                        break;
205                    }
206
207                    let path = &file_paths[idx];
208
209                    // Quick cwd filter check
210                    let header = crate::agent::session::read_session_header(path);
211                    if let Some(ref h) = header
212                        && let Some(ref filter) = *filter_cwd_owned
213                        && h.cwd != filter.to_string_lossy().as_ref()
214                    {
215                        let _ = tx.send(None);
216                        continue;
217                    }
218
219                    let info = load_session_info(path);
220                    let _ = tx.send(info);
221                }
222            });
223        }
224        // Drop the original tx so rx doesn't block forever
225        drop(tx);
226    });
227
228    let mut sessions: Vec<SessionInfo> = Vec::with_capacity(total);
229    for info in rx.into_iter().flatten() {
230        sessions.push(info);
231    }
232
233    sessions.sort_by_key(|b| std::cmp::Reverse(b.created));
234    sessions
235}
236
237#[cfg(test)]
238mod tests {
239    use super::super::model::SessionManager;
240    use super::*;
241    use crate::agent::types::{assistant_message, user_message};
242    use tempfile::TempDir;
243
244    fn make_user_msg(content: &str) -> yoagent::types::AgentMessage {
245        user_message(content)
246    }
247
248    fn make_asst_msg(content: &str) -> yoagent::types::AgentMessage {
249        assistant_message(content)
250    }
251
252    #[test]
253    fn test_list_empty_dir() {
254        let repo = DefaultSessionRepo::new();
255        let tmp = TempDir::new().unwrap();
256        let sessions = repo.list(tmp.path(), None, None);
257        assert!(sessions.is_empty());
258    }
259
260    #[test]
261    fn test_list_sessions_concurrent_with_files() {
262        let tmp = TempDir::new().unwrap();
263        let sessions_dir = tmp.path().join("sessions");
264        let cwd = tmp.path().join("project");
265        std::fs::create_dir_all(&cwd).unwrap();
266
267        // Create a few session files
268        for i in 0..3 {
269            let mut sm = SessionManager::create(&cwd, Some(&sessions_dir));
270            sm.append_message(&make_user_msg(&format!("msg {}", i)));
271            sm.append_message(&make_asst_msg(&format!("response {}", i)));
272            drop(sm);
273        }
274
275        let sessions = list_sessions_concurrent(&sessions_dir, None);
276        assert_eq!(sessions.len(), 3);
277    }
278
279    #[test]
280    fn test_list_sessions_concurrent_empty_dir() {
281        let tmp = TempDir::new().unwrap();
282        let sessions = list_sessions_concurrent(tmp.path(), None);
283        assert!(sessions.is_empty());
284    }
285
286    #[test]
287    fn test_list_sessions_concurrent_single_file() {
288        let tmp = TempDir::new().unwrap();
289        let sessions_dir = tmp.path().join("sessions");
290        let cwd = tmp.path().join("project");
291        std::fs::create_dir_all(&cwd).unwrap();
292
293        let mut sm = SessionManager::create(&cwd, Some(&sessions_dir));
294        sm.append_message(&make_user_msg("only"));
295        sm.append_message(&make_asst_msg("one"));
296        drop(sm);
297
298        let sessions = list_sessions_concurrent(&sessions_dir, None);
299        assert_eq!(sessions.len(), 1);
300    }
301
302    #[test]
303    fn test_list_sessions_concurrent_filter_cwd() {
304        let tmp = TempDir::new().unwrap();
305        let sessions_dir = tmp.path().join("sessions");
306        let cwd1 = tmp.path().join("project1");
307        let cwd2 = tmp.path().join("project2");
308        std::fs::create_dir_all(&cwd1).unwrap();
309        std::fs::create_dir_all(&cwd2).unwrap();
310
311        // Session in project1
312        let mut sm1 = SessionManager::create(&cwd1, Some(&sessions_dir));
313        sm1.append_message(&make_user_msg("p1 msg"));
314        sm1.append_message(&make_asst_msg("p1 resp"));
315        let _id1 = sm1.session().session_id().to_string();
316        drop(sm1);
317
318        // Session in project2
319        let mut sm2 = SessionManager::create(&cwd2, Some(&sessions_dir));
320        sm2.append_message(&make_user_msg("p2 msg"));
321        sm2.append_message(&make_asst_msg("p2 resp"));
322        drop(sm2);
323
324        // Filter by project1
325        let sessions = list_sessions_concurrent(&sessions_dir, Some(&cwd1));
326        assert_eq!(sessions.len(), 1);
327        assert!(sessions[0].cwd.ends_with("project1"));
328    }
329}