Skip to main content

dk_protocol/
connect.rs

1use tonic::{Response, Status};
2use tracing::{info, warn};
3
4use dk_engine::workspace::session_workspace::WorkspaceMode;
5
6use crate::server::ProtocolServer;
7use crate::{
8    ActiveSessionSummary, CodebaseSummary, ConnectRequest, ConnectResponse,
9    WorkspaceConcurrencyInfo,
10};
11
12/// Handle a CONNECT RPC.
13///
14/// 1. Validates the bearer token.
15/// 2. Looks up the repository by name.
16/// 3. Retrieves a high-level codebase summary (languages, symbol count, file count).
17/// 4. Reads the HEAD commit hash as the current codebase version.
18/// 5. Creates a stateful session and returns the session ID.
19/// 6. Creates a session workspace (isolated overlay for file changes).
20/// 7. Returns workspace ID and concurrency info.
21pub async fn handle_connect(
22    server: &ProtocolServer,
23    req: ConnectRequest,
24) -> Result<Response<ConnectResponse>, Status> {
25    // 1. Auth
26    let _authed_agent_id = server.validate_auth(&req.auth_token)?;
27
28    // Check for session resume.
29    //
30    // `take_snapshot` consumes the snapshot so it cannot be reused by a stale
31    // reconnect.  When a valid snapshot is found, its `codebase_version` is
32    // used as the default base_commit so the resumed session starts from the
33    // same commit the old session was on.
34    let mut resumed_snapshot: Option<crate::session::SessionSnapshot> = None;
35    if let Some(ref ws_config) = req.workspace_config {
36        if let Some(ref resume_id_str) = ws_config.resume_session_id {
37            match resume_id_str.parse::<uuid::Uuid>() {
38                Ok(resume_id) => {
39                    match server.session_mgr().take_snapshot(&resume_id) {
40                        Some(snapshot) => {
41                            if snapshot.codebase != req.codebase {
42                                return Err(Status::invalid_argument(format!(
43                                    "Cannot resume session from codebase '{}' into '{}'",
44                                    snapshot.codebase, req.codebase
45                                )));
46                            }
47                            info!(
48                                resume_from = %resume_id,
49                                agent_id = %snapshot.agent_id,
50                                base_version = %snapshot.codebase_version,
51                                "CONNECT: resuming from previous session snapshot"
52                            );
53                            resumed_snapshot = Some(snapshot);
54                        }
55                        None => {
56                            warn!(
57                                resume_session_id = %resume_id,
58                                "CONNECT: resume requested but no snapshot found \
59                                 (session may have expired beyond snapshot TTL)"
60                            );
61                        }
62                    }
63                }
64                Err(_) => {
65                    return Err(Status::invalid_argument(format!(
66                        "resume_session_id '{}' is not a valid UUID",
67                        resume_id_str
68                    )));
69                }
70            }
71        }
72    }
73
74    // Extract the requested base_commit early so it can be validated during
75    // the initial repo lookup (avoids a redundant `get_repo` call later).
76    // If resuming, default to the snapshot's codebase_version so the
77    // workspace starts from the same commit the old session was on.
78    let requested_base_commit = req
79        .workspace_config
80        .as_ref()
81        .and_then(|c| c.base_commit.clone())
82        .or_else(|| resumed_snapshot.as_ref().map(|s| s.codebase_version.clone()));
83
84    // 2-4. Resolve repo, get summary, read HEAD version, and validate
85    //      base_commit if one was provided.  Everything involving
86    //      `GitRepository` (which is !Sync) is scoped inside a block so
87    //      the future remains Send.
88    let engine = server.engine();
89
90    let (repo_id, version, summary) = {
91        let (repo_id, git_repo) = engine
92            .get_repo(&req.codebase)
93            .await
94            .map_err(|e| match e {
95                dk_core::Error::AmbiguousRepoName(_) => Status::invalid_argument(
96                    format!("Ambiguous repository name: use the full 'owner/repo' form ({e})"),
97                ),
98                _ => Status::not_found(format!("Repository not found: {e}")),
99            })?;
100
101        // HEAD commit hash (or "initial" for empty repos).
102        let version = git_repo
103            .head_hash()
104            .map_err(|e| Status::internal(format!("Failed to read HEAD: {e}")))?
105            .unwrap_or_else(|| "initial".to_string());
106
107        // Validate custom base_commit while we still hold git_repo, avoiding
108        // a second `get_repo` call.
109        if let Some(ref base) = requested_base_commit {
110            if base != &version && base != "initial" {
111                git_repo
112                    .list_tree_files(base)
113                    .map_err(|_| {
114                        Status::invalid_argument(format!(
115                            "base_commit '{base}' does not resolve to a valid commit"
116                        ))
117                    })?;
118            }
119        }
120
121        // Drop git_repo before the next .await to keep the future Send.
122        drop(git_repo);
123
124        let summary = engine
125            .codebase_summary(repo_id)
126            .await
127            .map_err(|e| Status::internal(format!("Failed to get summary: {e}")))?;
128
129        (repo_id, version, summary)
130    };
131
132    // 5. Create session (session_mgr is lock-free / DashMap-based).
133    let session_id = server.session_mgr().create_session(
134        req.agent_id.clone(),
135        req.codebase.clone(),
136        req.intent.clone(),
137        version.clone(),
138    );
139
140    // 5a. Resolve agent name: use provided name or auto-assign.
141    let agent_name = if req.agent_name.is_empty() {
142        engine.workspace_manager().next_agent_name(&repo_id)
143    } else {
144        req.agent_name.clone()
145    };
146
147    // 5b. Create a changeset (staging area for file changes).
148    let changeset = engine
149        .changeset_store()
150        .create(repo_id, Some(session_id), &req.agent_id, &req.intent, Some(&version), &agent_name)
151        .await
152        .map_err(|e| Status::internal(format!("failed to create changeset: {e}")))?;
153
154    // 6. Determine workspace mode from request config
155    let ws_mode = match req.workspace_config.as_ref().map(|c| c.mode()) {
156        Some(crate::WorkspaceMode::Persistent) => WorkspaceMode::Persistent { expires_at: None },
157        _ => WorkspaceMode::Ephemeral,
158    };
159
160    // Use the provided base_commit or default to current HEAD version
161    let base_commit = requested_base_commit.unwrap_or_else(|| version.clone());
162
163    // Create the session workspace
164    let workspace_id = engine
165        .workspace_manager()
166        .create_workspace(
167            session_id,
168            repo_id,
169            req.agent_id.clone(),
170            changeset.id,
171            req.intent.clone(),
172            base_commit,
173            ws_mode,
174            agent_name.clone(),
175        )
176        .await
177        .map_err(|e| Status::internal(format!("failed to create workspace: {e}")))?;
178
179    // 7. Build concurrency info
180    let other_session_ids = engine
181        .workspace_manager()
182        .active_sessions_for_repo(repo_id, Some(session_id));
183
184    let mut other_sessions = Vec::new();
185    for other_sid in &other_session_ids {
186        if let Some(other_ws) = engine.workspace_manager().get_workspace(other_sid) {
187            // Gather just the paths (avoids cloning file content)
188            let active_files: Vec<String> = other_ws.overlay.list_paths();
189
190            other_sessions.push(ActiveSessionSummary {
191                agent_id: other_ws.agent_id.clone(),
192                intent: other_ws.intent.clone(),
193                active_files,
194            });
195        }
196    }
197
198    let concurrency = WorkspaceConcurrencyInfo {
199        active_sessions: (other_session_ids.len() + 1) as u32, // include this session
200        other_sessions,
201    };
202
203    info!(
204        session_id = %session_id,
205        changeset_id = %changeset.id,
206        workspace_id = %workspace_id,
207        agent_id = %req.agent_id,
208        agent_name = %agent_name,
209        codebase = %req.codebase,
210        active_sessions = concurrency.active_sessions,
211        "CONNECT: session, changeset, and workspace created"
212    );
213
214    Ok(Response::new(ConnectResponse {
215        session_id: session_id.to_string(),
216        codebase_version: version,
217        summary: Some(CodebaseSummary {
218            languages: summary.languages,
219            total_symbols: summary.total_symbols,
220            total_files: summary.total_files,
221        }),
222        changeset_id: changeset.id.to_string(),
223        workspace_id: workspace_id.to_string(),
224        concurrency: Some(concurrency),
225    }))
226}