1use crate::error::{CollabError, Result};
4use crate::history::VersionControl;
5use crate::models::{ConflictType, MergeConflict, MergeStatus, WorkspaceMerge};
6use chrono::Utc;
7use serde_json::Value;
8use sqlx::{Pool, Sqlite};
9use uuid::Uuid;
10
11pub struct MergeService {
13 db: Pool<Sqlite>,
14 version_control: VersionControl,
15}
16
17impl MergeService {
18 pub fn new(db: Pool<Sqlite>) -> Self {
20 Self {
21 db: db.clone(),
22 version_control: VersionControl::new(db),
23 }
24 }
25
26 pub async fn find_common_ancestor(
31 &self,
32 source_workspace_id: Uuid,
33 target_workspace_id: Uuid,
34 ) -> Result<Option<Uuid>> {
35 let source_ws_id_str = source_workspace_id.to_string();
37 let target_ws_id_str = target_workspace_id.to_string();
38 let fork = sqlx::query!(
39 r#"
40 SELECT fork_point_commit_id
41 FROM workspace_forks
42 WHERE source_workspace_id = ? AND forked_workspace_id = ?
43 "#,
44 source_ws_id_str,
45 target_ws_id_str
46 )
47 .fetch_optional(&self.db)
48 .await?;
49
50 if let Some(fork) = fork {
51 if let Some(commit_id_str) = fork.fork_point_commit_id.as_ref() {
52 if let Ok(commit_id) = Uuid::parse_str(commit_id_str) {
53 return Ok(Some(commit_id));
54 }
55 }
56 }
57
58 let target_ws_id_str2 = target_workspace_id.to_string();
60 let source_ws_id_str2 = source_workspace_id.to_string();
61 let fork = sqlx::query!(
62 r#"
63 SELECT fork_point_commit_id
64 FROM workspace_forks
65 WHERE source_workspace_id = ? AND forked_workspace_id = ?
66 "#,
67 target_ws_id_str2,
68 source_ws_id_str2
69 )
70 .fetch_optional(&self.db)
71 .await?;
72
73 if let Some(fork) = fork {
74 if let Some(commit_id_str) = fork.fork_point_commit_id.as_ref() {
75 if let Ok(commit_id) = Uuid::parse_str(commit_id_str) {
76 return Ok(Some(commit_id));
77 }
78 }
79 }
80
81 let source_commits =
84 self.version_control.get_history(source_workspace_id, Some(1000)).await?;
85 let target_commits =
86 self.version_control.get_history(target_workspace_id, Some(1000)).await?;
87
88 let source_commit_ids: std::collections::HashSet<Uuid> =
90 source_commits.iter().map(|c| c.id).collect();
91 let target_commit_ids: std::collections::HashSet<Uuid> =
92 target_commits.iter().map(|c| c.id).collect();
93
94 for source_commit in &source_commits {
97 if target_commit_ids.contains(&source_commit.id) {
98 return Ok(Some(source_commit.id));
99 }
100 }
101
102 if let (Some(source_latest), Some(target_latest)) =
105 (source_commits.first(), target_commits.first())
106 {
107 let source_ancestors = self.build_ancestor_set(source_latest.id).await?;
109 let target_ancestors = self.build_ancestor_set(target_latest.id).await?;
110
111 for ancestor in &source_ancestors {
113 if target_ancestors.contains(ancestor) {
114 return Ok(Some(*ancestor));
115 }
116 }
117 }
118
119 Ok(None)
121 }
122
123 pub async fn merge_workspaces(
128 &self,
129 source_workspace_id: Uuid,
130 target_workspace_id: Uuid,
131 user_id: Uuid,
132 ) -> Result<(Value, Vec<MergeConflict>)> {
133 let source_commit =
135 self.version_control.get_latest_commit(source_workspace_id).await?.ok_or_else(
136 || CollabError::Internal("Source workspace has no commits".to_string()),
137 )?;
138
139 let target_commit =
140 self.version_control.get_latest_commit(target_workspace_id).await?.ok_or_else(
141 || CollabError::Internal("Target workspace has no commits".to_string()),
142 )?;
143
144 let base_commit_id = self
146 .find_common_ancestor(source_workspace_id, target_workspace_id)
147 .await?
148 .ok_or_else(|| {
149 CollabError::Internal(
150 "Cannot find common ancestor. Workspaces must be related by fork.".to_string(),
151 )
152 })?;
153
154 let base_commit = self.version_control.get_commit(base_commit_id).await?;
155
156 let (merged_state, conflicts) = self.three_way_merge(
158 &base_commit.snapshot,
159 &source_commit.snapshot,
160 &target_commit.snapshot,
161 )?;
162
163 let mut merge = WorkspaceMerge::new(
165 source_workspace_id,
166 target_workspace_id,
167 base_commit_id,
168 source_commit.id,
169 target_commit.id,
170 );
171
172 if conflicts.is_empty() {
173 merge.status = MergeStatus::Completed;
174 } else {
175 merge.status = MergeStatus::Conflict;
176 merge.conflict_data = Some(serde_json::to_value(&conflicts)?);
177 }
178
179 let merge_id_str = merge.id.to_string();
181 let source_ws_id_str = merge.source_workspace_id.to_string();
182 let target_ws_id_str = merge.target_workspace_id.to_string();
183 let base_commit_id_str = merge.base_commit_id.to_string();
184 let source_commit_id_str = merge.source_commit_id.to_string();
185 let target_commit_id_str = merge.target_commit_id.to_string();
186 let merge_commit_id_str = merge.merge_commit_id.map(|id| id.to_string());
187 let status_str = serde_json::to_string(&merge.status)?;
188 let conflict_data_str =
189 merge.conflict_data.as_ref().map(|v| serde_json::to_string(v)).transpose()?;
190 let merged_by_str = merge.merged_by.map(|id| id.to_string());
191 let merged_at_str = merge.merged_at.map(|dt| dt.to_rfc3339());
192 let created_at_str = merge.created_at.to_rfc3339();
193
194 sqlx::query!(
195 r#"
196 INSERT INTO workspace_merges (
197 id, source_workspace_id, target_workspace_id,
198 base_commit_id, source_commit_id, target_commit_id,
199 merge_commit_id, status, conflict_data, merged_by, merged_at, created_at
200 )
201 VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
202 "#,
203 merge_id_str,
204 source_ws_id_str,
205 target_ws_id_str,
206 base_commit_id_str,
207 source_commit_id_str,
208 target_commit_id_str,
209 merge_commit_id_str,
210 status_str,
211 conflict_data_str,
212 merged_by_str,
213 merged_at_str,
214 created_at_str
215 )
216 .execute(&self.db)
217 .await?;
218
219 Ok((merged_state, conflicts))
220 }
221
222 fn three_way_merge(
230 &self,
231 base: &Value,
232 source: &Value,
233 target: &Value,
234 ) -> Result<(Value, Vec<MergeConflict>)> {
235 let mut merged = target.clone();
236 let mut conflicts = Vec::new();
237
238 self.merge_value("", base, source, target, &mut merged, &mut conflicts)?;
239
240 Ok((merged, conflicts))
241 }
242
243 fn merge_value(
245 &self,
246 path: &str,
247 base: &Value,
248 source: &Value,
249 target: &Value,
250 merged: &mut Value,
251 conflicts: &mut Vec<MergeConflict>,
252 ) -> Result<()> {
253 match (base, source, target) {
254 (b, s, t) if b == s && s == t => {
256 }
258
259 (b, s, t) if b == s && t != b => {
261 }
263
264 (b, s, t) if b == t && s != b => {
266 *merged = source.clone();
267 }
268
269 (b, s, t) if s == t && s != b => {
271 *merged = source.clone();
272 }
273
274 (b, s, t) if s != t && s != b && t != b => {
276 conflicts.push(MergeConflict {
277 path: path.to_string(),
278 base_value: Some(b.clone()),
279 source_value: Some(s.clone()),
280 target_value: Some(t.clone()),
281 conflict_type: ConflictType::Modified,
282 });
283 }
285
286 (Value::Object(base_obj), Value::Object(source_obj), Value::Object(target_obj)) => {
288 if let Value::Object(merged_obj) = merged {
289 let all_keys: std::collections::HashSet<_> =
291 base_obj.keys().chain(source_obj.keys()).chain(target_obj.keys()).collect();
292
293 for key in all_keys {
294 let base_val = base_obj.get(key);
295 let source_val = source_obj.get(key);
296 let target_val = target_obj.get(key);
297
298 let new_path = if path.is_empty() {
299 key.clone()
300 } else {
301 format!("{}.{}", path, key)
302 };
303
304 match (base_val, source_val, target_val) {
305 (None, Some(s), None) => {
307 merged_obj.insert(key.clone(), s.clone());
308 }
309 (None, None, Some(t)) => {
311 merged_obj.insert(key.clone(), t.clone());
312 }
313 (None, Some(s), Some(t)) if s != t => {
315 conflicts.push(MergeConflict {
316 path: new_path.clone(),
317 base_value: None,
318 source_value: Some(s.clone()),
319 target_value: Some(t.clone()),
320 conflict_type: ConflictType::BothAdded,
321 });
322 }
324 (None, Some(s), Some(t)) if s == t => {
326 merged_obj.insert(key.clone(), s.clone());
327 }
328 (Some(b), Some(s), Some(t)) => {
330 if let Some(merged_val) = merged_obj.get_mut(key) {
331 self.merge_value(&new_path, b, s, t, merged_val, conflicts)?;
332 }
333 }
334 (Some(b), None, Some(t)) if b == t => {
336 merged_obj.remove(key);
337 }
338 (Some(b), Some(s), None) if b == s => {
340 merged_obj.remove(key);
341 }
342 (Some(b), None, Some(_t)) => {
344 conflicts.push(MergeConflict {
345 path: new_path.clone(),
346 base_value: Some(b.clone()),
347 source_value: source_val.cloned(),
348 target_value: target_val.cloned(),
349 conflict_type: ConflictType::DeletedModified,
350 });
351 }
352 (Some(b), Some(_s), None) => {
354 conflicts.push(MergeConflict {
355 path: new_path.clone(),
356 base_value: Some(b.clone()),
357 source_value: source_val.cloned(),
358 target_value: target_val.cloned(),
359 conflict_type: ConflictType::DeletedModified,
360 });
361 }
362 _ => {}
363 }
364 }
365 }
366 }
367
368 (Value::Array(base_arr), Value::Array(source_arr), Value::Array(target_arr)) => {
370 if base_arr != source_arr || base_arr != target_arr {
371 if source_arr != target_arr {
372 conflicts.push(MergeConflict {
373 path: path.to_string(),
374 base_value: Some(base.clone()),
375 source_value: Some(source.clone()),
376 target_value: Some(target.clone()),
377 conflict_type: ConflictType::Modified,
378 });
379 }
380 }
381 }
382
383 _ => {
384 }
386 }
387
388 Ok(())
389 }
390
391 pub async fn complete_merge(
393 &self,
394 merge_id: Uuid,
395 user_id: Uuid,
396 resolved_state: Value,
397 message: String,
398 ) -> Result<Uuid> {
399 let merge = self.get_merge(merge_id).await?;
401
402 if merge.status != MergeStatus::Conflict && merge.status != MergeStatus::Pending {
403 return Err(CollabError::InvalidInput(
404 "Merge is not in a state that can be completed".to_string(),
405 ));
406 }
407
408 let merge_commit = self
410 .version_control
411 .create_commit(
412 merge.target_workspace_id,
413 user_id,
414 message,
415 Some(merge.target_commit_id),
416 0, resolved_state.clone(),
419 serde_json::json!({
420 "type": "merge",
421 "source_workspace_id": merge.source_workspace_id,
422 "source_commit_id": merge.source_commit_id,
423 }),
424 )
425 .await?;
426
427 let now = Utc::now();
429 sqlx::query!(
430 r#"
431 UPDATE workspace_merges
432 SET merge_commit_id = ?, status = ?, merged_by = ?, merged_at = ?
433 WHERE id = ?
434 "#,
435 merge_commit.id,
436 MergeStatus::Completed,
437 user_id,
438 now,
439 merge_id
440 )
441 .execute(&self.db)
442 .await?;
443
444 Ok(merge_commit.id)
445 }
446
447 pub async fn get_merge(&self, merge_id: Uuid) -> Result<WorkspaceMerge> {
449 let merge_id_str = merge_id.to_string();
450 let row = sqlx::query!(
451 r#"
452 SELECT
453 id,
454 source_workspace_id,
455 target_workspace_id,
456 base_commit_id,
457 source_commit_id,
458 target_commit_id,
459 merge_commit_id,
460 status,
461 conflict_data,
462 merged_by,
463 merged_at,
464 created_at
465 FROM workspace_merges
466 WHERE id = ?
467 "#,
468 merge_id_str
469 )
470 .fetch_optional(&self.db)
471 .await?
472 .ok_or_else(|| CollabError::Internal(format!("Merge not found: {}", merge_id)))?;
473
474 Ok(WorkspaceMerge {
475 id: Uuid::parse_str(&row.id)
476 .map_err(|e| CollabError::Internal(format!("Invalid UUID: {}", e)))?,
477 source_workspace_id: Uuid::parse_str(&row.source_workspace_id)
478 .map_err(|e| CollabError::Internal(format!("Invalid UUID: {}", e)))?,
479 target_workspace_id: Uuid::parse_str(&row.target_workspace_id)
480 .map_err(|e| CollabError::Internal(format!("Invalid UUID: {}", e)))?,
481 base_commit_id: Uuid::parse_str(&row.base_commit_id)
482 .map_err(|e| CollabError::Internal(format!("Invalid UUID: {}", e)))?,
483 source_commit_id: Uuid::parse_str(&row.source_commit_id)
484 .map_err(|e| CollabError::Internal(format!("Invalid UUID: {}", e)))?,
485 target_commit_id: Uuid::parse_str(&row.target_commit_id)
486 .map_err(|e| CollabError::Internal(format!("Invalid UUID: {}", e)))?,
487 merge_commit_id: row.merge_commit_id.as_ref().and_then(|s| Uuid::parse_str(s).ok()),
488 status: serde_json::from_str(&row.status)
489 .map_err(|e| CollabError::Internal(format!("Invalid status: {}", e)))?,
490 conflict_data: row.conflict_data.as_ref().and_then(|s| serde_json::from_str(s).ok()),
491 merged_by: row.merged_by.as_ref().and_then(|s| Uuid::parse_str(s).ok()),
492 merged_at: row
493 .merged_at
494 .as_ref()
495 .map(|s| {
496 chrono::DateTime::parse_from_rfc3339(s)
497 .map(|dt| dt.with_timezone(&chrono::Utc))
498 .map_err(|e| CollabError::Internal(format!("Invalid timestamp: {}", e)))
499 })
500 .transpose()?,
501 created_at: chrono::DateTime::parse_from_rfc3339(&row.created_at)
502 .map_err(|e| CollabError::Internal(format!("Invalid timestamp: {}", e)))?
503 .with_timezone(&chrono::Utc),
504 })
505 }
506
507 pub async fn list_merges(&self, workspace_id: Uuid) -> Result<Vec<WorkspaceMerge>> {
509 let workspace_id_str = workspace_id.to_string();
510 let rows = sqlx::query!(
511 r#"
512 SELECT
513 id,
514 source_workspace_id,
515 target_workspace_id,
516 base_commit_id,
517 source_commit_id,
518 target_commit_id,
519 merge_commit_id,
520 status,
521 conflict_data,
522 merged_by,
523 merged_at,
524 created_at
525 FROM workspace_merges
526 WHERE source_workspace_id = ? OR target_workspace_id = ?
527 ORDER BY created_at DESC
528 "#,
529 workspace_id_str,
530 workspace_id_str
531 )
532 .fetch_all(&self.db)
533 .await?;
534
535 let merges: Result<Vec<WorkspaceMerge>> = rows
536 .into_iter()
537 .map(|row| {
538 Ok(WorkspaceMerge {
539 id: Uuid::parse_str(&row.id)
540 .map_err(|e| CollabError::Internal(format!("Invalid UUID: {}", e)))?,
541 source_workspace_id: Uuid::parse_str(&row.source_workspace_id)
542 .map_err(|e| CollabError::Internal(format!("Invalid UUID: {}", e)))?,
543 target_workspace_id: Uuid::parse_str(&row.target_workspace_id)
544 .map_err(|e| CollabError::Internal(format!("Invalid UUID: {}", e)))?,
545 base_commit_id: Uuid::parse_str(&row.base_commit_id)
546 .map_err(|e| CollabError::Internal(format!("Invalid UUID: {}", e)))?,
547 source_commit_id: Uuid::parse_str(&row.source_commit_id)
548 .map_err(|e| CollabError::Internal(format!("Invalid UUID: {}", e)))?,
549 target_commit_id: Uuid::parse_str(&row.target_commit_id)
550 .map_err(|e| CollabError::Internal(format!("Invalid UUID: {}", e)))?,
551 merge_commit_id: row
552 .merge_commit_id
553 .as_ref()
554 .and_then(|s| Uuid::parse_str(s).ok()),
555 status: serde_json::from_str(&row.status)
556 .map_err(|e| CollabError::Internal(format!("Invalid status: {}", e)))?,
557 conflict_data: row
558 .conflict_data
559 .as_ref()
560 .and_then(|s| serde_json::from_str(s).ok()),
561 merged_by: row.merged_by.as_ref().and_then(|s| Uuid::parse_str(s).ok()),
562 merged_at: row
563 .merged_at
564 .as_ref()
565 .map(|s| {
566 chrono::DateTime::parse_from_rfc3339(s)
567 .map(|dt| dt.with_timezone(&chrono::Utc))
568 .map_err(|e| {
569 CollabError::Internal(format!("Invalid timestamp: {}", e))
570 })
571 })
572 .transpose()?,
573 created_at: chrono::DateTime::parse_from_rfc3339(&row.created_at)
574 .map_err(|e| CollabError::Internal(format!("Invalid timestamp: {}", e)))?
575 .with_timezone(&chrono::Utc),
576 })
577 })
578 .collect();
579 let merges = merges?;
580
581 Ok(merges)
582 }
583
584 async fn build_ancestor_set(&self, commit_id: Uuid) -> Result<std::collections::HashSet<Uuid>> {
586 let mut ancestors = std::collections::HashSet::new();
587 let mut current_id = Some(commit_id);
588 let mut visited = std::collections::HashSet::new();
589
590 let max_depth = 1000;
592 let mut depth = 0;
593
594 while let Some(id) = current_id {
595 if visited.contains(&id) || depth > max_depth {
596 break; }
598 visited.insert(id);
599 ancestors.insert(id);
600
601 match self.version_control.get_commit(id).await {
603 Ok(commit) => {
604 current_id = commit.parent_id;
605 depth += 1;
606 }
607 Err(_) => break, }
609 }
610
611 Ok(ancestors)
612 }
613}