dk_protocol/
file_write.rs1use std::time::Instant;
2
3use tonic::{Response, Status};
4use tracing::{info, warn};
5
6use dk_engine::conflict::SymbolClaim;
7use crate::server::ProtocolServer;
8use crate::validation::{validate_file_path, MAX_FILE_SIZE};
9use crate::{ConflictWarning, FileWriteRequest, FileWriteResponse, SymbolChange};
10
11pub async fn handle_file_write(
16 server: &ProtocolServer,
17 req: FileWriteRequest,
18) -> Result<Response<FileWriteResponse>, Status> {
19 validate_file_path(&req.path)?;
20
21 if req.content.len() > MAX_FILE_SIZE {
22 return Err(Status::invalid_argument("file content exceeds 50MB limit"));
23 }
24
25 let session = server.validate_session(&req.session_id)?;
26
27 let sid = req
28 .session_id
29 .parse::<uuid::Uuid>()
30 .map_err(|_| Status::invalid_argument("Invalid session ID"))?;
31 server.session_mgr().touch_session(&sid);
32
33 let engine = server.engine();
34
35 let ws = engine
37 .workspace_manager()
38 .get_workspace(&sid)
39 .ok_or_else(|| Status::not_found("Workspace not found for session"))?;
40
41 let (repo_id, is_new) = {
45 let (rid, git_repo) = engine
46 .get_repo(&session.codebase)
47 .await
48 .map_err(|e| Status::internal(format!("Repo error: {e}")))?;
49 let new = git_repo
50 .read_tree_entry(&ws.base_commit, &req.path)
51 .is_err();
52 (rid, new)
53 };
54 let repo_id_str = repo_id.to_string();
55
56 let pre_write_symbols: std::collections::HashSet<String> = engine
58 .symbol_store()
59 .find_by_file(repo_id, &req.path)
60 .await
61 .unwrap_or_default()
62 .into_iter()
63 .map(|s| s.qualified_name)
64 .collect();
65
66 let new_hash = ws
68 .overlay
69 .write(&req.path, req.content.clone(), is_new)
70 .await
71 .map_err(|e| Status::internal(format!("Write failed: {e}")))?;
72
73 let changeset_id = ws.changeset_id;
74 let agent_name = ws.agent_name.clone();
75
76 drop(ws);
78
79 let op = if is_new { "add" } else { "modify" };
81 let content_str = std::str::from_utf8(&req.content).ok();
82 let _ = engine
83 .changeset_store()
84 .upsert_file(changeset_id, &req.path, op, content_str)
85 .await;
86
87 let detected_changes = detect_symbol_changes(engine, &req.path, &req.content);
89
90 let symbol_changes: Vec<crate::SymbolChangeDetail> = detected_changes
93 .iter()
94 .map(|sc| {
95 let change_type = if is_new || !pre_write_symbols.contains(&sc.symbol_name) {
96 "added"
97 } else {
98 "modified"
99 };
100 crate::SymbolChangeDetail {
101 symbol_name: sc.symbol_name.clone(),
102 file_path: req.path.clone(),
103 change_type: change_type.to_string(),
104 kind: sc.change_type.clone(),
105 }
106 })
107 .collect();
108
109 let mut all_symbol_changes = symbol_changes;
114 if !detected_changes.is_empty() {
115 let detected_names: std::collections::HashSet<&str> = detected_changes
116 .iter()
117 .map(|sc| sc.symbol_name.as_str())
118 .collect();
119 for name in &pre_write_symbols {
120 if !detected_names.contains(name.as_str()) {
121 all_symbol_changes.push(crate::SymbolChangeDetail {
122 symbol_name: name.clone(),
123 file_path: req.path.clone(),
124 change_type: "deleted".to_string(),
125 kind: String::new(),
126 });
127 }
128 }
129 }
130
131 let conflict_warnings = {
136 let claimable: Vec<&crate::SymbolChangeDetail> = all_symbol_changes
137 .iter()
138 .filter(|sc| sc.change_type == "added" || sc.change_type == "modified")
139 .collect();
140
141 let qualified_names: Vec<String> = claimable.iter().map(|sc| sc.symbol_name.clone()).collect();
143 let conflicts = server.claim_tracker().check_conflicts(
144 repo_id,
145 &req.path,
146 sid,
147 &qualified_names,
148 );
149
150 for sc in &claimable {
152 let kind = sc.kind.parse::<dk_core::SymbolKind>().unwrap_or(dk_core::SymbolKind::Function);
153 server.claim_tracker().record_claim(
154 repo_id,
155 &req.path,
156 SymbolClaim {
157 session_id: sid,
158 agent_name: agent_name.clone(),
159 qualified_name: sc.symbol_name.clone(),
160 kind,
161 first_touched_at: Instant::now(),
162 },
163 );
164 }
165
166 let warnings: Vec<ConflictWarning> = conflicts
168 .into_iter()
169 .map(|c| {
170 let msg = format!(
171 "Symbol '{}' was already modified by agent '{}' (session {})",
172 c.qualified_name, c.conflicting_agent, c.conflicting_session,
173 );
174 warn!(
175 session_id = %sid,
176 path = %req.path,
177 symbol = %c.qualified_name,
178 conflicting_agent = %c.conflicting_agent,
179 "CONFLICT_WARNING: {msg}"
180 );
181 ConflictWarning {
182 file_path: req.path.clone(),
183 symbol_name: c.qualified_name,
184 conflicting_agent: c.conflicting_agent,
185 conflicting_session_id: c.conflicting_session.to_string(),
186 message: msg,
187 }
188 })
189 .collect();
190 warnings
191 };
192
193 let event_type = if is_new { "file.added" } else { "file.modified" };
195 server.event_bus().publish(crate::WatchEvent {
196 event_type: event_type.to_string(),
197 changeset_id: changeset_id.to_string(),
198 agent_id: session.agent_id.clone(),
199 affected_symbols: vec![],
200 details: format!("file {}: {}", op, req.path),
201 session_id: req.session_id.clone(),
202 affected_files: vec![crate::FileChange {
203 path: req.path.clone(),
204 operation: op.to_string(),
205 }],
206 symbol_changes: all_symbol_changes,
207 repo_id: repo_id_str,
208 event_id: uuid::Uuid::new_v4().to_string(),
209 });
210
211 info!(
212 session_id = %req.session_id,
213 path = %req.path,
214 hash = %new_hash,
215 changes = detected_changes.len(),
216 conflicts = conflict_warnings.len(),
217 "FILE_WRITE: completed"
218 );
219
220 Ok(Response::new(FileWriteResponse {
221 new_hash,
222 detected_changes,
223 conflict_warnings,
224 }))
225}
226
227fn detect_symbol_changes(
232 engine: &dk_engine::repo::Engine,
233 path: &str,
234 content: &[u8],
235) -> Vec<SymbolChange> {
236 let file_path = std::path::Path::new(path);
237 let parser = engine.parser();
238
239 if !parser.supports_file(file_path) {
240 return Vec::new();
241 }
242
243 match parser.parse_file(file_path, content) {
244 Ok(analysis) => analysis
245 .symbols
246 .iter()
247 .map(|sym| SymbolChange {
248 symbol_name: sym.qualified_name.clone(),
249 change_type: sym.kind.to_string(),
250 })
251 .collect(),
252 Err(_) => Vec::new(),
253 }
254}