1use tonic::{Response, Status};
2use tracing::{info, warn};
3
4use dk_engine::conflict::{AcquireOutcome, SymbolClaim};
5use crate::server::ProtocolServer;
6use crate::validation::{validate_file_path, MAX_FILE_SIZE};
7use crate::{ConflictWarning, FileWriteRequest, FileWriteResponse, SymbolChange};
8
9pub async fn handle_file_write(
14 server: &ProtocolServer,
15 req: FileWriteRequest,
16) -> Result<Response<FileWriteResponse>, Status> {
17 validate_file_path(&req.path)?;
18
19 if req.content.len() > MAX_FILE_SIZE {
20 return Err(Status::invalid_argument("file content exceeds 50MB limit"));
21 }
22
23 let session = server.validate_session(&req.session_id)?;
24
25 let sid = req
26 .session_id
27 .parse::<uuid::Uuid>()
28 .map_err(|_| Status::invalid_argument("Invalid session ID"))?;
29 server.session_mgr().touch_session(&sid);
30
31 let engine = server.engine();
32
33 let ws = engine
35 .workspace_manager()
36 .get_workspace(&sid)
37 .ok_or_else(|| Status::not_found("Workspace not found for session"))?;
38
39 let (repo_id, is_new, old_content) = {
43 let (rid, git_repo) = engine
44 .get_repo(&session.codebase)
45 .await
46 .map_err(|e| Status::internal(format!("Repo error: {e}")))?;
47 match git_repo.read_tree_entry(&ws.base_commit, &req.path) {
48 Ok(bytes) => (rid, false, bytes),
49 Err(e) => {
50 warn!(
53 path = %req.path,
54 base_commit = %ws.base_commit,
55 error = %e,
56 "read_tree_entry failed — treating file as new"
57 );
58 (rid, true, Vec::new())
59 }
60 }
61 };
62 let repo_id_str = repo_id.to_string();
63 let changeset_id = ws.changeset_id;
64 let agent_name = ws.agent_name.clone();
65
66 drop(ws);
68
69 let op = if is_new { "add" } else { "modify" };
70
71 let (detected_changes, all_symbol_changes) =
73 detect_symbol_changes_diffed(engine, &req.path, &old_content, &req.content, is_new);
74
75 let claimable: Vec<&crate::SymbolChangeDetail> = all_symbol_changes
80 .iter()
81 .filter(|sc| sc.change_type == "added" || sc.change_type == "modified" || sc.change_type == "deleted")
82 .collect();
83
84 let mut acquired: Vec<String> = Vec::new();
85 let mut locked_symbols: Vec<ConflictWarning> = Vec::new();
86
87 for sc in &claimable {
88 let kind = sc.kind.parse::<dk_core::SymbolKind>().unwrap_or(dk_core::SymbolKind::Function);
89 match server.claim_tracker().acquire_lock(
90 repo_id,
91 &req.path,
92 SymbolClaim {
93 session_id: sid,
94 agent_name: agent_name.clone(),
95 qualified_name: sc.symbol_name.clone(),
96 kind,
97 first_touched_at: chrono::Utc::now(),
98 },
99 ).await {
100 Ok(AcquireOutcome::Fresh) => acquired.push(sc.symbol_name.clone()),
101 Ok(AcquireOutcome::ReAcquired) => {} Err(sl) => {
103 warn!(
104 session_id = %sid,
105 path = %req.path,
106 symbol = %sl.qualified_name,
107 locked_by = %sl.locked_by_agent,
108 "SYMBOL_LOCKED: write rejected"
109 );
110 locked_symbols.push(ConflictWarning {
111 file_path: req.path.clone(),
112 symbol_name: sl.qualified_name.clone(),
113 conflicting_agent: sl.locked_by_agent.clone(),
114 conflicting_session_id: sl.locked_by_session.to_string(),
115 message: format!(
116 "SYMBOL_LOCKED: '{}' is locked by agent '{}'. Call dk_watch(filter: '{}') to wait, then dk_file_read and retry.",
117 sl.qualified_name, sl.locked_by_agent, crate::merge::EVENT_LOCK_RELEASED,
118 ),
119 });
120 }
121 }
122 }
123
124 if !locked_symbols.is_empty() {
125 for name in &acquired {
128 server.claim_tracker().release_lock(repo_id, &req.path, sid, name).await;
129 server.event_bus().publish(crate::WatchEvent {
130 event_type: crate::merge::EVENT_LOCK_RELEASED.to_string(),
131 changeset_id: String::new(),
132 agent_id: agent_name.clone(),
133 affected_symbols: vec![name.clone()],
134 details: format!("Symbol lock rolled back on {}", req.path),
135 session_id: req.session_id.clone(),
136 affected_files: vec![crate::FileChange {
137 path: req.path.clone(),
138 operation: "unlock".to_string(),
139 }],
140 symbol_changes: vec![],
141 repo_id: repo_id_str.clone(),
142 event_id: uuid::Uuid::new_v4().to_string(),
143 });
144 }
145
146 info!(
147 session_id = %sid,
148 path = %req.path,
149 locked_count = locked_symbols.len(),
150 rolled_back = acquired.len(),
151 "FILE_WRITE: rejected — symbols locked, rolled back partial locks"
152 );
153
154 return Ok(Response::new(FileWriteResponse {
155 new_hash: String::new(),
156 detected_changes: Vec::new(),
157 conflict_warnings: locked_symbols,
158 }));
159 }
160
161 let ws = match engine.workspace_manager().get_workspace(&sid) {
164 Some(ws) => ws,
165 None => {
166 for name in &acquired {
167 server.claim_tracker().release_lock(repo_id, &req.path, sid, name).await;
168 server.event_bus().publish(crate::WatchEvent {
169 event_type: crate::merge::EVENT_LOCK_RELEASED.to_string(),
170 changeset_id: String::new(),
171 agent_id: agent_name.clone(),
172 affected_symbols: vec![name.clone()],
173 details: format!("Symbol lock released on error in {}", req.path),
174 session_id: req.session_id.clone(),
175 affected_files: vec![crate::FileChange {
176 path: req.path.clone(),
177 operation: "unlock".to_string(),
178 }],
179 symbol_changes: vec![],
180 repo_id: repo_id_str.clone(),
181 event_id: uuid::Uuid::new_v4().to_string(),
182 });
183 }
184 return Err(Status::not_found("Workspace not found for session"));
185 }
186 };
187
188 let new_hash = match ws.overlay.write(&req.path, req.content.clone(), is_new).await {
189 Ok(hash) => hash,
190 Err(e) => {
191 for name in &acquired {
192 server.claim_tracker().release_lock(repo_id, &req.path, sid, name).await;
193 server.event_bus().publish(crate::WatchEvent {
194 event_type: crate::merge::EVENT_LOCK_RELEASED.to_string(),
195 changeset_id: String::new(),
196 agent_id: agent_name.clone(),
197 affected_symbols: vec![name.clone()],
198 details: format!("Symbol lock released on error in {}", req.path),
199 session_id: req.session_id.clone(),
200 affected_files: vec![crate::FileChange {
201 path: req.path.clone(),
202 operation: "unlock".to_string(),
203 }],
204 symbol_changes: vec![],
205 repo_id: repo_id_str.clone(),
206 event_id: uuid::Uuid::new_v4().to_string(),
207 });
208 }
209 return Err(Status::internal(format!("Write failed: {e}")));
210 }
211 };
212
213 drop(ws);
214
215 let content_str = std::str::from_utf8(&req.content).ok();
216 let _ = engine
217 .changeset_store()
218 .upsert_file(changeset_id, &req.path, op, content_str)
219 .await;
220
221 let conflict_warnings: Vec<ConflictWarning> = Vec::new();
222
223 let event_type = if is_new { "file.added" } else { "file.modified" };
225 server.event_bus().publish(crate::WatchEvent {
226 event_type: event_type.to_string(),
227 changeset_id: changeset_id.to_string(),
228 agent_id: session.agent_id.clone(),
229 affected_symbols: vec![],
230 details: format!("file {}: {}", op, req.path),
231 session_id: req.session_id.clone(),
232 affected_files: vec![crate::FileChange {
233 path: req.path.clone(),
234 operation: op.to_string(),
235 }],
236 symbol_changes: all_symbol_changes,
237 repo_id: repo_id_str,
238 event_id: uuid::Uuid::new_v4().to_string(),
239 });
240
241 info!(
242 session_id = %req.session_id,
243 path = %req.path,
244 hash = %new_hash,
245 changes = detected_changes.len(),
246 conflicts = conflict_warnings.len(),
247 "FILE_WRITE: completed"
248 );
249
250 Ok(Response::new(FileWriteResponse {
251 new_hash,
252 detected_changes,
253 conflict_warnings,
254 }))
255}
256
257fn detect_symbol_changes_diffed(
264 engine: &dk_engine::repo::Engine,
265 path: &str,
266 old_content: &[u8],
267 new_content: &[u8],
268 is_new_file: bool,
269) -> (Vec<SymbolChange>, Vec<crate::SymbolChangeDetail>) {
270 let file_path = std::path::Path::new(path);
271 let parser = engine.parser();
272
273 if !parser.supports_file(file_path) {
274 return (Vec::new(), Vec::new());
275 }
276
277 let new_symbols = match parser.parse_file(file_path, new_content) {
279 Ok(analysis) => analysis.symbols,
280 Err(_) => return (Vec::new(), Vec::new()),
281 };
282
283 if is_new_file || old_content.is_empty() {
285 let changes: Vec<SymbolChange> = new_symbols
286 .iter()
287 .map(|sym| SymbolChange {
288 symbol_name: sym.qualified_name.clone(),
289 change_type: sym.kind.to_string(),
290 })
291 .collect();
292 let details: Vec<crate::SymbolChangeDetail> = new_symbols
293 .iter()
294 .map(|sym| crate::SymbolChangeDetail {
295 symbol_name: sym.qualified_name.clone(),
296 file_path: path.to_string(),
297 change_type: "added".to_string(),
298 kind: sym.kind.to_string(),
299 })
300 .collect();
301 return (changes, details);
302 }
303
304 let old_symbols = match parser.parse_file(file_path, old_content) {
306 Ok(analysis) => analysis.symbols,
307 Err(_) => {
308 let changes: Vec<SymbolChange> = new_symbols
310 .iter()
311 .map(|sym| SymbolChange {
312 symbol_name: sym.qualified_name.clone(),
313 change_type: sym.kind.to_string(),
314 })
315 .collect();
316 let details: Vec<crate::SymbolChangeDetail> = new_symbols
317 .iter()
318 .map(|sym| crate::SymbolChangeDetail {
319 symbol_name: sym.qualified_name.clone(),
320 file_path: path.to_string(),
321 change_type: "modified".to_string(),
322 kind: sym.kind.to_string(),
323 })
324 .collect();
325 return (changes, details);
326 }
327 };
328
329 let mut old_symbol_text: std::collections::HashMap<&str, &[u8]> = std::collections::HashMap::new();
333 for sym in &old_symbols {
334 let start = sym.span.start_byte as usize;
335 let end = sym.span.end_byte as usize;
336 if start <= end && end <= old_content.len() {
337 old_symbol_text.entry(sym.qualified_name.as_str()).or_insert(&old_content[start..end]);
338 }
339 }
340
341 let mut detected_changes = Vec::new();
342 let mut all_details = Vec::new();
343
344 let mut seen_new: std::collections::HashSet<&str> = std::collections::HashSet::new();
346
347 for sym in &new_symbols {
349 if !seen_new.insert(sym.qualified_name.as_str()) {
350 continue; }
352 let start = sym.span.start_byte as usize;
353 let end = sym.span.end_byte as usize;
354 let new_text = if start <= end && end <= new_content.len() {
355 &new_content[start..end]
356 } else {
357 continue; };
359
360 match old_symbol_text.get(sym.qualified_name.as_str()) {
361 None => {
362 detected_changes.push(SymbolChange {
364 symbol_name: sym.qualified_name.clone(),
365 change_type: sym.kind.to_string(),
366 });
367 all_details.push(crate::SymbolChangeDetail {
368 symbol_name: sym.qualified_name.clone(),
369 file_path: path.to_string(),
370 change_type: "added".to_string(),
371 kind: sym.kind.to_string(),
372 });
373 }
374 Some(old_text) => {
375 if *old_text != new_text {
376 detected_changes.push(SymbolChange {
378 symbol_name: sym.qualified_name.clone(),
379 change_type: sym.kind.to_string(),
380 });
381 all_details.push(crate::SymbolChangeDetail {
382 symbol_name: sym.qualified_name.clone(),
383 file_path: path.to_string(),
384 change_type: "modified".to_string(),
385 kind: sym.kind.to_string(),
386 });
387 }
388 }
390 }
391 }
392
393 let new_names: std::collections::HashSet<&str> = new_symbols
395 .iter()
396 .map(|s| s.qualified_name.as_str())
397 .collect();
398 let old_names: std::collections::HashSet<&str> = old_symbols
399 .iter()
400 .map(|s| s.qualified_name.as_str())
401 .collect();
402 for old_name in &old_names {
403 if !new_names.contains(old_name) {
404 if let Some(old_sym) = old_symbols.iter().find(|s| s.qualified_name.as_str() == *old_name) {
405 all_details.push(crate::SymbolChangeDetail {
406 symbol_name: old_sym.qualified_name.clone(),
407 file_path: path.to_string(),
408 change_type: "deleted".to_string(),
409 kind: old_sym.kind.to_string(),
410 });
411 }
412 }
413 }
414
415 (detected_changes, all_details)
416}