do_memory_mcp/mcp/tools/checkpoint/
tool.rs1use super::types::{
4 CheckpointEpisodeInput, CheckpointEpisodeOutput, GetHandoffPackInput, GetHandoffPackOutput,
5 HandoffPackResponse, ResumeFromHandoffInput, ResumeFromHandoffOutput,
6};
7use crate::types::Tool;
8use anyhow::{Result, anyhow};
9use do_memory_core::SelfLearningMemory;
10use do_memory_core::memory::checkpoint::{
11 checkpoint_episode, checkpoint_episode_with_note, get_handoff_pack, resume_from_handoff,
12};
13use serde_json::json;
14use std::sync::Arc;
15use tracing::{info, instrument};
16use uuid::Uuid;
17
18#[derive(Clone)]
20pub struct CheckpointTools {
21 memory: Arc<SelfLearningMemory>,
22}
23
24impl CheckpointTools {
25 pub fn new(memory: Arc<SelfLearningMemory>) -> Self {
27 Self { memory }
28 }
29
30 pub fn checkpoint_episode_tool() -> Tool {
32 Tool::new(
33 "checkpoint_episode".to_string(),
34 "Create a checkpoint for an in-progress episode. Use this when switching agents, pausing long-running tasks, or before risky operations.".to_string(),
35 json!({
36 "type": "object",
37 "properties": {
38 "episode_id": {
39 "type": "string",
40 "description": "Episode ID to checkpoint (UUID format)"
41 },
42 "reason": {
43 "type": "string",
44 "description": "Why the checkpoint is being created (e.g., 'Agent switch', 'Long-running task pause')"
45 },
46 "note": {
47 "type": "string",
48 "description": "Optional additional context about the checkpoint"
49 }
50 },
51 "required": ["episode_id", "reason"]
52 }),
53 )
54 }
55
56 pub fn get_handoff_pack_tool() -> Tool {
58 Tool::new(
59 "get_handoff_pack".to_string(),
60 "Generate a handoff pack from a checkpoint. Contains lessons learned, relevant patterns, and suggested next steps for transferring work to another agent.".to_string(),
61 json!({
62 "type": "object",
63 "properties": {
64 "checkpoint_id": {
65 "type": "string",
66 "description": "Checkpoint ID to generate handoff pack from (UUID format)"
67 }
68 },
69 "required": ["checkpoint_id"]
70 }),
71 )
72 }
73
74 pub fn resume_from_handoff_tool() -> Tool {
76 Tool::new(
77 "resume_from_handoff".to_string(),
78 "Resume work from a handoff pack. Creates a new episode initialized with context from a previous checkpoint for seamless task continuation.".to_string(),
79 json!({
80 "type": "object",
81 "properties": {
82 "handoff_pack": {
83 "type": "object",
84 "description": "The handoff pack to resume from (obtained from get_handoff_pack)"
85 }
86 },
87 "required": ["handoff_pack"]
88 }),
89 )
90 }
91
92 #[instrument(skip(self, input), fields(episode_id = %input.episode_id))]
109 pub async fn checkpoint_episode(
110 &self,
111 input: CheckpointEpisodeInput,
112 ) -> Result<CheckpointEpisodeOutput> {
113 info!(
114 "Creating checkpoint for episode: {} (reason: {})",
115 input.episode_id, input.reason
116 );
117
118 let episode_id =
120 Uuid::parse_str(&input.episode_id).map_err(|e| anyhow!("Invalid episode ID: {}", e))?;
121
122 let checkpoint = if let Some(note) = &input.note {
124 checkpoint_episode_with_note(
125 &self.memory,
126 episode_id,
127 input.reason.clone(),
128 Some(note.clone()),
129 )
130 .await
131 } else {
132 checkpoint_episode(&self.memory, episode_id, input.reason.clone()).await
133 };
134
135 match checkpoint {
136 Ok(checkpoint) => {
137 info!(
138 "Created checkpoint {} for episode {} at step {}",
139 checkpoint.checkpoint_id, episode_id, checkpoint.step_number
140 );
141
142 Ok(CheckpointEpisodeOutput {
143 success: true,
144 checkpoint_id: checkpoint.checkpoint_id.to_string(),
145 episode_id: input.episode_id,
146 step_number: checkpoint.step_number,
147 message: format!(
148 "Created checkpoint at step {} with reason: {}",
149 checkpoint.step_number, input.reason
150 ),
151 })
152 }
153 Err(e) => {
154 info!("Failed to create checkpoint: {}", e);
155 Ok(CheckpointEpisodeOutput {
156 success: false,
157 checkpoint_id: String::new(),
158 episode_id: input.episode_id,
159 step_number: 0,
160 message: format!("Failed to create checkpoint: {}", e),
161 })
162 }
163 }
164 }
165
166 #[instrument(skip(self, input), fields(checkpoint_id = %input.checkpoint_id))]
182 pub async fn get_handoff_pack(
183 &self,
184 input: GetHandoffPackInput,
185 ) -> Result<GetHandoffPackOutput> {
186 info!(
187 "Getting handoff pack for checkpoint: {}",
188 input.checkpoint_id
189 );
190
191 let checkpoint_id = Uuid::parse_str(&input.checkpoint_id)
193 .map_err(|e| anyhow!("Invalid checkpoint ID: {}", e))?;
194
195 match get_handoff_pack(&self.memory, checkpoint_id).await {
197 Ok(handoff) => {
198 info!(
199 "Generated handoff pack with {} steps, {} patterns, {} heuristics",
200 handoff.step_count(),
201 handoff.relevant_patterns.len(),
202 handoff.relevant_heuristics.len()
203 );
204
205 Ok(GetHandoffPackOutput {
206 success: true,
207 handoff_pack: Some(HandoffPackResponse::from(handoff)),
208 message: "Successfully generated handoff pack".to_string(),
209 })
210 }
211 Err(e) => {
212 info!("Failed to get handoff pack: {}", e);
213 Ok(GetHandoffPackOutput {
214 success: false,
215 handoff_pack: None,
216 message: format!("Failed to get handoff pack: {}", e),
217 })
218 }
219 }
220 }
221
222 #[instrument(skip(self, input))]
232 pub async fn resume_from_handoff(
233 &self,
234 input: ResumeFromHandoffInput,
235 ) -> Result<ResumeFromHandoffOutput> {
236 info!(
237 "Resuming from handoff pack: checkpoint_id={}",
238 input.handoff_pack.checkpoint_id
239 );
240
241 let checkpoint_id = input.handoff_pack.checkpoint_id;
242 let episode_id = input.handoff_pack.episode_id;
243
244 match resume_from_handoff(&self.memory, input.handoff_pack).await {
246 Ok(new_episode_id) => {
247 info!("Created new episode {} for resumption", new_episode_id);
248
249 Ok(ResumeFromHandoffOutput {
250 success: true,
251 new_episode_id: Some(new_episode_id.to_string()),
252 checkpoint_id: checkpoint_id.to_string(),
253 original_episode_id: episode_id.to_string(),
254 message: format!(
255 "Successfully resumed work in new episode {}",
256 new_episode_id
257 ),
258 })
259 }
260 Err(e) => {
261 info!("Failed to resume from handoff: {}", e);
262 Ok(ResumeFromHandoffOutput {
263 success: false,
264 new_episode_id: None,
265 checkpoint_id: checkpoint_id.to_string(),
266 original_episode_id: episode_id.to_string(),
267 message: format!("Failed to resume from handoff: {}", e),
268 })
269 }
270 }
271 }
272}
273
274#[cfg(test)]
275mod tests {
276 use super::*;
277
278 #[tokio::test]
279 async fn test_checkpoint_episode_invalid_uuid() {
280 let memory = Arc::new(SelfLearningMemory::new());
281 let tools = CheckpointTools::new(memory);
282
283 let input = CheckpointEpisodeInput {
284 episode_id: "not-a-uuid".to_string(),
285 reason: "test".to_string(),
286 note: None,
287 };
288
289 let result = tools.checkpoint_episode(input).await;
290 assert!(result.is_err());
291 }
292
293 #[tokio::test]
294 async fn test_get_handoff_pack_invalid_uuid() {
295 let memory = Arc::new(SelfLearningMemory::new());
296 let tools = CheckpointTools::new(memory);
297
298 let input = GetHandoffPackInput {
299 checkpoint_id: "not-a-uuid".to_string(),
300 };
301
302 let result = tools.get_handoff_pack(input).await;
303 assert!(result.is_err());
304 }
305}