1use std::time::Instant;
2
3use tonic::{Response, Status};
4use tracing::{info, warn};
5
6use dk_engine::conflict::{AcquireOutcome, SymbolClaim};
7use crate::server::ProtocolServer;
8use crate::validation::{validate_file_path, MAX_FILE_SIZE};
9use crate::{ConflictWarning, FileWriteRequest, FileWriteResponse, SymbolChange};
10
11pub async fn handle_file_write(
16 server: &ProtocolServer,
17 req: FileWriteRequest,
18) -> Result<Response<FileWriteResponse>, Status> {
19 validate_file_path(&req.path)?;
20
21 if req.content.len() > MAX_FILE_SIZE {
22 return Err(Status::invalid_argument("file content exceeds 50MB limit"));
23 }
24
25 let session = server.validate_session(&req.session_id)?;
26
27 let sid = req
28 .session_id
29 .parse::<uuid::Uuid>()
30 .map_err(|_| Status::invalid_argument("Invalid session ID"))?;
31 server.session_mgr().touch_session(&sid);
32
33 let engine = server.engine();
34
35 let ws = engine
37 .workspace_manager()
38 .get_workspace(&sid)
39 .ok_or_else(|| Status::not_found("Workspace not found for session"))?;
40
41 let (repo_id, is_new, old_content) = {
45 let (rid, git_repo) = engine
46 .get_repo(&session.codebase)
47 .await
48 .map_err(|e| Status::internal(format!("Repo error: {e}")))?;
49 match git_repo.read_tree_entry(&ws.base_commit, &req.path) {
50 Ok(bytes) => (rid, false, bytes),
51 Err(e) => {
52 warn!(
55 path = %req.path,
56 base_commit = %ws.base_commit,
57 error = %e,
58 "read_tree_entry failed — treating file as new"
59 );
60 (rid, true, Vec::new())
61 }
62 }
63 };
64 let repo_id_str = repo_id.to_string();
65 let changeset_id = ws.changeset_id;
66 let agent_name = ws.agent_name.clone();
67
68 drop(ws);
70
71 let op = if is_new { "add" } else { "modify" };
72
73 let (detected_changes, all_symbol_changes) =
75 detect_symbol_changes_diffed(engine, &req.path, &old_content, &req.content, is_new);
76
77 let claimable: Vec<&crate::SymbolChangeDetail> = all_symbol_changes
82 .iter()
83 .filter(|sc| sc.change_type == "added" || sc.change_type == "modified" || sc.change_type == "deleted")
84 .collect();
85
86 let mut acquired: Vec<String> = Vec::new();
87 let mut locked_symbols: Vec<ConflictWarning> = Vec::new();
88
89 for sc in &claimable {
90 let kind = sc.kind.parse::<dk_core::SymbolKind>().unwrap_or(dk_core::SymbolKind::Function);
91 match server.claim_tracker().acquire_lock(
92 repo_id,
93 &req.path,
94 SymbolClaim {
95 session_id: sid,
96 agent_name: agent_name.clone(),
97 qualified_name: sc.symbol_name.clone(),
98 kind,
99 first_touched_at: Instant::now(),
100 },
101 ) {
102 Ok(AcquireOutcome::Fresh) => acquired.push(sc.symbol_name.clone()),
103 Ok(AcquireOutcome::ReAcquired) => {} Err(sl) => {
105 warn!(
106 session_id = %sid,
107 path = %req.path,
108 symbol = %sl.qualified_name,
109 locked_by = %sl.locked_by_agent,
110 "SYMBOL_LOCKED: write rejected"
111 );
112 locked_symbols.push(ConflictWarning {
113 file_path: req.path.clone(),
114 symbol_name: sl.qualified_name.clone(),
115 conflicting_agent: sl.locked_by_agent.clone(),
116 conflicting_session_id: sl.locked_by_session.to_string(),
117 message: format!(
118 "SYMBOL_LOCKED: '{}' is locked by agent '{}'. Call dk_watch(filter: '{}') to wait, then dk_file_read and retry.",
119 sl.qualified_name, sl.locked_by_agent, crate::merge::EVENT_LOCK_RELEASED,
120 ),
121 });
122 }
123 }
124 }
125
126 if !locked_symbols.is_empty() {
127 for name in &acquired {
130 server.claim_tracker().release_lock(repo_id, &req.path, sid, name);
131 server.event_bus().publish(crate::WatchEvent {
132 event_type: crate::merge::EVENT_LOCK_RELEASED.to_string(),
133 changeset_id: String::new(),
134 agent_id: agent_name.clone(),
135 affected_symbols: vec![name.clone()],
136 details: format!("Symbol lock rolled back on {}", req.path),
137 session_id: req.session_id.clone(),
138 affected_files: vec![crate::FileChange {
139 path: req.path.clone(),
140 operation: "unlock".to_string(),
141 }],
142 symbol_changes: vec![],
143 repo_id: repo_id_str.clone(),
144 event_id: uuid::Uuid::new_v4().to_string(),
145 });
146 }
147
148 info!(
149 session_id = %sid,
150 path = %req.path,
151 locked_count = locked_symbols.len(),
152 rolled_back = acquired.len(),
153 "FILE_WRITE: rejected — symbols locked, rolled back partial locks"
154 );
155
156 return Ok(Response::new(FileWriteResponse {
157 new_hash: String::new(),
158 detected_changes: Vec::new(),
159 conflict_warnings: locked_symbols,
160 }));
161 }
162
163 let ws = match engine.workspace_manager().get_workspace(&sid) {
166 Some(ws) => ws,
167 None => {
168 for name in &acquired {
169 server.claim_tracker().release_lock(repo_id, &req.path, sid, name);
170 server.event_bus().publish(crate::WatchEvent {
171 event_type: crate::merge::EVENT_LOCK_RELEASED.to_string(),
172 changeset_id: String::new(),
173 agent_id: agent_name.clone(),
174 affected_symbols: vec![name.clone()],
175 details: format!("Symbol lock released on error in {}", req.path),
176 session_id: req.session_id.clone(),
177 affected_files: vec![crate::FileChange {
178 path: req.path.clone(),
179 operation: "unlock".to_string(),
180 }],
181 symbol_changes: vec![],
182 repo_id: repo_id_str.clone(),
183 event_id: uuid::Uuid::new_v4().to_string(),
184 });
185 }
186 return Err(Status::not_found("Workspace not found for session"));
187 }
188 };
189
190 let new_hash = match ws.overlay.write(&req.path, req.content.clone(), is_new).await {
191 Ok(hash) => hash,
192 Err(e) => {
193 for name in &acquired {
194 server.claim_tracker().release_lock(repo_id, &req.path, sid, name);
195 server.event_bus().publish(crate::WatchEvent {
196 event_type: crate::merge::EVENT_LOCK_RELEASED.to_string(),
197 changeset_id: String::new(),
198 agent_id: agent_name.clone(),
199 affected_symbols: vec![name.clone()],
200 details: format!("Symbol lock released on error in {}", req.path),
201 session_id: req.session_id.clone(),
202 affected_files: vec![crate::FileChange {
203 path: req.path.clone(),
204 operation: "unlock".to_string(),
205 }],
206 symbol_changes: vec![],
207 repo_id: repo_id_str.clone(),
208 event_id: uuid::Uuid::new_v4().to_string(),
209 });
210 }
211 return Err(Status::internal(format!("Write failed: {e}")));
212 }
213 };
214
215 drop(ws);
216
217 let content_str = std::str::from_utf8(&req.content).ok();
218 let _ = engine
219 .changeset_store()
220 .upsert_file(changeset_id, &req.path, op, content_str)
221 .await;
222
223 let conflict_warnings: Vec<ConflictWarning> = Vec::new();
224
225 let event_type = if is_new { "file.added" } else { "file.modified" };
227 server.event_bus().publish(crate::WatchEvent {
228 event_type: event_type.to_string(),
229 changeset_id: changeset_id.to_string(),
230 agent_id: session.agent_id.clone(),
231 affected_symbols: vec![],
232 details: format!("file {}: {}", op, req.path),
233 session_id: req.session_id.clone(),
234 affected_files: vec![crate::FileChange {
235 path: req.path.clone(),
236 operation: op.to_string(),
237 }],
238 symbol_changes: all_symbol_changes,
239 repo_id: repo_id_str,
240 event_id: uuid::Uuid::new_v4().to_string(),
241 });
242
243 info!(
244 session_id = %req.session_id,
245 path = %req.path,
246 hash = %new_hash,
247 changes = detected_changes.len(),
248 conflicts = conflict_warnings.len(),
249 "FILE_WRITE: completed"
250 );
251
252 Ok(Response::new(FileWriteResponse {
253 new_hash,
254 detected_changes,
255 conflict_warnings,
256 }))
257}
258
259fn detect_symbol_changes_diffed(
266 engine: &dk_engine::repo::Engine,
267 path: &str,
268 old_content: &[u8],
269 new_content: &[u8],
270 is_new_file: bool,
271) -> (Vec<SymbolChange>, Vec<crate::SymbolChangeDetail>) {
272 let file_path = std::path::Path::new(path);
273 let parser = engine.parser();
274
275 if !parser.supports_file(file_path) {
276 return (Vec::new(), Vec::new());
277 }
278
279 let new_symbols = match parser.parse_file(file_path, new_content) {
281 Ok(analysis) => analysis.symbols,
282 Err(_) => return (Vec::new(), Vec::new()),
283 };
284
285 if is_new_file || old_content.is_empty() {
287 let changes: Vec<SymbolChange> = new_symbols
288 .iter()
289 .map(|sym| SymbolChange {
290 symbol_name: sym.qualified_name.clone(),
291 change_type: sym.kind.to_string(),
292 })
293 .collect();
294 let details: Vec<crate::SymbolChangeDetail> = new_symbols
295 .iter()
296 .map(|sym| crate::SymbolChangeDetail {
297 symbol_name: sym.qualified_name.clone(),
298 file_path: path.to_string(),
299 change_type: "added".to_string(),
300 kind: sym.kind.to_string(),
301 })
302 .collect();
303 return (changes, details);
304 }
305
306 let old_symbols = match parser.parse_file(file_path, old_content) {
308 Ok(analysis) => analysis.symbols,
309 Err(_) => {
310 let changes: Vec<SymbolChange> = new_symbols
312 .iter()
313 .map(|sym| SymbolChange {
314 symbol_name: sym.qualified_name.clone(),
315 change_type: sym.kind.to_string(),
316 })
317 .collect();
318 let details: Vec<crate::SymbolChangeDetail> = new_symbols
319 .iter()
320 .map(|sym| crate::SymbolChangeDetail {
321 symbol_name: sym.qualified_name.clone(),
322 file_path: path.to_string(),
323 change_type: "modified".to_string(),
324 kind: sym.kind.to_string(),
325 })
326 .collect();
327 return (changes, details);
328 }
329 };
330
331 let mut old_symbol_text: std::collections::HashMap<&str, &[u8]> = std::collections::HashMap::new();
335 for sym in &old_symbols {
336 let start = sym.span.start_byte as usize;
337 let end = sym.span.end_byte as usize;
338 if start <= end && end <= old_content.len() {
339 old_symbol_text.entry(sym.qualified_name.as_str()).or_insert(&old_content[start..end]);
340 }
341 }
342
343 let mut detected_changes = Vec::new();
344 let mut all_details = Vec::new();
345
346 let mut seen_new: std::collections::HashSet<&str> = std::collections::HashSet::new();
348
349 for sym in &new_symbols {
351 if !seen_new.insert(sym.qualified_name.as_str()) {
352 continue; }
354 let start = sym.span.start_byte as usize;
355 let end = sym.span.end_byte as usize;
356 let new_text = if start <= end && end <= new_content.len() {
357 &new_content[start..end]
358 } else {
359 continue; };
361
362 match old_symbol_text.get(sym.qualified_name.as_str()) {
363 None => {
364 detected_changes.push(SymbolChange {
366 symbol_name: sym.qualified_name.clone(),
367 change_type: sym.kind.to_string(),
368 });
369 all_details.push(crate::SymbolChangeDetail {
370 symbol_name: sym.qualified_name.clone(),
371 file_path: path.to_string(),
372 change_type: "added".to_string(),
373 kind: sym.kind.to_string(),
374 });
375 }
376 Some(old_text) => {
377 if *old_text != new_text {
378 detected_changes.push(SymbolChange {
380 symbol_name: sym.qualified_name.clone(),
381 change_type: sym.kind.to_string(),
382 });
383 all_details.push(crate::SymbolChangeDetail {
384 symbol_name: sym.qualified_name.clone(),
385 file_path: path.to_string(),
386 change_type: "modified".to_string(),
387 kind: sym.kind.to_string(),
388 });
389 }
390 }
392 }
393 }
394
395 let new_names: std::collections::HashSet<&str> = new_symbols
397 .iter()
398 .map(|s| s.qualified_name.as_str())
399 .collect();
400 let old_names: std::collections::HashSet<&str> = old_symbols
401 .iter()
402 .map(|s| s.qualified_name.as_str())
403 .collect();
404 for old_name in &old_names {
405 if !new_names.contains(old_name) {
406 if let Some(old_sym) = old_symbols.iter().find(|s| s.qualified_name.as_str() == *old_name) {
407 all_details.push(crate::SymbolChangeDetail {
408 symbol_name: old_sym.qualified_name.clone(),
409 file_path: path.to_string(),
410 change_type: "deleted".to_string(),
411 kind: old_sym.kind.to_string(),
412 });
413 }
414 }
415 }
416
417 (detected_changes, all_details)
418}