#![cfg(test)]
use std::sync::{Arc, Mutex};
use async_trait::async_trait;
use garudust_core::{
config::AgentConfig,
error::{AgentError, ToolError, TransportError},
memory::{MemoryCategory, MemoryContent, MemoryEntry, MemoryStore},
tool::{ApprovalDecision, CommandApprover, Tool, ToolContext},
transport::{ApiMode, ProviderTransport, StreamResult},
types::{
ContentPart, InferenceConfig, Message, Role, StopReason, TokenUsage, ToolCall, ToolResult,
ToolSchema, TransportResponse,
},
};
use garudust_tools::ToolRegistry;
use crate::{compressor::ContextCompressor, prompt_builder::build_system_prompt, Agent};
struct StaticTransport {
reply: String,
}
#[async_trait]
impl ProviderTransport for StaticTransport {
fn api_mode(&self) -> ApiMode {
ApiMode::ChatCompletions
}
async fn chat(
&self,
_messages: &[Message],
_config: &InferenceConfig,
_tools: &[ToolSchema],
) -> Result<TransportResponse, TransportError> {
Ok(TransportResponse {
content: vec![ContentPart::Text(self.reply.clone())],
tool_calls: vec![],
usage: TokenUsage::default(),
stop_reason: StopReason::EndTurn,
})
}
async fn chat_stream(
&self,
_messages: &[Message],
_config: &InferenceConfig,
_tools: &[ToolSchema],
) -> Result<StreamResult, TransportError> {
use futures::stream;
use garudust_core::types::StreamChunk;
let chunks = vec![
Ok(StreamChunk::TextDelta(self.reply.clone())),
Ok(StreamChunk::Done {
usage: TokenUsage::default(),
}),
];
Ok(Box::pin(stream::iter(chunks)))
}
}
struct NopMemory;
#[async_trait]
impl MemoryStore for NopMemory {
async fn read_memory(&self) -> Result<MemoryContent, AgentError> {
Ok(MemoryContent::default())
}
async fn write_memory(&self, _: &MemoryContent) -> Result<(), AgentError> {
Ok(())
}
async fn read_user_profile(&self) -> Result<String, AgentError> {
Ok(String::new())
}
async fn write_user_profile(&self, _: &str) -> Result<(), AgentError> {
Ok(())
}
}
struct AutoApprove;
#[async_trait]
impl CommandApprover for AutoApprove {
async fn approve(&self, _: &str, _: &str) -> ApprovalDecision {
ApprovalDecision::Approved
}
}
fn make_agent(reply: &str) -> Arc<Agent> {
let config = Arc::new(AgentConfig::default());
make_agent_with_config(reply, config)
}
fn make_agent_with_config(reply: &str, config: Arc<AgentConfig>) -> Arc<Agent> {
let transport = Arc::new(StaticTransport {
reply: reply.to_string(),
});
let tools = Arc::new(ToolRegistry::new());
let memory = Arc::new(NopMemory);
Arc::new(Agent::new(transport, tools, memory, config))
}
#[test]
fn spawn_child_has_independent_budget() {
let config = AgentConfig {
max_iterations: 5,
..AgentConfig::default()
};
let parent = make_agent_with_config("hi", Arc::new(config));
parent.consume_budget(); let child = parent.spawn_child();
assert_eq!(child.budget_remaining(), 5, "child starts with full budget");
assert_eq!(
parent.budget_remaining(),
4,
"parent budget unaffected by child creation"
);
child.consume_budget(); assert_eq!(
parent.budget_remaining(),
4,
"parent unaffected by child consumption"
);
}
#[tokio::test]
async fn run_returns_reply() {
let agent = make_agent("Hello, world!");
let result = agent
.run("say hi", Arc::new(AutoApprove), "test", None, None)
.await
.unwrap();
assert!(
result.output.starts_with("Hello, world!"),
"unexpected output: {}",
result.output
);
assert_eq!(result.iterations, 1);
}
#[tokio::test]
async fn run_streaming_emits_chunks() {
let agent = make_agent("streamed response");
let (tx, mut rx) = tokio::sync::mpsc::unbounded_channel::<String>();
let result = agent
.run_streaming(
"say something",
Arc::new(AutoApprove),
"test",
tx,
None,
None,
)
.await
.unwrap();
let mut chunks = Vec::new();
while let Ok(c) = rx.try_recv() {
chunks.push(c);
}
assert_eq!(chunks.join(""), "streamed response");
assert!(
result.output.starts_with("streamed response"),
"unexpected output: {}",
result.output
);
}
struct ScriptedTransport {
responses: Mutex<std::collections::VecDeque<TransportResponse>>,
calls: Arc<std::sync::atomic::AtomicUsize>,
}
impl ScriptedTransport {
fn new(responses: Vec<TransportResponse>) -> (Arc<Self>, Arc<std::sync::atomic::AtomicUsize>) {
let calls = Arc::new(std::sync::atomic::AtomicUsize::new(0));
let t = Arc::new(Self {
responses: Mutex::new(responses.into()),
calls: calls.clone(),
});
(t, calls)
}
fn next(&self) -> TransportResponse {
self.calls.fetch_add(1, std::sync::atomic::Ordering::SeqCst);
self.responses
.lock()
.unwrap()
.pop_front()
.unwrap_or(TransportResponse {
content: vec![ContentPart::Text(String::new())],
tool_calls: vec![],
usage: TokenUsage::default(),
stop_reason: StopReason::EndTurn,
})
}
}
#[async_trait]
impl ProviderTransport for ScriptedTransport {
fn api_mode(&self) -> ApiMode {
ApiMode::ChatCompletions
}
async fn chat(
&self,
_messages: &[Message],
_config: &InferenceConfig,
_tools: &[ToolSchema],
) -> Result<TransportResponse, TransportError> {
Ok(self.next())
}
async fn chat_stream(
&self,
_messages: &[Message],
_config: &InferenceConfig,
_tools: &[ToolSchema],
) -> Result<StreamResult, TransportError> {
unimplemented!("scripted transport is non-streaming")
}
}
struct RecordingTool {
calls: Arc<Mutex<Vec<serde_json::Value>>>,
}
#[async_trait]
impl Tool for RecordingTool {
fn name(&self) -> &'static str {
"echo"
}
fn description(&self) -> &'static str {
"Echo back the provided text"
}
fn schema(&self) -> serde_json::Value {
serde_json::json!({
"type": "object",
"properties": { "text": { "type": "string" } },
"required": ["text"]
})
}
fn toolset(&self) -> &'static str {
"test"
}
async fn execute(
&self,
params: serde_json::Value,
_ctx: &ToolContext,
) -> Result<ToolResult, ToolError> {
let text = params
.get("text")
.and_then(|v| v.as_str())
.unwrap_or_default()
.to_string();
self.calls.lock().unwrap().push(params);
Ok(ToolResult::ok("", format!("echoed: {text}")))
}
}
#[tokio::test]
async fn compressor_should_compress_respects_threshold() {
let (transport, _) = ScriptedTransport::new(vec![]);
let compressor = ContextCompressor::new(transport, "m".into()).with_context_limit(100);
let small = vec![Message::user("short")];
assert!(!compressor.should_compress(&small));
let big = vec![Message::user("x".repeat(2_000))];
assert!(compressor.should_compress(&big));
}
#[tokio::test]
async fn compressor_short_conversation_left_unchanged() {
let (transport, calls) = ScriptedTransport::new(vec![]);
let compressor = ContextCompressor::new(transport, "m".into());
let mut msgs = vec![Message::system("sys")];
for i in 0..8 {
msgs.push(Message::user(format!("u{i}")));
}
let original_len = msgs.len();
let (out, _usage) = compressor.compress(msgs).await.unwrap();
assert_eq!(
out.len(),
original_len,
"short conversation must pass through"
);
assert_eq!(
calls.load(std::sync::atomic::Ordering::SeqCst),
0,
"no LLM call on the short-circuit path"
);
}
#[tokio::test]
async fn compressor_long_conversation_summarized() {
let summary = TransportResponse {
content: vec![ContentPart::Text("CONDENSED".into())],
tool_calls: vec![],
usage: TokenUsage::default(),
stop_reason: StopReason::EndTurn,
};
let (transport, calls) = ScriptedTransport::new(vec![summary]);
let compressor = ContextCompressor::new(transport, "m".into());
let mut msgs = vec![Message::system("SYSTEM")];
msgs.push(Message::user("FIRST_TASK")); for i in 0..18 {
msgs.push(Message::assistant(format!("middle{i}"))); }
for i in 0..12 {
msgs.push(Message::user(format!("tail{i}"))); }
let (out, _usage) = compressor.compress(msgs).await.unwrap();
assert_eq!(
calls.load(std::sync::atomic::Ordering::SeqCst),
1,
"exactly one summarize call"
);
assert_eq!(out.len(), 1 + 1 + 1 + 12);
assert_eq!(out[0].role, Role::System);
let texts: Vec<&str> = out
.iter()
.filter_map(|m| {
m.content.iter().find_map(|p| match p {
ContentPart::Text(t) => Some(t.as_str()),
_ => None,
})
})
.collect();
assert!(texts.contains(&"SYSTEM"));
assert!(texts.iter().any(|t| t.contains("FIRST_TASK")));
assert!(texts.iter().any(|t| t.contains("CONDENSED")));
assert!(texts.iter().any(|t| t.contains("tail11")));
assert!(
!texts.iter().any(|t| t.contains("middle9")),
"middle turns must be replaced by the summary"
);
}
#[tokio::test]
async fn system_prompt_contains_identity_and_optional_sections() {
let tmp = std::env::temp_dir().join(format!("garudust-prompt-test-{}", std::process::id()));
let config = AgentConfig {
home_dir: tmp,
..AgentConfig::default()
};
let bare = build_system_prompt(&config, None, None, "cli").await;
assert!(
bare.contains("You are Garudust"),
"identity must always be present"
);
assert!(!bare.contains("user prefers tabs"));
assert!(!bare.contains("Alice, engineer"));
let mem = MemoryContent {
entries: vec![MemoryEntry::new(
MemoryCategory::Fact,
"user prefers tabs".into(),
)],
};
let full = build_system_prompt(&config, Some(&mem), Some("Alice, engineer"), "cli").await;
assert!(full.contains("# Memory"));
assert!(full.contains("user prefers tabs"));
assert!(full.contains("# User Profile"));
assert!(full.contains("Alice, engineer"));
assert!(
full.contains("\n\n---\n\n"),
"sections joined by hr divider"
);
}
#[tokio::test]
async fn run_loop_executes_tool_then_finishes() {
let calls = Arc::new(Mutex::new(Vec::new()));
let turn1 = TransportResponse {
content: vec![],
tool_calls: vec![ToolCall {
id: "call_1".into(),
name: "echo".into(),
arguments: serde_json::json!({ "text": "hi" }),
}],
usage: TokenUsage::default(),
stop_reason: StopReason::ToolUse,
};
let turn2 = TransportResponse {
content: vec![ContentPart::Text("done".into())],
tool_calls: vec![],
usage: TokenUsage::default(),
stop_reason: StopReason::EndTurn,
};
let (transport, _) = ScriptedTransport::new(vec![turn1, turn2]);
let mut registry = ToolRegistry::new();
registry.register(RecordingTool {
calls: calls.clone(),
});
let agent = Arc::new(Agent::new(
transport,
Arc::new(registry),
Arc::new(NopMemory),
Arc::new(AgentConfig::default()),
));
let result = agent
.run(
"use the echo tool",
Arc::new(AutoApprove),
"test",
None,
None,
)
.await
.unwrap();
assert_eq!(result.iterations, 2, "one tool turn + one completion turn");
assert!(result.output.starts_with("done"), "got: {}", result.output);
let recorded = calls.lock().unwrap();
assert_eq!(
recorded.len(),
1,
"echo tool must be dispatched exactly once"
);
assert_eq!(recorded[0]["text"], "hi");
}