use crate::server::MemoryMCPServer;
use anyhow::{Result, anyhow};
use serde_json::{Value, json};
use tracing::debug;
use tracing::info;
use uuid::Uuid;
impl MemoryMCPServer {
pub async fn update_episode_tool(&self, args: Value) -> Result<Value> {
debug!("Updating episode with args: {}", args);
let episode_id_str = args
.get("episode_id")
.and_then(|v| v.as_str())
.ok_or_else(|| anyhow!("Missing required field: episode_id"))?
.to_string();
let uuid = Uuid::parse_str(&episode_id_str)
.map_err(|_| anyhow!("Invalid episode ID format: {}", episode_id_str))?;
let description = args
.get("description")
.and_then(|v| v.as_str())
.map(|s| s.to_string());
let add_tags = args.get("add_tags").and_then(|v| v.as_array()).map(|arr| {
arr.iter()
.filter_map(|v| v.as_str().map(|s| s.to_string()))
.collect()
});
let remove_tags = args
.get("remove_tags")
.and_then(|v| v.as_array())
.map(|arr| {
arr.iter()
.filter_map(|v| v.as_str().map(|s| s.to_string()))
.collect()
});
let set_tags = args.get("set_tags").and_then(|v| v.as_array()).map(|arr| {
arr.iter()
.filter_map(|v| v.as_str().map(|s| s.to_string()))
.collect()
});
let metadata = args.get("metadata").and_then(|v| v.as_object()).map(|obj| {
obj.iter()
.filter_map(|(k, v)| v.as_str().map(|s| (k.clone(), s.to_string())))
.collect()
});
let mut updated_fields = Vec::new();
if let Some(desc) = description {
self.memory
.update_episode(uuid, Some(desc.clone()), None)
.await
.map_err(|e| anyhow!("Failed to update description: {}", e))?;
updated_fields.push("description");
}
if let Some(meta) = metadata {
self.memory
.update_episode(uuid, None, Some(meta))
.await
.map_err(|e| anyhow!("Failed to update metadata: {}", e))?;
updated_fields.push("metadata");
}
if let Some(tags) = add_tags {
self.memory
.add_episode_tags(uuid, tags)
.await
.map_err(|e| anyhow!("Failed to add tags: {}", e))?;
updated_fields.push("tags (added)");
}
if let Some(tags) = remove_tags {
self.memory
.remove_episode_tags(uuid, tags)
.await
.map_err(|e| anyhow!("Failed to remove tags: {}", e))?;
updated_fields.push("tags (removed)");
}
if let Some(tags) = set_tags {
self.memory
.set_episode_tags(uuid, tags)
.await
.map_err(|e| anyhow!("Failed to set tags: {}", e))?;
updated_fields.push("tags (set)");
}
if updated_fields.is_empty() {
return Ok(json!({
"success": true,
"episode_id": episode_id_str,
"message": "No changes specified. Episode unchanged.",
"updated_fields": []
}));
}
info!(
episode_id = %episode_id_str,
fields = %updated_fields.join(", "),
"Updated episode via MCP"
);
Ok(json!({
"success": true,
"episode_id": episode_id_str,
"message": format!("Successfully updated episode {}. Updated fields: {}", episode_id_str, updated_fields.join(", ")),
"updated_fields": updated_fields
}))
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::types::SandboxConfig;
use do_memory_core::{TaskContext, TaskType};
use std::sync::Arc;
#[tokio::test]
async fn test_update_episode_description() {
let memory = do_memory_core::SelfLearningMemory::new();
let server = MemoryMCPServer::new(SandboxConfig::default(), Arc::new(memory))
.await
.unwrap();
let episode_id = server
.memory()
.start_episode(
"Original description".to_string(),
TaskContext::default(),
TaskType::Testing,
)
.await;
let result = server
.update_episode_tool(json!({
"episode_id": episode_id.to_string(),
"description": "Updated description"
}))
.await
.unwrap();
assert!(result["success"].as_bool().unwrap());
assert!(
result["message"]
.as_str()
.unwrap()
.contains("Successfully updated")
);
let episode = server.memory().get_episode(episode_id).await.unwrap();
assert_eq!(episode.task_description, "Updated description");
}
#[tokio::test]
async fn test_update_episode_tags() {
let memory = do_memory_core::SelfLearningMemory::new();
let server = MemoryMCPServer::new(SandboxConfig::default(), Arc::new(memory))
.await
.unwrap();
let episode_id = server
.memory()
.start_episode(
"Test task".to_string(),
TaskContext::default(),
TaskType::Testing,
)
.await;
let result = server
.update_episode_tool(json!({
"episode_id": episode_id.to_string(),
"add_tags": ["tag1", "tag2"]
}))
.await
.unwrap();
assert!(result["success"].as_bool().unwrap());
let tags = server.memory().get_episode_tags(episode_id).await.unwrap();
assert_eq!(tags.len(), 2);
}
#[tokio::test]
async fn test_update_episode_invalid_id() {
let memory = do_memory_core::SelfLearningMemory::new();
let server = MemoryMCPServer::new(SandboxConfig::default(), Arc::new(memory))
.await
.unwrap();
let result = server
.update_episode_tool(json!({
"episode_id": "invalid-uuid",
"description": "Test"
}))
.await;
assert!(result.is_err());
assert!(
result
.unwrap_err()
.to_string()
.contains("Invalid episode ID format")
);
}
}