use appam::agent::AgentBuilder;
use appam::llm::{ChatMessage, Role, ToolSpec};
use appam::tools::Tool;
use serde_json::json;
use std::sync::{Arc, Mutex};
#[derive(Clone)]
#[allow(dead_code)]
struct MockCompletionTool {
name: String,
calls: Arc<Mutex<Vec<serde_json::Value>>>,
}
#[allow(dead_code)]
impl MockCompletionTool {
fn new(name: impl Into<String>) -> Self {
Self {
name: name.into(),
calls: Arc::new(Mutex::new(Vec::new())),
}
}
fn call_count(&self) -> usize {
self.calls.lock().unwrap().len()
}
}
impl Tool for MockCompletionTool {
fn name(&self) -> &str {
&self.name
}
fn spec(&self) -> anyhow::Result<ToolSpec> {
Ok(serde_json::from_value(json!({
"type": "function",
"name": self.name,
"description": "Mock completion tool for testing",
"parameters": {
"type": "object",
"properties": {
"message": {
"type": "string",
"description": "Message to return"
}
},
"required": ["message"]
}
}))?)
}
fn execute(&self, args: serde_json::Value) -> anyhow::Result<serde_json::Value> {
self.calls.lock().unwrap().push(args.clone());
Ok(json!({
"success": true,
"message": args.get("message").and_then(|v| v.as_str()).unwrap_or("completed")
}))
}
}
#[test]
fn test_continuation_config_fields() {
let tool1 = Arc::new(MockCompletionTool::new("tool1"));
let tool2 = Arc::new(MockCompletionTool::new("tool2"));
let agent = AgentBuilder::new("test-agent")
.system_prompt("Test prompt")
.require_completion_tools(vec![tool1 as Arc<dyn Tool>, tool2 as Arc<dyn Tool>])
.continuation_message("Custom continuation message")
.max_continuations(3)
.build()
.unwrap();
assert_eq!(
agent.required_completion_tools(),
Some(&vec!["tool1".to_string(), "tool2".to_string()])
);
assert_eq!(agent.max_continuations(), 3);
assert_eq!(
agent.continuation_message(),
Some("Custom continuation message")
);
}
#[test]
fn test_no_continuation_when_disabled() {
let agent = AgentBuilder::new("test-agent")
.system_prompt("Test prompt")
.build()
.unwrap();
assert_eq!(agent.required_completion_tools(), None);
assert_eq!(agent.max_continuations(), 2); }
#[test]
fn test_default_max_continuations() {
let tool1 = Arc::new(MockCompletionTool::new("tool1"));
let agent = AgentBuilder::new("test-agent")
.system_prompt("Test prompt")
.require_completion_tools(vec![tool1 as Arc<dyn Tool>])
.build()
.unwrap();
assert_eq!(agent.max_continuations(), 2);
assert_eq!(agent.continuation_message(), None); }
fn has_continuation_message(messages: &[ChatMessage]) -> bool {
messages.iter().any(|msg| {
msg.role == Role::User
&& msg
.content
.as_ref()
.map(|c| c.contains("Continue analysis"))
.unwrap_or(false)
})
}
fn count_continuation_messages(messages: &[ChatMessage]) -> usize {
messages
.iter()
.filter(|msg| {
msg.role == Role::User
&& msg
.content
.as_ref()
.map(|c| c.contains("Continue analysis"))
.unwrap_or(false)
})
.count()
}
fn has_tool_call(messages: &[ChatMessage], tool_name: &str) -> bool {
messages.iter().any(|msg| {
if let Some(tool_calls) = &msg.tool_calls {
tool_calls.iter().any(|tc| tc.function.name == tool_name)
} else {
false
}
})
}
#[test]
fn test_continuation_message_detection() {
let messages = vec![
ChatMessage {
role: Role::User,
content: Some("Initial prompt".to_string()),
name: None,
tool_call_id: None,
tool_calls: None,
reasoning: None,
raw_content_blocks: None,
tool_metadata: None,
timestamp: None,
id: None,
provider_response_id: None,
status: None,
},
ChatMessage {
role: Role::User,
content: Some("Continue analysis. You should call `no_vulnerabilities_found` or `store_vulnerability_finding` to complete this analysis session.".to_string()),
name: None,
tool_call_id: None,
tool_calls: None,
reasoning: None,
raw_content_blocks: None,
tool_metadata: None,
timestamp: None,
id: None,
provider_response_id: None,
status: None,
},
];
assert!(has_continuation_message(&messages));
assert_eq!(count_continuation_messages(&messages), 1);
}
#[test]
fn test_tool_call_detection() {
use appam::llm::{ToolCall, ToolCallFunction};
let messages = vec![ChatMessage {
role: Role::Assistant,
content: None,
name: None,
tool_call_id: None,
tool_calls: Some(vec![ToolCall {
id: "call_1".to_string(),
type_field: "function".to_string(),
function: ToolCallFunction {
name: "no_vulnerabilities_found".to_string(),
arguments: "{}".to_string(),
},
}]),
reasoning: None,
raw_content_blocks: None,
tool_metadata: None,
timestamp: None,
id: None,
provider_response_id: None,
status: None,
}];
assert!(has_tool_call(&messages, "no_vulnerabilities_found"));
assert!(!has_tool_call(&messages, "store_vulnerability_finding"));
}
#[test]
fn test_multiple_continuation_messages() {
let messages = vec![
ChatMessage {
role: Role::User,
content: Some("Continue analysis. You should call `no_vulnerabilities_found` or `store_vulnerability_finding` to complete this analysis session.".to_string()),
name: None,
tool_call_id: None,
tool_calls: None,
reasoning: None,
raw_content_blocks: None,
tool_metadata: None,
timestamp: None,
id: None,
provider_response_id: None,
status: None,
},
ChatMessage {
role: Role::Assistant,
content: Some("I'm still analyzing...".to_string()),
name: None,
tool_call_id: None,
tool_calls: None,
reasoning: None,
raw_content_blocks: None,
tool_metadata: None,
timestamp: None,
id: None,
provider_response_id: None,
status: None,
},
ChatMessage {
role: Role::User,
content: Some("Continue analysis. You should call `no_vulnerabilities_found` or `store_vulnerability_finding` to complete this analysis session.".to_string()),
name: None,
tool_call_id: None,
tool_calls: None,
reasoning: None,
raw_content_blocks: None,
tool_metadata: None,
timestamp: None,
id: None,
provider_response_id: None,
status: None,
},
];
assert_eq!(count_continuation_messages(&messages), 2);
}