Skip to main content

dk_protocol/
file_write.rs

1use std::time::Instant;
2
3use tonic::{Response, Status};
4use tracing::{info, warn};
5
6use dk_engine::conflict::SymbolClaim;
7use crate::server::ProtocolServer;
8use crate::validation::{validate_file_path, MAX_FILE_SIZE};
9use crate::{ConflictWarning, FileWriteRequest, FileWriteResponse, SymbolChange};
10
11/// Handle a FileWrite RPC.
12///
13/// Writes a file through the session workspace overlay and optionally
14/// detects symbol changes by parsing the new content.
15pub async fn handle_file_write(
16    server: &ProtocolServer,
17    req: FileWriteRequest,
18) -> Result<Response<FileWriteResponse>, Status> {
19    validate_file_path(&req.path)?;
20
21    if req.content.len() > MAX_FILE_SIZE {
22        return Err(Status::invalid_argument("file content exceeds 50MB limit"));
23    }
24
25    let session = server.validate_session(&req.session_id)?;
26
27    let sid = req
28        .session_id
29        .parse::<uuid::Uuid>()
30        .map_err(|_| Status::invalid_argument("Invalid session ID"))?;
31    server.session_mgr().touch_session(&sid);
32
33    let engine = server.engine();
34
35    // Get workspace for this session
36    let ws = engine
37        .workspace_manager()
38        .get_workspace(&sid)
39        .ok_or_else(|| Status::not_found("Workspace not found for session"))?;
40
41    // Determine if the file is new (not in base tree) synchronously,
42    // then drop the git_repo before async work to keep future Send.
43    // Capture repo_id here to avoid a redundant second get_repo call later.
44    let (repo_id, is_new) = {
45        let (rid, git_repo) = engine
46            .get_repo(&session.codebase)
47            .await
48            .map_err(|e| Status::internal(format!("Repo error: {e}")))?;
49        let new = git_repo
50            .read_tree_entry(&ws.base_commit, &req.path)
51            .is_err();
52        (rid, new)
53    };
54    let repo_id_str = repo_id.to_string();
55
56    // Capture pre-write symbols for accurate change detection
57    let pre_write_symbols: std::collections::HashSet<String> = engine
58        .symbol_store()
59        .find_by_file(repo_id, &req.path)
60        .await
61        .unwrap_or_default()
62        .into_iter()
63        .map(|s| s.qualified_name)
64        .collect();
65
66    // Write through the overlay (async DB persist)
67    let new_hash = ws
68        .overlay
69        .write(&req.path, req.content.clone(), is_new)
70        .await
71        .map_err(|e| Status::internal(format!("Write failed: {e}")))?;
72
73    let changeset_id = ws.changeset_id;
74    let agent_name = ws.agent_name.clone();
75
76    // Drop workspace guard before further work
77    drop(ws);
78
79    // Also record in changeset_files so the verify pipeline can materialize them.
80    let op = if is_new { "add" } else { "modify" };
81    let content_str = std::str::from_utf8(&req.content).ok();
82    let _ = engine
83        .changeset_store()
84        .upsert_file(changeset_id, &req.path, op, content_str)
85        .await;
86
87    // Attempt to detect symbol changes by parsing the new content
88    let detected_changes = detect_symbol_changes(engine, &req.path, &req.content);
89
90    // Build symbol change details with accurate change types by comparing
91    // against pre-write symbols.
92    let symbol_changes: Vec<crate::SymbolChangeDetail> = detected_changes
93        .iter()
94        .map(|sc| {
95            let change_type = if is_new || !pre_write_symbols.contains(&sc.symbol_name) {
96                "added"
97            } else {
98                "modified"
99            };
100            crate::SymbolChangeDetail {
101                symbol_name: sc.symbol_name.clone(),
102                file_path: req.path.clone(),
103                change_type: change_type.to_string(),
104                kind: sc.change_type.clone(),
105            }
106        })
107        .collect();
108
109    // Detect deleted symbols (existed before but no longer present).
110    // Only infer deletions when the parser produced output — if parsing
111    // failed (e.g. incomplete syntax mid-edit), detected_changes is empty
112    // and we'd falsely report every pre-existing symbol as deleted.
113    let mut all_symbol_changes = symbol_changes;
114    if !detected_changes.is_empty() {
115        let detected_names: std::collections::HashSet<&str> = detected_changes
116            .iter()
117            .map(|sc| sc.symbol_name.as_str())
118            .collect();
119        for name in &pre_write_symbols {
120            if !detected_names.contains(name.as_str()) {
121                all_symbol_changes.push(crate::SymbolChangeDetail {
122                    symbol_name: name.clone(),
123                    file_path: req.path.clone(),
124                    change_type: "deleted".to_string(),
125                    kind: String::new(),
126                });
127            }
128        }
129    }
130
131    // ── Symbol claim tracking ──
132    // Build claims from "added" and "modified" symbol changes and check for
133    // cross-session conflicts. Two sessions modifying DIFFERENT symbols in the
134    // same file is NOT a conflict — only same-symbol is a true conflict.
135    let conflict_warnings = {
136        let claimable: Vec<&crate::SymbolChangeDetail> = all_symbol_changes
137            .iter()
138            .filter(|sc| sc.change_type == "added" || sc.change_type == "modified")
139            .collect();
140
141        // Check for conflicts before recording our claims
142        let qualified_names: Vec<String> = claimable.iter().map(|sc| sc.symbol_name.clone()).collect();
143        let conflicts = server.claim_tracker().check_conflicts(
144            repo_id,
145            &req.path,
146            sid,
147            &qualified_names,
148        );
149
150        // Record claims (even if conflicts exist — warning only at write time)
151        for sc in &claimable {
152            let kind = sc.kind.parse::<dk_core::SymbolKind>().unwrap_or(dk_core::SymbolKind::Function);
153            server.claim_tracker().record_claim(
154                repo_id,
155                &req.path,
156                SymbolClaim {
157                    session_id: sid,
158                    agent_name: agent_name.clone(),
159                    qualified_name: sc.symbol_name.clone(),
160                    kind,
161                    first_touched_at: Instant::now(),
162                },
163            );
164        }
165
166        // Build ConflictWarning proto messages
167        let warnings: Vec<ConflictWarning> = conflicts
168            .into_iter()
169            .map(|c| {
170                let msg = format!(
171                    "Symbol '{}' was already modified by agent '{}' (session {})",
172                    c.qualified_name, c.conflicting_agent, c.conflicting_session,
173                );
174                warn!(
175                    session_id = %sid,
176                    path = %req.path,
177                    symbol = %c.qualified_name,
178                    conflicting_agent = %c.conflicting_agent,
179                    "CONFLICT_WARNING: {msg}"
180                );
181                ConflictWarning {
182                    file_path: req.path.clone(),
183                    symbol_name: c.qualified_name,
184                    conflicting_agent: c.conflicting_agent,
185                    conflicting_session_id: c.conflicting_session.to_string(),
186                    message: msg,
187                }
188            })
189            .collect();
190        warnings
191    };
192
193    // Emit a file.modified (or file.added) event
194    let event_type = if is_new { "file.added" } else { "file.modified" };
195    server.event_bus().publish(crate::WatchEvent {
196        event_type: event_type.to_string(),
197        changeset_id: changeset_id.to_string(),
198        agent_id: session.agent_id.clone(),
199        affected_symbols: vec![],
200        details: format!("file {}: {}", op, req.path),
201        session_id: req.session_id.clone(),
202        affected_files: vec![crate::FileChange {
203            path: req.path.clone(),
204            operation: op.to_string(),
205        }],
206        symbol_changes: all_symbol_changes,
207        repo_id: repo_id_str,
208        event_id: uuid::Uuid::new_v4().to_string(),
209    });
210
211    info!(
212        session_id = %req.session_id,
213        path = %req.path,
214        hash = %new_hash,
215        changes = detected_changes.len(),
216        conflicts = conflict_warnings.len(),
217        "FILE_WRITE: completed"
218    );
219
220    Ok(Response::new(FileWriteResponse {
221        new_hash,
222        detected_changes,
223        conflict_warnings,
224    }))
225}
226
227/// Parse the file content and detect symbol-level changes.
228///
229/// This is best-effort: if the parser doesn't support the file type
230/// or parsing fails, we return an empty list.
231fn detect_symbol_changes(
232    engine: &dk_engine::repo::Engine,
233    path: &str,
234    content: &[u8],
235) -> Vec<SymbolChange> {
236    let file_path = std::path::Path::new(path);
237    let parser = engine.parser();
238
239    if !parser.supports_file(file_path) {
240        return Vec::new();
241    }
242
243    match parser.parse_file(file_path, content) {
244        Ok(analysis) => analysis
245            .symbols
246            .iter()
247            .map(|sym| SymbolChange {
248                symbol_name: sym.qualified_name.clone(),
249                change_type: sym.kind.to_string(),
250            })
251            .collect(),
252        Err(_) => Vec::new(),
253    }
254}