Skip to main content

mnemo_core/query/
branch.rs

1use serde::{Deserialize, Serialize};
2use uuid::Uuid;
3
4use crate::error::{Error, Result};
5use crate::model::checkpoint::Checkpoint;
6use crate::model::event::EventType;
7use crate::query::MnemoEngine;
8
9#[derive(Debug, Clone, Serialize, Deserialize)]
10pub struct BranchRequest {
11    pub thread_id: String,
12    pub agent_id: Option<String>,
13    pub new_branch_name: String,
14    pub source_checkpoint_id: Option<Uuid>,
15    pub source_branch: Option<String>,
16}
17
18impl BranchRequest {
19    pub fn new(thread_id: String, new_branch_name: String) -> Self {
20        Self {
21            thread_id,
22            agent_id: None,
23            new_branch_name,
24            source_checkpoint_id: None,
25            source_branch: None,
26        }
27    }
28}
29
30#[non_exhaustive]
31#[derive(Debug, Clone, Serialize, Deserialize)]
32pub struct BranchResponse {
33    pub checkpoint_id: Uuid,
34    pub branch_name: String,
35    pub source_checkpoint_id: Uuid,
36}
37
38impl BranchResponse {
39    pub fn new(checkpoint_id: Uuid, branch_name: String, source_checkpoint_id: Uuid) -> Self {
40        Self {
41            checkpoint_id,
42            branch_name,
43            source_checkpoint_id,
44        }
45    }
46}
47
48pub async fn execute(engine: &MnemoEngine, request: BranchRequest) -> Result<BranchResponse> {
49    let agent_id = request
50        .agent_id
51        .unwrap_or_else(|| engine.default_agent_id.clone());
52    let now = chrono::Utc::now().to_rfc3339();
53
54    // Find source checkpoint
55    let source_cp = if let Some(cp_id) = request.source_checkpoint_id {
56        engine
57            .storage
58            .get_checkpoint(cp_id)
59            .await?
60            .ok_or_else(|| Error::NotFound(format!("checkpoint {cp_id} not found")))?
61    } else {
62        let source_branch = request.source_branch.as_deref().unwrap_or("main");
63        engine
64            .storage
65            .get_latest_checkpoint(&request.thread_id, source_branch)
66            .await?
67            .ok_or_else(|| {
68                Error::NotFound(format!(
69                    "no checkpoint found on branch '{source_branch}' for thread '{}'",
70                    request.thread_id
71                ))
72            })?
73    };
74
75    // Create new checkpoint on the new branch with parent = source
76    let id = Uuid::now_v7();
77    let new_cp = Checkpoint {
78        id,
79        thread_id: request.thread_id.clone(),
80        agent_id: agent_id.clone(),
81        parent_id: Some(source_cp.id),
82        branch_name: request.new_branch_name.clone(),
83        state_snapshot: source_cp.state_snapshot.clone(),
84        state_diff: None,
85        memory_refs: source_cp.memory_refs.clone(),
86        event_cursor: source_cp.event_cursor,
87        label: Some(format!("branch from {}", source_cp.id)),
88        created_at: now.clone(),
89        metadata: serde_json::json!({"branched_from": source_cp.id.to_string()}),
90    };
91
92    engine.storage.insert_checkpoint(&new_cp).await?;
93
94    // Emit Branch event
95    let event = super::event_builder::build_event(
96        engine,
97        &agent_id,
98        EventType::Branch,
99        serde_json::json!({
100            "checkpoint_id": id.to_string(),
101            "new_branch": request.new_branch_name,
102            "source_checkpoint": source_cp.id.to_string(),
103        }),
104        &id.to_string(),
105        Some(request.thread_id),
106    )
107    .await;
108    if let Err(e) = engine.storage.insert_event(&event).await {
109        tracing::error!(event_id = %event.id, error = %e, "failed to insert audit event");
110    }
111
112    Ok(BranchResponse {
113        checkpoint_id: id,
114        branch_name: request.new_branch_name,
115        source_checkpoint_id: source_cp.id,
116    })
117}