Skip to main content

do_memory_mcp/mcp/tools/checkpoint/
tool.rs

1//! Checkpoint tool implementations.
2
3use super::types::{
4    CheckpointEpisodeInput, CheckpointEpisodeOutput, GetHandoffPackInput, GetHandoffPackOutput,
5    HandoffPackResponse, ResumeFromHandoffInput, ResumeFromHandoffOutput,
6};
7use crate::types::Tool;
8use anyhow::{Result, anyhow};
9use do_memory_core::SelfLearningMemory;
10use do_memory_core::memory::checkpoint::{
11    checkpoint_episode, checkpoint_episode_with_note, get_handoff_pack, resume_from_handoff,
12};
13use serde_json::json;
14use std::sync::Arc;
15use tracing::{info, instrument};
16use uuid::Uuid;
17
18/// Checkpoint tools for episode handoffs
19#[derive(Clone)]
20pub struct CheckpointTools {
21    memory: Arc<SelfLearningMemory>,
22}
23
24impl CheckpointTools {
25    /// Create a new checkpoint tools instance
26    pub fn new(memory: Arc<SelfLearningMemory>) -> Self {
27        Self { memory }
28    }
29
30    /// Get the tool definition for checkpoint_episode
31    pub fn checkpoint_episode_tool() -> Tool {
32        Tool::new(
33            "checkpoint_episode".to_string(),
34            "Create a checkpoint for an in-progress episode. Use this when switching agents, pausing long-running tasks, or before risky operations.".to_string(),
35            json!({
36                "type": "object",
37                "properties": {
38                    "episode_id": {
39                        "type": "string",
40                        "description": "Episode ID to checkpoint (UUID format)"
41                    },
42                    "reason": {
43                        "type": "string",
44                        "description": "Why the checkpoint is being created (e.g., 'Agent switch', 'Long-running task pause')"
45                    },
46                    "note": {
47                        "type": "string",
48                        "description": "Optional additional context about the checkpoint"
49                    }
50                },
51                "required": ["episode_id", "reason"]
52            }),
53        )
54    }
55
56    /// Get the tool definition for get_handoff_pack
57    pub fn get_handoff_pack_tool() -> Tool {
58        Tool::new(
59            "get_handoff_pack".to_string(),
60            "Generate a handoff pack from a checkpoint. Contains lessons learned, relevant patterns, and suggested next steps for transferring work to another agent.".to_string(),
61            json!({
62                "type": "object",
63                "properties": {
64                    "checkpoint_id": {
65                        "type": "string",
66                        "description": "Checkpoint ID to generate handoff pack from (UUID format)"
67                    }
68                },
69                "required": ["checkpoint_id"]
70            }),
71        )
72    }
73
74    /// Get the tool definition for resume_from_handoff
75    pub fn resume_from_handoff_tool() -> Tool {
76        Tool::new(
77            "resume_from_handoff".to_string(),
78            "Resume work from a handoff pack. Creates a new episode initialized with context from a previous checkpoint for seamless task continuation.".to_string(),
79            json!({
80                "type": "object",
81                "properties": {
82                    "handoff_pack": {
83                        "type": "object",
84                        "description": "The handoff pack to resume from (obtained from get_handoff_pack)"
85                    }
86                },
87                "required": ["handoff_pack"]
88            }),
89        )
90    }
91
92    /// Create a checkpoint for an episode
93    ///
94    /// # Arguments
95    ///
96    /// * `input` - Input containing episode ID and reason
97    ///
98    /// # Returns
99    ///
100    /// Returns the checkpoint ID and step number.
101    ///
102    /// # Errors
103    ///
104    /// Returns an error if:
105    /// - Episode ID is invalid (not a UUID)
106    /// - Episode does not exist
107    /// - Episode is already completed
108    #[instrument(skip(self, input), fields(episode_id = %input.episode_id))]
109    pub async fn checkpoint_episode(
110        &self,
111        input: CheckpointEpisodeInput,
112    ) -> Result<CheckpointEpisodeOutput> {
113        info!(
114            "Creating checkpoint for episode: {} (reason: {})",
115            input.episode_id, input.reason
116        );
117
118        // Parse episode ID
119        let episode_id =
120            Uuid::parse_str(&input.episode_id).map_err(|e| anyhow!("Invalid episode ID: {}", e))?;
121
122        // Create checkpoint
123        let checkpoint = if let Some(note) = &input.note {
124            checkpoint_episode_with_note(
125                &self.memory,
126                episode_id,
127                input.reason.clone(),
128                Some(note.clone()),
129            )
130            .await
131        } else {
132            checkpoint_episode(&self.memory, episode_id, input.reason.clone()).await
133        };
134
135        match checkpoint {
136            Ok(checkpoint) => {
137                info!(
138                    "Created checkpoint {} for episode {} at step {}",
139                    checkpoint.checkpoint_id, episode_id, checkpoint.step_number
140                );
141
142                Ok(CheckpointEpisodeOutput {
143                    success: true,
144                    checkpoint_id: checkpoint.checkpoint_id.to_string(),
145                    episode_id: input.episode_id,
146                    step_number: checkpoint.step_number,
147                    message: format!(
148                        "Created checkpoint at step {} with reason: {}",
149                        checkpoint.step_number, input.reason
150                    ),
151                })
152            }
153            Err(e) => {
154                info!("Failed to create checkpoint: {}", e);
155                Ok(CheckpointEpisodeOutput {
156                    success: false,
157                    checkpoint_id: String::new(),
158                    episode_id: input.episode_id,
159                    step_number: 0,
160                    message: format!("Failed to create checkpoint: {}", e),
161                })
162            }
163        }
164    }
165
166    /// Get a handoff pack from a checkpoint
167    ///
168    /// # Arguments
169    ///
170    /// * `input` - Input containing checkpoint ID
171    ///
172    /// # Returns
173    ///
174    /// Returns the handoff pack with lessons learned and guidance.
175    ///
176    /// # Errors
177    ///
178    /// Returns an error if:
179    /// - Checkpoint ID is invalid (not a UUID)
180    /// - Checkpoint does not exist
181    #[instrument(skip(self, input), fields(checkpoint_id = %input.checkpoint_id))]
182    pub async fn get_handoff_pack(
183        &self,
184        input: GetHandoffPackInput,
185    ) -> Result<GetHandoffPackOutput> {
186        info!(
187            "Getting handoff pack for checkpoint: {}",
188            input.checkpoint_id
189        );
190
191        // Parse checkpoint ID
192        let checkpoint_id = Uuid::parse_str(&input.checkpoint_id)
193            .map_err(|e| anyhow!("Invalid checkpoint ID: {}", e))?;
194
195        // Get handoff pack
196        match get_handoff_pack(&self.memory, checkpoint_id).await {
197            Ok(handoff) => {
198                info!(
199                    "Generated handoff pack with {} steps, {} patterns, {} heuristics",
200                    handoff.step_count(),
201                    handoff.relevant_patterns.len(),
202                    handoff.relevant_heuristics.len()
203                );
204
205                Ok(GetHandoffPackOutput {
206                    success: true,
207                    handoff_pack: Some(HandoffPackResponse::from(handoff)),
208                    message: "Successfully generated handoff pack".to_string(),
209                })
210            }
211            Err(e) => {
212                info!("Failed to get handoff pack: {}", e);
213                Ok(GetHandoffPackOutput {
214                    success: false,
215                    handoff_pack: None,
216                    message: format!("Failed to get handoff pack: {}", e),
217                })
218            }
219        }
220    }
221
222    /// Resume work from a handoff pack
223    ///
224    /// # Arguments
225    ///
226    /// * `input` - Input containing the handoff pack
227    ///
228    /// # Returns
229    ///
230    /// Returns the new episode ID for resumption.
231    #[instrument(skip(self, input))]
232    pub async fn resume_from_handoff(
233        &self,
234        input: ResumeFromHandoffInput,
235    ) -> Result<ResumeFromHandoffOutput> {
236        info!(
237            "Resuming from handoff pack: checkpoint_id={}",
238            input.handoff_pack.checkpoint_id
239        );
240
241        let checkpoint_id = input.handoff_pack.checkpoint_id;
242        let episode_id = input.handoff_pack.episode_id;
243
244        // Resume from handoff
245        match resume_from_handoff(&self.memory, input.handoff_pack).await {
246            Ok(new_episode_id) => {
247                info!("Created new episode {} for resumption", new_episode_id);
248
249                Ok(ResumeFromHandoffOutput {
250                    success: true,
251                    new_episode_id: Some(new_episode_id.to_string()),
252                    checkpoint_id: checkpoint_id.to_string(),
253                    original_episode_id: episode_id.to_string(),
254                    message: format!(
255                        "Successfully resumed work in new episode {}",
256                        new_episode_id
257                    ),
258                })
259            }
260            Err(e) => {
261                info!("Failed to resume from handoff: {}", e);
262                Ok(ResumeFromHandoffOutput {
263                    success: false,
264                    new_episode_id: None,
265                    checkpoint_id: checkpoint_id.to_string(),
266                    original_episode_id: episode_id.to_string(),
267                    message: format!("Failed to resume from handoff: {}", e),
268                })
269            }
270        }
271    }
272}
273
274#[cfg(test)]
275mod tests {
276    use super::*;
277
278    #[tokio::test]
279    async fn test_checkpoint_episode_invalid_uuid() {
280        let memory = Arc::new(SelfLearningMemory::new());
281        let tools = CheckpointTools::new(memory);
282
283        let input = CheckpointEpisodeInput {
284            episode_id: "not-a-uuid".to_string(),
285            reason: "test".to_string(),
286            note: None,
287        };
288
289        let result = tools.checkpoint_episode(input).await;
290        assert!(result.is_err());
291    }
292
293    #[tokio::test]
294    async fn test_get_handoff_pack_invalid_uuid() {
295        let memory = Arc::new(SelfLearningMemory::new());
296        let tools = CheckpointTools::new(memory);
297
298        let input = GetHandoffPackInput {
299            checkpoint_id: "not-a-uuid".to_string(),
300        };
301
302        let result = tools.get_handoff_pack(input).await;
303        assert!(result.is_err());
304    }
305}