use crate::agent::AgentCallback;
use crate::error::ReactError;
use chrono::{DateTime, Utc};
use futures::future::BoxFuture;
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::sync::{Arc, RwLock};
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
pub enum NodeType {
Orchestrator,
Worker,
Planner,
External,
Tool,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct TopologyNode {
pub id: String,
pub label: String,
pub node_type: NodeType,
pub metadata: HashMap<String, String>,
}
impl TopologyNode {
pub fn new(id: impl Into<String>, node_type: NodeType) -> Self {
let id_str: String = id.into();
Self {
label: id_str.clone(),
id: id_str,
node_type,
metadata: HashMap::new(),
}
}
pub fn with_label(mut self, label: impl Into<String>) -> Self {
self.label = label.into();
self
}
pub fn with_metadata(mut self, key: impl Into<String>, value: impl Into<String>) -> Self {
self.metadata.insert(key.into(), value.into());
self
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct TopologyEdge {
pub from: String,
pub to: String,
pub label: Option<String>,
pub call_count: u64,
pub last_called: DateTime<Utc>,
pub total_duration_ms: u64,
}
pub struct TopologyTracker {
nodes: RwLock<HashMap<String, TopologyNode>>,
edges: RwLock<HashMap<(String, String), TopologyEdge>>,
}
impl TopologyTracker {
pub fn new() -> Self {
Self {
nodes: RwLock::new(HashMap::new()),
edges: RwLock::new(HashMap::new()),
}
}
pub fn add_node(&self, node: TopologyNode) {
if let Ok(mut nodes) = self.nodes.write() {
nodes.insert(node.id.clone(), node);
}
}
pub fn record_call(&self, from: &str, to: &str, label: &str) {
self.record_call_with_duration(from, to, label, 0);
}
pub fn record_call_with_duration(&self, from: &str, to: &str, label: &str, duration_ms: u64) {
self.ensure_node(from, NodeType::Worker);
self.ensure_node(to, NodeType::Tool);
if let Ok(mut edges) = self.edges.write() {
let key = (from.to_string(), to.to_string());
let edge = edges.entry(key).or_insert_with(|| TopologyEdge {
from: from.to_string(),
to: to.to_string(),
label: Some(label.to_string()),
call_count: 0,
last_called: Utc::now(),
total_duration_ms: 0,
});
edge.call_count += 1;
edge.last_called = Utc::now();
edge.total_duration_ms += duration_ms;
edge.label = Some(label.to_string());
}
}
fn ensure_node(&self, id: &str, default_type: NodeType) {
if let Ok(mut nodes) = self.nodes.write() {
nodes
.entry(id.to_string())
.or_insert_with(|| TopologyNode::new(id, default_type));
}
}
pub fn nodes(&self) -> Vec<TopologyNode> {
self.nodes
.read()
.map(|n| n.values().cloned().collect())
.unwrap_or_default()
}
pub fn edges(&self) -> Vec<TopologyEdge> {
self.edges
.read()
.map(|e| e.values().cloned().collect())
.unwrap_or_default()
}
pub fn stats(&self) -> TopologyStats {
let nodes = self.nodes.read().map(|n| n.len()).unwrap_or(0);
let edges = self.edges.read().map(|e| e.len()).unwrap_or(0);
let total_calls = self
.edges
.read()
.map(|e| e.values().map(|edge| edge.call_count).sum())
.unwrap_or(0);
TopologyStats {
node_count: nodes,
edge_count: edges,
total_calls,
}
}
pub fn clear(&self) {
if let Ok(mut nodes) = self.nodes.write() {
nodes.clear();
}
if let Ok(mut edges) = self.edges.write() {
edges.clear();
}
}
pub fn to_mermaid(&self) -> String {
let mut lines = vec!["graph TD".to_string()];
if let Ok(nodes) = self.nodes.read() {
for node in nodes.values() {
let icon = match node.node_type {
NodeType::Orchestrator => "🎯",
NodeType::Worker => "⚙️",
NodeType::Planner => "📐",
NodeType::External => "🌐",
NodeType::Tool => "🔧",
};
let shape = match node.node_type {
NodeType::Orchestrator => {
format!("{}{{{{\"{} {}\"}}}}", node.id, icon, node.label)
}
NodeType::Tool => format!("{}[/\"{} {}\"/]", node.id, icon, node.label),
NodeType::External => format!("{}((\"{} {}\"))", node.id, icon, node.label),
_ => format!("{}[\"{} {}\"]", node.id, icon, node.label),
};
lines.push(format!(" {}", shape));
}
}
if let Ok(edges) = self.edges.read() {
for edge in edges.values() {
let label = if let Some(l) = &edge.label {
let truncated = if l.len() > 30 {
format!("{}...", &l[..30])
} else {
l.clone()
};
if edge.call_count > 1 {
format!("{} (x{})", truncated, edge.call_count)
} else {
truncated
}
} else {
format!("x{}", edge.call_count)
};
lines.push(format!(" {} -->|\"{}\"| {}", edge.from, label, edge.to));
}
}
lines.join("\n")
}
pub fn to_dot(&self) -> String {
let mut lines = vec![
"digraph AgentTopology {".to_string(),
" rankdir=TB;".to_string(),
" node [shape=box, style=rounded];".to_string(),
String::new(),
];
if let Ok(nodes) = self.nodes.read() {
for node in nodes.values() {
let (shape, color) = match node.node_type {
NodeType::Orchestrator => ("diamond", "#FFB74D"),
NodeType::Worker => ("box", "#64B5F6"),
NodeType::Planner => ("hexagon", "#81C784"),
NodeType::External => ("ellipse", "#CE93D8"),
NodeType::Tool => ("component", "#90A4AE"),
};
lines.push(format!(
" \"{}\" [label=\"{}\", shape={}, style=\"filled,rounded\", fillcolor=\"{}\"];",
node.id, node.label, shape, color
));
}
}
lines.push(String::new());
if let Ok(edges) = self.edges.read() {
for edge in edges.values() {
let label = if let Some(l) = &edge.label {
let truncated = if l.len() > 30 {
format!("{}...", &l[..30])
} else {
l.clone()
};
format!("{} (x{})", truncated, edge.call_count)
} else {
format!("x{}", edge.call_count)
};
let penwidth = if edge.call_count > 5 {
"3.0"
} else if edge.call_count > 1 {
"2.0"
} else {
"1.0"
};
lines.push(format!(
" \"{}\" -> \"{}\" [label=\"{}\", penwidth={}];",
edge.from, edge.to, label, penwidth
));
}
}
lines.push("}".to_string());
lines.join("\n")
}
pub fn to_json(&self) -> Result<String, serde_json::Error> {
let data = TopologyData {
nodes: self.nodes(),
edges: self.edges(),
stats: self.stats(),
};
serde_json::to_string_pretty(&data)
}
}
impl Default for TopologyTracker {
fn default() -> Self {
Self::new()
}
}
#[derive(Debug, Serialize, Deserialize)]
pub struct TopologyData {
pub nodes: Vec<TopologyNode>,
pub edges: Vec<TopologyEdge>,
pub stats: TopologyStats,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct TopologyStats {
pub node_count: usize,
pub edge_count: usize,
pub total_calls: u64,
}
pub struct TopologyCallback {
tracker: Arc<TopologyTracker>,
}
impl TopologyCallback {
pub fn new(tracker: Arc<TopologyTracker>) -> Self {
Self { tracker }
}
}
impl AgentCallback for TopologyCallback {
fn on_tool_start<'a>(
&'a self,
agent: &'a str,
tool: &'a str,
_args: &'a serde_json::Value,
) -> BoxFuture<'a, ()> {
Box::pin(async move {
self.tracker
.add_node(TopologyNode::new(agent, NodeType::Worker));
self.tracker
.add_node(TopologyNode::new(tool, NodeType::Tool));
self.tracker.record_call(agent, tool, "call");
})
}
fn on_tool_end<'a>(
&'a self,
_agent: &'a str,
_tool: &'a str,
_result: &'a str,
) -> BoxFuture<'a, ()> {
Box::pin(async {})
}
fn on_tool_error<'a>(
&'a self,
agent: &'a str,
tool: &'a str,
_err: &'a ReactError,
) -> BoxFuture<'a, ()> {
Box::pin(async move {
self.tracker.record_call(agent, tool, "call failed");
})
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_topology_tracker_basic() {
let tracker = TopologyTracker::new();
tracker.add_node(TopologyNode::new("agent_a", NodeType::Orchestrator));
tracker.add_node(TopologyNode::new("agent_b", NodeType::Worker));
tracker.record_call("agent_a", "agent_b", "dispatch task");
let stats = tracker.stats();
assert_eq!(stats.node_count, 2);
assert_eq!(stats.edge_count, 1);
assert_eq!(stats.total_calls, 1);
}
#[test]
fn test_topology_multiple_calls() {
let tracker = TopologyTracker::new();
tracker.add_node(TopologyNode::new("agent", NodeType::Worker));
tracker.record_call("agent", "calc", "1+1");
tracker.record_call("agent", "calc", "2+2");
tracker.record_call("agent", "calc", "3+3");
let edges = tracker.edges();
assert_eq!(edges.len(), 1);
assert_eq!(edges[0].call_count, 3);
}
#[test]
fn test_topology_to_mermaid() {
let tracker = TopologyTracker::new();
tracker.add_node(TopologyNode::new("orchestrator", NodeType::Orchestrator));
tracker.add_node(TopologyNode::new("worker", NodeType::Worker));
tracker.record_call("orchestrator", "worker", "execute");
let mermaid = tracker.to_mermaid();
assert!(mermaid.starts_with("graph TD"));
assert!(mermaid.contains("orchestrator"));
assert!(mermaid.contains("worker"));
}
#[test]
fn test_topology_to_dot() {
let tracker = TopologyTracker::new();
tracker.add_node(TopologyNode::new("agent", NodeType::Worker));
tracker.record_call("agent", "tool1", "use");
let dot = tracker.to_dot();
assert!(dot.starts_with("digraph"));
assert!(dot.contains("agent"));
assert!(dot.contains("tool1"));
}
#[test]
fn test_topology_to_json() {
let tracker = TopologyTracker::new();
tracker.add_node(TopologyNode::new("a", NodeType::Worker));
tracker.record_call("a", "b", "call");
let json = tracker.to_json().unwrap();
let data: TopologyData = serde_json::from_str(&json).unwrap();
assert!(!data.nodes.is_empty());
assert!(!data.edges.is_empty());
}
#[test]
fn test_topology_clear() {
let tracker = TopologyTracker::new();
tracker.add_node(TopologyNode::new("a", NodeType::Worker));
tracker.record_call("a", "b", "call");
assert!(tracker.stats().node_count > 0);
tracker.clear();
assert_eq!(tracker.stats().node_count, 0);
assert_eq!(tracker.stats().edge_count, 0);
}
#[test]
fn test_topology_node_builder() {
let node = TopologyNode::new("test", NodeType::External)
.with_label("Test Service")
.with_metadata("url", "http://localhost");
assert_eq!(node.id, "test");
assert_eq!(node.label, "Test Service");
assert_eq!(node.metadata.get("url").unwrap(), "http://localhost");
}
}