use crate::mcp::tools::episode_relationships::types::{
AddEpisodeRelationshipInput, AddEpisodeRelationshipOutput, CheckRelationshipExistsInput,
CheckRelationshipExistsOutput, DependencyGraphInput, DependencyGraphOutput,
FindRelatedEpisodesInput, FindRelatedEpisodesOutput, GetEpisodeRelationshipsInput,
GetEpisodeRelationshipsOutput, GetTopologicalOrderInput, GetTopologicalOrderOutput,
RelatedEpisode, RelationshipEdge, RelationshipNode, RemoveEpisodeRelationshipInput,
RemoveEpisodeRelationshipOutput, TopologicalEpisode, ValidateNoCyclesInput,
ValidateNoCyclesOutput,
};
use anyhow::{Result, anyhow};
use do_memory_core::SelfLearningMemory;
use do_memory_core::episode::{
Direction, EpisodeRelationship, RelationshipMetadata, RelationshipType,
};
use do_memory_core::memory::relationship_query::RelationshipFilter;
use std::collections::HashMap;
use std::sync::Arc;
use tracing::{debug, info, instrument};
use uuid::Uuid;
mod graph_ops;
use graph_ops::relationship_to_edge;
#[derive(Clone)]
pub struct EpisodeRelationshipTools {
memory: Arc<SelfLearningMemory>,
}
impl EpisodeRelationshipTools {
pub fn new(memory: Arc<SelfLearningMemory>) -> Self {
Self { memory }
}
#[instrument(skip(self, input), fields(from = %input.from_episode_id, to = %input.to_episode_id, rel_type = %input.relationship_type))]
pub async fn add_relationship(
&self,
input: AddEpisodeRelationshipInput,
) -> Result<AddEpisodeRelationshipOutput> {
info!(
"Adding {} relationship from {} to {}",
input.relationship_type, input.from_episode_id, input.to_episode_id
);
let from_id = Uuid::parse_str(&input.from_episode_id)
.map_err(|e| anyhow!("Invalid from_episode_id: {}", e))?;
let to_id = Uuid::parse_str(&input.to_episode_id)
.map_err(|e| anyhow!("Invalid to_episode_id: {}", e))?;
let rel_type = RelationshipType::parse(&input.relationship_type)
.map_err(|e| anyhow!("Invalid relationship_type: {}", e))?;
let mut metadata = RelationshipMetadata::new();
if let Some(reason) = &input.reason {
metadata.reason = Some(reason.clone());
}
if let Some(created_by) = &input.created_by {
metadata.created_by = Some(created_by.clone());
}
if let Some(priority) = input.priority {
metadata.priority = Some(priority);
}
let relationship_id = self
.memory
.add_episode_relationship(from_id, to_id, rel_type, metadata)
.await?;
info!(
"Successfully created relationship {} from {} to {}",
relationship_id, from_id, to_id
);
Ok(AddEpisodeRelationshipOutput {
success: true,
relationship_id: relationship_id.to_string(),
from_episode_id: input.from_episode_id,
to_episode_id: input.to_episode_id,
relationship_type: input.relationship_type,
message: format!("Relationship created successfully: {}", relationship_id),
})
}
#[instrument(skip(self, input), fields(relationship_id = %input.relationship_id))]
pub async fn remove_relationship(
&self,
input: RemoveEpisodeRelationshipInput,
) -> Result<RemoveEpisodeRelationshipOutput> {
info!("Removing relationship: {}", input.relationship_id);
let rel_id = Uuid::parse_str(&input.relationship_id)
.map_err(|e| anyhow!("Invalid relationship_id: {}", e))?;
self.memory.remove_episode_relationship(rel_id).await?;
info!("Successfully removed relationship: {}", rel_id);
Ok(RemoveEpisodeRelationshipOutput {
success: true,
relationship_id: input.relationship_id,
message: "Relationship removed successfully".to_string(),
})
}
#[instrument(skip(self, input), fields(episode_id = %input.episode_id))]
pub async fn get_relationships(
&self,
input: GetEpisodeRelationshipsInput,
) -> Result<GetEpisodeRelationshipsOutput> {
debug!("Getting relationships for episode: {}", input.episode_id);
let episode_id =
Uuid::parse_str(&input.episode_id).map_err(|e| anyhow!("Invalid episode_id: {}", e))?;
let direction = match input.direction.as_deref() {
Some("outgoing") => Direction::Outgoing,
Some("incoming") => Direction::Incoming,
_ => Direction::Both,
};
let relationships = self
.memory
.get_episode_relationships(episode_id, direction)
.await?;
let mut outgoing = Vec::new();
let mut incoming = Vec::new();
for rel in relationships {
let edge = relationship_to_edge(&rel);
if let Some(ref filter_type) = input.relationship_type {
let filter_rel_type = RelationshipType::parse(filter_type)
.map_err(|e| anyhow!("Invalid relationship_type filter: {}", e))?;
if rel.relationship_type != filter_rel_type {
continue;
}
}
if rel.from_episode_id == episode_id {
outgoing.push(edge);
} else {
incoming.push(edge);
}
}
let total_count = outgoing.len() + incoming.len();
debug!(
"Found {} outgoing and {} incoming relationships for episode {}",
outgoing.len(),
incoming.len(),
episode_id
);
Ok(GetEpisodeRelationshipsOutput {
success: true,
episode_id: input.episode_id,
outgoing,
incoming,
total_count,
message: format!("Found {} relationship(s)", total_count),
})
}
#[instrument(skip(self, input), fields(episode_id = %input.episode_id))]
pub async fn find_related(
&self,
input: FindRelatedEpisodesInput,
) -> Result<FindRelatedEpisodesOutput> {
info!("Finding related episodes for: {}", input.episode_id);
let episode_id =
Uuid::parse_str(&input.episode_id).map_err(|e| anyhow!("Invalid episode_id: {}", e))?;
let mut filter = RelationshipFilter::new();
if let Some(ref rel_type_str) = input.relationship_type {
let rel_type = RelationshipType::parse(rel_type_str)
.map_err(|e| anyhow!("Invalid relationship_type: {}", e))?;
filter = filter.with_type(rel_type);
}
if let Some(limit) = input.limit {
filter = filter.with_limit(limit);
}
let related_ids = self
.memory
.find_related_episodes(episode_id, filter)
.await?;
let relationships = self
.memory
.get_episode_relationships(episode_id, Direction::Both)
.await?;
let mut related_episodes = Vec::new();
for related_id in related_ids {
if let Ok(episode) = self.memory.get_episode(related_id).await {
if let Some(rel) = relationships
.iter()
.find(|r| r.from_episode_id == related_id || r.to_episode_id == related_id)
{
let direction = if rel.from_episode_id == episode_id {
"outgoing"
} else {
"incoming"
};
related_episodes.push(RelatedEpisode {
episode_id: related_id.to_string(),
task_description: episode.task_description.clone(),
task_type: format!("{:?}", episode.task_type),
relationship_type: rel.relationship_type.as_str().to_string(),
direction: direction.to_string(),
reason: if input.include_metadata.unwrap_or(false) {
rel.metadata.reason.clone()
} else {
None
},
priority: if input.include_metadata.unwrap_or(false) {
rel.metadata.priority
} else {
None
},
});
}
}
}
let count = related_episodes.len();
info!("Found {} related episodes for {}", count, episode_id);
Ok(FindRelatedEpisodesOutput {
success: true,
episode_id: input.episode_id,
related_episodes,
count,
message: format!("Found {} related episode(s)", count),
})
}
#[instrument(skip(self, input), fields(from = %input.from_episode_id, to = %input.to_episode_id))]
pub async fn check_exists(
&self,
input: CheckRelationshipExistsInput,
) -> Result<CheckRelationshipExistsOutput> {
debug!(
"Checking if relationship exists from {} to {}",
input.from_episode_id, input.to_episode_id
);
let from_id = Uuid::parse_str(&input.from_episode_id)
.map_err(|e| anyhow!("Invalid from_episode_id: {}", e))?;
let to_id = Uuid::parse_str(&input.to_episode_id)
.map_err(|e| anyhow!("Invalid to_episode_id: {}", e))?;
let rel_type = RelationshipType::parse(&input.relationship_type)
.map_err(|e| anyhow!("Invalid relationship_type: {}", e))?;
let exists = self
.memory
.relationship_exists(from_id, to_id, rel_type)
.await?;
Ok(CheckRelationshipExistsOutput {
success: true,
exists,
from_episode_id: input.from_episode_id,
to_episode_id: input.to_episode_id,
relationship_type: input.relationship_type,
message: if exists {
"Relationship exists".to_string()
} else {
"Relationship does not exist".to_string()
},
})
}
}