dk_protocol/
pre_submit.rs1use std::collections::HashMap;
2
3use tonic::{Response, Status};
4use tracing::info;
5
6use crate::server::ProtocolServer;
7use crate::{PreSubmitCheckRequest, PreSubmitCheckResponse, SemanticConflict};
8
9pub async fn handle_pre_submit_check(
17 server: &ProtocolServer,
18 req: PreSubmitCheckRequest,
19) -> Result<Response<PreSubmitCheckResponse>, Status> {
20 let session = server.validate_session(&req.session_id)?;
21
22 let sid = req
23 .session_id
24 .parse::<uuid::Uuid>()
25 .map_err(|_| Status::invalid_argument("Invalid session ID"))?;
26 server.session_mgr().touch_session(&sid);
27
28 let engine = server.engine();
29
30 let ws = engine
32 .workspace_manager()
33 .get_workspace(&sid)
34 .ok_or_else(|| Status::not_found("Workspace not found for session"))?;
35
36 let (_repo_id, git_repo) = engine
38 .get_repo(&session.codebase)
39 .await
40 .map_err(|e| Status::internal(format!("Repo error: {e}")))?;
41
42 let head_hash = git_repo
43 .head_hash()
44 .map_err(|e| Status::internal(format!("Failed to read HEAD: {e}")))?
45 .unwrap_or_else(|| "initial".to_string());
46
47 let overlay = ws.overlay_for_tree();
48 let files_modified = overlay.len() as u32;
49 let symbols_changed = ws.graph.change_count() as u32;
50
51 if head_hash == ws.base_commit || overlay.is_empty() {
53 info!(
54 session_id = %req.session_id,
55 files_modified,
56 symbols_changed,
57 "PRE_SUBMIT_CHECK: clean (fast-forward possible)"
58 );
59
60 return Ok(Response::new(PreSubmitCheckResponse {
61 has_conflicts: false,
62 potential_conflicts: Vec::new(),
63 files_modified,
64 symbols_changed,
65 }));
66 }
67
68 let parser = engine.parser();
74 let paths: Vec<&String> = overlay.iter().map(|(p, _)| p).collect();
75
76 let mut base_entries: HashMap<&str, Option<Vec<u8>>> = HashMap::with_capacity(paths.len());
77 for path in &paths {
78 base_entries.insert(path.as_str(), git_repo.read_tree_entry(&ws.base_commit, path).ok());
79 }
80
81 let mut head_entries: HashMap<&str, Option<Vec<u8>>> = HashMap::with_capacity(paths.len());
82 for path in &paths {
83 head_entries.insert(path.as_str(), git_repo.read_tree_entry(&head_hash, path).ok());
84 }
85
86 let mut conflicts = Vec::new();
87
88 for (path, maybe_content) in &overlay {
89 let base_content = base_entries.get(path.as_str()).and_then(|v| v.as_ref());
90 let head_content = head_entries.get(path.as_str()).and_then(|v| v.as_ref());
91
92 match maybe_content {
93 None => {
94 if let (Some(base), Some(head)) = (base_content, head_content) {
96 if base != head {
97 conflicts.push(SemanticConflict {
98 file_path: path.clone(),
99 symbol_name: "<entire file>".to_string(),
100 our_change: "deleted".to_string(),
101 their_change: "modified".to_string(),
102 });
103 }
104 }
105 }
106 Some(overlay_content) => {
107 match (base_content, head_content) {
108 (Some(base), Some(head)) => {
109 if base != head {
110 let analysis =
111 dk_engine::workspace::conflict::analyze_file_conflict(
112 path,
113 base,
114 head,
115 overlay_content,
116 parser,
117 );
118
119 if let dk_engine::workspace::conflict::MergeAnalysis::Conflict {
120 conflicts: file_conflicts,
121 } = analysis
122 {
123 for c in file_conflicts {
124 conflicts.push(SemanticConflict {
125 file_path: c.file_path,
126 symbol_name: c.symbol_name,
127 our_change: format!("{:?}", c.our_change),
128 their_change: format!("{:?}", c.their_change),
129 });
130 }
131 }
132 }
133 }
134 (None, Some(head_blob)) => {
135 if *overlay_content != *head_blob {
136 conflicts.push(SemanticConflict {
137 file_path: path.clone(),
138 symbol_name: "<entire file>".to_string(),
139 our_change: "added".to_string(),
140 their_change: "added".to_string(),
141 });
142 }
143 }
144 (Some(_), None) => {
145 conflicts.push(SemanticConflict {
146 file_path: path.clone(),
147 symbol_name: "<entire file>".to_string(),
148 our_change: "modified".to_string(),
149 their_change: "deleted".to_string(),
150 });
151 }
152 (None, None) => {
153 }
155 }
156 }
157 }
158 }
159
160 let has_conflicts = !conflicts.is_empty();
161
162 info!(
163 session_id = %req.session_id,
164 has_conflicts,
165 conflict_count = conflicts.len(),
166 files_modified,
167 symbols_changed,
168 "PRE_SUBMIT_CHECK: completed"
169 );
170
171 Ok(Response::new(PreSubmitCheckResponse {
172 has_conflicts,
173 potential_conflicts: conflicts,
174 files_modified,
175 symbols_changed,
176 }))
177}