Skip to main content

dk_protocol/
file_write.rs

1use std::time::Instant;
2
3use tonic::{Response, Status};
4use tracing::{info, warn};
5
6use dk_engine::conflict::SymbolClaim;
7use crate::server::ProtocolServer;
8use crate::validation::{validate_file_path, MAX_FILE_SIZE};
9use crate::{ConflictWarning, FileWriteRequest, FileWriteResponse, SymbolChange};
10
11/// Handle a FileWrite RPC.
12///
13/// Writes a file through the session workspace overlay and optionally
14/// detects symbol changes by parsing the new content.
15pub async fn handle_file_write(
16    server: &ProtocolServer,
17    req: FileWriteRequest,
18) -> Result<Response<FileWriteResponse>, Status> {
19    validate_file_path(&req.path)?;
20
21    if req.content.len() > MAX_FILE_SIZE {
22        return Err(Status::invalid_argument("file content exceeds 50MB limit"));
23    }
24
25    let session = server.validate_session(&req.session_id)?;
26
27    let sid = req
28        .session_id
29        .parse::<uuid::Uuid>()
30        .map_err(|_| Status::invalid_argument("Invalid session ID"))?;
31    server.session_mgr().touch_session(&sid);
32
33    let engine = server.engine();
34
35    // Get workspace for this session
36    let ws = engine
37        .workspace_manager()
38        .get_workspace(&sid)
39        .ok_or_else(|| Status::not_found("Workspace not found for session"))?;
40
41    // Determine if the file is new (not in base tree) and read old content
42    // in a single get_repo call. Drop git_repo before async work to keep
43    // future Send.
44    let (repo_id, is_new, old_content) = {
45        let (rid, git_repo) = engine
46            .get_repo(&session.codebase)
47            .await
48            .map_err(|e| Status::internal(format!("Repo error: {e}")))?;
49        match git_repo.read_tree_entry(&ws.base_commit, &req.path) {
50            Ok(bytes) => (rid, false, bytes),
51            Err(e) => {
52                // File not in base tree — treat as new. Log the error in case
53                // it's a transient git failure rather than a genuine "not found".
54                warn!(
55                    path = %req.path,
56                    base_commit = %ws.base_commit,
57                    error = %e,
58                    "read_tree_entry failed — treating file as new"
59                );
60                (rid, true, Vec::new())
61            }
62        }
63    };
64    let repo_id_str = repo_id.to_string();
65
66    // Write through the overlay (async DB persist)
67    let new_hash = ws
68        .overlay
69        .write(&req.path, req.content.clone(), is_new)
70        .await
71        .map_err(|e| Status::internal(format!("Write failed: {e}")))?;
72
73    let changeset_id = ws.changeset_id;
74    let agent_name = ws.agent_name.clone();
75
76    // Drop workspace guard before further work
77    drop(ws);
78
79    // Also record in changeset_files so the verify pipeline can materialize them.
80    let op = if is_new { "add" } else { "modify" };
81    let content_str = std::str::from_utf8(&req.content).ok();
82    let _ = engine
83        .changeset_store()
84        .upsert_file(changeset_id, &req.path, op, content_str)
85        .await;
86
87    // Detect symbol changes by diffing old vs new file content.
88    // Only symbols whose source text actually changed are reported.
89    let (detected_changes, all_symbol_changes) =
90        detect_symbol_changes_diffed(engine, &req.path, &old_content, &req.content, is_new);
91
92    // ── Symbol claim tracking ──
93    // Build claims from "added" and "modified" symbol changes and check for
94    // cross-session conflicts. Two sessions modifying DIFFERENT symbols in the
95    // same file is NOT a conflict — only same-symbol is a true conflict.
96    let conflict_warnings = {
97        let claimable: Vec<&crate::SymbolChangeDetail> = all_symbol_changes
98            .iter()
99            .filter(|sc| sc.change_type == "added" || sc.change_type == "modified")
100            .collect();
101
102        // Check for conflicts before recording our claims
103        let qualified_names: Vec<String> = claimable.iter().map(|sc| sc.symbol_name.clone()).collect();
104        let conflicts = server.claim_tracker().check_conflicts(
105            repo_id,
106            &req.path,
107            sid,
108            &qualified_names,
109        );
110
111        // Record claims (even if conflicts exist — warning only at write time)
112        for sc in &claimable {
113            let kind = sc.kind.parse::<dk_core::SymbolKind>().unwrap_or(dk_core::SymbolKind::Function);
114            server.claim_tracker().record_claim(
115                repo_id,
116                &req.path,
117                SymbolClaim {
118                    session_id: sid,
119                    agent_name: agent_name.clone(),
120                    qualified_name: sc.symbol_name.clone(),
121                    kind,
122                    first_touched_at: Instant::now(),
123                },
124            );
125        }
126
127        // Build ConflictWarning proto messages
128        let warnings: Vec<ConflictWarning> = conflicts
129            .into_iter()
130            .map(|c| {
131                let msg = format!(
132                    "Symbol '{}' was already modified by agent '{}' (session {})",
133                    c.qualified_name, c.conflicting_agent, c.conflicting_session,
134                );
135                warn!(
136                    session_id = %sid,
137                    path = %req.path,
138                    symbol = %c.qualified_name,
139                    conflicting_agent = %c.conflicting_agent,
140                    "CONFLICT_WARNING: {msg}"
141                );
142                ConflictWarning {
143                    file_path: req.path.clone(),
144                    symbol_name: c.qualified_name,
145                    conflicting_agent: c.conflicting_agent,
146                    conflicting_session_id: c.conflicting_session.to_string(),
147                    message: msg,
148                }
149            })
150            .collect();
151        warnings
152    };
153
154    // Emit a file.modified (or file.added) event
155    let event_type = if is_new { "file.added" } else { "file.modified" };
156    server.event_bus().publish(crate::WatchEvent {
157        event_type: event_type.to_string(),
158        changeset_id: changeset_id.to_string(),
159        agent_id: session.agent_id.clone(),
160        affected_symbols: vec![],
161        details: format!("file {}: {}", op, req.path),
162        session_id: req.session_id.clone(),
163        affected_files: vec![crate::FileChange {
164            path: req.path.clone(),
165            operation: op.to_string(),
166        }],
167        symbol_changes: all_symbol_changes,
168        repo_id: repo_id_str,
169        event_id: uuid::Uuid::new_v4().to_string(),
170    });
171
172    info!(
173        session_id = %req.session_id,
174        path = %req.path,
175        hash = %new_hash,
176        changes = detected_changes.len(),
177        conflicts = conflict_warnings.len(),
178        "FILE_WRITE: completed"
179    );
180
181    Ok(Response::new(FileWriteResponse {
182        new_hash,
183        detected_changes,
184        conflict_warnings,
185    }))
186}
187
188/// Parse both old and new file content, diff per-symbol source text,
189/// and return only symbols that actually changed.
190///
191/// Returns `(detected_changes, all_symbol_change_details)`:
192/// - `detected_changes`: `SymbolChange` for the gRPC response (only truly changed symbols)
193/// - `all_symbol_change_details`: `SymbolChangeDetail` for claims + events (added/modified/deleted)
194fn detect_symbol_changes_diffed(
195    engine: &dk_engine::repo::Engine,
196    path: &str,
197    old_content: &[u8],
198    new_content: &[u8],
199    is_new_file: bool,
200) -> (Vec<SymbolChange>, Vec<crate::SymbolChangeDetail>) {
201    let file_path = std::path::Path::new(path);
202    let parser = engine.parser();
203
204    if !parser.supports_file(file_path) {
205        return (Vec::new(), Vec::new());
206    }
207
208    // Parse new file
209    let new_symbols = match parser.parse_file(file_path, new_content) {
210        Ok(analysis) => analysis.symbols,
211        Err(_) => return (Vec::new(), Vec::new()),
212    };
213
214    // If file is new, all symbols are "added"
215    if is_new_file || old_content.is_empty() {
216        let changes: Vec<SymbolChange> = new_symbols
217            .iter()
218            .map(|sym| SymbolChange {
219                symbol_name: sym.qualified_name.clone(),
220                change_type: sym.kind.to_string(),
221            })
222            .collect();
223        let details: Vec<crate::SymbolChangeDetail> = new_symbols
224            .iter()
225            .map(|sym| crate::SymbolChangeDetail {
226                symbol_name: sym.qualified_name.clone(),
227                file_path: path.to_string(),
228                change_type: "added".to_string(),
229                kind: sym.kind.to_string(),
230            })
231            .collect();
232        return (changes, details);
233    }
234
235    // Parse old file to get baseline symbols
236    let old_symbols = match parser.parse_file(file_path, old_content) {
237        Ok(analysis) => analysis.symbols,
238        Err(_) => {
239            // Can't parse old file — fall back to treating all new symbols as modified
240            let changes: Vec<SymbolChange> = new_symbols
241                .iter()
242                .map(|sym| SymbolChange {
243                    symbol_name: sym.qualified_name.clone(),
244                    change_type: sym.kind.to_string(),
245                })
246                .collect();
247            let details: Vec<crate::SymbolChangeDetail> = new_symbols
248                .iter()
249                .map(|sym| crate::SymbolChangeDetail {
250                    symbol_name: sym.qualified_name.clone(),
251                    file_path: path.to_string(),
252                    change_type: "modified".to_string(),
253                    kind: sym.kind.to_string(),
254                })
255                .collect();
256            return (changes, details);
257        }
258    };
259
260    // Build a map of old symbol qualified_name → source text.
261    // Use entry().or_insert() to keep the first occurrence when duplicate
262    // qualified names exist (e.g., overloaded methods in Java/Kotlin/C#).
263    let mut old_symbol_text: std::collections::HashMap<&str, &[u8]> = std::collections::HashMap::new();
264    for sym in &old_symbols {
265        let start = sym.span.start_byte as usize;
266        let end = sym.span.end_byte as usize;
267        if start <= end && end <= old_content.len() {
268            old_symbol_text.entry(sym.qualified_name.as_str()).or_insert(&old_content[start..end]);
269        }
270    }
271
272    let mut detected_changes = Vec::new();
273    let mut all_details = Vec::new();
274
275    // Deduplicate new symbols while preserving original parse order.
276    let mut seen_new: std::collections::HashSet<&str> = std::collections::HashSet::new();
277
278    // Compare each deduplicated new symbol against its old version
279    for sym in &new_symbols {
280        if !seen_new.insert(sym.qualified_name.as_str()) {
281            continue; // duplicate qualified name — already handled
282        }
283        let start = sym.span.start_byte as usize;
284        let end = sym.span.end_byte as usize;
285        let new_text = if start <= end && end <= new_content.len() {
286            &new_content[start..end]
287        } else {
288            continue; // invalid or inverted span, skip
289        };
290
291        match old_symbol_text.get(sym.qualified_name.as_str()) {
292            None => {
293                // Symbol not in old file — added
294                detected_changes.push(SymbolChange {
295                    symbol_name: sym.qualified_name.clone(),
296                    change_type: sym.kind.to_string(),
297                });
298                all_details.push(crate::SymbolChangeDetail {
299                    symbol_name: sym.qualified_name.clone(),
300                    file_path: path.to_string(),
301                    change_type: "added".to_string(),
302                    kind: sym.kind.to_string(),
303                });
304            }
305            Some(old_text) => {
306                if *old_text != new_text {
307                    // Symbol text changed — modified
308                    detected_changes.push(SymbolChange {
309                        symbol_name: sym.qualified_name.clone(),
310                        change_type: sym.kind.to_string(),
311                    });
312                    all_details.push(crate::SymbolChangeDetail {
313                        symbol_name: sym.qualified_name.clone(),
314                        file_path: path.to_string(),
315                        change_type: "modified".to_string(),
316                        kind: sym.kind.to_string(),
317                    });
318                }
319                // else: symbol text identical — skip (no claim needed)
320            }
321        }
322    }
323
324    // Detect deleted symbols (deduplicated to avoid double-reporting overloads)
325    let new_names: std::collections::HashSet<&str> = new_symbols
326        .iter()
327        .map(|s| s.qualified_name.as_str())
328        .collect();
329    let old_names: std::collections::HashSet<&str> = old_symbols
330        .iter()
331        .map(|s| s.qualified_name.as_str())
332        .collect();
333    for old_name in &old_names {
334        if !new_names.contains(old_name) {
335            if let Some(old_sym) = old_symbols.iter().find(|s| s.qualified_name.as_str() == *old_name) {
336                all_details.push(crate::SymbolChangeDetail {
337                    symbol_name: old_sym.qualified_name.clone(),
338                    file_path: path.to_string(),
339                    change_type: "deleted".to_string(),
340                    kind: old_sym.kind.to_string(),
341                });
342            }
343        }
344    }
345
346    (detected_changes, all_details)
347}