1use crate::error::{CollabError, Result};
4use crate::history::VersionControl;
5use crate::models::{ConflictType, MergeConflict, MergeStatus, WorkspaceMerge};
6use chrono::Utc;
7use serde_json::Value;
8use sqlx::{Pool, Sqlite};
9use uuid::Uuid;
10
11pub struct MergeService {
13 db: Pool<Sqlite>,
14 version_control: VersionControl,
15}
16
17impl MergeService {
18 #[must_use]
20 pub fn new(db: Pool<Sqlite>) -> Self {
21 Self {
22 db: db.clone(),
23 version_control: VersionControl::new(db),
24 }
25 }
26
27 pub async fn find_common_ancestor(
32 &self,
33 source_workspace_id: Uuid,
34 target_workspace_id: Uuid,
35 ) -> Result<Option<Uuid>> {
36 let source_ws_id_str = source_workspace_id.to_string();
38 let target_ws_id_str = target_workspace_id.to_string();
39 let fork = sqlx::query!(
40 r#"
41 SELECT fork_point_commit_id
42 FROM workspace_forks
43 WHERE source_workspace_id = ? AND forked_workspace_id = ?
44 "#,
45 source_ws_id_str,
46 target_ws_id_str
47 )
48 .fetch_optional(&self.db)
49 .await?;
50
51 if let Some(fork) = fork {
52 if let Some(commit_id_str) = fork.fork_point_commit_id.as_ref() {
53 if let Ok(commit_id) = Uuid::parse_str(commit_id_str) {
54 return Ok(Some(commit_id));
55 }
56 }
57 }
58
59 let target_ws_id_str2 = target_workspace_id.to_string();
61 let source_ws_id_str2 = source_workspace_id.to_string();
62 let fork = sqlx::query!(
63 r#"
64 SELECT fork_point_commit_id
65 FROM workspace_forks
66 WHERE source_workspace_id = ? AND forked_workspace_id = ?
67 "#,
68 target_ws_id_str2,
69 source_ws_id_str2
70 )
71 .fetch_optional(&self.db)
72 .await?;
73
74 if let Some(fork) = fork {
75 if let Some(commit_id_str) = fork.fork_point_commit_id.as_ref() {
76 if let Ok(commit_id) = Uuid::parse_str(commit_id_str) {
77 return Ok(Some(commit_id));
78 }
79 }
80 }
81
82 let source_commits =
85 self.version_control.get_history(source_workspace_id, Some(1000)).await?;
86 let target_commits =
87 self.version_control.get_history(target_workspace_id, Some(1000)).await?;
88
89 let source_commit_ids: std::collections::HashSet<Uuid> =
91 source_commits.iter().map(|c| c.id).collect();
92 let target_commit_ids: std::collections::HashSet<Uuid> =
93 target_commits.iter().map(|c| c.id).collect();
94
95 for source_commit in &source_commits {
98 if target_commit_ids.contains(&source_commit.id) {
99 return Ok(Some(source_commit.id));
100 }
101 }
102
103 if let (Some(source_latest), Some(target_latest)) =
106 (source_commits.first(), target_commits.first())
107 {
108 let source_ancestors = self.build_ancestor_set(source_latest.id).await?;
110 let target_ancestors = self.build_ancestor_set(target_latest.id).await?;
111
112 for ancestor in &source_ancestors {
114 if target_ancestors.contains(ancestor) {
115 return Ok(Some(*ancestor));
116 }
117 }
118 }
119
120 Ok(None)
122 }
123
124 pub async fn merge_workspaces(
129 &self,
130 source_workspace_id: Uuid,
131 target_workspace_id: Uuid,
132 user_id: Uuid,
133 ) -> Result<(Value, Vec<MergeConflict>)> {
134 let source_commit =
136 self.version_control.get_latest_commit(source_workspace_id).await?.ok_or_else(
137 || CollabError::Internal("Source workspace has no commits".to_string()),
138 )?;
139
140 let target_commit =
141 self.version_control.get_latest_commit(target_workspace_id).await?.ok_or_else(
142 || CollabError::Internal("Target workspace has no commits".to_string()),
143 )?;
144
145 let base_commit_id = self
147 .find_common_ancestor(source_workspace_id, target_workspace_id)
148 .await?
149 .ok_or_else(|| {
150 CollabError::Internal(
151 "Cannot find common ancestor. Workspaces must be related by fork.".to_string(),
152 )
153 })?;
154
155 let base_commit = self.version_control.get_commit(base_commit_id).await?;
156
157 let (merged_state, conflicts) = self.three_way_merge(
159 &base_commit.snapshot,
160 &source_commit.snapshot,
161 &target_commit.snapshot,
162 )?;
163
164 let mut merge = WorkspaceMerge::new(
166 source_workspace_id,
167 target_workspace_id,
168 base_commit_id,
169 source_commit.id,
170 target_commit.id,
171 );
172
173 if conflicts.is_empty() {
174 merge.status = MergeStatus::Completed;
175 } else {
176 merge.status = MergeStatus::Conflict;
177 merge.conflict_data = Some(serde_json::to_value(&conflicts)?);
178 }
179
180 let merge_id_str = merge.id.to_string();
182 let source_ws_id_str = merge.source_workspace_id.to_string();
183 let target_ws_id_str = merge.target_workspace_id.to_string();
184 let base_commit_id_str = merge.base_commit_id.to_string();
185 let source_commit_id_str = merge.source_commit_id.to_string();
186 let target_commit_id_str = merge.target_commit_id.to_string();
187 let merge_commit_id_str = merge.merge_commit_id.map(|id| id.to_string());
188 let status_str = serde_json::to_string(&merge.status)?;
189 let conflict_data_str =
190 merge.conflict_data.as_ref().map(serde_json::to_string).transpose()?;
191 let merged_by_str = merge.merged_by.map(|id| id.to_string());
192 let merged_at_str = merge.merged_at.map(|dt| dt.to_rfc3339());
193 let created_at_str = merge.created_at.to_rfc3339();
194
195 sqlx::query!(
196 r#"
197 INSERT INTO workspace_merges (
198 id, source_workspace_id, target_workspace_id,
199 base_commit_id, source_commit_id, target_commit_id,
200 merge_commit_id, status, conflict_data, merged_by, merged_at, created_at
201 )
202 VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
203 "#,
204 merge_id_str,
205 source_ws_id_str,
206 target_ws_id_str,
207 base_commit_id_str,
208 source_commit_id_str,
209 target_commit_id_str,
210 merge_commit_id_str,
211 status_str,
212 conflict_data_str,
213 merged_by_str,
214 merged_at_str,
215 created_at_str
216 )
217 .execute(&self.db)
218 .await?;
219
220 Ok((merged_state, conflicts))
221 }
222
223 fn three_way_merge(
231 &self,
232 base: &Value,
233 source: &Value,
234 target: &Value,
235 ) -> Result<(Value, Vec<MergeConflict>)> {
236 let mut merged = target.clone();
237 let mut conflicts = Vec::new();
238
239 self.merge_value("", base, source, target, &mut merged, &mut conflicts)?;
240
241 Ok((merged, conflicts))
242 }
243
244 fn merge_value(
246 &self,
247 path: &str,
248 base: &Value,
249 source: &Value,
250 target: &Value,
251 merged: &mut Value,
252 conflicts: &mut Vec<MergeConflict>,
253 ) -> Result<()> {
254 match (base, source, target) {
255 (b, s, t) if b == s && s == t => {
257 }
259
260 (b, s, t) if b == s && t != b => {
262 }
264
265 (b, s, t) if b == t && s != b => {
267 *merged = source.clone();
268 }
269
270 (b, s, t) if s == t && s != b => {
272 *merged = source.clone();
273 }
274
275 (b, s, t) if s != t && s != b && t != b => {
277 conflicts.push(MergeConflict {
278 path: path.to_string(),
279 base_value: Some(b.clone()),
280 source_value: Some(s.clone()),
281 target_value: Some(t.clone()),
282 conflict_type: ConflictType::Modified,
283 });
284 }
286
287 (Value::Object(base_obj), Value::Object(source_obj), Value::Object(target_obj)) => {
289 if let Value::Object(merged_obj) = merged {
290 let all_keys: std::collections::HashSet<_> =
292 base_obj.keys().chain(source_obj.keys()).chain(target_obj.keys()).collect();
293
294 for key in all_keys {
295 let base_val = base_obj.get(key);
296 let source_val = source_obj.get(key);
297 let target_val = target_obj.get(key);
298
299 let new_path = if path.is_empty() {
300 key.clone()
301 } else {
302 format!("{path}.{key}")
303 };
304
305 match (base_val, source_val, target_val) {
306 (None, Some(s), None) => {
308 merged_obj.insert(key.clone(), s.clone());
309 }
310 (None, None, Some(t)) => {
312 merged_obj.insert(key.clone(), t.clone());
313 }
314 (None, Some(s), Some(t)) if s != t => {
316 conflicts.push(MergeConflict {
317 path: new_path.clone(),
318 base_value: None,
319 source_value: Some(s.clone()),
320 target_value: Some(t.clone()),
321 conflict_type: ConflictType::BothAdded,
322 });
323 }
325 (None, Some(s), Some(t)) if s == t => {
327 merged_obj.insert(key.clone(), s.clone());
328 }
329 (Some(b), Some(s), Some(t)) => {
331 if let Some(merged_val) = merged_obj.get_mut(key) {
332 self.merge_value(&new_path, b, s, t, merged_val, conflicts)?;
333 }
334 }
335 (Some(b), None, Some(t)) if b == t => {
337 merged_obj.remove(key);
338 }
339 (Some(b), Some(s), None) if b == s => {
341 merged_obj.remove(key);
342 }
343 (Some(b), None, Some(_t)) => {
345 conflicts.push(MergeConflict {
346 path: new_path.clone(),
347 base_value: Some(b.clone()),
348 source_value: source_val.cloned(),
349 target_value: target_val.cloned(),
350 conflict_type: ConflictType::DeletedModified,
351 });
352 }
353 (Some(b), Some(_s), None) => {
355 conflicts.push(MergeConflict {
356 path: new_path.clone(),
357 base_value: Some(b.clone()),
358 source_value: source_val.cloned(),
359 target_value: target_val.cloned(),
360 conflict_type: ConflictType::DeletedModified,
361 });
362 }
363 _ => {}
364 }
365 }
366 }
367 }
368
369 (Value::Array(base_arr), Value::Array(source_arr), Value::Array(target_arr)) => {
371 if (base_arr != source_arr || base_arr != target_arr) && source_arr != target_arr {
372 conflicts.push(MergeConflict {
373 path: path.to_string(),
374 base_value: Some(base.clone()),
375 source_value: Some(source.clone()),
376 target_value: Some(target.clone()),
377 conflict_type: ConflictType::Modified,
378 });
379 }
380 }
381
382 _ => {
383 }
385 }
386
387 Ok(())
388 }
389
390 pub async fn complete_merge(
392 &self,
393 merge_id: Uuid,
394 user_id: Uuid,
395 resolved_state: Value,
396 message: String,
397 ) -> Result<Uuid> {
398 let merge = self.get_merge(merge_id).await?;
400
401 if merge.status != MergeStatus::Conflict && merge.status != MergeStatus::Pending {
402 return Err(CollabError::InvalidInput(
403 "Merge is not in a state that can be completed".to_string(),
404 ));
405 }
406
407 let merge_commit = self
409 .version_control
410 .create_commit(
411 merge.target_workspace_id,
412 user_id,
413 message,
414 Some(merge.target_commit_id),
415 0, resolved_state.clone(),
418 serde_json::json!({
419 "type": "merge",
420 "source_workspace_id": merge.source_workspace_id,
421 "source_commit_id": merge.source_commit_id,
422 }),
423 )
424 .await?;
425
426 let now = Utc::now();
428 sqlx::query!(
429 r#"
430 UPDATE workspace_merges
431 SET merge_commit_id = ?, status = ?, merged_by = ?, merged_at = ?
432 WHERE id = ?
433 "#,
434 merge_commit.id,
435 MergeStatus::Completed,
436 user_id,
437 now,
438 merge_id
439 )
440 .execute(&self.db)
441 .await?;
442
443 Ok(merge_commit.id)
444 }
445
446 pub async fn get_merge(&self, merge_id: Uuid) -> Result<WorkspaceMerge> {
448 let merge_id_str = merge_id.to_string();
449 let row = sqlx::query!(
450 r#"
451 SELECT
452 id,
453 source_workspace_id,
454 target_workspace_id,
455 base_commit_id,
456 source_commit_id,
457 target_commit_id,
458 merge_commit_id,
459 status,
460 conflict_data,
461 merged_by,
462 merged_at,
463 created_at
464 FROM workspace_merges
465 WHERE id = ?
466 "#,
467 merge_id_str
468 )
469 .fetch_optional(&self.db)
470 .await?
471 .ok_or_else(|| CollabError::Internal(format!("Merge not found: {merge_id}")))?;
472
473 Ok(WorkspaceMerge {
474 id: Uuid::parse_str(&row.id)
475 .map_err(|e| CollabError::Internal(format!("Invalid UUID: {e}")))?,
476 source_workspace_id: Uuid::parse_str(&row.source_workspace_id)
477 .map_err(|e| CollabError::Internal(format!("Invalid UUID: {e}")))?,
478 target_workspace_id: Uuid::parse_str(&row.target_workspace_id)
479 .map_err(|e| CollabError::Internal(format!("Invalid UUID: {e}")))?,
480 base_commit_id: Uuid::parse_str(&row.base_commit_id)
481 .map_err(|e| CollabError::Internal(format!("Invalid UUID: {e}")))?,
482 source_commit_id: Uuid::parse_str(&row.source_commit_id)
483 .map_err(|e| CollabError::Internal(format!("Invalid UUID: {e}")))?,
484 target_commit_id: Uuid::parse_str(&row.target_commit_id)
485 .map_err(|e| CollabError::Internal(format!("Invalid UUID: {e}")))?,
486 merge_commit_id: row.merge_commit_id.as_ref().and_then(|s| Uuid::parse_str(s).ok()),
487 status: serde_json::from_str(&row.status)
488 .map_err(|e| CollabError::Internal(format!("Invalid status: {e}")))?,
489 conflict_data: row.conflict_data.as_ref().and_then(|s| serde_json::from_str(s).ok()),
490 merged_by: row.merged_by.as_ref().and_then(|s| Uuid::parse_str(s).ok()),
491 merged_at: row
492 .merged_at
493 .as_ref()
494 .map(|s| {
495 chrono::DateTime::parse_from_rfc3339(s)
496 .map(|dt| dt.with_timezone(&Utc))
497 .map_err(|e| CollabError::Internal(format!("Invalid timestamp: {e}")))
498 })
499 .transpose()?,
500 created_at: chrono::DateTime::parse_from_rfc3339(&row.created_at)
501 .map_err(|e| CollabError::Internal(format!("Invalid timestamp: {e}")))?
502 .with_timezone(&Utc),
503 })
504 }
505
506 pub async fn list_merges(&self, workspace_id: Uuid) -> Result<Vec<WorkspaceMerge>> {
508 let workspace_id_str = workspace_id.to_string();
509 let rows = sqlx::query!(
510 r#"
511 SELECT
512 id,
513 source_workspace_id,
514 target_workspace_id,
515 base_commit_id,
516 source_commit_id,
517 target_commit_id,
518 merge_commit_id,
519 status,
520 conflict_data,
521 merged_by,
522 merged_at,
523 created_at
524 FROM workspace_merges
525 WHERE source_workspace_id = ? OR target_workspace_id = ?
526 ORDER BY created_at DESC
527 "#,
528 workspace_id_str,
529 workspace_id_str
530 )
531 .fetch_all(&self.db)
532 .await?;
533
534 let merges: Result<Vec<WorkspaceMerge>> = rows
535 .into_iter()
536 .map(|row| {
537 Ok(WorkspaceMerge {
538 id: Uuid::parse_str(&row.id)
539 .map_err(|e| CollabError::Internal(format!("Invalid UUID: {e}")))?,
540 source_workspace_id: Uuid::parse_str(&row.source_workspace_id)
541 .map_err(|e| CollabError::Internal(format!("Invalid UUID: {e}")))?,
542 target_workspace_id: Uuid::parse_str(&row.target_workspace_id)
543 .map_err(|e| CollabError::Internal(format!("Invalid UUID: {e}")))?,
544 base_commit_id: Uuid::parse_str(&row.base_commit_id)
545 .map_err(|e| CollabError::Internal(format!("Invalid UUID: {e}")))?,
546 source_commit_id: Uuid::parse_str(&row.source_commit_id)
547 .map_err(|e| CollabError::Internal(format!("Invalid UUID: {e}")))?,
548 target_commit_id: Uuid::parse_str(&row.target_commit_id)
549 .map_err(|e| CollabError::Internal(format!("Invalid UUID: {e}")))?,
550 merge_commit_id: row
551 .merge_commit_id
552 .as_ref()
553 .and_then(|s| Uuid::parse_str(s).ok()),
554 status: serde_json::from_str(&row.status)
555 .map_err(|e| CollabError::Internal(format!("Invalid status: {e}")))?,
556 conflict_data: row
557 .conflict_data
558 .as_ref()
559 .and_then(|s| serde_json::from_str(s).ok()),
560 merged_by: row.merged_by.as_ref().and_then(|s| Uuid::parse_str(s).ok()),
561 merged_at: row
562 .merged_at
563 .as_ref()
564 .map(|s| {
565 chrono::DateTime::parse_from_rfc3339(s)
566 .map(|dt| dt.with_timezone(&Utc))
567 .map_err(|e| {
568 CollabError::Internal(format!("Invalid timestamp: {e}"))
569 })
570 })
571 .transpose()?,
572 created_at: chrono::DateTime::parse_from_rfc3339(&row.created_at)
573 .map_err(|e| CollabError::Internal(format!("Invalid timestamp: {e}")))?
574 .with_timezone(&Utc),
575 })
576 })
577 .collect();
578 let merges = merges?;
579
580 Ok(merges)
581 }
582
583 async fn build_ancestor_set(&self, commit_id: Uuid) -> Result<std::collections::HashSet<Uuid>> {
585 let mut ancestors = std::collections::HashSet::new();
586 let mut current_id = Some(commit_id);
587 let mut visited = std::collections::HashSet::new();
588
589 let max_depth = 1000;
591 let mut depth = 0;
592
593 while let Some(id) = current_id {
594 if visited.contains(&id) || depth > max_depth {
595 break; }
597 visited.insert(id);
598 ancestors.insert(id);
599
600 match self.version_control.get_commit(id).await {
602 Ok(commit) => {
603 current_id = commit.parent_id;
604 depth += 1;
605 }
606 Err(_) => break, }
608 }
609
610 Ok(ancestors)
611 }
612}