mnemo_core/query/
merge.rs1use serde::{Deserialize, Serialize};
2use uuid::Uuid;
3
4use crate::error::{Error, Result};
5use crate::model::checkpoint::Checkpoint;
6use crate::model::event::EventType;
7use crate::query::MnemoEngine;
8
9#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
10#[serde(rename_all = "snake_case")]
11pub enum MergeStrategy {
12 FullMerge,
13 CherryPick,
14 Squash,
15}
16
17#[derive(Debug, Clone, Serialize, Deserialize)]
18pub struct MergeRequest {
19 pub thread_id: String,
20 pub agent_id: Option<String>,
21 pub source_branch: String,
22 pub target_branch: Option<String>,
23 pub strategy: Option<MergeStrategy>,
24 pub cherry_pick_ids: Option<Vec<Uuid>>,
25}
26
27impl MergeRequest {
28 pub fn new(thread_id: String, source_branch: String) -> Self {
29 Self {
30 thread_id,
31 agent_id: None,
32 source_branch,
33 target_branch: None,
34 strategy: None,
35 cherry_pick_ids: None,
36 }
37 }
38}
39
40#[non_exhaustive]
41#[derive(Debug, Clone, Serialize, Deserialize)]
42pub struct MergeResponse {
43 pub checkpoint_id: Uuid,
44 pub target_branch: String,
45 pub merged_memory_count: usize,
46}
47
48impl MergeResponse {
49 pub fn new(checkpoint_id: Uuid, target_branch: String, merged_memory_count: usize) -> Self {
50 Self {
51 checkpoint_id,
52 target_branch,
53 merged_memory_count,
54 }
55 }
56}
57
58pub async fn execute(engine: &MnemoEngine, request: MergeRequest) -> Result<MergeResponse> {
59 let agent_id = request
60 .agent_id
61 .unwrap_or_else(|| engine.default_agent_id.clone());
62 let target_branch = request.target_branch.unwrap_or_else(|| "main".to_string());
63 let strategy = request.strategy.unwrap_or(MergeStrategy::FullMerge);
64 let now = chrono::Utc::now().to_rfc3339();
65
66 let source_cp = engine
68 .storage
69 .get_latest_checkpoint(&request.thread_id, &request.source_branch)
70 .await?
71 .ok_or_else(|| {
72 Error::NotFound(format!(
73 "no checkpoint on branch '{}' for thread '{}'",
74 request.source_branch, request.thread_id
75 ))
76 })?;
77
78 let target_cp = engine
80 .storage
81 .get_latest_checkpoint(&request.thread_id, &target_branch)
82 .await?;
83
84 let target_parent_id = target_cp.as_ref().map(|cp| cp.id);
85
86 let merged_refs: Vec<Uuid> = match strategy {
88 MergeStrategy::CherryPick => {
89 let cherry = request.cherry_pick_ids.unwrap_or_default();
90 let mut existing = target_cp
91 .as_ref()
92 .map(|cp| cp.memory_refs.clone())
93 .unwrap_or_default();
94 for id in &cherry {
95 if !existing.contains(id) {
96 existing.push(*id);
97 }
98 }
99 existing
100 }
101 MergeStrategy::FullMerge | MergeStrategy::Squash => {
102 let mut merged = target_cp
103 .as_ref()
104 .map(|cp| cp.memory_refs.clone())
105 .unwrap_or_default();
106 for id in &source_cp.memory_refs {
107 if !merged.contains(id) {
108 merged.push(*id);
109 }
110 }
111 merged
112 }
113 };
114
115 let merged_count = merged_refs.len();
116
117 let merged_snapshot = if let Some(ref tcp) = target_cp {
119 let mut base = tcp.state_snapshot.clone();
120 if let (Some(base_obj), Some(source_obj)) =
121 (base.as_object_mut(), source_cp.state_snapshot.as_object())
122 {
123 for (k, v) in source_obj {
124 if !base_obj.contains_key(k) {
125 base_obj.insert(k.clone(), v.clone());
126 }
127 }
128 }
129 base
130 } else {
131 source_cp.state_snapshot.clone()
132 };
133
134 let id = Uuid::now_v7();
135 let new_cp = Checkpoint {
136 id,
137 thread_id: request.thread_id.clone(),
138 agent_id: agent_id.clone(),
139 parent_id: target_parent_id,
140 branch_name: target_branch.clone(),
141 state_snapshot: merged_snapshot,
142 state_diff: Some(serde_json::json!({
143 "merge_source": request.source_branch,
144 "strategy": format!("{strategy:?}"),
145 })),
146 memory_refs: merged_refs,
147 event_cursor: source_cp.event_cursor,
148 label: Some(format!("merge from {}", request.source_branch)),
149 created_at: now.clone(),
150 metadata: serde_json::json!({
151 "source_branch": request.source_branch,
152 "source_checkpoint": source_cp.id.to_string(),
153 }),
154 };
155
156 engine.storage.insert_checkpoint(&new_cp).await?;
157
158 let event = super::event_builder::build_event(
160 engine,
161 &agent_id,
162 EventType::Merge,
163 serde_json::json!({
164 "checkpoint_id": id.to_string(),
165 "source_branch": request.source_branch,
166 "target_branch": target_branch,
167 "strategy": format!("{strategy:?}"),
168 }),
169 &id.to_string(),
170 Some(request.thread_id),
171 )
172 .await;
173 if let Err(e) = engine.storage.insert_event(&event).await {
174 tracing::error!(event_id = %event.id, error = %e, "failed to insert audit event");
175 }
176
177 Ok(MergeResponse {
178 checkpoint_id: id,
179 target_branch,
180 merged_memory_count: merged_count,
181 })
182}