Skip to main content

dk_protocol/
merge.rs

1use tonic::Status;
2use uuid::Uuid;
3
4use dk_engine::workspace::merge::{merge_workspace, WorkspaceMergeResult};
5
6use crate::server::ProtocolServer;
7use crate::{merge_response, ConflictDetail, MergeConflict, MergeRequest, MergeResponse, MergeSuccess};
8
9/// Conflict type for true write-write semantic conflicts.
10const CONFLICT_TYPE_TRUE: &str = "true_conflict";
11
12/// Sanitize a string for protobuf `string` fields.
13///
14/// Rust `String` is guaranteed valid UTF-8, but content originating from
15/// tree-sitter AST parsing may contain null bytes or replacement characters
16/// from lossy conversions.  Strip null bytes so the value round-trips cleanly
17/// through protobuf serialization/deserialization.
18fn sanitize_for_proto(s: &str) -> String {
19    s.replace('\0', "")
20}
21
22pub async fn handle_merge(
23    server: &ProtocolServer,
24    req: MergeRequest,
25) -> Result<MergeResponse, Status> {
26    let session = server.validate_session(&req.session_id)?;
27    let engine = server.engine();
28
29    let sid = req
30        .session_id
31        .parse::<Uuid>()
32        .map_err(|_| Status::invalid_argument("Invalid session ID"))?;
33
34    // Resolve repo_id_str for enriched events (non-fatal — empty string on failure)
35    let repo_id_str = match engine.get_repo(&session.codebase).await {
36        Ok((rid, _)) => rid.to_string(),
37        Err(_) => String::new(),
38    };
39
40    let changeset_id = req.changeset_id.parse::<Uuid>()
41        .map_err(|_| Status::invalid_argument("invalid changeset_id"))?;
42
43    // Get changeset and verify it's approved
44    let changeset = engine.changeset_store().get(changeset_id).await
45        .map_err(|e| Status::not_found(e.to_string()))?;
46
47    if changeset.state != "approved" {
48        return Err(Status::failed_precondition(format!(
49            "changeset is '{}', must be 'approved' to merge",
50            changeset.state
51        )));
52    }
53
54    // Get workspace for this session
55    let ws = engine
56        .workspace_manager()
57        .get_workspace(&sid)
58        .ok_or_else(|| Status::not_found("Workspace not found for session"))?;
59
60    // Get git repo — also use this repo_id for lock release (the first get_repo
61    // call is non-fatal and may return empty string, but this one propagates errors)
62    let (repo_id, git_repo) = engine.get_repo(&session.codebase).await
63        .map_err(|e| Status::internal(e.to_string()))?;
64
65    let agent = changeset.agent_id.as_deref().unwrap_or("agent");
66
67    let (effective_name, effective_email) =
68        dk_core::resolve_author(&req.author_name, &req.author_email, agent);
69
70    // Capture affected files from workspace overlay before merge/drop
71    let affected_files: Vec<crate::FileChange> = ws.overlay.list_changes()
72        .iter()
73        .map(|(path, entry)| {
74            let operation = match entry {
75                dk_engine::workspace::overlay::OverlayEntry::Added { .. } => "add",
76                dk_engine::workspace::overlay::OverlayEntry::Modified { .. } => "modify",
77                dk_engine::workspace::overlay::OverlayEntry::Deleted => "delete",
78            };
79            crate::FileChange {
80                path: path.clone(),
81                operation: operation.to_string(),
82            }
83        })
84        .collect();
85
86    // Use the programmatic workspace merge instead of git add -A
87    let merge_result = merge_workspace(
88        &ws,
89        &git_repo,
90        engine.parser(),
91        &req.commit_message,
92        &effective_name,
93        &effective_email,
94    )
95    .map_err(|e| Status::internal(format!("merge failed: {e}")))?;
96
97    // Drop workspace guard before further async work
98    drop(ws);
99
100    match merge_result {
101        WorkspaceMergeResult::FastMerge { commit_hash } => {
102            // Release locks first — git commit is already in the tree,
103            // so locks must be freed regardless of changeset-store state.
104            release_locks_and_emit(server, repo_id, sid, &req.session_id);
105
106            // Update changeset status to merged
107            engine.changeset_store().set_merged(changeset_id, &commit_hash).await
108                .map_err(|e| Status::internal(e.to_string()))?;
109
110            // Publish merge event
111            server.event_bus().publish(crate::WatchEvent {
112                event_type: "changeset.merged".to_string(),
113                changeset_id: changeset_id.to_string(),
114                agent_id: changeset.agent_id.clone().unwrap_or_default(),
115                affected_symbols: vec![],
116                details: format!("fast-merged as {}", commit_hash),
117                session_id: req.session_id.clone(),
118                affected_files: affected_files.clone(),
119                symbol_changes: vec![],
120                repo_id: repo_id_str.clone(),
121                event_id: Uuid::new_v4().to_string(),
122            });
123
124            Ok(MergeResponse {
125                result: Some(merge_response::Result::Success(MergeSuccess {
126                    commit_hash: commit_hash.clone(),
127                    merged_version: commit_hash,
128                    auto_rebased: false,
129                    auto_rebased_files: Vec::new(),
130                })),
131            })
132        }
133
134        WorkspaceMergeResult::RebaseMerge {
135            commit_hash,
136            auto_rebased_files,
137        } => {
138            // Release locks first — git commit is already in the tree.
139            release_locks_and_emit(server, repo_id, sid, &req.session_id);
140
141            // Update changeset status to merged
142            engine.changeset_store().set_merged(changeset_id, &commit_hash).await
143                .map_err(|e| Status::internal(e.to_string()))?;
144
145            // Publish merge event
146            server.event_bus().publish(crate::WatchEvent {
147                event_type: "changeset.merged".to_string(),
148                changeset_id: changeset_id.to_string(),
149                agent_id: changeset.agent_id.clone().unwrap_or_default(),
150                affected_symbols: vec![],
151                details: format!(
152                    "rebase-merged as {} (auto-rebased {} files)",
153                    commit_hash,
154                    auto_rebased_files.len()
155                ),
156                session_id: req.session_id.clone(),
157                affected_files,
158                symbol_changes: vec![],
159                repo_id: repo_id_str.clone(),
160                event_id: Uuid::new_v4().to_string(),
161            });
162
163            Ok(MergeResponse {
164                result: Some(merge_response::Result::Success(MergeSuccess {
165                    commit_hash: commit_hash.clone(),
166                    merged_version: commit_hash,
167                    auto_rebased: true,
168                    auto_rebased_files,
169                })),
170            })
171        }
172
173        WorkspaceMergeResult::Conflicts { conflicts } => {
174            // Intentionally NOT releasing locks here. The agent retains its locks
175            // while resolving conflicts (dk_resolve → retry dk_merge). Locks are
176            // released when the session is closed (dk_close) or times out (30 min GC).
177            let conflict_details: Vec<ConflictDetail> = conflicts
178                .iter()
179                .map(|c| {
180                    let file = sanitize_for_proto(&c.file_path);
181                    let symbol = sanitize_for_proto(&c.symbol_name);
182                    ConflictDetail {
183                        file_path: file,
184                        symbols: vec![symbol.clone()],
185                        your_agent: agent.to_string(),
186                        // TODO: resolve their_agent from the session/changeset store
187                        // once SemanticConflict carries agent attribution.
188                        their_agent: String::new(),
189                        conflict_type: CONFLICT_TYPE_TRUE.to_string(),
190                        description: format!(
191                            "Symbol '{}' — our change: {:?}, their change: {:?}",
192                            symbol, c.our_change, c.their_change
193                        ),
194                    }
195                })
196                .collect();
197
198            let suggested_action = "adapt".to_string();
199            let available_actions = vec!["adapt".to_string(), "keep_mine".to_string(), "keep_theirs".to_string()];
200
201            debug_assert!(
202                available_actions.iter().any(|a| a == &suggested_action),
203                "suggested_action '{}' is not in available_actions {:?}",
204                suggested_action, available_actions
205            );
206
207            Ok(MergeResponse {
208                result: Some(merge_response::Result::Conflict(MergeConflict {
209                    changeset_id: changeset_id.to_string(),
210                    conflicts: conflict_details,
211                    suggested_action,
212                    available_actions,
213                })),
214            })
215        }
216    }
217}
218
219/// Release all symbol locks for a session and emit `symbol.lock.released` events
220/// so blocked agents can wake up and retry their writes.
221fn release_locks_and_emit(
222    server: &ProtocolServer,
223    repo_id: Uuid,
224    session_id: Uuid,
225    session_id_str: &str,
226) {
227    let released = server.claim_tracker().release_locks(repo_id, session_id);
228
229    if released.is_empty() {
230        return;
231    }
232
233    // Group released locks by file_path for efficient event emission
234    let mut by_file: std::collections::HashMap<String, Vec<String>> = std::collections::HashMap::new();
235    for lock in &released {
236        by_file
237            .entry(lock.file_path.clone())
238            .or_default()
239            .push(lock.qualified_name.clone());
240    }
241
242    for (file_path, symbols) in by_file {
243        server.event_bus().publish(crate::WatchEvent {
244            event_type: EVENT_LOCK_RELEASED.to_string(),
245            changeset_id: String::new(),
246            agent_id: released.first().map(|r| r.agent_name.clone()).unwrap_or_default(),
247            affected_symbols: symbols,
248            details: format!("Symbol locks released on {}", file_path),
249            session_id: session_id_str.to_string(),
250            affected_files: vec![crate::FileChange {
251                path: file_path,
252                operation: "unlock".to_string(),
253            }],
254            symbol_changes: vec![],
255            repo_id: repo_id.to_string(),
256            event_id: Uuid::new_v4().to_string(),
257        });
258    }
259}
260
261// ── Event type constants ────────────────────────────────────────────
262
263/// Event published when a changeset is successfully merged.
264pub const EVENT_MERGED: &str = "changeset.merged";
265
266/// Event published when symbol locks are released (after merge, close, or timeout).
267/// Blocked agents watch for this to retry their `dk_file_write`.
268pub const EVENT_LOCK_RELEASED: &str = "symbol.lock.released";
269
270#[cfg(test)]
271mod tests {
272    use super::*;
273
274    #[test]
275    fn merged_event_type() {
276        assert_eq!(EVENT_MERGED, "changeset.merged");
277    }
278
279    #[test]
280    fn merged_event_type_uses_dot_separator() {
281        assert!(
282            EVENT_MERGED.contains('.'),
283            "event type should use dot separator"
284        );
285        assert!(
286            EVENT_MERGED.starts_with("changeset."),
287            "event type should start with 'changeset.'"
288        );
289    }
290
291    #[test]
292    fn merged_event_type_is_not_underscore_format() {
293        // Verify the event was renamed from "changeset_merged" to "changeset.merged"
294        assert_ne!(EVENT_MERGED, "changeset_merged");
295        assert_eq!(EVENT_MERGED, "changeset.merged");
296    }
297
298    #[test]
299    fn merge_success_construction() {
300        let resp = MergeResponse {
301            result: Some(merge_response::Result::Success(MergeSuccess {
302                commit_hash: "abc123".to_string(),
303                merged_version: "abc123".to_string(),
304                auto_rebased: false,
305                auto_rebased_files: Vec::new(),
306            })),
307        };
308        match resp.result {
309            Some(merge_response::Result::Success(s)) => {
310                assert_eq!(s.commit_hash, "abc123");
311                assert!(!s.auto_rebased);
312                assert!(s.auto_rebased_files.is_empty());
313            }
314            _ => panic!("expected MergeSuccess"),
315        }
316    }
317
318    #[test]
319    fn merge_success_with_rebase() {
320        let resp = MergeResponse {
321            result: Some(merge_response::Result::Success(MergeSuccess {
322                commit_hash: "def456".to_string(),
323                merged_version: "def456".to_string(),
324                auto_rebased: true,
325                auto_rebased_files: vec!["src/main.rs".to_string()],
326            })),
327        };
328        match resp.result {
329            Some(merge_response::Result::Success(s)) => {
330                assert!(s.auto_rebased);
331                assert_eq!(s.auto_rebased_files, vec!["src/main.rs"]);
332            }
333            _ => panic!("expected MergeSuccess"),
334        }
335    }
336
337    #[test]
338    fn merge_conflict_construction() {
339        // their_agent is currently not populated by the server (SemanticConflict
340        // does not carry agent attribution yet), so the test mirrors real
341        // behavior by using an empty string.
342        let detail = ConflictDetail {
343            file_path: "src/lib.rs".to_string(),
344            symbols: vec!["process_data".to_string()],
345            your_agent: "agent-1".to_string(),
346            their_agent: String::new(),
347            conflict_type: CONFLICT_TYPE_TRUE.to_string(),
348            description: "both agents modified process_data".to_string(),
349        };
350        let resp = MergeResponse {
351            result: Some(merge_response::Result::Conflict(MergeConflict {
352                changeset_id: "cs-001".to_string(),
353                conflicts: vec![detail],
354                suggested_action: "adapt".to_string(),
355                available_actions: vec![
356                    "adapt".to_string(),
357                    "keep_mine".to_string(),
358                    "keep_theirs".to_string(),
359                ],
360            })),
361        };
362        match resp.result {
363            Some(merge_response::Result::Conflict(c)) => {
364                assert_eq!(c.changeset_id, "cs-001");
365                assert_eq!(c.conflicts.len(), 1);
366                assert_eq!(c.conflicts[0].file_path, "src/lib.rs");
367                assert_eq!(c.conflicts[0].symbols, vec!["process_data"]);
368                assert_eq!(c.conflicts[0].your_agent, "agent-1");
369                assert!(c.conflicts[0].their_agent.is_empty());
370                assert_eq!(c.suggested_action, "adapt");
371                assert_eq!(c.available_actions.len(), 3);
372            }
373            _ => panic!("expected MergeConflict"),
374        }
375    }
376
377    #[test]
378    fn conflict_detail_fields() {
379        let detail = ConflictDetail {
380            file_path: "src/handler.rs".to_string(),
381            symbols: vec!["handle_request".to_string(), "parse_input".to_string()],
382            your_agent: "agent-a".to_string(),
383            their_agent: "agent-b".to_string(),
384            conflict_type: CONFLICT_TYPE_TRUE.to_string(),
385            description: "multiple symbols in conflict".to_string(),
386        };
387        assert_eq!(detail.symbols.len(), 2);
388        assert_eq!(detail.conflict_type, CONFLICT_TYPE_TRUE);
389    }
390
391    #[test]
392    fn sanitize_for_proto_strips_null_bytes() {
393        assert_eq!(sanitize_for_proto("hello\0world"), "helloworld");
394        assert_eq!(sanitize_for_proto("\0\0"), "");
395        assert_eq!(sanitize_for_proto("clean"), "clean");
396    }
397
398    #[test]
399    fn sanitize_for_proto_preserves_valid_utf8() {
400        // Multi-byte UTF-8 characters must survive sanitization
401        assert_eq!(sanitize_for_proto("fn résumé()"), "fn résumé()");
402        assert_eq!(sanitize_for_proto("日本語"), "日本語");
403        // Replacement character from String::from_utf8_lossy is valid UTF-8
404        assert_eq!(sanitize_for_proto("bad\u{FFFD}char"), "bad\u{FFFD}char");
405    }
406}