1use crate::error::{CollabError, Result};
4use crate::history::VersionControl;
5use crate::models::{ConflictType, MergeConflict, MergeStatus, WorkspaceMerge};
6use chrono::Utc;
7use serde_json::Value;
8use sqlx::{Pool, Sqlite};
9use uuid::Uuid;
10
11pub struct MergeService {
13 db: Pool<Sqlite>,
14 version_control: VersionControl,
15}
16
17impl MergeService {
18 pub fn new(db: Pool<Sqlite>) -> Self {
20 Self {
21 db: db.clone(),
22 version_control: VersionControl::new(db),
23 }
24 }
25
26 pub async fn find_common_ancestor(
31 &self,
32 source_workspace_id: Uuid,
33 target_workspace_id: Uuid,
34 ) -> Result<Option<Uuid>> {
35 let source_ws_id_str = source_workspace_id.to_string();
37 let target_ws_id_str = target_workspace_id.to_string();
38 let fork = sqlx::query!(
39 r#"
40 SELECT fork_point_commit_id
41 FROM workspace_forks
42 WHERE source_workspace_id = ? AND forked_workspace_id = ?
43 "#,
44 source_ws_id_str,
45 target_ws_id_str
46 )
47 .fetch_optional(&self.db)
48 .await?;
49
50 if let Some(fork) = fork {
51 if let Some(commit_id_str) = fork.fork_point_commit_id {
52 if let Ok(commit_id) = Uuid::parse_str(&commit_id_str) {
53 return Ok(Some(commit_id));
54 }
55 }
56 }
57
58 let target_ws_id_str2 = target_workspace_id.to_string();
60 let source_ws_id_str2 = source_workspace_id.to_string();
61 let fork = sqlx::query!(
62 r#"
63 SELECT fork_point_commit_id
64 FROM workspace_forks
65 WHERE source_workspace_id = ? AND forked_workspace_id = ?
66 "#,
67 target_ws_id_str2,
68 source_ws_id_str2
69 )
70 .fetch_optional(&self.db)
71 .await?;
72
73 if let Some(fork) = fork {
74 if let Some(commit_id_str) = fork.fork_point_commit_id {
75 if let Ok(commit_id) = Uuid::parse_str(&commit_id_str) {
76 return Ok(Some(commit_id));
77 }
78 }
79 }
80
81 Ok(None)
84 }
85
86 pub async fn merge_workspaces(
91 &self,
92 source_workspace_id: Uuid,
93 target_workspace_id: Uuid,
94 user_id: Uuid,
95 ) -> Result<(Value, Vec<MergeConflict>)> {
96 let source_commit =
98 self.version_control.get_latest_commit(source_workspace_id).await?.ok_or_else(
99 || CollabError::Internal("Source workspace has no commits".to_string()),
100 )?;
101
102 let target_commit =
103 self.version_control.get_latest_commit(target_workspace_id).await?.ok_or_else(
104 || CollabError::Internal("Target workspace has no commits".to_string()),
105 )?;
106
107 let base_commit_id = self
109 .find_common_ancestor(source_workspace_id, target_workspace_id)
110 .await?
111 .ok_or_else(|| {
112 CollabError::Internal(
113 "Cannot find common ancestor. Workspaces must be related by fork.".to_string(),
114 )
115 })?;
116
117 let base_commit = self.version_control.get_commit(base_commit_id).await?;
118
119 let (merged_state, conflicts) = self.three_way_merge(
121 &base_commit.snapshot,
122 &source_commit.snapshot,
123 &target_commit.snapshot,
124 )?;
125
126 let mut merge = WorkspaceMerge::new(
128 source_workspace_id,
129 target_workspace_id,
130 base_commit_id,
131 source_commit.id,
132 target_commit.id,
133 );
134
135 if conflicts.is_empty() {
136 merge.status = MergeStatus::Completed;
137 } else {
138 merge.status = MergeStatus::Conflict;
139 merge.conflict_data = Some(serde_json::to_value(&conflicts)?);
140 }
141
142 let merge_id_str = merge.id.to_string();
144 let source_ws_id_str = merge.source_workspace_id.to_string();
145 let target_ws_id_str = merge.target_workspace_id.to_string();
146 let base_commit_id_str = merge.base_commit_id.to_string();
147 let source_commit_id_str = merge.source_commit_id.to_string();
148 let target_commit_id_str = merge.target_commit_id.to_string();
149 let merge_commit_id_str = merge.merge_commit_id.map(|id| id.to_string());
150 let status_str = serde_json::to_string(&merge.status)?;
151 let conflict_data_str =
152 merge.conflict_data.as_ref().map(|v| serde_json::to_string(v)).transpose()?;
153 let merged_by_str = merge.merged_by.map(|id| id.to_string());
154 let merged_at_str = merge.merged_at.map(|dt| dt.to_rfc3339());
155 let created_at_str = merge.created_at.to_rfc3339();
156
157 sqlx::query!(
158 r#"
159 INSERT INTO workspace_merges (
160 id, source_workspace_id, target_workspace_id,
161 base_commit_id, source_commit_id, target_commit_id,
162 merge_commit_id, status, conflict_data, merged_by, merged_at, created_at
163 )
164 VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
165 "#,
166 merge_id_str,
167 source_ws_id_str,
168 target_ws_id_str,
169 base_commit_id_str,
170 source_commit_id_str,
171 target_commit_id_str,
172 merge_commit_id_str,
173 status_str,
174 conflict_data_str,
175 merged_by_str,
176 merged_at_str,
177 created_at_str
178 )
179 .execute(&self.db)
180 .await?;
181
182 Ok((merged_state, conflicts))
183 }
184
185 fn three_way_merge(
193 &self,
194 base: &Value,
195 source: &Value,
196 target: &Value,
197 ) -> Result<(Value, Vec<MergeConflict>)> {
198 let mut merged = target.clone();
199 let mut conflicts = Vec::new();
200
201 self.merge_value("", base, source, target, &mut merged, &mut conflicts)?;
202
203 Ok((merged, conflicts))
204 }
205
206 fn merge_value(
208 &self,
209 path: &str,
210 base: &Value,
211 source: &Value,
212 target: &Value,
213 merged: &mut Value,
214 conflicts: &mut Vec<MergeConflict>,
215 ) -> Result<()> {
216 match (base, source, target) {
217 (b, s, t) if b == s && s == t => {
219 }
221
222 (b, s, t) if b == s && t != b => {
224 }
226
227 (b, s, t) if b == t && s != b => {
229 *merged = source.clone();
230 }
231
232 (b, s, t) if s == t && s != b => {
234 *merged = source.clone();
235 }
236
237 (b, s, t) if s != t && s != b && t != b => {
239 conflicts.push(MergeConflict {
240 path: path.to_string(),
241 base_value: Some(b.clone()),
242 source_value: Some(s.clone()),
243 target_value: Some(t.clone()),
244 conflict_type: ConflictType::Modified,
245 });
246 }
248
249 (Value::Object(base_obj), Value::Object(source_obj), Value::Object(target_obj)) => {
251 if let Value::Object(merged_obj) = merged {
252 let all_keys: std::collections::HashSet<_> =
254 base_obj.keys().chain(source_obj.keys()).chain(target_obj.keys()).collect();
255
256 for key in all_keys {
257 let base_val = base_obj.get(key);
258 let source_val = source_obj.get(key);
259 let target_val = target_obj.get(key);
260
261 let new_path = if path.is_empty() {
262 key.clone()
263 } else {
264 format!("{}.{}", path, key)
265 };
266
267 match (base_val, source_val, target_val) {
268 (None, Some(s), None) => {
270 merged_obj.insert(key.clone(), s.clone());
271 }
272 (None, None, Some(t)) => {
274 merged_obj.insert(key.clone(), t.clone());
275 }
276 (None, Some(s), Some(t)) if s != t => {
278 conflicts.push(MergeConflict {
279 path: new_path.clone(),
280 base_value: None,
281 source_value: Some(s.clone()),
282 target_value: Some(t.clone()),
283 conflict_type: ConflictType::BothAdded,
284 });
285 }
287 (None, Some(s), Some(t)) if s == t => {
289 merged_obj.insert(key.clone(), s.clone());
290 }
291 (Some(b), Some(s), Some(t)) => {
293 if let Some(merged_val) = merged_obj.get_mut(key) {
294 self.merge_value(&new_path, b, s, t, merged_val, conflicts)?;
295 }
296 }
297 (Some(b), None, Some(t)) if b == t => {
299 merged_obj.remove(key);
300 }
301 (Some(b), Some(s), None) if b == s => {
303 merged_obj.remove(key);
304 }
305 (Some(b), None, Some(_t)) => {
307 conflicts.push(MergeConflict {
308 path: new_path.clone(),
309 base_value: Some(b.clone()),
310 source_value: source_val.cloned(),
311 target_value: target_val.cloned(),
312 conflict_type: ConflictType::DeletedModified,
313 });
314 }
315 (Some(b), Some(_s), None) => {
317 conflicts.push(MergeConflict {
318 path: new_path.clone(),
319 base_value: Some(b.clone()),
320 source_value: source_val.cloned(),
321 target_value: target_val.cloned(),
322 conflict_type: ConflictType::DeletedModified,
323 });
324 }
325 _ => {}
326 }
327 }
328 }
329 }
330
331 (Value::Array(base_arr), Value::Array(source_arr), Value::Array(target_arr)) => {
333 if base_arr != source_arr || base_arr != target_arr {
334 if source_arr != target_arr {
335 conflicts.push(MergeConflict {
336 path: path.to_string(),
337 base_value: Some(base.clone()),
338 source_value: Some(source.clone()),
339 target_value: Some(target.clone()),
340 conflict_type: ConflictType::Modified,
341 });
342 }
343 }
344 }
345
346 _ => {
347 }
349 }
350
351 Ok(())
352 }
353
354 pub async fn complete_merge(
356 &self,
357 merge_id: Uuid,
358 user_id: Uuid,
359 resolved_state: Value,
360 message: String,
361 ) -> Result<Uuid> {
362 let merge = self.get_merge(merge_id).await?;
364
365 if merge.status != MergeStatus::Conflict && merge.status != MergeStatus::Pending {
366 return Err(CollabError::InvalidInput(
367 "Merge is not in a state that can be completed".to_string(),
368 ));
369 }
370
371 let merge_commit = self
373 .version_control
374 .create_commit(
375 merge.target_workspace_id,
376 user_id,
377 message,
378 Some(merge.target_commit_id),
379 0, resolved_state.clone(),
382 serde_json::json!({
383 "type": "merge",
384 "source_workspace_id": merge.source_workspace_id,
385 "source_commit_id": merge.source_commit_id,
386 }),
387 )
388 .await?;
389
390 let now = Utc::now();
392 sqlx::query!(
393 r#"
394 UPDATE workspace_merges
395 SET merge_commit_id = ?, status = ?, merged_by = ?, merged_at = ?
396 WHERE id = ?
397 "#,
398 merge_commit.id,
399 MergeStatus::Completed,
400 user_id,
401 now,
402 merge_id
403 )
404 .execute(&self.db)
405 .await?;
406
407 Ok(merge_commit.id)
408 }
409
410 pub async fn get_merge(&self, merge_id: Uuid) -> Result<WorkspaceMerge> {
412 let merge_id_str = merge_id.to_string();
413 let row = sqlx::query!(
414 r#"
415 SELECT
416 id,
417 source_workspace_id,
418 target_workspace_id,
419 base_commit_id,
420 source_commit_id,
421 target_commit_id,
422 merge_commit_id,
423 status,
424 conflict_data,
425 merged_by,
426 merged_at,
427 created_at
428 FROM workspace_merges
429 WHERE id = ?
430 "#,
431 merge_id_str
432 )
433 .fetch_optional(&self.db)
434 .await?
435 .ok_or_else(|| CollabError::Internal(format!("Merge not found: {}", merge_id)))?;
436
437 Ok(WorkspaceMerge {
438 id: Uuid::parse_str(&row.id)
439 .map_err(|e| CollabError::Internal(format!("Invalid UUID: {}", e)))?,
440 source_workspace_id: Uuid::parse_str(&row.source_workspace_id)
441 .map_err(|e| CollabError::Internal(format!("Invalid UUID: {}", e)))?,
442 target_workspace_id: Uuid::parse_str(&row.target_workspace_id)
443 .map_err(|e| CollabError::Internal(format!("Invalid UUID: {}", e)))?,
444 base_commit_id: Uuid::parse_str(&row.base_commit_id)
445 .map_err(|e| CollabError::Internal(format!("Invalid UUID: {}", e)))?,
446 source_commit_id: Uuid::parse_str(&row.source_commit_id)
447 .map_err(|e| CollabError::Internal(format!("Invalid UUID: {}", e)))?,
448 target_commit_id: Uuid::parse_str(&row.target_commit_id)
449 .map_err(|e| CollabError::Internal(format!("Invalid UUID: {}", e)))?,
450 merge_commit_id: row.merge_commit_id.and_then(|s| Uuid::parse_str(&s).ok()),
451 status: serde_json::from_str(&row.status)
452 .map_err(|e| CollabError::Internal(format!("Invalid status: {}", e)))?,
453 conflict_data: row.conflict_data.and_then(|s| serde_json::from_str(&s).ok()),
454 merged_by: row.merged_by.and_then(|s| Uuid::parse_str(&s).ok()),
455 merged_at: row
456 .merged_at
457 .map(|s| {
458 chrono::DateTime::parse_from_rfc3339(&s)
459 .map(|dt| dt.with_timezone(&chrono::Utc))
460 .map_err(|e| CollabError::Internal(format!("Invalid timestamp: {}", e)))
461 })
462 .transpose()?,
463 created_at: chrono::DateTime::parse_from_rfc3339(&row.created_at)
464 .map_err(|e| CollabError::Internal(format!("Invalid timestamp: {}", e)))?
465 .with_timezone(&chrono::Utc),
466 })
467 }
468
469 pub async fn list_merges(&self, workspace_id: Uuid) -> Result<Vec<WorkspaceMerge>> {
471 let workspace_id_str = workspace_id.to_string();
472 let rows = sqlx::query!(
473 r#"
474 SELECT
475 id,
476 source_workspace_id,
477 target_workspace_id,
478 base_commit_id,
479 source_commit_id,
480 target_commit_id,
481 merge_commit_id,
482 status,
483 conflict_data,
484 merged_by,
485 merged_at,
486 created_at
487 FROM workspace_merges
488 WHERE source_workspace_id = ? OR target_workspace_id = ?
489 ORDER BY created_at DESC
490 "#,
491 workspace_id_str,
492 workspace_id_str
493 )
494 .fetch_all(&self.db)
495 .await?;
496
497 let merges: Result<Vec<WorkspaceMerge>> = rows
498 .into_iter()
499 .map(|row| {
500 Ok(WorkspaceMerge {
501 id: Uuid::parse_str(&row.id)
502 .map_err(|e| CollabError::Internal(format!("Invalid UUID: {}", e)))?,
503 source_workspace_id: Uuid::parse_str(&row.source_workspace_id)
504 .map_err(|e| CollabError::Internal(format!("Invalid UUID: {}", e)))?,
505 target_workspace_id: Uuid::parse_str(&row.target_workspace_id)
506 .map_err(|e| CollabError::Internal(format!("Invalid UUID: {}", e)))?,
507 base_commit_id: Uuid::parse_str(&row.base_commit_id)
508 .map_err(|e| CollabError::Internal(format!("Invalid UUID: {}", e)))?,
509 source_commit_id: Uuid::parse_str(&row.source_commit_id)
510 .map_err(|e| CollabError::Internal(format!("Invalid UUID: {}", e)))?,
511 target_commit_id: Uuid::parse_str(&row.target_commit_id)
512 .map_err(|e| CollabError::Internal(format!("Invalid UUID: {}", e)))?,
513 merge_commit_id: row.merge_commit_id.and_then(|s| Uuid::parse_str(&s).ok()),
514 status: serde_json::from_str(&row.status)
515 .map_err(|e| CollabError::Internal(format!("Invalid status: {}", e)))?,
516 conflict_data: row.conflict_data.and_then(|s| serde_json::from_str(&s).ok()),
517 merged_by: row.merged_by.and_then(|s| Uuid::parse_str(&s).ok()),
518 merged_at: row
519 .merged_at
520 .map(|s| {
521 chrono::DateTime::parse_from_rfc3339(&s)
522 .map(|dt| dt.with_timezone(&chrono::Utc))
523 .map_err(|e| {
524 CollabError::Internal(format!("Invalid timestamp: {}", e))
525 })
526 })
527 .transpose()?,
528 created_at: chrono::DateTime::parse_from_rfc3339(&row.created_at)
529 .map_err(|e| CollabError::Internal(format!("Invalid timestamp: {}", e)))?
530 .with_timezone(&chrono::Utc),
531 })
532 })
533 .collect();
534 let merges = merges?;
535
536 Ok(merges)
537 }
538}