Skip to main content

mnemo_core/query/
merge.rs

1use serde::{Deserialize, Serialize};
2use uuid::Uuid;
3
4use crate::error::{Error, Result};
5use crate::model::checkpoint::Checkpoint;
6use crate::model::event::EventType;
7use crate::query::MnemoEngine;
8
9#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
10#[serde(rename_all = "snake_case")]
11pub enum MergeStrategy {
12    FullMerge,
13    CherryPick,
14    Squash,
15}
16
17#[derive(Debug, Clone, Serialize, Deserialize)]
18pub struct MergeRequest {
19    pub thread_id: String,
20    pub agent_id: Option<String>,
21    pub source_branch: String,
22    pub target_branch: Option<String>,
23    pub strategy: Option<MergeStrategy>,
24    pub cherry_pick_ids: Option<Vec<Uuid>>,
25}
26
27impl MergeRequest {
28    pub fn new(thread_id: String, source_branch: String) -> Self {
29        Self {
30            thread_id,
31            agent_id: None,
32            source_branch,
33            target_branch: None,
34            strategy: None,
35            cherry_pick_ids: None,
36        }
37    }
38}
39
40#[non_exhaustive]
41#[derive(Debug, Clone, Serialize, Deserialize)]
42pub struct MergeResponse {
43    pub checkpoint_id: Uuid,
44    pub target_branch: String,
45    pub merged_memory_count: usize,
46}
47
48impl MergeResponse {
49    pub fn new(checkpoint_id: Uuid, target_branch: String, merged_memory_count: usize) -> Self {
50        Self {
51            checkpoint_id,
52            target_branch,
53            merged_memory_count,
54        }
55    }
56}
57
58pub async fn execute(engine: &MnemoEngine, request: MergeRequest) -> Result<MergeResponse> {
59    let agent_id = request
60        .agent_id
61        .unwrap_or_else(|| engine.default_agent_id.clone());
62    let target_branch = request.target_branch.unwrap_or_else(|| "main".to_string());
63    let strategy = request.strategy.unwrap_or(MergeStrategy::FullMerge);
64    let now = chrono::Utc::now().to_rfc3339();
65
66    // Get latest checkpoint on source branch
67    let source_cp = engine
68        .storage
69        .get_latest_checkpoint(&request.thread_id, &request.source_branch)
70        .await?
71        .ok_or_else(|| {
72            Error::NotFound(format!(
73                "no checkpoint on branch '{}' for thread '{}'",
74                request.source_branch, request.thread_id
75            ))
76        })?;
77
78    // Get latest checkpoint on target branch (may not exist yet)
79    let target_cp = engine
80        .storage
81        .get_latest_checkpoint(&request.thread_id, &target_branch)
82        .await?;
83
84    let target_parent_id = target_cp.as_ref().map(|cp| cp.id);
85
86    // Determine merged memory_refs based on strategy
87    let merged_refs: Vec<Uuid> = match strategy {
88        MergeStrategy::CherryPick => {
89            let cherry = request.cherry_pick_ids.unwrap_or_default();
90            let mut existing = target_cp
91                .as_ref()
92                .map(|cp| cp.memory_refs.clone())
93                .unwrap_or_default();
94            for id in &cherry {
95                if !existing.contains(id) {
96                    existing.push(*id);
97                }
98            }
99            existing
100        }
101        MergeStrategy::FullMerge | MergeStrategy::Squash => {
102            let mut merged = target_cp
103                .as_ref()
104                .map(|cp| cp.memory_refs.clone())
105                .unwrap_or_default();
106            for id in &source_cp.memory_refs {
107                if !merged.contains(id) {
108                    merged.push(*id);
109                }
110            }
111            merged
112        }
113    };
114
115    let merged_count = merged_refs.len();
116
117    // Merge state snapshots (target takes precedence, source fields added)
118    let merged_snapshot = if let Some(ref tcp) = target_cp {
119        let mut base = tcp.state_snapshot.clone();
120        if let (Some(base_obj), Some(source_obj)) =
121            (base.as_object_mut(), source_cp.state_snapshot.as_object())
122        {
123            for (k, v) in source_obj {
124                if !base_obj.contains_key(k) {
125                    base_obj.insert(k.clone(), v.clone());
126                }
127            }
128        }
129        base
130    } else {
131        source_cp.state_snapshot.clone()
132    };
133
134    let id = Uuid::now_v7();
135    let new_cp = Checkpoint {
136        id,
137        thread_id: request.thread_id.clone(),
138        agent_id: agent_id.clone(),
139        parent_id: target_parent_id,
140        branch_name: target_branch.clone(),
141        state_snapshot: merged_snapshot,
142        state_diff: Some(serde_json::json!({
143            "merge_source": request.source_branch,
144            "strategy": format!("{strategy:?}"),
145        })),
146        memory_refs: merged_refs,
147        event_cursor: source_cp.event_cursor,
148        label: Some(format!("merge from {}", request.source_branch)),
149        created_at: now.clone(),
150        metadata: serde_json::json!({
151            "source_branch": request.source_branch,
152            "source_checkpoint": source_cp.id.to_string(),
153        }),
154    };
155
156    engine.storage.insert_checkpoint(&new_cp).await?;
157
158    // Emit Merge event
159    let event = super::event_builder::build_event(
160        engine,
161        &agent_id,
162        EventType::Merge,
163        serde_json::json!({
164            "checkpoint_id": id.to_string(),
165            "source_branch": request.source_branch,
166            "target_branch": target_branch,
167            "strategy": format!("{strategy:?}"),
168        }),
169        &id.to_string(),
170        Some(request.thread_id),
171    )
172    .await;
173    if let Err(e) = engine.storage.insert_event(&event).await {
174        tracing::error!(event_id = %event.id, error = %e, "failed to insert audit event");
175    }
176
177    Ok(MergeResponse {
178        checkpoint_id: id,
179        target_branch,
180        merged_memory_count: merged_count,
181    })
182}