Skip to main content

lean_ctx/core/context_os/
shared_sessions.rs

1use std::collections::HashMap;
2use std::path::PathBuf;
3use std::sync::{Arc, Mutex};
4
5use tokio::sync::RwLock;
6
7use crate::core::project_hash;
8use crate::core::session::SessionState;
9
10#[derive(Debug, Clone, PartialEq, Eq, Hash)]
11pub struct SharedSessionKey {
12    pub project_hash: String,
13    pub workspace_id: String,
14    pub channel_id: String,
15}
16
17impl SharedSessionKey {
18    pub fn new(project_root: &str, workspace_id: &str, channel_id: &str) -> Self {
19        Self {
20            project_hash: project_hash::hash_project_root(project_root),
21            workspace_id: normalize_id(workspace_id, "default"),
22            channel_id: normalize_id(channel_id, "default"),
23        }
24    }
25}
26
27pub struct SharedSessionStore {
28    sessions: Mutex<HashMap<SharedSessionKey, Arc<RwLock<SessionState>>>>,
29}
30
31impl Default for SharedSessionStore {
32    fn default() -> Self {
33        Self {
34            sessions: Mutex::new(HashMap::new()),
35        }
36    }
37}
38
39impl SharedSessionStore {
40    pub fn new() -> Self {
41        Self::default()
42    }
43
44    pub fn get_or_load(
45        &self,
46        project_root: &str,
47        workspace_id: &str,
48        channel_id: &str,
49    ) -> Arc<RwLock<SessionState>> {
50        let key = SharedSessionKey::new(project_root, workspace_id, channel_id);
51
52        if let Some(existing) = self.sessions.lock().ok().and_then(|m| m.get(&key).cloned()) {
53            return existing;
54        }
55
56        let loaded = load_session_from_disk(project_root, &key)
57            .or_else(|| SessionState::load_latest_for_project_root(project_root))
58            .unwrap_or_default();
59
60        let mut loaded = loaded;
61        loaded.project_root = Some(project_root.to_string());
62
63        let arc = Arc::new(RwLock::new(loaded));
64        if let Ok(mut m) = self.sessions.lock() {
65            m.insert(key, arc.clone());
66        }
67        arc
68    }
69
70    pub fn persist_best_effort(
71        &self,
72        project_root: &str,
73        workspace_id: &str,
74        channel_id: &str,
75        session: &SessionState,
76    ) {
77        let key = SharedSessionKey::new(project_root, workspace_id, channel_id);
78        let Some(dir) = shared_session_dir(&key) else {
79            return;
80        };
81        let _ = std::fs::create_dir_all(&dir);
82        let state_path = dir.join("session.json");
83        let tmp = dir.join("session.json.tmp");
84
85        if let Ok(json) = serde_json::to_string_pretty(session) {
86            let _ = std::fs::write(&tmp, json);
87            let _ = std::fs::rename(&tmp, &state_path);
88        }
89
90        // Persist a compaction snapshot alongside the shared session (premium UX).
91        let snap = if session.task.is_some() {
92            Some(session.build_compaction_snapshot())
93        } else {
94            None
95        };
96        if let Some(snapshot) = snap {
97            let _ = std::fs::write(dir.join("snapshot.txt"), snapshot);
98        }
99    }
100}
101
102fn normalize_id(s: &str, fallback: &str) -> String {
103    let t = s.trim();
104    if t.is_empty() {
105        fallback.to_string()
106    } else {
107        // Keep IDs URL/header safe.
108        t.chars()
109            .filter(|c| c.is_ascii_alphanumeric() || *c == '-' || *c == '_' || *c == '.')
110            .collect::<String>()
111    }
112}
113
114fn shared_session_dir(key: &SharedSessionKey) -> Option<PathBuf> {
115    let data = crate::core::data_dir::lean_ctx_data_dir().ok()?;
116    Some(
117        data.join("context-os")
118            .join("sessions")
119            .join(&key.project_hash)
120            .join(&key.workspace_id)
121            .join(&key.channel_id),
122    )
123}
124
125fn load_session_from_disk(project_root: &str, key: &SharedSessionKey) -> Option<SessionState> {
126    let dir = shared_session_dir(key)?;
127    let state_path = dir.join("session.json");
128    let json = std::fs::read_to_string(&state_path).ok()?;
129    let mut session: SessionState = serde_json::from_str(&json).ok()?;
130    // Safety: enforce project_root from the current server.
131    session.project_root = Some(project_root.to_string());
132    Some(session)
133}
134
135#[cfg(test)]
136mod tests {
137    use super::*;
138
139    #[test]
140    fn normalize_id_filters_weird_chars() {
141        assert_eq!(normalize_id("  ", "x"), "x");
142        assert_eq!(normalize_id("abc-123_DEF", "x"), "abc-123_DEF");
143        assert_eq!(normalize_id("a b$c", "x"), "abc");
144    }
145
146    #[test]
147    fn key_is_stable() {
148        let k1 = SharedSessionKey::new("/tmp/proj", "ws", "ch");
149        let k2 = SharedSessionKey::new("/tmp/proj", "ws", "ch");
150        assert_eq!(k1, k2);
151    }
152
153    #[tokio::test]
154    async fn concurrent_session_access_no_data_race() {
155        let store = Arc::new(SharedSessionStore::new());
156        let n_tasks: usize = 8;
157
158        let mut handles = vec![];
159        for task_idx in 0..n_tasks {
160            let store = Arc::clone(&store);
161            handles.push(tokio::spawn(async move {
162                let project_root = "/tmp/test-concurrent";
163                for i in 0..10 {
164                    let session_arc = store.get_or_load(project_root, "ws-shared", "ch-shared");
165                    let mut s = session_arc.write().await;
166                    s.files_touched.push(crate::core::session::FileTouched {
167                        path: format!("file-{task_idx}-{i}.rs"),
168                        file_ref: None,
169                        read_count: 1,
170                        modified: false,
171                        last_mode: "full".to_string(),
172                        tokens: 0,
173                        stale: false,
174                        context_item_id: None,
175                    });
176                }
177            }));
178        }
179
180        for h in handles {
181            h.await.unwrap();
182        }
183
184        let final_arc = store.get_or_load("/tmp/test-concurrent", "ws-shared", "ch-shared");
185        let final_session = final_arc.read().await;
186        assert_eq!(
187            final_session.files_touched.len(),
188            n_tasks * 10,
189            "all concurrent mutations must be persisted"
190        );
191    }
192
193    #[tokio::test]
194    async fn different_workspace_channels_are_isolated() {
195        let store = SharedSessionStore::new();
196
197        {
198            let arc_a = store.get_or_load("/tmp/proj-iso", "ws-a", "ch-1");
199            arc_a
200                .write()
201                .await
202                .files_touched
203                .push(crate::core::session::FileTouched {
204                    path: "fileA.rs".to_string(),
205                    file_ref: None,
206                    read_count: 1,
207                    modified: false,
208                    last_mode: "full".to_string(),
209                    tokens: 0,
210                    stale: false,
211                    context_item_id: None,
212                });
213        }
214        {
215            let arc_b = store.get_or_load("/tmp/proj-iso", "ws-b", "ch-1");
216            arc_b
217                .write()
218                .await
219                .files_touched
220                .push(crate::core::session::FileTouched {
221                    path: "fileB.rs".to_string(),
222                    file_ref: None,
223                    read_count: 1,
224                    modified: false,
225                    last_mode: "full".to_string(),
226                    tokens: 0,
227                    stale: false,
228                    context_item_id: None,
229                });
230        }
231
232        let session_a = store.get_or_load("/tmp/proj-iso", "ws-a", "ch-1");
233        let session_b = store.get_or_load("/tmp/proj-iso", "ws-b", "ch-1");
234
235        assert_eq!(session_a.read().await.files_touched.len(), 1);
236        assert_eq!(session_a.read().await.files_touched[0].path, "fileA.rs");
237        assert_eq!(session_b.read().await.files_touched.len(), 1);
238        assert_eq!(session_b.read().await.files_touched[0].path, "fileB.rs");
239    }
240}