Skip to main content

dk_engine/workspace/
session_manager.rs

1//! WorkspaceManager — manages all active session workspaces.
2//!
3//! Provides creation, lookup, destruction, and garbage collection of
4//! workspaces. Uses `DashMap` for lock-free concurrent access from
5//! multiple agent sessions.
6
7use std::sync::atomic::{AtomicU32, Ordering};
8
9use dashmap::DashMap;
10use dk_core::{AgentId, RepoId, Result};
11use serde::Serialize;
12use sqlx::PgPool;
13use tokio::time::Instant;
14use uuid::Uuid;
15
16use crate::workspace::session_workspace::{
17    SessionId, SessionWorkspace, WorkspaceMode,
18};
19
20// ── SessionInfo ─────────────────────────────────────────────────────
21
22/// Lightweight snapshot of a session workspace, suitable for JSON serialization.
23#[derive(Debug, Clone, Serialize)]
24pub struct SessionInfo {
25    pub session_id: Uuid,
26    pub agent_id: String,
27    pub agent_name: String,
28    pub intent: String,
29    pub repo_id: Uuid,
30    pub changeset_id: Uuid,
31    pub state: String,
32    pub elapsed_secs: u64,
33}
34
35// ── WorkspaceManager ─────────────────────────────────────────────────
36
37/// Central registry of all active session workspaces.
38///
39/// Thread-safe via `DashMap`; every public method is either `&self` or
40/// returns a scoped reference guard.
41pub struct WorkspaceManager {
42    workspaces: DashMap<SessionId, SessionWorkspace>,
43    agent_counters: DashMap<Uuid, AtomicU32>,
44    db: PgPool,
45}
46
47impl WorkspaceManager {
48    /// Create a new, empty workspace manager.
49    pub fn new(db: PgPool) -> Self {
50        Self {
51            workspaces: DashMap::new(),
52            agent_counters: DashMap::new(),
53            db,
54        }
55    }
56
57    /// Auto-assign the next agent name for a repository.
58    ///
59    /// Returns "agent-1", "agent-2", etc. incrementing per repo.
60    pub fn next_agent_name(&self, repo_id: &Uuid) -> String {
61        let counter = self
62            .agent_counters
63            .entry(*repo_id)
64            .or_insert_with(|| AtomicU32::new(0));
65        let n = counter.value().fetch_add(1, Ordering::Relaxed) + 1;
66        format!("agent-{n}")
67    }
68
69    /// Create a new workspace for a session and register it.
70    #[allow(clippy::too_many_arguments)]
71    pub async fn create_workspace(
72        &self,
73        session_id: SessionId,
74        repo_id: RepoId,
75        agent_id: AgentId,
76        changeset_id: uuid::Uuid,
77        intent: String,
78        base_commit: String,
79        mode: WorkspaceMode,
80        agent_name: String,
81    ) -> Result<SessionId> {
82        let ws = SessionWorkspace::new(
83            session_id,
84            repo_id,
85            agent_id,
86            changeset_id,
87            intent,
88            base_commit,
89            mode,
90            agent_name,
91            self.db.clone(),
92        )
93        .await?;
94
95        self.workspaces.insert(session_id, ws);
96        Ok(session_id)
97    }
98
99    /// Get an immutable reference to a workspace.
100    pub fn get_workspace(
101        &self,
102        session_id: &SessionId,
103    ) -> Option<dashmap::mapref::one::Ref<'_, SessionId, SessionWorkspace>> {
104        self.workspaces.get(session_id)
105    }
106
107    /// Get a mutable reference to a workspace.
108    pub fn get_workspace_mut(
109        &self,
110        session_id: &SessionId,
111    ) -> Option<dashmap::mapref::one::RefMut<'_, SessionId, SessionWorkspace>> {
112        self.workspaces.get_mut(session_id)
113    }
114
115    /// Remove and drop a workspace.
116    pub fn destroy_workspace(&self, session_id: &SessionId) -> Option<SessionWorkspace> {
117        self.workspaces.remove(session_id).map(|(_, ws)| ws)
118    }
119
120    /// Count active workspaces for a specific repository.
121    pub fn active_count(&self, repo_id: RepoId) -> usize {
122        self.workspaces
123            .iter()
124            .filter(|entry| entry.value().repo_id == repo_id)
125            .count()
126    }
127
128    /// Return session IDs of all active workspaces for a repo,
129    /// optionally excluding one session.
130    pub fn active_sessions_for_repo(
131        &self,
132        repo_id: RepoId,
133        exclude_session: Option<SessionId>,
134    ) -> Vec<SessionId> {
135        self.workspaces
136            .iter()
137            .filter(|entry| {
138                entry.value().repo_id == repo_id
139                    && exclude_session.is_none_or(|ex| *entry.key() != ex)
140            })
141            .map(|entry| *entry.key())
142            .collect()
143    }
144
145    /// Garbage-collect expired persistent workspaces.
146    ///
147    /// Ephemeral workspaces are not GC'd here — they are destroyed when
148    /// the session disconnects. This only handles persistent workspaces
149    /// whose `expires_at` deadline has passed.
150    pub fn gc_expired(&self) -> Vec<SessionId> {
151        let now = Instant::now();
152        let mut expired = Vec::new();
153
154        // Collect IDs first to avoid holding DashMap guards during removal.
155        self.workspaces.iter().for_each(|entry| {
156            if let WorkspaceMode::Persistent {
157                expires_at: Some(deadline),
158            } = &entry.value().mode
159            {
160                if now >= *deadline {
161                    expired.push(*entry.key());
162                }
163            }
164        });
165
166        for sid in &expired {
167            self.workspaces.remove(sid);
168        }
169
170        expired
171    }
172
173    /// Destroy workspaces for sessions that no longer exist.
174    /// Call this when a session disconnects or during periodic cleanup.
175    pub fn cleanup_disconnected(&self, active_session_ids: &[uuid::Uuid]) {
176        let to_remove: Vec<uuid::Uuid> = self.workspaces.iter()
177            .filter(|entry| !active_session_ids.contains(entry.key()))
178            .map(|entry| *entry.key())
179            .collect();
180        for sid in to_remove {
181            self.workspaces.remove(&sid);
182        }
183    }
184
185    /// Remove workspaces that are idle beyond `idle_ttl` or alive beyond `max_ttl`.
186    ///
187    /// Returns the list of expired session IDs. This complements [`gc_expired`]
188    /// (which handles persistent workspace deadlines) by enforcing activity-based
189    /// and hard-maximum lifetime limits on **all** workspaces.
190    pub fn gc_expired_sessions(
191        &self,
192        idle_ttl: std::time::Duration,
193        max_ttl: std::time::Duration,
194    ) -> Vec<SessionId> {
195        let now = Instant::now();
196        let mut expired = Vec::new();
197
198        self.workspaces.retain(|_session_id, ws| {
199            let idle = now.duration_since(ws.last_active);
200            let total = now.duration_since(ws.created_at);
201
202            if idle > idle_ttl || total > max_ttl {
203                expired.push(ws.session_id);
204                false // remove
205            } else {
206                true // keep
207            }
208        });
209
210        expired
211    }
212
213    /// Insert a pre-built workspace (test-only).
214    ///
215    /// Allows unit tests to insert workspaces with manipulated timestamps
216    /// without requiring a live database connection.
217    #[doc(hidden)]
218    pub fn insert_test_workspace(&self, ws: SessionWorkspace) {
219        let sid = ws.session_id;
220        self.workspaces.insert(sid, ws);
221    }
222
223    /// Total number of active workspaces across all repos.
224    pub fn total_active(&self) -> usize {
225        self.workspaces.len()
226    }
227
228    /// Describe which other sessions have modified a given file.
229    ///
230    /// Returns a formatted string like `"fn create_task modified by agent-2"`
231    /// or `"modified by agent-2, agent-3"`. Returns an empty string if no
232    /// other session has touched the file.
233    pub fn describe_other_modifiers(
234        &self,
235        file_path: &str,
236        repo_id: RepoId,
237        exclude_session: SessionId,
238    ) -> String {
239        let mut parts: Vec<String> = Vec::new();
240
241        for entry in self.workspaces.iter() {
242            let ws = entry.value();
243            if ws.repo_id != repo_id || ws.session_id == exclude_session {
244                continue;
245            }
246
247            // Check if this other session has the file in its overlay
248            if !ws.overlay.list_paths().contains(&file_path.to_string()) {
249                continue;
250            }
251
252            // Get changed symbols for this file from the session graph
253            let symbols = ws.graph.changed_symbols_for_file(file_path);
254            let agent = &ws.agent_name;
255
256            if symbols.is_empty() {
257                parts.push(format!("modified by {agent}"));
258            } else {
259                // Take up to 3 symbol names to keep it concise
260                let sym_list: Vec<&str> = symbols.iter().take(3).map(|s| s.as_str()).collect();
261                let sym_str = sym_list.join(", ");
262                if symbols.len() > 3 {
263                    parts.push(format!("{sym_str},... modified by {agent}"));
264                } else {
265                    parts.push(format!("{sym_str} modified by {agent}"));
266                }
267            }
268        }
269
270        parts.join("; ")
271    }
272
273    /// List all active sessions for a given repository.
274    pub fn list_sessions(&self, repo_id: RepoId) -> Vec<SessionInfo> {
275        let now = Instant::now();
276        self.workspaces
277            .iter()
278            .filter(|entry| entry.value().repo_id == repo_id)
279            .map(|entry| {
280                let ws = entry.value();
281                SessionInfo {
282                    session_id: ws.session_id,
283                    agent_id: ws.agent_id.clone(),
284                    agent_name: ws.agent_name.clone(),
285                    intent: ws.intent.clone(),
286                    repo_id: ws.repo_id,
287                    changeset_id: ws.changeset_id,
288                    state: ws.state.as_str().to_string(),
289                    elapsed_secs: now.duration_since(ws.created_at).as_secs(),
290                }
291            })
292            .collect()
293    }
294}
295
296#[cfg(test)]
297mod tests {
298    use super::*;
299
300    #[test]
301    fn session_info_serializes_to_json() {
302        let info = SessionInfo {
303            session_id: Uuid::nil(),
304            agent_id: "test-agent".to_string(),
305            agent_name: "agent-1".to_string(),
306            intent: "fix bug".to_string(),
307            repo_id: Uuid::nil(),
308            changeset_id: Uuid::nil(),
309            state: "active".to_string(),
310            elapsed_secs: 42,
311        };
312
313        let json = serde_json::to_value(&info).expect("SessionInfo should serialize to JSON");
314
315        assert_eq!(json["agent_id"], "test-agent");
316        assert_eq!(json["agent_name"], "agent-1");
317        assert_eq!(json["intent"], "fix bug");
318        assert_eq!(json["state"], "active");
319        assert_eq!(json["elapsed_secs"], 42);
320        assert_eq!(
321            json["session_id"],
322            "00000000-0000-0000-0000-000000000000"
323        );
324    }
325
326    #[test]
327    fn session_info_all_fields_present_in_json() {
328        let info = SessionInfo {
329            session_id: Uuid::new_v4(),
330            agent_id: "claude".to_string(),
331            agent_name: "agent-1".to_string(),
332            intent: "refactor".to_string(),
333            repo_id: Uuid::new_v4(),
334            changeset_id: Uuid::new_v4(),
335            state: "submitted".to_string(),
336            elapsed_secs: 100,
337        };
338
339        let json = serde_json::to_value(&info).expect("serialize");
340        let obj = json.as_object().expect("should be an object");
341
342        let expected_keys = [
343            "session_id",
344            "agent_id",
345            "agent_name",
346            "intent",
347            "repo_id",
348            "changeset_id",
349            "state",
350            "elapsed_secs",
351        ];
352        for key in &expected_keys {
353            assert!(obj.contains_key(*key), "missing key: {}", key);
354        }
355        assert_eq!(obj.len(), expected_keys.len(), "unexpected extra keys in SessionInfo JSON");
356    }
357
358    #[test]
359    fn session_info_clone_preserves_values() {
360        let info = SessionInfo {
361            session_id: Uuid::new_v4(),
362            agent_id: "agent-1".to_string(),
363            agent_name: "feature-bot".to_string(),
364            intent: "deploy".to_string(),
365            repo_id: Uuid::new_v4(),
366            changeset_id: Uuid::new_v4(),
367            state: "active".to_string(),
368            elapsed_secs: 5,
369        };
370
371        let cloned = info.clone();
372        assert_eq!(info.session_id, cloned.session_id);
373        assert_eq!(info.agent_id, cloned.agent_id);
374        assert_eq!(info.agent_name, cloned.agent_name);
375        assert_eq!(info.intent, cloned.intent);
376        assert_eq!(info.repo_id, cloned.repo_id);
377        assert_eq!(info.changeset_id, cloned.changeset_id);
378        assert_eq!(info.state, cloned.state);
379        assert_eq!(info.elapsed_secs, cloned.elapsed_secs);
380    }
381
382    #[tokio::test]
383    async fn next_agent_name_increments_per_repo() {
384        let db = PgPool::connect_lazy("postgres://localhost/nonexistent").unwrap();
385        let mgr = WorkspaceManager::new(db);
386        let repo1 = Uuid::new_v4();
387        let repo2 = Uuid::new_v4();
388
389        assert_eq!(mgr.next_agent_name(&repo1), "agent-1");
390        assert_eq!(mgr.next_agent_name(&repo1), "agent-2");
391        assert_eq!(mgr.next_agent_name(&repo1), "agent-3");
392
393        // Different repo starts at 1
394        assert_eq!(mgr.next_agent_name(&repo2), "agent-1");
395        assert_eq!(mgr.next_agent_name(&repo2), "agent-2");
396
397        // Original repo continues
398        assert_eq!(mgr.next_agent_name(&repo1), "agent-4");
399    }
400
401    /// Integration-level test for list_sessions and WorkspaceManager.
402    /// Requires PgPool which we cannot construct without a DB, so this
403    /// is marked #[ignore]. Run with:
404    ///   DATABASE_URL=postgres://localhost/dkod_test cargo test -p dk-engine -- --ignored
405    #[test]
406    #[ignore]
407    fn list_sessions_returns_empty_for_unknown_repo() {
408        // This test would require a PgPool. The structural tests above
409        // validate SessionInfo independently.
410    }
411
412    #[tokio::test]
413    async fn describe_other_modifiers_empty_when_no_other_sessions() {
414        let db = PgPool::connect_lazy("postgres://localhost/nonexistent").unwrap();
415        let mgr = WorkspaceManager::new(db);
416        let repo_id = Uuid::new_v4();
417        let session_id = Uuid::new_v4();
418
419        let result = mgr.describe_other_modifiers("src/lib.rs", repo_id, session_id);
420        assert!(result.is_empty());
421    }
422
423    #[tokio::test]
424    async fn describe_other_modifiers_shows_agent_name() {
425        use crate::workspace::session_workspace::{SessionWorkspace, WorkspaceMode};
426
427        let db = PgPool::connect_lazy("postgres://localhost/nonexistent").unwrap();
428        let mgr = WorkspaceManager::new(db);
429        let repo_id = Uuid::new_v4();
430
431        let session1 = Uuid::new_v4();
432        let session2 = Uuid::new_v4();
433
434        let mut ws2 = SessionWorkspace::new_test(
435            session2,
436            repo_id,
437            "agent-2-id".to_string(),
438            "fix bug".to_string(),
439            "abc123".to_string(),
440            WorkspaceMode::Ephemeral,
441        );
442        ws2.agent_name = "agent-2".to_string();
443        ws2.overlay.write_local("src/lib.rs", b"content".to_vec(), false);
444
445        mgr.insert_test_workspace(ws2);
446
447        let result = mgr.describe_other_modifiers("src/lib.rs", repo_id, session1);
448        assert_eq!(result, "modified by agent-2");
449
450        let result2 = mgr.describe_other_modifiers("src/other.rs", repo_id, session1);
451        assert!(result2.is_empty());
452
453        let result3 = mgr.describe_other_modifiers("src/lib.rs", repo_id, session2);
454        assert!(result3.is_empty());
455    }
456
457    #[tokio::test]
458    async fn describe_other_modifiers_includes_symbols() {
459        use crate::workspace::session_workspace::{SessionWorkspace, WorkspaceMode};
460        use dk_core::{Span, Symbol, SymbolKind, Visibility};
461        use std::path::PathBuf;
462
463        let db = PgPool::connect_lazy("postgres://localhost/nonexistent").unwrap();
464        let mgr = WorkspaceManager::new(db);
465        let repo_id = Uuid::new_v4();
466
467        let session1 = Uuid::new_v4();
468        let session2 = Uuid::new_v4();
469
470        let mut ws2 = SessionWorkspace::new_test(
471            session2,
472            repo_id,
473            "agent-2-id".to_string(),
474            "add feature".to_string(),
475            "abc123".to_string(),
476            WorkspaceMode::Ephemeral,
477        );
478        ws2.agent_name = "agent-2".to_string();
479        ws2.overlay
480            .write_local("src/tasks.rs", b"fn create_task() {}".to_vec(), true);
481        ws2.graph.add_symbol(Symbol {
482            id: Uuid::new_v4(),
483            name: "create_task".to_string(),
484            qualified_name: "create_task".to_string(),
485            kind: SymbolKind::Function,
486            visibility: Visibility::Public,
487            file_path: PathBuf::from("src/tasks.rs"),
488            span: Span {
489                start_byte: 0,
490                end_byte: 20,
491            },
492            signature: None,
493            doc_comment: None,
494            parent: None,
495            last_modified_by: None,
496            last_modified_intent: None,
497        });
498
499        mgr.insert_test_workspace(ws2);
500
501        let result = mgr.describe_other_modifiers("src/tasks.rs", repo_id, session1);
502        assert_eq!(result, "create_task modified by agent-2");
503    }
504}