use adk_agent::LlmAgentBuilder;
use adk_core::{
Agent, CallbackContext, Content, FinishReason, InvocationContext, Llm, LlmRequest, LlmResponse,
LlmResponseStream, Part, Result, RunConfig, Session, State, Tool, ToolContext,
};
use async_trait::async_trait;
use futures::StreamExt;
use serde_json::{Value, json};
use std::collections::HashMap;
use std::sync::Arc;
struct MockModel {
response: LlmResponse,
}
impl MockModel {
fn new_function_call(name: &str, args: Value) -> Self {
let content = Content {
role: "model".to_string(),
parts: vec![Part::FunctionCall {
name: name.to_string(),
args,
id: Some(format!("call_{}", name)),
thought_signature: None,
}],
};
Self {
response: LlmResponse {
content: Some(content),
usage_metadata: None,
finish_reason: Some(FinishReason::Stop),
citation_metadata: None,
partial: false,
turn_complete: true,
interrupted: false,
error_code: None,
error_message: None,
provider_metadata: None,
},
}
}
}
#[async_trait]
impl Llm for MockModel {
fn name(&self) -> &str {
"mock-model"
}
async fn generate_content(&self, _req: LlmRequest, _stream: bool) -> Result<LlmResponseStream> {
let response = self.response.clone();
let s = async_stream::stream! {
yield Ok(response);
};
Ok(Box::pin(s))
}
}
struct MockTool;
#[async_trait]
impl Tool for MockTool {
fn name(&self) -> &str {
"test_tool"
}
fn description(&self) -> &str {
"Test tool"
}
fn parameters_schema(&self) -> Option<Value> {
None
}
fn response_schema(&self) -> Option<Value> {
None
}
async fn execute(&self, _ctx: Arc<dyn ToolContext>, _args: Value) -> Result<Value> {
Ok(json!({ "status": "ok" }))
}
}
struct MockSession;
impl Session for MockSession {
fn id(&self) -> &str {
"session-1"
}
fn app_name(&self) -> &str {
"test-app"
}
fn user_id(&self) -> &str {
"user-1"
}
fn state(&self) -> &dyn State {
&MockState
}
fn conversation_history(&self) -> Vec<Content> {
Vec::new()
}
}
struct MockState;
impl State for MockState {
fn get(&self, _key: &str) -> Option<Value> {
None
}
fn set(&mut self, _key: String, _value: Value) {}
fn all(&self) -> HashMap<String, Value> {
HashMap::new()
}
}
struct MockContext {
session: MockSession,
user_content: Content,
}
impl MockContext {
fn new() -> Self {
Self {
session: MockSession,
user_content: Content {
role: "user".to_string(),
parts: vec![Part::Text { text: "start".to_string() }],
},
}
}
}
#[async_trait]
impl adk_core::ReadonlyContext for MockContext {
fn invocation_id(&self) -> &str {
"inv-1"
}
fn agent_name(&self) -> &str {
"test-agent"
}
fn user_id(&self) -> &str {
"user-1"
}
fn app_name(&self) -> &str {
"test-app"
}
fn session_id(&self) -> &str {
"session-1"
}
fn branch(&self) -> &str {
"main"
}
fn user_content(&self) -> &Content {
&self.user_content
}
}
#[async_trait]
impl adk_core::CallbackContext for MockContext {
fn artifacts(&self) -> Option<Arc<dyn adk_core::Artifacts>> {
None
}
}
#[async_trait]
impl InvocationContext for MockContext {
fn agent(&self) -> Arc<dyn Agent> {
unimplemented!()
}
fn memory(&self) -> Option<Arc<dyn adk_core::Memory>> {
None
}
fn session(&self) -> &dyn Session {
&self.session
}
fn run_config(&self) -> &RunConfig {
static RUN_CONFIG: std::sync::OnceLock<RunConfig> = std::sync::OnceLock::new();
RUN_CONFIG.get_or_init(RunConfig::default)
}
fn end_invocation(&self) {}
fn ended(&self) -> bool {
false
}
}
#[tokio::test]
async fn test_callbacks_execution_order() {
let tool = Arc::new(MockTool);
let model = Arc::new(MockModel::new_function_call("test_tool", json!({})));
let agent = LlmAgentBuilder::new("test-agent").model(model).tool(tool).build().unwrap();
let ctx = Arc::new(MockContext::new());
let mut stream = agent.run(ctx).await.unwrap();
let mut event_count = 0;
while let Some(result) = stream.next().await {
match result {
Ok(_) => event_count += 1,
Err(e) => {
println!("Agent error: {:?}", e);
break;
}
}
}
assert!(event_count > 0);
}
#[tokio::test]
async fn test_callback_short_circuit() {
let model = Arc::new(MockModel::new_function_call("test_tool", json!({})));
let short_circuit_callback = |_ctx: Arc<dyn CallbackContext>| {
Box::pin(async move {
Ok(Some(Content {
role: "assistant".to_string(),
parts: vec![Part::Text { text: "Short-circuited!".to_string() }],
}))
})
as std::pin::Pin<Box<dyn std::future::Future<Output = Result<Option<Content>>> + Send>>
};
let agent = LlmAgentBuilder::new("test-agent")
.model(model)
.before_callback(Box::new(short_circuit_callback))
.build()
.unwrap();
let ctx = Arc::new(MockContext::new());
let mut stream = agent.run(ctx).await.unwrap();
let mut found_short_circuit = false;
while let Some(result) = stream.next().await {
let event = result.unwrap();
if let Some(content) = event.llm_response.content {
if let Some(Part::Text { text }) = content.parts.first() {
if text.contains("Short-circuited") {
found_short_circuit = true;
}
}
}
}
assert!(found_short_circuit, "Callback should have short-circuited execution");
}