use std::sync::Arc;
use tokio::sync::mpsc;
use tokio_util::sync::CancellationToken;
use yoagent::agent_loop::{agent_loop, AgentLoopConfig};
use yoagent::provider::mock::*;
use yoagent::provider::MockProvider;
use yoagent::sub_agent::SubAgentTool;
use yoagent::*;
fn make_config(provider: MockProvider) -> AgentLoopConfig {
AgentLoopConfig {
provider: std::sync::Arc::new(provider),
model: "mock".into(),
api_key: "test".into(),
thinking_level: ThinkingLevel::Off,
max_tokens: None,
temperature: None,
model_config: None,
convert_to_llm: None,
transform_context: None,
get_steering_messages: None,
get_follow_up_messages: None,
context_config: None,
compaction_strategy: None,
execution_limits: None,
cache_config: CacheConfig::default(),
tool_execution: ToolExecutionStrategy::default(),
retry_config: yoagent::RetryConfig::default(),
before_turn: None,
after_turn: None,
on_error: None,
input_filters: vec![],
turn_delay: None,
}
}
fn collect_events(mut rx: mpsc::UnboundedReceiver<AgentEvent>) -> Vec<AgentEvent> {
let mut events = Vec::new();
while let Ok(e) = rx.try_recv() {
events.push(e);
}
events
}
#[tokio::test]
async fn test_sub_agent_basic() {
let sub_provider = Arc::new(MockProvider::text("Research result: Rust is great"));
let sub_agent = SubAgentTool::new("researcher", sub_provider)
.with_description("Researches topics")
.with_system_prompt("You are a research assistant.")
.with_model("mock")
.with_api_key("test");
let params = serde_json::json!({"task": "Tell me about Rust"});
let result = sub_agent
.execute(
params,
ToolContext {
tool_call_id: "tc-1".into(),
tool_name: "researcher".into(),
cancel: CancellationToken::new(),
on_update: None,
on_progress: None,
},
)
.await
.expect("sub-agent should succeed");
let text = match &result.content[0] {
Content::Text { text } => text.as_str(),
_ => panic!("Expected text content"),
};
assert_eq!(text, "Research result: Rust is great");
assert_eq!(result.details["sub_agent"], "researcher");
}
struct EchoTool;
#[async_trait::async_trait]
impl AgentTool for EchoTool {
fn name(&self) -> &str {
"echo"
}
fn label(&self) -> &str {
"Echo"
}
fn description(&self) -> &str {
"Echoes input"
}
fn parameters_schema(&self) -> serde_json::Value {
serde_json::json!({
"type": "object",
"properties": {
"text": {"type": "string"}
}
})
}
async fn execute(
&self,
params: serde_json::Value,
_ctx: ToolContext,
) -> Result<ToolResult, ToolError> {
let text = params["text"].as_str().unwrap_or("(empty)");
Ok(ToolResult {
content: vec![Content::Text {
text: format!("echoed: {}", text),
}],
details: serde_json::Value::Null,
})
}
}
#[tokio::test]
async fn test_sub_agent_with_tools() {
let sub_provider = Arc::new(MockProvider::new(vec![
MockResponse::ToolCalls(vec![MockToolCall {
provider_metadata: None,
name: "echo".into(),
arguments: serde_json::json!({"text": "hello"}),
}]),
MockResponse::Text("The echo returned: echoed: hello".into()),
]));
let echo_tool: Arc<dyn AgentTool> = Arc::new(EchoTool);
let sub_agent = SubAgentTool::new("echo_agent", sub_provider)
.with_description("Agent that echoes")
.with_system_prompt("Use the echo tool.")
.with_model("mock")
.with_api_key("test")
.with_tools(vec![echo_tool]);
let params = serde_json::json!({"task": "Echo hello"});
let result = sub_agent
.execute(
params,
ToolContext {
tool_call_id: "tc-1".into(),
tool_name: "echo_agent".into(),
cancel: CancellationToken::new(),
on_update: None,
on_progress: None,
},
)
.await
.expect("sub-agent should succeed");
let text = match &result.content[0] {
Content::Text { text } => text.as_str(),
_ => panic!("Expected text content"),
};
assert_eq!(text, "The echo returned: echoed: hello");
}
#[tokio::test]
async fn test_sub_agent_cancellation() {
let sub_provider = Arc::new(MockProvider::text("Should not appear"));
let sub_agent = SubAgentTool::new("cancelled_agent", sub_provider)
.with_model("mock")
.with_api_key("test");
let cancel = CancellationToken::new();
cancel.cancel();
let params = serde_json::json!({"task": "Do something"});
let result = sub_agent
.execute(
params,
ToolContext {
tool_call_id: "tc-1".into(),
tool_name: "cancelled_agent".into(),
cancel,
on_update: None,
on_progress: None,
},
)
.await
.expect("should return a result even when cancelled");
let text = match &result.content[0] {
Content::Text { text } => text.as_str(),
_ => panic!("Expected text content"),
};
assert_ne!(
text, "Should not appear",
"Sub-agent ran despite cancellation"
);
}
#[tokio::test]
async fn test_sub_agent_max_turns() {
let sub_provider = Arc::new(MockProvider::new(vec![
MockResponse::ToolCalls(vec![MockToolCall {
provider_metadata: None,
name: "echo".into(),
arguments: serde_json::json!({"text": "loop"}),
}]),
MockResponse::Text("Should not reach".into()),
]));
let echo_tool: Arc<dyn AgentTool> = Arc::new(EchoTool);
let sub_agent = SubAgentTool::new("limited_agent", sub_provider)
.with_model("mock")
.with_api_key("test")
.with_tools(vec![echo_tool])
.with_max_turns(1);
let params = serde_json::json!({"task": "Keep going"});
let result = sub_agent
.execute(
params,
ToolContext {
tool_call_id: "tc-1".into(),
tool_name: "limited_agent".into(),
cancel: CancellationToken::new(),
on_update: None,
on_progress: None,
},
)
.await
.expect("sub-agent should succeed");
let text = match &result.content[0] {
Content::Text { text } => text.as_str(),
_ => panic!("Expected text content"),
};
assert_ne!(text, "Should not reach");
}
#[tokio::test]
async fn test_sub_agent_parallel() {
struct SlowProvider {
delay_ms: u64,
text: String,
}
#[async_trait::async_trait]
impl yoagent::provider::StreamProvider for SlowProvider {
async fn stream(
&self,
_config: yoagent::provider::StreamConfig,
tx: tokio::sync::mpsc::UnboundedSender<yoagent::provider::StreamEvent>,
cancel: tokio_util::sync::CancellationToken,
) -> Result<Message, yoagent::provider::ProviderError> {
if cancel.is_cancelled() {
return Err(yoagent::provider::ProviderError::Cancelled);
}
tokio::time::sleep(std::time::Duration::from_millis(self.delay_ms)).await;
let _ = tx.send(yoagent::provider::StreamEvent::Start);
let _ = tx.send(yoagent::provider::StreamEvent::TextDelta {
content_index: 0,
delta: self.text.clone(),
});
let msg = Message::Assistant {
content: vec![Content::Text {
text: self.text.clone(),
}],
stop_reason: StopReason::Stop,
model: "slow".into(),
provider: "slow".into(),
usage: Usage::default(),
timestamp: yoagent::now_ms(),
error_message: None,
};
let _ = tx.send(yoagent::provider::StreamEvent::Done {
message: msg.clone(),
});
Ok(msg)
}
}
let sub_a = SubAgentTool::new(
"agent_a",
Arc::new(SlowProvider {
delay_ms: 50,
text: "Result A".into(),
}),
)
.with_model("slow")
.with_api_key("test");
let sub_b = SubAgentTool::new(
"agent_b",
Arc::new(SlowProvider {
delay_ms: 50,
text: "Result B".into(),
}),
)
.with_model("slow")
.with_api_key("test");
let parent_provider = MockProvider::new(vec![
MockResponse::ToolCalls(vec![
MockToolCall {
provider_metadata: None,
name: "agent_a".into(),
arguments: serde_json::json!({"task": "Do A"}),
},
MockToolCall {
provider_metadata: None,
name: "agent_b".into(),
arguments: serde_json::json!({"task": "Do B"}),
},
]),
MockResponse::Text("Both sub-agents completed.".into()),
]);
let config = make_config(parent_provider);
let mut context = AgentContext {
system_prompt: "You are a coordinator.".into(),
messages: Vec::new(),
tools: vec![Box::new(sub_a), Box::new(sub_b)],
};
let prompt = AgentMessage::Llm(Message::user("Run both agents"));
let (tx, rx) = mpsc::unbounded_channel();
let cancel = CancellationToken::new();
let start = std::time::Instant::now();
let new_messages = agent_loop(vec![prompt], &mut context, &config, tx, cancel).await;
let elapsed = start.elapsed();
let _events = collect_events(rx);
let tool_results: Vec<_> = new_messages
.iter()
.filter(|m| m.role() == "toolResult")
.collect();
assert_eq!(tool_results.len(), 2);
assert!(
elapsed.as_millis() < 130,
"Parallel sub-agents took {}ms, expected <130ms",
elapsed.as_millis()
);
}
#[tokio::test]
async fn test_sub_agent_event_forwarding() {
let sub_provider = Arc::new(MockProvider::text("Sub-agent done"));
let sub_agent = SubAgentTool::new("streaming_agent", sub_provider)
.with_model("mock")
.with_api_key("test");
let params = serde_json::json!({"task": "Do work"});
let updates: Arc<std::sync::Mutex<Vec<String>>> = Arc::new(std::sync::Mutex::new(Vec::new()));
let updates_clone = updates.clone();
let on_update: ToolUpdateFn = Arc::new(move |result: ToolResult| {
if let Some(Content::Text { text }) = result.content.first() {
updates_clone.lock().unwrap().push(text.clone());
}
});
let result = sub_agent
.execute(
params,
ToolContext {
tool_call_id: "tc-1".into(),
tool_name: "streaming_agent".into(),
cancel: CancellationToken::new(),
on_update: Some(on_update),
on_progress: None,
},
)
.await
.expect("sub-agent should succeed");
let text = match &result.content[0] {
Content::Text { text } => text.as_str(),
_ => panic!("Expected text content"),
};
assert_eq!(text, "Sub-agent done");
let collected = updates.lock().unwrap();
assert!(
!collected.is_empty(),
"Expected on_update to receive streaming events"
);
assert!(
collected.iter().any(|t| t.contains("Sub-agent done")),
"Expected text delta in updates, got: {:?}",
*collected
);
}
#[tokio::test]
async fn test_sub_agent_missing_task_parameter() {
let sub_provider = Arc::new(MockProvider::text("Should not run"));
let sub_agent = SubAgentTool::new("test_agent", sub_provider)
.with_model("mock")
.with_api_key("test");
let params = serde_json::json!({});
let result = sub_agent
.execute(
params,
ToolContext {
tool_call_id: "tc-1".into(),
tool_name: "test_agent".into(),
cancel: CancellationToken::new(),
on_update: None,
on_progress: None,
},
)
.await;
assert!(result.is_err());
match result.unwrap_err() {
ToolError::InvalidArgs(msg) => assert!(msg.contains("task")),
other => panic!("Expected InvalidArgs, got: {:?}", other),
}
}
struct CapturingProvider {
captured: Arc<std::sync::Mutex<String>>,
}
#[async_trait::async_trait]
impl yoagent::provider::StreamProvider for CapturingProvider {
async fn stream(
&self,
config: yoagent::provider::StreamConfig,
tx: mpsc::UnboundedSender<yoagent::provider::StreamEvent>,
_cancel: CancellationToken,
) -> Result<Message, yoagent::provider::ProviderError> {
*self.captured.lock().unwrap() = config.system_prompt.clone();
let _ = tx.send(yoagent::provider::StreamEvent::Start);
let msg = Message::Assistant {
content: vec![Content::Text {
text: "done".into(),
}],
stop_reason: StopReason::Stop,
model: "mock".into(),
provider: "mock".into(),
usage: Usage::default(),
timestamp: yoagent::now_ms(),
error_message: None,
};
let _ = tx.send(yoagent::provider::StreamEvent::Done {
message: msg.clone(),
});
Ok(msg)
}
}
struct SkillsDir(std::path::PathBuf);
impl SkillsDir {
fn with_one_skill(unique: &str, name: &str, description: &str) -> Self {
let dir = std::env::temp_dir().join(format!("yoagent-test-skills-{unique}"));
let _ = std::fs::remove_dir_all(&dir);
let skill_dir = dir.join(name);
std::fs::create_dir_all(&skill_dir).unwrap();
std::fs::write(
skill_dir.join("SKILL.md"),
format!("---\nname: {name}\ndescription: {description}\n---\n\nBody.\n"),
)
.unwrap();
Self(dir)
}
fn load(&self) -> yoagent::skills::SkillSet {
yoagent::skills::SkillSet::load(&[self.0.to_string_lossy().to_string()]).unwrap()
}
}
impl Drop for SkillsDir {
fn drop(&mut self) {
let _ = std::fs::remove_dir_all(&self.0);
}
}
async fn capture_system_prompt(
build: impl FnOnce(Arc<CapturingProvider>) -> SubAgentTool,
) -> String {
let captured = Arc::new(std::sync::Mutex::new(String::new()));
let provider = Arc::new(CapturingProvider {
captured: captured.clone(),
});
let sub_agent = build(provider);
sub_agent
.execute(
serde_json::json!({"task": "do work"}),
ToolContext {
tool_call_id: "tc-1".into(),
tool_name: "sub".into(),
cancel: CancellationToken::new(),
on_update: None,
on_progress: None,
},
)
.await
.expect("sub-agent should succeed");
let prompt = captured.lock().unwrap().clone();
prompt
}
#[tokio::test]
async fn test_sub_agent_with_skills() {
let skills_dir = SkillsDir::with_one_skill(
"with-skills",
"research",
"How to call the search and read APIs",
);
let skills = skills_dir.load();
assert_eq!(skills.len(), 1, "expected the research skill to load");
let prompt = capture_system_prompt(|provider| {
SubAgentTool::new("researcher", provider)
.with_system_prompt("You are a research assistant.")
.with_model("mock")
.with_api_key("test")
.with_skills(skills)
})
.await;
assert!(
prompt.contains("You are a research assistant."),
"base system prompt missing, got: {prompt}"
);
assert!(
prompt.contains("<available_skills>") && prompt.contains("<name>research</name>"),
"skills index not injected into sub-agent system prompt, got: {prompt}"
);
}
#[tokio::test]
async fn test_sub_agent_with_skills_empty_base_prompt() {
let skills_dir = SkillsDir::with_one_skill("empty-base", "research", "desc");
let skills = skills_dir.load();
let expected = skills.format_for_prompt();
assert!(!expected.is_empty());
let prompt = capture_system_prompt(|provider| {
SubAgentTool::new("researcher", provider)
.with_model("mock")
.with_api_key("test")
.with_skills(skills)
})
.await;
assert_eq!(
prompt, expected,
"with empty base prompt, the skills index should be the whole prompt verbatim"
);
}
#[tokio::test]
async fn test_sub_agent_with_empty_skillset_is_noop() {
let prompt = capture_system_prompt(|provider| {
SubAgentTool::new("researcher", provider)
.with_system_prompt("Base prompt.")
.with_model("mock")
.with_api_key("test")
.with_skills(yoagent::skills::SkillSet::empty())
})
.await;
assert_eq!(prompt, "Base prompt.", "empty SkillSet should be a no-op");
}
#[tokio::test]
async fn test_sub_agent_skills_before_shared_state() {
let skills_dir = SkillsDir::with_one_skill("ordering", "research", "desc");
let skills = skills_dir.load();
let state = SharedState::new();
let prompt = capture_system_prompt(|provider| {
SubAgentTool::new("researcher", provider)
.with_system_prompt("Base prompt.")
.with_model("mock")
.with_api_key("test")
.with_skills(skills)
.with_shared_state(state)
})
.await;
let skills_at = prompt
.find("<available_skills>")
.expect("skills index present");
let shared_at = prompt
.find("## Shared State")
.expect("shared-state block present");
assert!(
skills_at < shared_at,
"skills index should precede the shared-state block, got: {prompt}"
);
}
#[tokio::test]
async fn test_sub_agent_in_parent_loop() {
let sub_provider = Arc::new(MockProvider::text("42 is the answer"));
let sub_agent = SubAgentTool::new("calculator", sub_provider)
.with_description("Calculates things")
.with_model("mock")
.with_api_key("test");
let parent_provider = MockProvider::new(vec![
MockResponse::ToolCalls(vec![MockToolCall {
provider_metadata: None,
name: "calculator".into(),
arguments: serde_json::json!({"task": "What is 6*7?"}),
}]),
MockResponse::Text("The calculator says: 42 is the answer".into()),
]);
let config = make_config(parent_provider);
let mut context = AgentContext {
system_prompt: "You are a coordinator.".into(),
messages: Vec::new(),
tools: vec![Box::new(sub_agent)],
};
let prompt = AgentMessage::Llm(Message::user("What is 6*7?"));
let (tx, rx) = mpsc::unbounded_channel();
let cancel = CancellationToken::new();
let new_messages = agent_loop(vec![prompt], &mut context, &config, tx, cancel).await;
let events = collect_events(rx);
assert_eq!(new_messages.len(), 4);
assert_eq!(new_messages[0].role(), "user");
assert_eq!(new_messages[1].role(), "assistant");
assert_eq!(new_messages[2].role(), "toolResult");
assert_eq!(new_messages[3].role(), "assistant");
if let AgentMessage::Llm(Message::ToolResult { content, .. }) = &new_messages[2] {
let text = match &content[0] {
Content::Text { text } => text.as_str(),
_ => panic!("Expected text content"),
};
assert_eq!(text, "42 is the answer");
} else {
panic!("Expected tool result message");
}
let has_tool_start = events
.iter()
.any(|e| matches!(e, AgentEvent::ToolExecutionStart { tool_name, .. } if tool_name == "calculator"));
let has_tool_end = events
.iter()
.any(|e| matches!(e, AgentEvent::ToolExecutionEnd { tool_name, .. } if tool_name == "calculator"));
assert!(has_tool_start);
assert!(has_tool_end);
}