mnemo_core/query/
branch.rs1use serde::{Deserialize, Serialize};
2use uuid::Uuid;
3
4use crate::error::{Error, Result};
5use crate::model::checkpoint::Checkpoint;
6use crate::model::event::EventType;
7use crate::query::MnemoEngine;
8
9#[derive(Debug, Clone, Serialize, Deserialize)]
10pub struct BranchRequest {
11 pub thread_id: String,
12 pub agent_id: Option<String>,
13 pub new_branch_name: String,
14 pub source_checkpoint_id: Option<Uuid>,
15 pub source_branch: Option<String>,
16}
17
18impl BranchRequest {
19 pub fn new(thread_id: String, new_branch_name: String) -> Self {
20 Self {
21 thread_id,
22 agent_id: None,
23 new_branch_name,
24 source_checkpoint_id: None,
25 source_branch: None,
26 }
27 }
28}
29
30#[non_exhaustive]
31#[derive(Debug, Clone, Serialize, Deserialize)]
32pub struct BranchResponse {
33 pub checkpoint_id: Uuid,
34 pub branch_name: String,
35 pub source_checkpoint_id: Uuid,
36}
37
38impl BranchResponse {
39 pub fn new(checkpoint_id: Uuid, branch_name: String, source_checkpoint_id: Uuid) -> Self {
40 Self {
41 checkpoint_id,
42 branch_name,
43 source_checkpoint_id,
44 }
45 }
46}
47
48pub async fn execute(engine: &MnemoEngine, request: BranchRequest) -> Result<BranchResponse> {
49 let agent_id = request
50 .agent_id
51 .unwrap_or_else(|| engine.default_agent_id.clone());
52 let now = chrono::Utc::now().to_rfc3339();
53
54 let source_cp = if let Some(cp_id) = request.source_checkpoint_id {
56 engine
57 .storage
58 .get_checkpoint(cp_id)
59 .await?
60 .ok_or_else(|| Error::NotFound(format!("checkpoint {cp_id} not found")))?
61 } else {
62 let source_branch = request.source_branch.as_deref().unwrap_or("main");
63 engine
64 .storage
65 .get_latest_checkpoint(&request.thread_id, source_branch)
66 .await?
67 .ok_or_else(|| {
68 Error::NotFound(format!(
69 "no checkpoint found on branch '{source_branch}' for thread '{}'",
70 request.thread_id
71 ))
72 })?
73 };
74
75 let id = Uuid::now_v7();
77 let new_cp = Checkpoint {
78 id,
79 thread_id: request.thread_id.clone(),
80 agent_id: agent_id.clone(),
81 parent_id: Some(source_cp.id),
82 branch_name: request.new_branch_name.clone(),
83 state_snapshot: source_cp.state_snapshot.clone(),
84 state_diff: None,
85 memory_refs: source_cp.memory_refs.clone(),
86 event_cursor: source_cp.event_cursor,
87 label: Some(format!("branch from {}", source_cp.id)),
88 created_at: now.clone(),
89 metadata: serde_json::json!({"branched_from": source_cp.id.to_string()}),
90 };
91
92 engine.storage.insert_checkpoint(&new_cp).await?;
93
94 let event = super::event_builder::build_event(
96 engine,
97 &agent_id,
98 EventType::Branch,
99 serde_json::json!({
100 "checkpoint_id": id.to_string(),
101 "new_branch": request.new_branch_name,
102 "source_checkpoint": source_cp.id.to_string(),
103 }),
104 &id.to_string(),
105 Some(request.thread_id),
106 )
107 .await;
108 if let Err(e) = engine.storage.insert_event(&event).await {
109 tracing::error!(event_id = %event.id, error = %e, "failed to insert audit event");
110 }
111
112 Ok(BranchResponse {
113 checkpoint_id: id,
114 branch_name: request.new_branch_name,
115 source_checkpoint_id: source_cp.id,
116 })
117}