Skip to main content

dk_protocol/
file_write.rs

1use tonic::{Response, Status};
2use tracing::{info, warn};
3
4use dk_engine::conflict::{AcquireOutcome, SymbolClaim};
5use crate::server::ProtocolServer;
6use crate::validation::{validate_file_path, MAX_FILE_SIZE};
7use crate::{ConflictWarning, FileWriteRequest, FileWriteResponse, SymbolChange};
8
9/// Handle a FileWrite RPC.
10///
11/// Writes a file through the session workspace overlay and optionally
12/// detects symbol changes by parsing the new content.
13pub async fn handle_file_write(
14    server: &ProtocolServer,
15    req: FileWriteRequest,
16) -> Result<Response<FileWriteResponse>, Status> {
17    validate_file_path(&req.path)?;
18
19    if req.content.len() > MAX_FILE_SIZE {
20        return Err(Status::invalid_argument("file content exceeds 50MB limit"));
21    }
22
23    let session = server.validate_session(&req.session_id)?;
24
25    let sid = req
26        .session_id
27        .parse::<uuid::Uuid>()
28        .map_err(|_| Status::invalid_argument("Invalid session ID"))?;
29    server.session_mgr().touch_session(&sid);
30
31    let engine = server.engine();
32
33    // Get workspace for this session
34    let ws = engine
35        .workspace_manager()
36        .get_workspace(&sid)
37        .ok_or_else(|| Status::not_found("Workspace not found for session"))?;
38
39    // Determine if the file is new (not in base tree) and read old content
40    // in a single get_repo call. Drop git_repo before async work to keep
41    // future Send.
42    let (repo_id, is_new, old_content) = {
43        let (rid, git_repo) = engine
44            .get_repo(&session.codebase)
45            .await
46            .map_err(|e| Status::internal(format!("Repo error: {e}")))?;
47        match git_repo.read_tree_entry(&ws.base_commit, &req.path) {
48            Ok(bytes) => (rid, false, bytes),
49            Err(e) => {
50                // File not in base tree — treat as new. Log the error in case
51                // it's a transient git failure rather than a genuine "not found".
52                warn!(
53                    path = %req.path,
54                    base_commit = %ws.base_commit,
55                    error = %e,
56                    "read_tree_entry failed — treating file as new"
57                );
58                (rid, true, Vec::new())
59            }
60        }
61    };
62    let repo_id_str = repo_id.to_string();
63    let changeset_id = ws.changeset_id;
64    let agent_name = ws.agent_name.clone();
65
66    // Drop workspace guard — overlay write is deferred until after lock acquisition
67    drop(ws);
68
69    let op = if is_new { "add" } else { "modify" };
70
71    // Detect symbol changes from req.content directly — no overlay needed yet.
72    let (detected_changes, all_symbol_changes) =
73        detect_symbol_changes_diffed(engine, &req.path, &old_content, &req.content, is_new);
74
75    // ── Symbol locking (acquire with rollback) ──
76    // Attempt to acquire locks for each changed symbol. If any fails, roll back
77    // all previously acquired locks and reject the write. No overlay write, no
78    // changeset store entry — completely clean rejection.
79    let claimable: Vec<&crate::SymbolChangeDetail> = all_symbol_changes
80        .iter()
81        .filter(|sc| sc.change_type == "added" || sc.change_type == "modified" || sc.change_type == "deleted")
82        .collect();
83
84    let mut acquired: Vec<String> = Vec::new();
85    let mut locked_symbols: Vec<ConflictWarning> = Vec::new();
86
87    for sc in &claimable {
88        let kind = sc.kind.parse::<dk_core::SymbolKind>().unwrap_or(dk_core::SymbolKind::Function);
89        match server.claim_tracker().acquire_lock(
90            repo_id,
91            &req.path,
92            SymbolClaim {
93                session_id: sid,
94                agent_name: agent_name.clone(),
95                qualified_name: sc.symbol_name.clone(),
96                kind,
97                first_touched_at: chrono::Utc::now(),
98            },
99        ).await {
100            Ok(AcquireOutcome::Fresh) => acquired.push(sc.symbol_name.clone()),
101            Ok(AcquireOutcome::ReAcquired) => {} // already held — exclude from rollback
102            Err(sl) => {
103                warn!(
104                    session_id = %sid,
105                    path = %req.path,
106                    symbol = %sl.qualified_name,
107                    locked_by = %sl.locked_by_agent,
108                    "SYMBOL_LOCKED: write rejected"
109                );
110                locked_symbols.push(ConflictWarning {
111                    file_path: req.path.clone(),
112                    symbol_name: sl.qualified_name.clone(),
113                    conflicting_agent: sl.locked_by_agent.clone(),
114                    conflicting_session_id: sl.locked_by_session.to_string(),
115                    message: format!(
116                        "SYMBOL_LOCKED: '{}' is locked by agent '{}'. Call dk_watch(filter: '{}') to wait, then dk_file_read and retry.",
117                        sl.qualified_name, sl.locked_by_agent, crate::merge::EVENT_LOCK_RELEASED,
118                    ),
119                });
120            }
121        }
122    }
123
124    if !locked_symbols.is_empty() {
125        // Roll back any locks acquired before the failure and emit events
126        // so any agent that raced and observed the transient lock can wake up.
127        for name in &acquired {
128            server.claim_tracker().release_lock(repo_id, &req.path, sid, name).await;
129            server.event_bus().publish(crate::WatchEvent {
130                event_type: crate::merge::EVENT_LOCK_RELEASED.to_string(),
131                changeset_id: String::new(),
132                agent_id: agent_name.clone(),
133                affected_symbols: vec![name.clone()],
134                details: format!("Symbol lock rolled back on {}", req.path),
135                session_id: req.session_id.clone(),
136                affected_files: vec![crate::FileChange {
137                    path: req.path.clone(),
138                    operation: "unlock".to_string(),
139                }],
140                symbol_changes: vec![],
141                repo_id: repo_id_str.clone(),
142                event_id: uuid::Uuid::new_v4().to_string(),
143            });
144        }
145
146        info!(
147            session_id = %sid,
148            path = %req.path,
149            locked_count = locked_symbols.len(),
150            rolled_back = acquired.len(),
151            "FILE_WRITE: rejected — symbols locked, rolled back partial locks"
152        );
153
154        return Ok(Response::new(FileWriteResponse {
155            new_hash: String::new(),
156            detected_changes: Vec::new(),
157            conflict_warnings: locked_symbols,
158        }));
159    }
160
161    // All locks acquired — now write the overlay and changeset store.
162    // If either fails, release all acquired locks before propagating the error.
163    let ws = match engine.workspace_manager().get_workspace(&sid) {
164        Some(ws) => ws,
165        None => {
166            for name in &acquired {
167                server.claim_tracker().release_lock(repo_id, &req.path, sid, name).await;
168                server.event_bus().publish(crate::WatchEvent {
169                    event_type: crate::merge::EVENT_LOCK_RELEASED.to_string(),
170                    changeset_id: String::new(),
171                    agent_id: agent_name.clone(),
172                    affected_symbols: vec![name.clone()],
173                    details: format!("Symbol lock released on error in {}", req.path),
174                    session_id: req.session_id.clone(),
175                    affected_files: vec![crate::FileChange {
176                        path: req.path.clone(),
177                        operation: "unlock".to_string(),
178                    }],
179                    symbol_changes: vec![],
180                    repo_id: repo_id_str.clone(),
181                    event_id: uuid::Uuid::new_v4().to_string(),
182                });
183            }
184            return Err(Status::not_found("Workspace not found for session"));
185        }
186    };
187
188    let new_hash = match ws.overlay.write(&req.path, req.content.clone(), is_new).await {
189        Ok(hash) => hash,
190        Err(e) => {
191            for name in &acquired {
192                server.claim_tracker().release_lock(repo_id, &req.path, sid, name).await;
193                server.event_bus().publish(crate::WatchEvent {
194                    event_type: crate::merge::EVENT_LOCK_RELEASED.to_string(),
195                    changeset_id: String::new(),
196                    agent_id: agent_name.clone(),
197                    affected_symbols: vec![name.clone()],
198                    details: format!("Symbol lock released on error in {}", req.path),
199                    session_id: req.session_id.clone(),
200                    affected_files: vec![crate::FileChange {
201                        path: req.path.clone(),
202                        operation: "unlock".to_string(),
203                    }],
204                    symbol_changes: vec![],
205                    repo_id: repo_id_str.clone(),
206                    event_id: uuid::Uuid::new_v4().to_string(),
207                });
208            }
209            return Err(Status::internal(format!("Write failed: {e}")));
210        }
211    };
212
213    drop(ws);
214
215    let content_str = std::str::from_utf8(&req.content).ok();
216    let _ = engine
217        .changeset_store()
218        .upsert_file(changeset_id, &req.path, op, content_str)
219        .await;
220
221    let conflict_warnings: Vec<ConflictWarning> = Vec::new();
222
223    // Emit a file.modified (or file.added) event
224    let event_type = if is_new { "file.added" } else { "file.modified" };
225    server.event_bus().publish(crate::WatchEvent {
226        event_type: event_type.to_string(),
227        changeset_id: changeset_id.to_string(),
228        agent_id: session.agent_id.clone(),
229        affected_symbols: vec![],
230        details: format!("file {}: {}", op, req.path),
231        session_id: req.session_id.clone(),
232        affected_files: vec![crate::FileChange {
233            path: req.path.clone(),
234            operation: op.to_string(),
235        }],
236        symbol_changes: all_symbol_changes,
237        repo_id: repo_id_str,
238        event_id: uuid::Uuid::new_v4().to_string(),
239    });
240
241    info!(
242        session_id = %req.session_id,
243        path = %req.path,
244        hash = %new_hash,
245        changes = detected_changes.len(),
246        conflicts = conflict_warnings.len(),
247        "FILE_WRITE: completed"
248    );
249
250    Ok(Response::new(FileWriteResponse {
251        new_hash,
252        detected_changes,
253        conflict_warnings,
254    }))
255}
256
257/// Parse both old and new file content, diff per-symbol source text,
258/// and return only symbols that actually changed.
259///
260/// Returns `(detected_changes, all_symbol_change_details)`:
261/// - `detected_changes`: `SymbolChange` for the gRPC response (only truly changed symbols)
262/// - `all_symbol_change_details`: `SymbolChangeDetail` for claims + events (added/modified/deleted)
263fn detect_symbol_changes_diffed(
264    engine: &dk_engine::repo::Engine,
265    path: &str,
266    old_content: &[u8],
267    new_content: &[u8],
268    is_new_file: bool,
269) -> (Vec<SymbolChange>, Vec<crate::SymbolChangeDetail>) {
270    let file_path = std::path::Path::new(path);
271    let parser = engine.parser();
272
273    if !parser.supports_file(file_path) {
274        return (Vec::new(), Vec::new());
275    }
276
277    // Parse new file
278    let new_symbols = match parser.parse_file(file_path, new_content) {
279        Ok(analysis) => analysis.symbols,
280        Err(_) => return (Vec::new(), Vec::new()),
281    };
282
283    // If file is new, all symbols are "added"
284    if is_new_file || old_content.is_empty() {
285        let changes: Vec<SymbolChange> = new_symbols
286            .iter()
287            .map(|sym| SymbolChange {
288                symbol_name: sym.qualified_name.clone(),
289                change_type: sym.kind.to_string(),
290            })
291            .collect();
292        let details: Vec<crate::SymbolChangeDetail> = new_symbols
293            .iter()
294            .map(|sym| crate::SymbolChangeDetail {
295                symbol_name: sym.qualified_name.clone(),
296                file_path: path.to_string(),
297                change_type: "added".to_string(),
298                kind: sym.kind.to_string(),
299            })
300            .collect();
301        return (changes, details);
302    }
303
304    // Parse old file to get baseline symbols
305    let old_symbols = match parser.parse_file(file_path, old_content) {
306        Ok(analysis) => analysis.symbols,
307        Err(_) => {
308            // Can't parse old file — fall back to treating all new symbols as modified
309            let changes: Vec<SymbolChange> = new_symbols
310                .iter()
311                .map(|sym| SymbolChange {
312                    symbol_name: sym.qualified_name.clone(),
313                    change_type: sym.kind.to_string(),
314                })
315                .collect();
316            let details: Vec<crate::SymbolChangeDetail> = new_symbols
317                .iter()
318                .map(|sym| crate::SymbolChangeDetail {
319                    symbol_name: sym.qualified_name.clone(),
320                    file_path: path.to_string(),
321                    change_type: "modified".to_string(),
322                    kind: sym.kind.to_string(),
323                })
324                .collect();
325            return (changes, details);
326        }
327    };
328
329    // Build a map of old symbol qualified_name → source text.
330    // Use entry().or_insert() to keep the first occurrence when duplicate
331    // qualified names exist (e.g., overloaded methods in Java/Kotlin/C#).
332    let mut old_symbol_text: std::collections::HashMap<&str, &[u8]> = std::collections::HashMap::new();
333    for sym in &old_symbols {
334        let start = sym.span.start_byte as usize;
335        let end = sym.span.end_byte as usize;
336        if start <= end && end <= old_content.len() {
337            old_symbol_text.entry(sym.qualified_name.as_str()).or_insert(&old_content[start..end]);
338        }
339    }
340
341    let mut detected_changes = Vec::new();
342    let mut all_details = Vec::new();
343
344    // Deduplicate new symbols while preserving original parse order.
345    let mut seen_new: std::collections::HashSet<&str> = std::collections::HashSet::new();
346
347    // Compare each deduplicated new symbol against its old version
348    for sym in &new_symbols {
349        if !seen_new.insert(sym.qualified_name.as_str()) {
350            continue; // duplicate qualified name — already handled
351        }
352        let start = sym.span.start_byte as usize;
353        let end = sym.span.end_byte as usize;
354        let new_text = if start <= end && end <= new_content.len() {
355            &new_content[start..end]
356        } else {
357            continue; // invalid or inverted span, skip
358        };
359
360        match old_symbol_text.get(sym.qualified_name.as_str()) {
361            None => {
362                // Symbol not in old file — added
363                detected_changes.push(SymbolChange {
364                    symbol_name: sym.qualified_name.clone(),
365                    change_type: sym.kind.to_string(),
366                });
367                all_details.push(crate::SymbolChangeDetail {
368                    symbol_name: sym.qualified_name.clone(),
369                    file_path: path.to_string(),
370                    change_type: "added".to_string(),
371                    kind: sym.kind.to_string(),
372                });
373            }
374            Some(old_text) => {
375                if *old_text != new_text {
376                    // Symbol text changed — modified
377                    detected_changes.push(SymbolChange {
378                        symbol_name: sym.qualified_name.clone(),
379                        change_type: sym.kind.to_string(),
380                    });
381                    all_details.push(crate::SymbolChangeDetail {
382                        symbol_name: sym.qualified_name.clone(),
383                        file_path: path.to_string(),
384                        change_type: "modified".to_string(),
385                        kind: sym.kind.to_string(),
386                    });
387                }
388                // else: symbol text identical — skip (no claim needed)
389            }
390        }
391    }
392
393    // Detect deleted symbols (deduplicated to avoid double-reporting overloads)
394    let new_names: std::collections::HashSet<&str> = new_symbols
395        .iter()
396        .map(|s| s.qualified_name.as_str())
397        .collect();
398    let old_names: std::collections::HashSet<&str> = old_symbols
399        .iter()
400        .map(|s| s.qualified_name.as_str())
401        .collect();
402    for old_name in &old_names {
403        if !new_names.contains(old_name) {
404            if let Some(old_sym) = old_symbols.iter().find(|s| s.qualified_name.as_str() == *old_name) {
405                all_details.push(crate::SymbolChangeDetail {
406                    symbol_name: old_sym.qualified_name.clone(),
407                    file_path: path.to_string(),
408                    change_type: "deleted".to_string(),
409                    kind: old_sym.kind.to_string(),
410                });
411            }
412        }
413    }
414
415    (detected_changes, all_details)
416}