1use std::time::Instant;
2
3use tonic::{Response, Status};
4use tracing::{info, warn};
5
6use dk_engine::conflict::SymbolClaim;
7use crate::server::ProtocolServer;
8use crate::validation::{validate_file_path, MAX_FILE_SIZE};
9use crate::{ConflictWarning, FileWriteRequest, FileWriteResponse, SymbolChange};
10
11pub async fn handle_file_write(
16 server: &ProtocolServer,
17 req: FileWriteRequest,
18) -> Result<Response<FileWriteResponse>, Status> {
19 validate_file_path(&req.path)?;
20
21 if req.content.len() > MAX_FILE_SIZE {
22 return Err(Status::invalid_argument("file content exceeds 50MB limit"));
23 }
24
25 let session = server.validate_session(&req.session_id)?;
26
27 let sid = req
28 .session_id
29 .parse::<uuid::Uuid>()
30 .map_err(|_| Status::invalid_argument("Invalid session ID"))?;
31 server.session_mgr().touch_session(&sid);
32
33 let engine = server.engine();
34
35 let ws = engine
37 .workspace_manager()
38 .get_workspace(&sid)
39 .ok_or_else(|| Status::not_found("Workspace not found for session"))?;
40
41 let (repo_id, is_new, old_content) = {
45 let (rid, git_repo) = engine
46 .get_repo(&session.codebase)
47 .await
48 .map_err(|e| Status::internal(format!("Repo error: {e}")))?;
49 match git_repo.read_tree_entry(&ws.base_commit, &req.path) {
50 Ok(bytes) => (rid, false, bytes),
51 Err(e) => {
52 warn!(
55 path = %req.path,
56 base_commit = %ws.base_commit,
57 error = %e,
58 "read_tree_entry failed — treating file as new"
59 );
60 (rid, true, Vec::new())
61 }
62 }
63 };
64 let repo_id_str = repo_id.to_string();
65
66 let new_hash = ws
68 .overlay
69 .write(&req.path, req.content.clone(), is_new)
70 .await
71 .map_err(|e| Status::internal(format!("Write failed: {e}")))?;
72
73 let changeset_id = ws.changeset_id;
74 let agent_name = ws.agent_name.clone();
75
76 drop(ws);
78
79 let op = if is_new { "add" } else { "modify" };
81 let content_str = std::str::from_utf8(&req.content).ok();
82 let _ = engine
83 .changeset_store()
84 .upsert_file(changeset_id, &req.path, op, content_str)
85 .await;
86
87 let (detected_changes, all_symbol_changes) =
90 detect_symbol_changes_diffed(engine, &req.path, &old_content, &req.content, is_new);
91
92 let conflict_warnings = {
97 let claimable: Vec<&crate::SymbolChangeDetail> = all_symbol_changes
98 .iter()
99 .filter(|sc| sc.change_type == "added" || sc.change_type == "modified")
100 .collect();
101
102 let qualified_names: Vec<String> = claimable.iter().map(|sc| sc.symbol_name.clone()).collect();
104 let conflicts = server.claim_tracker().check_conflicts(
105 repo_id,
106 &req.path,
107 sid,
108 &qualified_names,
109 );
110
111 for sc in &claimable {
113 let kind = sc.kind.parse::<dk_core::SymbolKind>().unwrap_or(dk_core::SymbolKind::Function);
114 server.claim_tracker().record_claim(
115 repo_id,
116 &req.path,
117 SymbolClaim {
118 session_id: sid,
119 agent_name: agent_name.clone(),
120 qualified_name: sc.symbol_name.clone(),
121 kind,
122 first_touched_at: Instant::now(),
123 },
124 );
125 }
126
127 let warnings: Vec<ConflictWarning> = conflicts
129 .into_iter()
130 .map(|c| {
131 let msg = format!(
132 "Symbol '{}' was already modified by agent '{}' (session {})",
133 c.qualified_name, c.conflicting_agent, c.conflicting_session,
134 );
135 warn!(
136 session_id = %sid,
137 path = %req.path,
138 symbol = %c.qualified_name,
139 conflicting_agent = %c.conflicting_agent,
140 "CONFLICT_WARNING: {msg}"
141 );
142 ConflictWarning {
143 file_path: req.path.clone(),
144 symbol_name: c.qualified_name,
145 conflicting_agent: c.conflicting_agent,
146 conflicting_session_id: c.conflicting_session.to_string(),
147 message: msg,
148 }
149 })
150 .collect();
151 warnings
152 };
153
154 let event_type = if is_new { "file.added" } else { "file.modified" };
156 server.event_bus().publish(crate::WatchEvent {
157 event_type: event_type.to_string(),
158 changeset_id: changeset_id.to_string(),
159 agent_id: session.agent_id.clone(),
160 affected_symbols: vec![],
161 details: format!("file {}: {}", op, req.path),
162 session_id: req.session_id.clone(),
163 affected_files: vec![crate::FileChange {
164 path: req.path.clone(),
165 operation: op.to_string(),
166 }],
167 symbol_changes: all_symbol_changes,
168 repo_id: repo_id_str,
169 event_id: uuid::Uuid::new_v4().to_string(),
170 });
171
172 info!(
173 session_id = %req.session_id,
174 path = %req.path,
175 hash = %new_hash,
176 changes = detected_changes.len(),
177 conflicts = conflict_warnings.len(),
178 "FILE_WRITE: completed"
179 );
180
181 Ok(Response::new(FileWriteResponse {
182 new_hash,
183 detected_changes,
184 conflict_warnings,
185 }))
186}
187
188fn detect_symbol_changes_diffed(
195 engine: &dk_engine::repo::Engine,
196 path: &str,
197 old_content: &[u8],
198 new_content: &[u8],
199 is_new_file: bool,
200) -> (Vec<SymbolChange>, Vec<crate::SymbolChangeDetail>) {
201 let file_path = std::path::Path::new(path);
202 let parser = engine.parser();
203
204 if !parser.supports_file(file_path) {
205 return (Vec::new(), Vec::new());
206 }
207
208 let new_symbols = match parser.parse_file(file_path, new_content) {
210 Ok(analysis) => analysis.symbols,
211 Err(_) => return (Vec::new(), Vec::new()),
212 };
213
214 if is_new_file || old_content.is_empty() {
216 let changes: Vec<SymbolChange> = new_symbols
217 .iter()
218 .map(|sym| SymbolChange {
219 symbol_name: sym.qualified_name.clone(),
220 change_type: sym.kind.to_string(),
221 })
222 .collect();
223 let details: Vec<crate::SymbolChangeDetail> = new_symbols
224 .iter()
225 .map(|sym| crate::SymbolChangeDetail {
226 symbol_name: sym.qualified_name.clone(),
227 file_path: path.to_string(),
228 change_type: "added".to_string(),
229 kind: sym.kind.to_string(),
230 })
231 .collect();
232 return (changes, details);
233 }
234
235 let old_symbols = match parser.parse_file(file_path, old_content) {
237 Ok(analysis) => analysis.symbols,
238 Err(_) => {
239 let changes: Vec<SymbolChange> = new_symbols
241 .iter()
242 .map(|sym| SymbolChange {
243 symbol_name: sym.qualified_name.clone(),
244 change_type: sym.kind.to_string(),
245 })
246 .collect();
247 let details: Vec<crate::SymbolChangeDetail> = new_symbols
248 .iter()
249 .map(|sym| crate::SymbolChangeDetail {
250 symbol_name: sym.qualified_name.clone(),
251 file_path: path.to_string(),
252 change_type: "modified".to_string(),
253 kind: sym.kind.to_string(),
254 })
255 .collect();
256 return (changes, details);
257 }
258 };
259
260 let mut old_symbol_text: std::collections::HashMap<&str, &[u8]> = std::collections::HashMap::new();
264 for sym in &old_symbols {
265 let start = sym.span.start_byte as usize;
266 let end = sym.span.end_byte as usize;
267 if start <= end && end <= old_content.len() {
268 old_symbol_text.entry(sym.qualified_name.as_str()).or_insert(&old_content[start..end]);
269 }
270 }
271
272 let mut detected_changes = Vec::new();
273 let mut all_details = Vec::new();
274
275 let mut seen_new: std::collections::HashSet<&str> = std::collections::HashSet::new();
277
278 for sym in &new_symbols {
280 if !seen_new.insert(sym.qualified_name.as_str()) {
281 continue; }
283 let start = sym.span.start_byte as usize;
284 let end = sym.span.end_byte as usize;
285 let new_text = if start <= end && end <= new_content.len() {
286 &new_content[start..end]
287 } else {
288 continue; };
290
291 match old_symbol_text.get(sym.qualified_name.as_str()) {
292 None => {
293 detected_changes.push(SymbolChange {
295 symbol_name: sym.qualified_name.clone(),
296 change_type: sym.kind.to_string(),
297 });
298 all_details.push(crate::SymbolChangeDetail {
299 symbol_name: sym.qualified_name.clone(),
300 file_path: path.to_string(),
301 change_type: "added".to_string(),
302 kind: sym.kind.to_string(),
303 });
304 }
305 Some(old_text) => {
306 if *old_text != new_text {
307 detected_changes.push(SymbolChange {
309 symbol_name: sym.qualified_name.clone(),
310 change_type: sym.kind.to_string(),
311 });
312 all_details.push(crate::SymbolChangeDetail {
313 symbol_name: sym.qualified_name.clone(),
314 file_path: path.to_string(),
315 change_type: "modified".to_string(),
316 kind: sym.kind.to_string(),
317 });
318 }
319 }
321 }
322 }
323
324 let new_names: std::collections::HashSet<&str> = new_symbols
326 .iter()
327 .map(|s| s.qualified_name.as_str())
328 .collect();
329 let old_names: std::collections::HashSet<&str> = old_symbols
330 .iter()
331 .map(|s| s.qualified_name.as_str())
332 .collect();
333 for old_name in &old_names {
334 if !new_names.contains(old_name) {
335 if let Some(old_sym) = old_symbols.iter().find(|s| s.qualified_name.as_str() == *old_name) {
336 all_details.push(crate::SymbolChangeDetail {
337 symbol_name: old_sym.qualified_name.clone(),
338 file_path: path.to_string(),
339 change_type: "deleted".to_string(),
340 kind: old_sym.kind.to_string(),
341 });
342 }
343 }
344 }
345
346 (detected_changes, all_details)
347}