use std::sync::Arc;
use async_trait::async_trait;
use cognis_core::schemars::{self, JsonSchema};
use serde::Deserialize;
use tokio::sync::Mutex;
use cognis_core::{CognisError, Result};
use cognis_llm::tools::{Tool, ToolInput, ToolOutput};
use crate::agent::Agent;
#[derive(Debug, Deserialize, JsonSchema)]
struct SubAgentInput {
task: String,
}
pub struct SubAgentTool {
name: String,
description: String,
agent: Arc<Mutex<Agent>>,
}
impl SubAgentTool {
pub fn new(name: impl Into<String>, description: impl Into<String>, agent: Agent) -> Self {
Self {
name: name.into(),
description: description.into(),
agent: Arc::new(Mutex::new(agent)),
}
}
}
#[async_trait]
impl Tool for SubAgentTool {
fn name(&self) -> &str {
&self.name
}
fn description(&self) -> &str {
&self.description
}
fn args_schema(&self) -> Option<serde_json::Value> {
Some(serde_json::to_value(schemars::schema_for!(SubAgentInput)).unwrap_or_default())
}
async fn _run(&self, input: ToolInput) -> Result<ToolOutput> {
let parsed: SubAgentInput = serde_json::from_value(input.into_json())
.map_err(|e| CognisError::ToolValidationError(format!("subagent: {e}")))?;
let mut agent = self.agent.lock().await;
let resp = agent.run(parsed.task).await?;
Ok(ToolOutput::Text(resp.content))
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::sync::atomic::{AtomicUsize, Ordering};
use std::sync::Mutex as StdMutex;
use async_trait::async_trait;
use serde_json::json;
use cognis_core::{Message, Result, RunnableStream};
use cognis_llm::chat::{ChatOptions, ChatResponse, HealthStatus, StreamChunk, Usage};
use cognis_llm::provider::{LLMProvider, Provider};
use cognis_llm::Client;
use crate::agent::default_react_graph;
struct Constant {
content: String,
calls: AtomicUsize,
seen: StdMutex<Vec<Vec<Message>>>,
}
#[async_trait]
impl LLMProvider for Constant {
fn name(&self) -> &str {
"constant"
}
fn provider_type(&self) -> Provider {
Provider::Ollama
}
async fn chat_completion(
&self,
messages: Vec<Message>,
_opts: ChatOptions,
) -> Result<ChatResponse> {
self.calls.fetch_add(1, Ordering::SeqCst);
self.seen.lock().unwrap().push(messages);
Ok(ChatResponse {
message: Message::ai(&self.content),
usage: Some(Usage::default()),
finish_reason: "stop".into(),
model: "constant".into(),
})
}
async fn chat_completion_stream(
&self,
_: Vec<Message>,
_: ChatOptions,
) -> Result<RunnableStream<StreamChunk>> {
unimplemented!()
}
async fn health_check(&self) -> Result<HealthStatus> {
Ok(HealthStatus::Healthy { latency_ms: 0 })
}
}
#[tokio::test]
async fn delegates_task_to_inner_agent() {
let provider = Arc::new(Constant {
content: "sub-agent result".into(),
calls: AtomicUsize::new(0),
seen: StdMutex::new(Vec::new()),
});
let client = Client::new(provider.clone());
let graph = default_react_graph(client, Vec::new(), 5).unwrap();
let inner = Agent::wrap(graph);
let tool = SubAgentTool::new("research", "Researches a topic", inner);
let mut a = std::collections::HashMap::new();
a.insert("task".into(), json!("look up rust"));
let out = tool._run(ToolInput::Structured(a)).await.unwrap();
assert_eq!(out.as_string(), "sub-agent result");
assert_eq!(provider.calls.load(Ordering::SeqCst), 1);
let received = provider.seen.lock().unwrap();
assert!(received[0]
.iter()
.any(|m| matches!(m, Message::Human(h) if h.content.contains("look up rust"))));
}
}