Skip to main content

do_memory_mcp/server/tools/
episode_update.rs

1//! Episode update MCP tool
2//!
3//! Provides the `update_episode` tool for modifying episodes through the MCP server.
4
5use crate::server::MemoryMCPServer;
6use anyhow::{Result, anyhow};
7use serde_json::{Value, json};
8
9use tracing::debug;
10use tracing::info;
11use uuid::Uuid;
12
13impl MemoryMCPServer {
14    /// Update an existing episode with new information
15    ///
16    /// This tool allows AI agents to programmatically update episodes
17    /// to modify descriptions, tags, and metadata.
18    ///
19    /// # Arguments (from JSON)
20    ///
21    /// * `episode_id` - UUID of the episode to update
22    /// * `description` - Optional new task description
23    /// * `add_tags` - Optional tags to add to the episode
24    /// * `remove_tags` - Optional tags to remove from the episode
25    /// * `set_tags` - Optional tags to replace all existing tags
26    /// * `metadata` - Optional metadata key-value pairs to merge
27    pub async fn update_episode_tool(&self, args: Value) -> Result<Value> {
28        debug!("Updating episode with args: {}", args);
29
30        // Extract episode_id
31        let episode_id_str = args
32            .get("episode_id")
33            .and_then(|v| v.as_str())
34            .ok_or_else(|| anyhow!("Missing required field: episode_id"))?
35            .to_string();
36
37        let uuid = Uuid::parse_str(&episode_id_str)
38            .map_err(|_| anyhow!("Invalid episode ID format: {}", episode_id_str))?;
39
40        // Extract optional fields
41        let description = args
42            .get("description")
43            .and_then(|v| v.as_str())
44            .map(|s| s.to_string());
45
46        let add_tags = args.get("add_tags").and_then(|v| v.as_array()).map(|arr| {
47            arr.iter()
48                .filter_map(|v| v.as_str().map(|s| s.to_string()))
49                .collect()
50        });
51
52        let remove_tags = args
53            .get("remove_tags")
54            .and_then(|v| v.as_array())
55            .map(|arr| {
56                arr.iter()
57                    .filter_map(|v| v.as_str().map(|s| s.to_string()))
58                    .collect()
59            });
60
61        let set_tags = args.get("set_tags").and_then(|v| v.as_array()).map(|arr| {
62            arr.iter()
63                .filter_map(|v| v.as_str().map(|s| s.to_string()))
64                .collect()
65        });
66
67        let metadata = args.get("metadata").and_then(|v| v.as_object()).map(|obj| {
68            obj.iter()
69                .filter_map(|(k, v)| v.as_str().map(|s| (k.clone(), s.to_string())))
70                .collect()
71        });
72
73        // Track updated fields
74        let mut updated_fields = Vec::new();
75
76        // Update description if provided
77        if let Some(desc) = description {
78            self.memory
79                .update_episode(uuid, Some(desc.clone()), None)
80                .await
81                .map_err(|e| anyhow!("Failed to update description: {}", e))?;
82            updated_fields.push("description");
83        }
84
85        // Update metadata if provided
86        if let Some(meta) = metadata {
87            self.memory
88                .update_episode(uuid, None, Some(meta))
89                .await
90                .map_err(|e| anyhow!("Failed to update metadata: {}", e))?;
91            updated_fields.push("metadata");
92        }
93
94        // Add tags
95        if let Some(tags) = add_tags {
96            self.memory
97                .add_episode_tags(uuid, tags)
98                .await
99                .map_err(|e| anyhow!("Failed to add tags: {}", e))?;
100            updated_fields.push("tags (added)");
101        }
102
103        // Remove tags
104        if let Some(tags) = remove_tags {
105            self.memory
106                .remove_episode_tags(uuid, tags)
107                .await
108                .map_err(|e| anyhow!("Failed to remove tags: {}", e))?;
109            updated_fields.push("tags (removed)");
110        }
111
112        // Set tags (replace all)
113        if let Some(tags) = set_tags {
114            self.memory
115                .set_episode_tags(uuid, tags)
116                .await
117                .map_err(|e| anyhow!("Failed to set tags: {}", e))?;
118            updated_fields.push("tags (set)");
119        }
120
121        if updated_fields.is_empty() {
122            return Ok(json!({
123                "success": true,
124                "episode_id": episode_id_str,
125                "message": "No changes specified. Episode unchanged.",
126                "updated_fields": []
127            }));
128        }
129
130        info!(
131            episode_id = %episode_id_str,
132            fields = %updated_fields.join(", "),
133            "Updated episode via MCP"
134        );
135
136        Ok(json!({
137            "success": true,
138            "episode_id": episode_id_str,
139            "message": format!("Successfully updated episode {}. Updated fields: {}", episode_id_str, updated_fields.join(", ")),
140            "updated_fields": updated_fields
141        }))
142    }
143}
144
145#[cfg(test)]
146mod tests {
147    use super::*;
148    use crate::types::SandboxConfig;
149    use do_memory_core::{TaskContext, TaskType};
150    use std::sync::Arc;
151
152    #[tokio::test]
153    async fn test_update_episode_description() {
154        let memory = do_memory_core::SelfLearningMemory::new();
155        let server = MemoryMCPServer::new(SandboxConfig::default(), Arc::new(memory))
156            .await
157            .unwrap();
158
159        // Create an episode first
160        let episode_id = server
161            .memory()
162            .start_episode(
163                "Original description".to_string(),
164                TaskContext::default(),
165                TaskType::Testing,
166            )
167            .await;
168
169        // Update description
170        let result = server
171            .update_episode_tool(json!({
172                "episode_id": episode_id.to_string(),
173                "description": "Updated description"
174            }))
175            .await
176            .unwrap();
177
178        assert!(result["success"].as_bool().unwrap());
179        assert!(
180            result["message"]
181                .as_str()
182                .unwrap()
183                .contains("Successfully updated")
184        );
185
186        // Verify update
187        let episode = server.memory().get_episode(episode_id).await.unwrap();
188        assert_eq!(episode.task_description, "Updated description");
189    }
190
191    #[tokio::test]
192    async fn test_update_episode_tags() {
193        let memory = do_memory_core::SelfLearningMemory::new();
194        let server = MemoryMCPServer::new(SandboxConfig::default(), Arc::new(memory))
195            .await
196            .unwrap();
197
198        // Create an episode
199        let episode_id = server
200            .memory()
201            .start_episode(
202                "Test task".to_string(),
203                TaskContext::default(),
204                TaskType::Testing,
205            )
206            .await;
207
208        // Add tags
209        let result = server
210            .update_episode_tool(json!({
211                "episode_id": episode_id.to_string(),
212                "add_tags": ["tag1", "tag2"]
213            }))
214            .await
215            .unwrap();
216
217        assert!(result["success"].as_bool().unwrap());
218
219        // Verify tags
220        let tags = server.memory().get_episode_tags(episode_id).await.unwrap();
221        assert_eq!(tags.len(), 2);
222    }
223
224    #[tokio::test]
225    async fn test_update_episode_invalid_id() {
226        let memory = do_memory_core::SelfLearningMemory::new();
227        let server = MemoryMCPServer::new(SandboxConfig::default(), Arc::new(memory))
228            .await
229            .unwrap();
230
231        let result = server
232            .update_episode_tool(json!({
233                "episode_id": "invalid-uuid",
234                "description": "Test"
235            }))
236            .await;
237
238        assert!(result.is_err());
239        assert!(
240            result
241                .unwrap_err()
242                .to_string()
243                .contains("Invalid episode ID format")
244        );
245    }
246}