do_memory_mcp/server/tools/
episode_update.rs1use crate::server::MemoryMCPServer;
6use anyhow::{Result, anyhow};
7use serde_json::{Value, json};
8
9use tracing::debug;
10use tracing::info;
11use uuid::Uuid;
12
13impl MemoryMCPServer {
14 pub async fn update_episode_tool(&self, args: Value) -> Result<Value> {
28 debug!("Updating episode with args: {}", args);
29
30 let episode_id_str = args
32 .get("episode_id")
33 .and_then(|v| v.as_str())
34 .ok_or_else(|| anyhow!("Missing required field: episode_id"))?
35 .to_string();
36
37 let uuid = Uuid::parse_str(&episode_id_str)
38 .map_err(|_| anyhow!("Invalid episode ID format: {}", episode_id_str))?;
39
40 let description = args
42 .get("description")
43 .and_then(|v| v.as_str())
44 .map(|s| s.to_string());
45
46 let add_tags = args.get("add_tags").and_then(|v| v.as_array()).map(|arr| {
47 arr.iter()
48 .filter_map(|v| v.as_str().map(|s| s.to_string()))
49 .collect()
50 });
51
52 let remove_tags = args
53 .get("remove_tags")
54 .and_then(|v| v.as_array())
55 .map(|arr| {
56 arr.iter()
57 .filter_map(|v| v.as_str().map(|s| s.to_string()))
58 .collect()
59 });
60
61 let set_tags = args.get("set_tags").and_then(|v| v.as_array()).map(|arr| {
62 arr.iter()
63 .filter_map(|v| v.as_str().map(|s| s.to_string()))
64 .collect()
65 });
66
67 let metadata = args.get("metadata").and_then(|v| v.as_object()).map(|obj| {
68 obj.iter()
69 .filter_map(|(k, v)| v.as_str().map(|s| (k.clone(), s.to_string())))
70 .collect()
71 });
72
73 let mut updated_fields = Vec::new();
75
76 if let Some(desc) = description {
78 self.memory
79 .update_episode(uuid, Some(desc.clone()), None)
80 .await
81 .map_err(|e| anyhow!("Failed to update description: {}", e))?;
82 updated_fields.push("description");
83 }
84
85 if let Some(meta) = metadata {
87 self.memory
88 .update_episode(uuid, None, Some(meta))
89 .await
90 .map_err(|e| anyhow!("Failed to update metadata: {}", e))?;
91 updated_fields.push("metadata");
92 }
93
94 if let Some(tags) = add_tags {
96 self.memory
97 .add_episode_tags(uuid, tags)
98 .await
99 .map_err(|e| anyhow!("Failed to add tags: {}", e))?;
100 updated_fields.push("tags (added)");
101 }
102
103 if let Some(tags) = remove_tags {
105 self.memory
106 .remove_episode_tags(uuid, tags)
107 .await
108 .map_err(|e| anyhow!("Failed to remove tags: {}", e))?;
109 updated_fields.push("tags (removed)");
110 }
111
112 if let Some(tags) = set_tags {
114 self.memory
115 .set_episode_tags(uuid, tags)
116 .await
117 .map_err(|e| anyhow!("Failed to set tags: {}", e))?;
118 updated_fields.push("tags (set)");
119 }
120
121 if updated_fields.is_empty() {
122 return Ok(json!({
123 "success": true,
124 "episode_id": episode_id_str,
125 "message": "No changes specified. Episode unchanged.",
126 "updated_fields": []
127 }));
128 }
129
130 info!(
131 episode_id = %episode_id_str,
132 fields = %updated_fields.join(", "),
133 "Updated episode via MCP"
134 );
135
136 Ok(json!({
137 "success": true,
138 "episode_id": episode_id_str,
139 "message": format!("Successfully updated episode {}. Updated fields: {}", episode_id_str, updated_fields.join(", ")),
140 "updated_fields": updated_fields
141 }))
142 }
143}
144
145#[cfg(test)]
146mod tests {
147 use super::*;
148 use crate::types::SandboxConfig;
149 use do_memory_core::{TaskContext, TaskType};
150 use std::sync::Arc;
151
152 #[tokio::test]
153 async fn test_update_episode_description() {
154 let memory = do_memory_core::SelfLearningMemory::new();
155 let server = MemoryMCPServer::new(SandboxConfig::default(), Arc::new(memory))
156 .await
157 .unwrap();
158
159 let episode_id = server
161 .memory()
162 .start_episode(
163 "Original description".to_string(),
164 TaskContext::default(),
165 TaskType::Testing,
166 )
167 .await;
168
169 let result = server
171 .update_episode_tool(json!({
172 "episode_id": episode_id.to_string(),
173 "description": "Updated description"
174 }))
175 .await
176 .unwrap();
177
178 assert!(result["success"].as_bool().unwrap());
179 assert!(
180 result["message"]
181 .as_str()
182 .unwrap()
183 .contains("Successfully updated")
184 );
185
186 let episode = server.memory().get_episode(episode_id).await.unwrap();
188 assert_eq!(episode.task_description, "Updated description");
189 }
190
191 #[tokio::test]
192 async fn test_update_episode_tags() {
193 let memory = do_memory_core::SelfLearningMemory::new();
194 let server = MemoryMCPServer::new(SandboxConfig::default(), Arc::new(memory))
195 .await
196 .unwrap();
197
198 let episode_id = server
200 .memory()
201 .start_episode(
202 "Test task".to_string(),
203 TaskContext::default(),
204 TaskType::Testing,
205 )
206 .await;
207
208 let result = server
210 .update_episode_tool(json!({
211 "episode_id": episode_id.to_string(),
212 "add_tags": ["tag1", "tag2"]
213 }))
214 .await
215 .unwrap();
216
217 assert!(result["success"].as_bool().unwrap());
218
219 let tags = server.memory().get_episode_tags(episode_id).await.unwrap();
221 assert_eq!(tags.len(), 2);
222 }
223
224 #[tokio::test]
225 async fn test_update_episode_invalid_id() {
226 let memory = do_memory_core::SelfLearningMemory::new();
227 let server = MemoryMCPServer::new(SandboxConfig::default(), Arc::new(memory))
228 .await
229 .unwrap();
230
231 let result = server
232 .update_episode_tool(json!({
233 "episode_id": "invalid-uuid",
234 "description": "Test"
235 }))
236 .await;
237
238 assert!(result.is_err());
239 assert!(
240 result
241 .unwrap_err()
242 .to_string()
243 .contains("Invalid episode ID format")
244 );
245 }
246}