Skip to main content

dk_protocol/
file_write.rs

1use std::time::Instant;
2
3use tonic::{Response, Status};
4use tracing::{info, warn};
5
6use dk_engine::conflict::{AcquireOutcome, SymbolClaim};
7use crate::server::ProtocolServer;
8use crate::validation::{validate_file_path, MAX_FILE_SIZE};
9use crate::{ConflictWarning, FileWriteRequest, FileWriteResponse, SymbolChange};
10
11/// Handle a FileWrite RPC.
12///
13/// Writes a file through the session workspace overlay and optionally
14/// detects symbol changes by parsing the new content.
15pub async fn handle_file_write(
16    server: &ProtocolServer,
17    req: FileWriteRequest,
18) -> Result<Response<FileWriteResponse>, Status> {
19    validate_file_path(&req.path)?;
20
21    if req.content.len() > MAX_FILE_SIZE {
22        return Err(Status::invalid_argument("file content exceeds 50MB limit"));
23    }
24
25    let session = server.validate_session(&req.session_id)?;
26
27    let sid = req
28        .session_id
29        .parse::<uuid::Uuid>()
30        .map_err(|_| Status::invalid_argument("Invalid session ID"))?;
31    server.session_mgr().touch_session(&sid);
32
33    let engine = server.engine();
34
35    // Get workspace for this session
36    let ws = engine
37        .workspace_manager()
38        .get_workspace(&sid)
39        .ok_or_else(|| Status::not_found("Workspace not found for session"))?;
40
41    // Determine if the file is new (not in base tree) and read old content
42    // in a single get_repo call. Drop git_repo before async work to keep
43    // future Send.
44    let (repo_id, is_new, old_content) = {
45        let (rid, git_repo) = engine
46            .get_repo(&session.codebase)
47            .await
48            .map_err(|e| Status::internal(format!("Repo error: {e}")))?;
49        match git_repo.read_tree_entry(&ws.base_commit, &req.path) {
50            Ok(bytes) => (rid, false, bytes),
51            Err(e) => {
52                // File not in base tree — treat as new. Log the error in case
53                // it's a transient git failure rather than a genuine "not found".
54                warn!(
55                    path = %req.path,
56                    base_commit = %ws.base_commit,
57                    error = %e,
58                    "read_tree_entry failed — treating file as new"
59                );
60                (rid, true, Vec::new())
61            }
62        }
63    };
64    let repo_id_str = repo_id.to_string();
65    let changeset_id = ws.changeset_id;
66    let agent_name = ws.agent_name.clone();
67
68    // Drop workspace guard — overlay write is deferred until after lock acquisition
69    drop(ws);
70
71    let op = if is_new { "add" } else { "modify" };
72
73    // Detect symbol changes from req.content directly — no overlay needed yet.
74    let (detected_changes, all_symbol_changes) =
75        detect_symbol_changes_diffed(engine, &req.path, &old_content, &req.content, is_new);
76
77    // ── Symbol locking (acquire with rollback) ──
78    // Attempt to acquire locks for each changed symbol. If any fails, roll back
79    // all previously acquired locks and reject the write. No overlay write, no
80    // changeset store entry — completely clean rejection.
81    let claimable: Vec<&crate::SymbolChangeDetail> = all_symbol_changes
82        .iter()
83        .filter(|sc| sc.change_type == "added" || sc.change_type == "modified" || sc.change_type == "deleted")
84        .collect();
85
86    let mut acquired: Vec<String> = Vec::new();
87    let mut locked_symbols: Vec<ConflictWarning> = Vec::new();
88
89    for sc in &claimable {
90        let kind = sc.kind.parse::<dk_core::SymbolKind>().unwrap_or(dk_core::SymbolKind::Function);
91        match server.claim_tracker().acquire_lock(
92            repo_id,
93            &req.path,
94            SymbolClaim {
95                session_id: sid,
96                agent_name: agent_name.clone(),
97                qualified_name: sc.symbol_name.clone(),
98                kind,
99                first_touched_at: Instant::now(),
100            },
101        ) {
102            Ok(AcquireOutcome::Fresh) => acquired.push(sc.symbol_name.clone()),
103            Ok(AcquireOutcome::ReAcquired) => {} // already held — exclude from rollback
104            Err(sl) => {
105                warn!(
106                    session_id = %sid,
107                    path = %req.path,
108                    symbol = %sl.qualified_name,
109                    locked_by = %sl.locked_by_agent,
110                    "SYMBOL_LOCKED: write rejected"
111                );
112                locked_symbols.push(ConflictWarning {
113                    file_path: req.path.clone(),
114                    symbol_name: sl.qualified_name.clone(),
115                    conflicting_agent: sl.locked_by_agent.clone(),
116                    conflicting_session_id: sl.locked_by_session.to_string(),
117                    message: format!(
118                        "SYMBOL_LOCKED: '{}' is locked by agent '{}'. Call dk_watch(filter: '{}') to wait, then dk_file_read and retry.",
119                        sl.qualified_name, sl.locked_by_agent, crate::merge::EVENT_LOCK_RELEASED,
120                    ),
121                });
122            }
123        }
124    }
125
126    if !locked_symbols.is_empty() {
127        // Roll back any locks acquired before the failure and emit events
128        // so any agent that raced and observed the transient lock can wake up.
129        for name in &acquired {
130            server.claim_tracker().release_lock(repo_id, &req.path, sid, name);
131            server.event_bus().publish(crate::WatchEvent {
132                event_type: crate::merge::EVENT_LOCK_RELEASED.to_string(),
133                changeset_id: String::new(),
134                agent_id: agent_name.clone(),
135                affected_symbols: vec![name.clone()],
136                details: format!("Symbol lock rolled back on {}", req.path),
137                session_id: req.session_id.clone(),
138                affected_files: vec![crate::FileChange {
139                    path: req.path.clone(),
140                    operation: "unlock".to_string(),
141                }],
142                symbol_changes: vec![],
143                repo_id: repo_id_str.clone(),
144                event_id: uuid::Uuid::new_v4().to_string(),
145            });
146        }
147
148        info!(
149            session_id = %sid,
150            path = %req.path,
151            locked_count = locked_symbols.len(),
152            rolled_back = acquired.len(),
153            "FILE_WRITE: rejected — symbols locked, rolled back partial locks"
154        );
155
156        return Ok(Response::new(FileWriteResponse {
157            new_hash: String::new(),
158            detected_changes: Vec::new(),
159            conflict_warnings: locked_symbols,
160        }));
161    }
162
163    // All locks acquired — now write the overlay and changeset store.
164    // If either fails, release all acquired locks before propagating the error.
165    let ws = match engine.workspace_manager().get_workspace(&sid) {
166        Some(ws) => ws,
167        None => {
168            for name in &acquired {
169                server.claim_tracker().release_lock(repo_id, &req.path, sid, name);
170                server.event_bus().publish(crate::WatchEvent {
171                    event_type: crate::merge::EVENT_LOCK_RELEASED.to_string(),
172                    changeset_id: String::new(),
173                    agent_id: agent_name.clone(),
174                    affected_symbols: vec![name.clone()],
175                    details: format!("Symbol lock released on error in {}", req.path),
176                    session_id: req.session_id.clone(),
177                    affected_files: vec![crate::FileChange {
178                        path: req.path.clone(),
179                        operation: "unlock".to_string(),
180                    }],
181                    symbol_changes: vec![],
182                    repo_id: repo_id_str.clone(),
183                    event_id: uuid::Uuid::new_v4().to_string(),
184                });
185            }
186            return Err(Status::not_found("Workspace not found for session"));
187        }
188    };
189
190    let new_hash = match ws.overlay.write(&req.path, req.content.clone(), is_new).await {
191        Ok(hash) => hash,
192        Err(e) => {
193            for name in &acquired {
194                server.claim_tracker().release_lock(repo_id, &req.path, sid, name);
195                server.event_bus().publish(crate::WatchEvent {
196                    event_type: crate::merge::EVENT_LOCK_RELEASED.to_string(),
197                    changeset_id: String::new(),
198                    agent_id: agent_name.clone(),
199                    affected_symbols: vec![name.clone()],
200                    details: format!("Symbol lock released on error in {}", req.path),
201                    session_id: req.session_id.clone(),
202                    affected_files: vec![crate::FileChange {
203                        path: req.path.clone(),
204                        operation: "unlock".to_string(),
205                    }],
206                    symbol_changes: vec![],
207                    repo_id: repo_id_str.clone(),
208                    event_id: uuid::Uuid::new_v4().to_string(),
209                });
210            }
211            return Err(Status::internal(format!("Write failed: {e}")));
212        }
213    };
214
215    drop(ws);
216
217    let content_str = std::str::from_utf8(&req.content).ok();
218    let _ = engine
219        .changeset_store()
220        .upsert_file(changeset_id, &req.path, op, content_str)
221        .await;
222
223    let conflict_warnings: Vec<ConflictWarning> = Vec::new();
224
225    // Emit a file.modified (or file.added) event
226    let event_type = if is_new { "file.added" } else { "file.modified" };
227    server.event_bus().publish(crate::WatchEvent {
228        event_type: event_type.to_string(),
229        changeset_id: changeset_id.to_string(),
230        agent_id: session.agent_id.clone(),
231        affected_symbols: vec![],
232        details: format!("file {}: {}", op, req.path),
233        session_id: req.session_id.clone(),
234        affected_files: vec![crate::FileChange {
235            path: req.path.clone(),
236            operation: op.to_string(),
237        }],
238        symbol_changes: all_symbol_changes,
239        repo_id: repo_id_str,
240        event_id: uuid::Uuid::new_v4().to_string(),
241    });
242
243    info!(
244        session_id = %req.session_id,
245        path = %req.path,
246        hash = %new_hash,
247        changes = detected_changes.len(),
248        conflicts = conflict_warnings.len(),
249        "FILE_WRITE: completed"
250    );
251
252    Ok(Response::new(FileWriteResponse {
253        new_hash,
254        detected_changes,
255        conflict_warnings,
256    }))
257}
258
259/// Parse both old and new file content, diff per-symbol source text,
260/// and return only symbols that actually changed.
261///
262/// Returns `(detected_changes, all_symbol_change_details)`:
263/// - `detected_changes`: `SymbolChange` for the gRPC response (only truly changed symbols)
264/// - `all_symbol_change_details`: `SymbolChangeDetail` for claims + events (added/modified/deleted)
265fn detect_symbol_changes_diffed(
266    engine: &dk_engine::repo::Engine,
267    path: &str,
268    old_content: &[u8],
269    new_content: &[u8],
270    is_new_file: bool,
271) -> (Vec<SymbolChange>, Vec<crate::SymbolChangeDetail>) {
272    let file_path = std::path::Path::new(path);
273    let parser = engine.parser();
274
275    if !parser.supports_file(file_path) {
276        return (Vec::new(), Vec::new());
277    }
278
279    // Parse new file
280    let new_symbols = match parser.parse_file(file_path, new_content) {
281        Ok(analysis) => analysis.symbols,
282        Err(_) => return (Vec::new(), Vec::new()),
283    };
284
285    // If file is new, all symbols are "added"
286    if is_new_file || old_content.is_empty() {
287        let changes: Vec<SymbolChange> = new_symbols
288            .iter()
289            .map(|sym| SymbolChange {
290                symbol_name: sym.qualified_name.clone(),
291                change_type: sym.kind.to_string(),
292            })
293            .collect();
294        let details: Vec<crate::SymbolChangeDetail> = new_symbols
295            .iter()
296            .map(|sym| crate::SymbolChangeDetail {
297                symbol_name: sym.qualified_name.clone(),
298                file_path: path.to_string(),
299                change_type: "added".to_string(),
300                kind: sym.kind.to_string(),
301            })
302            .collect();
303        return (changes, details);
304    }
305
306    // Parse old file to get baseline symbols
307    let old_symbols = match parser.parse_file(file_path, old_content) {
308        Ok(analysis) => analysis.symbols,
309        Err(_) => {
310            // Can't parse old file — fall back to treating all new symbols as modified
311            let changes: Vec<SymbolChange> = new_symbols
312                .iter()
313                .map(|sym| SymbolChange {
314                    symbol_name: sym.qualified_name.clone(),
315                    change_type: sym.kind.to_string(),
316                })
317                .collect();
318            let details: Vec<crate::SymbolChangeDetail> = new_symbols
319                .iter()
320                .map(|sym| crate::SymbolChangeDetail {
321                    symbol_name: sym.qualified_name.clone(),
322                    file_path: path.to_string(),
323                    change_type: "modified".to_string(),
324                    kind: sym.kind.to_string(),
325                })
326                .collect();
327            return (changes, details);
328        }
329    };
330
331    // Build a map of old symbol qualified_name → source text.
332    // Use entry().or_insert() to keep the first occurrence when duplicate
333    // qualified names exist (e.g., overloaded methods in Java/Kotlin/C#).
334    let mut old_symbol_text: std::collections::HashMap<&str, &[u8]> = std::collections::HashMap::new();
335    for sym in &old_symbols {
336        let start = sym.span.start_byte as usize;
337        let end = sym.span.end_byte as usize;
338        if start <= end && end <= old_content.len() {
339            old_symbol_text.entry(sym.qualified_name.as_str()).or_insert(&old_content[start..end]);
340        }
341    }
342
343    let mut detected_changes = Vec::new();
344    let mut all_details = Vec::new();
345
346    // Deduplicate new symbols while preserving original parse order.
347    let mut seen_new: std::collections::HashSet<&str> = std::collections::HashSet::new();
348
349    // Compare each deduplicated new symbol against its old version
350    for sym in &new_symbols {
351        if !seen_new.insert(sym.qualified_name.as_str()) {
352            continue; // duplicate qualified name — already handled
353        }
354        let start = sym.span.start_byte as usize;
355        let end = sym.span.end_byte as usize;
356        let new_text = if start <= end && end <= new_content.len() {
357            &new_content[start..end]
358        } else {
359            continue; // invalid or inverted span, skip
360        };
361
362        match old_symbol_text.get(sym.qualified_name.as_str()) {
363            None => {
364                // Symbol not in old file — added
365                detected_changes.push(SymbolChange {
366                    symbol_name: sym.qualified_name.clone(),
367                    change_type: sym.kind.to_string(),
368                });
369                all_details.push(crate::SymbolChangeDetail {
370                    symbol_name: sym.qualified_name.clone(),
371                    file_path: path.to_string(),
372                    change_type: "added".to_string(),
373                    kind: sym.kind.to_string(),
374                });
375            }
376            Some(old_text) => {
377                if *old_text != new_text {
378                    // Symbol text changed — modified
379                    detected_changes.push(SymbolChange {
380                        symbol_name: sym.qualified_name.clone(),
381                        change_type: sym.kind.to_string(),
382                    });
383                    all_details.push(crate::SymbolChangeDetail {
384                        symbol_name: sym.qualified_name.clone(),
385                        file_path: path.to_string(),
386                        change_type: "modified".to_string(),
387                        kind: sym.kind.to_string(),
388                    });
389                }
390                // else: symbol text identical — skip (no claim needed)
391            }
392        }
393    }
394
395    // Detect deleted symbols (deduplicated to avoid double-reporting overloads)
396    let new_names: std::collections::HashSet<&str> = new_symbols
397        .iter()
398        .map(|s| s.qualified_name.as_str())
399        .collect();
400    let old_names: std::collections::HashSet<&str> = old_symbols
401        .iter()
402        .map(|s| s.qualified_name.as_str())
403        .collect();
404    for old_name in &old_names {
405        if !new_names.contains(old_name) {
406            if let Some(old_sym) = old_symbols.iter().find(|s| s.qualified_name.as_str() == *old_name) {
407                all_details.push(crate::SymbolChangeDetail {
408                    symbol_name: old_sym.qualified_name.clone(),
409                    file_path: path.to_string(),
410                    change_type: "deleted".to_string(),
411                    kind: old_sym.kind.to_string(),
412                });
413            }
414        }
415    }
416
417    (detected_changes, all_details)
418}