Skip to main content

dk_protocol/
connect.rs

1use tonic::{Response, Status};
2use tracing::{info, warn};
3
4use dk_engine::workspace::session_workspace::WorkspaceMode;
5
6use crate::server::ProtocolServer;
7use crate::{
8    ActiveSessionSummary, CodebaseSummary, ConnectRequest, ConnectResponse,
9    WorkspaceConcurrencyInfo,
10};
11
12/// Handle a CONNECT RPC.
13///
14/// 1. Validates the bearer token.
15/// 2. Looks up the repository by name.
16/// 3. Retrieves a high-level codebase summary (languages, symbol count, file count).
17/// 4. Reads the HEAD commit hash as the current codebase version.
18/// 5. Creates a stateful session and returns the session ID.
19/// 6. Creates a session workspace (isolated overlay for file changes).
20/// 7. Returns workspace ID and concurrency info.
21pub async fn handle_connect(
22    server: &ProtocolServer,
23    req: ConnectRequest,
24) -> Result<Response<ConnectResponse>, Status> {
25    // 1. Auth
26    let _authed_agent_id = server.validate_auth(&req.auth_token)?;
27
28    // Check for session resume.
29    //
30    // `take_snapshot` consumes the snapshot so it cannot be reused by a stale
31    // reconnect.  When a valid snapshot is found, its `codebase_version` is
32    // used as the default base_commit so the resumed session starts from the
33    // same commit the old session was on.
34    let mut resumed_snapshot: Option<crate::session::SessionSnapshot> = None;
35    if let Some(ref ws_config) = req.workspace_config {
36        if let Some(ref resume_id_str) = ws_config.resume_session_id {
37            match resume_id_str.parse::<uuid::Uuid>() {
38                Ok(resume_id) => {
39                    match server.session_mgr().take_snapshot(&resume_id) {
40                        Some(snapshot) => {
41                            if snapshot.codebase != req.codebase {
42                                return Err(Status::invalid_argument(format!(
43                                    "Cannot resume session from codebase '{}' into '{}'",
44                                    snapshot.codebase, req.codebase
45                                )));
46                            }
47                            info!(
48                                resume_from = %resume_id,
49                                agent_id = %snapshot.agent_id,
50                                base_version = %snapshot.codebase_version,
51                                "CONNECT: resuming from previous session snapshot"
52                            );
53                            resumed_snapshot = Some(snapshot);
54                        }
55                        None => {
56                            warn!(
57                                resume_session_id = %resume_id,
58                                "CONNECT: resume requested but no snapshot found \
59                                 (session may have expired beyond snapshot TTL)"
60                            );
61                        }
62                    }
63                }
64                Err(_) => {
65                    return Err(Status::invalid_argument(format!(
66                        "resume_session_id '{}' is not a valid UUID",
67                        resume_id_str
68                    )));
69                }
70            }
71        }
72    }
73
74    // Extract the requested base_commit early so it can be validated during
75    // the initial repo lookup (avoids a redundant `get_repo` call later).
76    // If resuming, default to the snapshot's codebase_version so the
77    // workspace starts from the same commit the old session was on.
78    let requested_base_commit = req
79        .workspace_config
80        .as_ref()
81        .and_then(|c| c.base_commit.clone())
82        .or_else(|| resumed_snapshot.as_ref().map(|s| s.codebase_version.clone()));
83
84    // 2-4. Resolve repo, get summary, read HEAD version, and validate
85    //      base_commit if one was provided.  Everything involving
86    //      `GitRepository` (which is !Sync) is scoped inside a block so
87    //      the future remains Send.
88    let engine = server.engine();
89
90    let (repo_id, version, summary) = {
91        let (repo_id, git_repo) = engine
92            .get_repo(&req.codebase)
93            .await
94            .map_err(|e| Status::not_found(format!("Repository not found: {e}")))?;
95
96        // HEAD commit hash (or "initial" for empty repos).
97        let version = git_repo
98            .head_hash()
99            .map_err(|e| Status::internal(format!("Failed to read HEAD: {e}")))?
100            .unwrap_or_else(|| "initial".to_string());
101
102        // Validate custom base_commit while we still hold git_repo, avoiding
103        // a second `get_repo` call.
104        if let Some(ref base) = requested_base_commit {
105            if base != &version && base != "initial" {
106                git_repo
107                    .list_tree_files(base)
108                    .map_err(|_| {
109                        Status::invalid_argument(format!(
110                            "base_commit '{base}' does not resolve to a valid commit"
111                        ))
112                    })?;
113            }
114        }
115
116        // Drop git_repo before the next .await to keep the future Send.
117        drop(git_repo);
118
119        let summary = engine
120            .codebase_summary(repo_id)
121            .await
122            .map_err(|e| Status::internal(format!("Failed to get summary: {e}")))?;
123
124        (repo_id, version, summary)
125    };
126
127    // 5. Create session (session_mgr is lock-free / DashMap-based).
128    let session_id = server.session_mgr().create_session(
129        req.agent_id.clone(),
130        req.codebase.clone(),
131        req.intent.clone(),
132        version.clone(),
133    );
134
135    // 5a. Resolve agent name: use provided name or auto-assign.
136    let agent_name = if req.agent_name.is_empty() {
137        engine.workspace_manager().next_agent_name(&repo_id)
138    } else {
139        req.agent_name.clone()
140    };
141
142    // 5b. Create a changeset (staging area for file changes).
143    let changeset = engine
144        .changeset_store()
145        .create(repo_id, Some(session_id), &req.agent_id, &req.intent, Some(&version), &agent_name)
146        .await
147        .map_err(|e| Status::internal(format!("failed to create changeset: {e}")))?;
148
149    // 6. Determine workspace mode from request config
150    let ws_mode = match req.workspace_config.as_ref().map(|c| c.mode()) {
151        Some(crate::WorkspaceMode::Persistent) => WorkspaceMode::Persistent { expires_at: None },
152        _ => WorkspaceMode::Ephemeral,
153    };
154
155    // Use the provided base_commit or default to current HEAD version
156    let base_commit = requested_base_commit.unwrap_or_else(|| version.clone());
157
158    // Create the session workspace
159    let workspace_id = engine
160        .workspace_manager()
161        .create_workspace(
162            session_id,
163            repo_id,
164            req.agent_id.clone(),
165            changeset.id,
166            req.intent.clone(),
167            base_commit,
168            ws_mode,
169            agent_name.clone(),
170        )
171        .await
172        .map_err(|e| Status::internal(format!("failed to create workspace: {e}")))?;
173
174    // 7. Build concurrency info
175    let other_session_ids = engine
176        .workspace_manager()
177        .active_sessions_for_repo(repo_id, Some(session_id));
178
179    let mut other_sessions = Vec::new();
180    for other_sid in &other_session_ids {
181        if let Some(other_ws) = engine.workspace_manager().get_workspace(other_sid) {
182            // Gather just the paths (avoids cloning file content)
183            let active_files: Vec<String> = other_ws.overlay.list_paths();
184
185            other_sessions.push(ActiveSessionSummary {
186                agent_id: other_ws.agent_id.clone(),
187                intent: other_ws.intent.clone(),
188                active_files,
189            });
190        }
191    }
192
193    let concurrency = WorkspaceConcurrencyInfo {
194        active_sessions: (other_session_ids.len() + 1) as u32, // include this session
195        other_sessions,
196    };
197
198    info!(
199        session_id = %session_id,
200        changeset_id = %changeset.id,
201        workspace_id = %workspace_id,
202        agent_id = %req.agent_id,
203        agent_name = %agent_name,
204        codebase = %req.codebase,
205        active_sessions = concurrency.active_sessions,
206        "CONNECT: session, changeset, and workspace created"
207    );
208
209    Ok(Response::new(ConnectResponse {
210        session_id: session_id.to_string(),
211        codebase_version: version,
212        summary: Some(CodebaseSummary {
213            languages: summary.languages,
214            total_symbols: summary.total_symbols,
215            total_files: summary.total_files,
216        }),
217        changeset_id: changeset.id.to_string(),
218        workspace_id: workspace_id.to_string(),
219        concurrency: Some(concurrency),
220    }))
221}