#![allow(clippy::expect_used, clippy::unwrap_used)]
use std::sync::atomic::{AtomicUsize, Ordering};
use std::sync::Arc;
use async_trait::async_trait;
use capo_agent::tools::{ToolCtx, ToolProgressChunk};
use capo_agent::{AppBuilder, Config, UiEvent, UserMessage};
use futures::StreamExt;
use motosan_agent_loop::{ChatOutput, LlmClient, LlmResponse, Message};
use motosan_agent_tool::{Tool, ToolContext, ToolDef, ToolResult};
use tempfile::tempdir;
struct ChattyTool {
ctx: Arc<ToolCtx>,
}
impl Tool for ChattyTool {
fn def(&self) -> ToolDef {
ToolDef {
name: "chatty".into(),
description: "emits progress chunks".into(),
input_schema: serde_json::json!({"type":"object"}),
}
}
fn call(
&self,
_args: serde_json::Value,
_ctx: &ToolContext,
) -> std::pin::Pin<Box<dyn std::future::Future<Output = ToolResult> + Send + '_>> {
let tx = self.ctx.progress_tx.clone();
Box::pin(async move {
tx.send(ToolProgressChunk::Status("starting".into()))
.await
.expect("progress send 1");
tx.send(ToolProgressChunk::Stdout(b"hello\n".to_vec()))
.await
.expect("progress send 2");
ToolResult::text("done")
})
}
}
struct ScriptedLlm {
turn: AtomicUsize,
}
#[async_trait]
impl LlmClient for ScriptedLlm {
async fn chat(
&self,
_messages: &[Message],
_tools: &[ToolDef],
) -> motosan_agent_loop::Result<ChatOutput> {
let turn = self.turn.fetch_add(1, Ordering::SeqCst);
let response = match turn {
0 => LlmResponse::ToolCalls(vec![motosan_agent_loop::ToolCallItem {
id: "c1".into(),
name: "chatty".into(),
args: serde_json::json!({}),
}]),
_ => LlmResponse::Message("ok".into()),
};
Ok(ChatOutput::new(response))
}
}
#[tokio::test]
async fn progress_chunks_flow_into_ui_stream() {
let dir = tempdir().unwrap();
let mut cfg = Config::default();
cfg.anthropic.api_key = Some("sk-unused".into());
let app = AppBuilder::new()
.with_config(cfg)
.with_cwd(dir.path())
.with_llm(Arc::new(ScriptedLlm {
turn: AtomicUsize::new(0),
}))
.build_with_custom_tools(|ctx| {
vec![Arc::new(ChattyTool { ctx: Arc::new(ctx) }) as Arc<dyn Tool>]
})
.await
.expect("build");
let events: Vec<UiEvent> = app
.send_user_message(UserMessage::text("go"))
.collect()
.await;
let progress_count = events
.iter()
.filter(|e| matches!(e, UiEvent::ToolCallProgress { .. }))
.count();
assert!(
progress_count >= 2,
"expected >=2 ToolCallProgress events, got {progress_count} in {events:?}"
);
}