use crate::loop_::AgentEvent;
use crate::stream::{AssistantMessageEvent, StreamErrorKind, StreamFn, StreamOptions};
use crate::tool::permissive_object_schema;
use crate::tool::{AgentTool, AgentToolResult, ToolFuture};
use crate::types::{AgentContext, ModelSpec};
use crate::types::{
AgentMessage, AssistantMessage, ContentBlock, Cost, LlmMessage, StopReason, ToolResultMessage,
Usage, UserMessage,
};
use futures::Stream;
use serde_json::Value;
use std::pin::Pin;
use std::process::Command;
#[cfg(feature = "plugins")]
use std::sync::atomic::AtomicUsize;
use std::sync::atomic::{AtomicBool, AtomicU32, Ordering};
use std::sync::{Arc, Mutex, OnceLock};
use std::time::Duration;
use tokio_util::sync::CancellationToken;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum TestOs {
MacOs,
Linux,
Windows,
Other,
}
impl TestOs {
#[must_use]
pub fn current() -> Self {
match std::env::consts::OS {
"macos" => Self::MacOs,
"linux" => Self::Linux,
"windows" => Self::Windows,
_ => Self::Other,
}
}
#[must_use]
pub const fn as_str(self) -> &'static str {
match self {
Self::MacOs => "macOS",
Self::Linux => "Linux",
Self::Windows => "Windows",
Self::Other => "other",
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum TestGpu {
None,
Any,
Nvidia,
AppleMetal,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct TestRuntimeRequirements {
pub os: Option<TestOs>,
pub gpu: TestGpu,
}
impl TestRuntimeRequirements {
#[must_use]
pub const fn new() -> Self {
Self {
os: None,
gpu: TestGpu::None,
}
}
#[must_use]
pub const fn with_os(mut self, os: TestOs) -> Self {
self.os = Some(os);
self
}
#[must_use]
pub const fn with_gpu(mut self, gpu: TestGpu) -> Self {
self.gpu = gpu;
self
}
}
impl Default for TestRuntimeRequirements {
fn default() -> Self {
Self::new()
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct TestRuntime {
pub os: TestOs,
pub arch: &'static str,
pub has_any_gpu: bool,
pub has_nvidia_gpu: bool,
pub has_apple_metal_gpu: bool,
}
#[must_use]
pub fn test_runtime() -> &'static TestRuntime {
static RUNTIME: OnceLock<TestRuntime> = OnceLock::new();
RUNTIME.get_or_init(detect_test_runtime)
}
#[must_use]
pub fn test_runtime_skip_reason(requirements: TestRuntimeRequirements) -> Option<String> {
evaluate_test_runtime(test_runtime(), requirements).err()
}
#[must_use]
pub fn should_run_test(requirements: TestRuntimeRequirements) -> bool {
if let Some(reason) = test_runtime_skip_reason(requirements) {
eprintln!("skipping: {reason}");
return false;
}
true
}
fn evaluate_test_runtime(
runtime: &TestRuntime,
requirements: TestRuntimeRequirements,
) -> Result<(), String> {
if let Some(expected_os) = requirements.os
&& runtime.os != expected_os
{
return Err(format!(
"requires {}, detected {}",
expected_os.as_str(),
runtime.os.as_str()
));
}
match requirements.gpu {
TestGpu::None => Ok(()),
TestGpu::Any if runtime.has_any_gpu => Ok(()),
TestGpu::Any => Err("requires a detected GPU on the host".to_string()),
TestGpu::Nvidia if runtime.has_nvidia_gpu => Ok(()),
TestGpu::Nvidia => Err("requires an NVIDIA GPU on the host".to_string()),
TestGpu::AppleMetal if runtime.has_apple_metal_gpu => Ok(()),
TestGpu::AppleMetal => Err("requires an Apple Metal-capable GPU on the host".to_string()),
}
}
fn detect_test_runtime() -> TestRuntime {
let os = TestOs::current();
let arch = std::env::consts::ARCH;
let has_nvidia_gpu = detect_nvidia_gpu();
let has_apple_metal_gpu = detect_apple_metal_gpu();
let has_any_gpu = has_nvidia_gpu || has_apple_metal_gpu || detect_generic_gpu(os);
TestRuntime {
os,
arch,
has_any_gpu,
has_nvidia_gpu,
has_apple_metal_gpu,
}
}
fn detect_nvidia_gpu() -> bool {
command_stdout("nvidia-smi", &["-L"])
.is_some_and(|stdout| stdout.lines().any(|line| !line.trim().is_empty()))
}
fn detect_apple_metal_gpu() -> bool {
if TestOs::current() != TestOs::MacOs {
return false;
}
command_stdout("system_profiler", &["SPDisplaysDataType"]).is_some_and(|stdout| {
stdout.contains("Metal Support:")
|| stdout.contains("Metal Family:")
|| stdout.contains("Chipset Model: Apple")
})
}
fn detect_generic_gpu(os: TestOs) -> bool {
match os {
TestOs::MacOs => command_stdout("system_profiler", &["SPDisplaysDataType"])
.is_some_and(|stdout| stdout.contains("Chipset Model:")),
TestOs::Linux => command_stdout("lspci", &[]).is_some_and(|stdout| {
let lower = stdout.to_ascii_lowercase();
lower.contains("vga compatible controller")
|| lower.contains("3d controller")
|| lower.contains("display controller")
}),
TestOs::Windows => windows_video_controller_present(),
TestOs::Other => false,
}
}
fn windows_video_controller_present() -> bool {
const POWERSHELL_COMMAND: &str =
"Get-CimInstance Win32_VideoController | Select-Object -ExpandProperty Name";
command_stdout(
"powershell",
&["-NoProfile", "-Command", POWERSHELL_COMMAND],
)
.or_else(|| command_stdout("pwsh", &["-NoProfile", "-Command", POWERSHELL_COMMAND]))
.is_some_and(|stdout| {
stdout
.lines()
.map(str::trim)
.any(|line| !line.is_empty() && line != "Microsoft Basic Display Adapter")
})
}
fn command_stdout(command: &str, args: &[&str]) -> Option<String> {
let output = Command::new(command).args(args).output().ok()?;
if !output.status.success() {
return None;
}
Some(String::from_utf8_lossy(&output.stdout).into_owned())
}
pub struct MockStreamFn(ScriptedStreamFn);
impl MockStreamFn {
#[must_use]
pub const fn new(responses: Vec<Vec<AssistantMessageEvent>>) -> Self {
Self(ScriptedStreamFn::with_error_fallback(responses))
}
}
impl StreamFn for MockStreamFn {
fn stream<'a>(
&'a self,
model: &'a ModelSpec,
context: &'a AgentContext,
options: &'a StreamOptions,
cancellation_token: CancellationToken,
) -> Pin<Box<dyn Stream<Item = AssistantMessageEvent> + Send + 'a>> {
self.0.stream(model, context, options, cancellation_token)
}
}
pub struct SimpleMockStreamFn {
tokens: Arc<Vec<String>>,
}
impl SimpleMockStreamFn {
#[must_use]
pub fn new(tokens: Vec<String>) -> Self {
Self {
tokens: Arc::new(tokens),
}
}
#[must_use]
pub fn from_text(text: &str) -> Self {
Self::new(vec![text.to_string()])
}
}
impl StreamFn for SimpleMockStreamFn {
fn stream<'a>(
&'a self,
_model: &'a ModelSpec,
_context: &'a AgentContext,
_options: &'a StreamOptions,
_cancellation_token: CancellationToken,
) -> Pin<Box<dyn Stream<Item = AssistantMessageEvent> + Send + 'a>> {
let events = text_only_events_multi((*self.tokens).clone());
Box::pin(futures::stream::iter(events))
}
}
pub struct ScriptedStreamFn {
responses: Mutex<Vec<Vec<AssistantMessageEvent>>>,
use_error_fallback: bool,
}
impl ScriptedStreamFn {
#[must_use]
pub const fn new(responses: Vec<Vec<AssistantMessageEvent>>) -> Self {
Self {
responses: Mutex::new(responses),
use_error_fallback: false,
}
}
#[must_use]
pub const fn with_error_fallback(responses: Vec<Vec<AssistantMessageEvent>>) -> Self {
Self {
responses: Mutex::new(responses),
use_error_fallback: true,
}
}
}
impl StreamFn for ScriptedStreamFn {
fn stream<'a>(
&'a self,
_model: &'a ModelSpec,
_context: &'a AgentContext,
_options: &'a StreamOptions,
_cancellation_token: CancellationToken,
) -> Pin<Box<dyn Stream<Item = AssistantMessageEvent> + Send + 'a>> {
let fallback = if self.use_error_fallback {
default_exhausted_fallback()
} else {
text_only_events_multi(vec!["default response".to_string()])
};
let events = next_response(&self.responses, fallback);
Box::pin(futures::stream::iter(events))
}
}
#[allow(dead_code)]
pub struct MockFlagStreamFn {
pub called: AtomicBool,
pub responses: Mutex<Vec<Vec<AssistantMessageEvent>>>,
}
impl StreamFn for MockFlagStreamFn {
fn stream<'a>(
&'a self,
_model: &'a ModelSpec,
_context: &'a AgentContext,
_options: &'a StreamOptions,
_cancellation_token: CancellationToken,
) -> Pin<Box<dyn Stream<Item = AssistantMessageEvent> + Send + 'a>> {
self.called.store(true, Ordering::SeqCst);
let events = next_response(&self.responses, text_events("fallback"));
Box::pin(futures::stream::iter(events))
}
}
#[allow(dead_code)]
pub struct MockContextCapturingStreamFn {
pub responses: Mutex<Vec<Vec<AssistantMessageEvent>>>,
pub captured_message_counts: Mutex<Vec<usize>>,
}
#[allow(dead_code)]
impl MockContextCapturingStreamFn {
pub const fn new(responses: Vec<Vec<AssistantMessageEvent>>) -> Self {
Self {
responses: Mutex::new(responses),
captured_message_counts: Mutex::new(Vec::new()),
}
}
}
impl StreamFn for MockContextCapturingStreamFn {
fn stream<'a>(
&'a self,
_model: &'a ModelSpec,
context: &'a AgentContext,
_options: &'a StreamOptions,
_cancellation_token: CancellationToken,
) -> Pin<Box<dyn futures::Stream<Item = AssistantMessageEvent> + Send + 'a>> {
self.captured_message_counts
.lock()
.unwrap()
.push(context.messages.len());
let events = next_response(&self.responses, default_exhausted_fallback());
Box::pin(futures::stream::iter(events))
}
}
#[allow(dead_code)]
pub struct MockApiKeyCapturingStreamFn {
pub responses: Mutex<Vec<Vec<AssistantMessageEvent>>>,
pub captured_api_keys: Mutex<Vec<Option<String>>>,
}
#[allow(dead_code)]
impl MockApiKeyCapturingStreamFn {
pub const fn new(responses: Vec<Vec<AssistantMessageEvent>>) -> Self {
Self {
responses: Mutex::new(responses),
captured_api_keys: Mutex::new(Vec::new()),
}
}
}
impl StreamFn for MockApiKeyCapturingStreamFn {
fn stream<'a>(
&'a self,
_model: &'a ModelSpec,
_context: &'a AgentContext,
options: &'a StreamOptions,
_cancellation_token: CancellationToken,
) -> Pin<Box<dyn futures::Stream<Item = AssistantMessageEvent> + Send + 'a>> {
self.captured_api_keys
.lock()
.unwrap()
.push(options.api_key.clone());
let events = next_response(&self.responses, default_exhausted_fallback());
Box::pin(futures::stream::iter(events))
}
}
pub struct MockTool {
tool_name: String,
schema: Value,
result: Mutex<Option<AgentToolResult>>,
delay: Option<Duration>,
executed: Arc<AtomicBool>,
execute_count: Arc<AtomicU32>,
approval_required: bool,
}
impl MockTool {
#[must_use]
pub fn new(name: &str) -> Self {
Self {
tool_name: name.to_string(),
schema: permissive_object_schema(),
result: Mutex::new(Some(AgentToolResult::text("mock result"))),
delay: None,
executed: Arc::new(AtomicBool::new(false)),
execute_count: Arc::new(AtomicU32::new(0)),
approval_required: false,
}
}
#[must_use]
pub fn with_schema(mut self, schema: Value) -> Self {
self.schema = schema;
self
}
#[must_use]
pub fn with_result(self, result: AgentToolResult) -> Self {
*self.result.lock().unwrap() = Some(result);
self
}
#[must_use]
pub const fn with_delay(mut self, delay: Duration) -> Self {
self.delay = Some(delay);
self
}
#[must_use]
pub const fn with_requires_approval(mut self, required: bool) -> Self {
self.approval_required = required;
self
}
pub fn was_executed(&self) -> bool {
self.executed.load(Ordering::SeqCst)
}
pub fn execution_count(&self) -> u32 {
self.execute_count.load(Ordering::SeqCst)
}
}
impl AgentTool for MockTool {
fn name(&self) -> &str {
&self.tool_name
}
fn label(&self) -> &str {
&self.tool_name
}
fn description(&self) -> &'static str {
"A mock tool for testing"
}
fn parameters_schema(&self) -> &Value {
&self.schema
}
fn requires_approval(&self) -> bool {
self.approval_required
}
fn execute(
&self,
_tool_call_id: &str,
_params: Value,
cancellation_token: CancellationToken,
_on_update: Option<Box<dyn Fn(AgentToolResult) + Send + Sync>>,
_state: std::sync::Arc<std::sync::RwLock<crate::SessionState>>,
_credential: Option<crate::credential::ResolvedCredential>,
) -> ToolFuture<'_> {
self.executed.store(true, Ordering::SeqCst);
self.execute_count.fetch_add(1, Ordering::SeqCst);
let result = self
.result
.lock()
.unwrap()
.clone()
.unwrap_or_else(|| AgentToolResult::text("mock result"));
let delay = self.delay;
Box::pin(async move {
if let Some(d) = delay {
tokio::select! {
() = tokio::time::sleep(d) => {}
() = cancellation_token.cancelled() => {
return AgentToolResult::text("cancelled");
}
}
}
result
})
}
}
#[must_use]
pub fn text_only_events(text: &str) -> Vec<AssistantMessageEvent> {
text_only_events_multi(vec![text.to_string()])
}
#[must_use]
pub fn text_events(text: &str) -> Vec<AssistantMessageEvent> {
text_only_events(text)
}
#[must_use]
pub fn text_only_events_multi(tokens: Vec<String>) -> Vec<AssistantMessageEvent> {
let mut events = Vec::with_capacity(tokens.len() + 4);
events.push(AssistantMessageEvent::Start);
events.push(AssistantMessageEvent::TextStart { content_index: 0 });
for token in tokens {
events.push(AssistantMessageEvent::TextDelta {
content_index: 0,
delta: token,
});
}
events.push(AssistantMessageEvent::TextEnd { content_index: 0 });
events.push(AssistantMessageEvent::Done {
stop_reason: StopReason::Stop,
usage: Usage::default(),
cost: Cost::default(),
});
events
}
#[allow(dead_code)]
#[must_use]
pub fn tool_call_events(id: &str, name: &str, args: &str) -> Vec<AssistantMessageEvent> {
vec![
AssistantMessageEvent::Start,
AssistantMessageEvent::ToolCallStart {
content_index: 0,
id: id.to_string(),
name: name.to_string(),
},
AssistantMessageEvent::ToolCallDelta {
content_index: 0,
delta: args.to_string(),
},
AssistantMessageEvent::ToolCallEnd { content_index: 0 },
AssistantMessageEvent::Done {
stop_reason: StopReason::ToolUse,
usage: Usage::default(),
cost: Cost::default(),
},
]
}
#[allow(dead_code)]
#[must_use]
pub fn tool_call_events_multi(calls: &[(&str, &str, &str)]) -> Vec<AssistantMessageEvent> {
let mut events = vec![AssistantMessageEvent::Start];
for (i, (id, name, args)) in calls.iter().enumerate() {
events.push(AssistantMessageEvent::ToolCallStart {
content_index: i,
id: id.to_string(),
name: name.to_string(),
});
events.push(AssistantMessageEvent::ToolCallDelta {
content_index: i,
delta: args.to_string(),
});
events.push(AssistantMessageEvent::ToolCallEnd { content_index: i });
}
events.push(AssistantMessageEvent::Done {
stop_reason: StopReason::ToolUse,
usage: Usage::default(),
cost: Cost::default(),
});
events
}
#[allow(dead_code)]
#[must_use]
pub fn error_events(
message: &str,
error_kind: Option<StreamErrorKind>,
) -> Vec<AssistantMessageEvent> {
vec![AssistantMessageEvent::Error {
stop_reason: StopReason::Error,
error_message: message.to_string(),
usage: None,
error_kind,
}]
}
#[allow(dead_code)]
#[must_use]
pub fn abort_events(message: &str) -> Vec<AssistantMessageEvent> {
vec![AssistantMessageEvent::Error {
stop_reason: StopReason::Aborted,
error_message: message.to_string(),
usage: None,
error_kind: None,
}]
}
pub fn user_msg(text: &str) -> AgentMessage {
AgentMessage::Llm(LlmMessage::User(UserMessage {
content: vec![ContentBlock::Text {
text: text.to_string(),
}],
timestamp: 0,
cache_hint: None,
}))
}
pub fn assistant_msg(text: &str) -> AgentMessage {
AgentMessage::Llm(LlmMessage::Assistant(AssistantMessage {
content: vec![ContentBlock::Text {
text: text.to_string(),
}],
provider: String::new(),
model_id: String::new(),
usage: Usage::default(),
cost: Cost::default(),
stop_reason: StopReason::Stop,
error_message: None,
error_kind: None,
timestamp: 0,
cache_hint: None,
}))
}
pub fn tool_result_msg(id: &str, text: &str) -> AgentMessage {
AgentMessage::Llm(LlmMessage::ToolResult(ToolResultMessage {
tool_call_id: id.to_string(),
content: vec![ContentBlock::Text {
text: text.to_string(),
}],
is_error: false,
timestamp: 0,
details: serde_json::Value::Null,
cache_hint: None,
}))
}
#[allow(dead_code)]
#[must_use]
pub fn default_model() -> ModelSpec {
ModelSpec::new("test", "test-model")
}
#[allow(dead_code)]
pub fn default_convert(msg: &AgentMessage) -> Option<LlmMessage> {
match msg {
AgentMessage::Llm(llm) => Some(llm.clone()),
AgentMessage::Custom(_) => None,
}
}
#[allow(dead_code)]
pub fn next_response(
responses: &Mutex<Vec<Vec<AssistantMessageEvent>>>,
fallback: Vec<AssistantMessageEvent>,
) -> Vec<AssistantMessageEvent> {
let mut guard = responses.lock().unwrap();
if guard.is_empty() {
fallback
} else {
guard.remove(0)
}
}
#[allow(dead_code)]
#[must_use]
pub fn default_exhausted_fallback() -> Vec<AssistantMessageEvent> {
vec![AssistantMessageEvent::Error {
stop_reason: StopReason::Error,
error_message: "no more scripted responses".to_string(),
usage: None,
error_kind: None,
}]
}
#[allow(dead_code)]
#[derive(Clone)]
pub struct EventCollector {
events: Arc<Mutex<Vec<String>>>,
}
#[allow(dead_code)]
impl Default for EventCollector {
fn default() -> Self {
Self::new()
}
}
#[allow(dead_code)]
impl EventCollector {
pub fn new() -> Self {
Self {
events: Arc::new(Mutex::new(Vec::new())),
}
}
pub fn subscriber(&self) -> impl Fn(&AgentEvent) + Send + Sync + 'static {
let events = Arc::clone(&self.events);
move |event: &AgentEvent| {
let name = event_variant_name(event);
events.lock().unwrap().push(name);
}
}
pub fn events(&self) -> Vec<String> {
self.events.lock().unwrap().clone()
}
pub fn count(&self) -> usize {
self.events.lock().unwrap().len()
}
pub fn position(&self, name: &str) -> Option<usize> {
self.events().iter().position(|n| n == name)
}
}
#[cfg(feature = "plugins")]
use crate::plugin::Plugin;
#[cfg(feature = "plugins")]
use crate::policy::{
PolicyContext, PolicyVerdict, PostLoopPolicy, PostTurnPolicy, PreDispatchPolicy, PreTurnPolicy,
TurnPolicyContext,
};
#[cfg(feature = "plugins")]
pub struct MockPlugin {
plugin_name: String,
priority: i32,
tool_names: Vec<String>,
tools: Vec<Arc<dyn AgentTool>>,
pre_turn_policies: Vec<Arc<dyn PreTurnPolicy>>,
pre_dispatch_policies: Vec<Arc<dyn PreDispatchPolicy>>,
post_turn_policies: Vec<Arc<dyn PostTurnPolicy>>,
post_loop_policies: Vec<Arc<dyn PostLoopPolicy>>,
event_counter: Option<Arc<AtomicUsize>>,
post_turn_tracker: Option<Arc<AtomicBool>>,
pre_turn_order: Option<Arc<AtomicUsize>>,
stopping_pre_turn: bool,
init_called: Arc<AtomicBool>,
init_count: Arc<AtomicUsize>,
init_order: Option<Arc<AtomicUsize>>,
}
#[cfg(feature = "plugins")]
impl MockPlugin {
#[must_use]
pub fn new(name: impl Into<String>) -> Self {
Self {
plugin_name: name.into(),
priority: 0,
tool_names: vec![],
tools: vec![],
pre_turn_policies: vec![],
pre_dispatch_policies: vec![],
post_turn_policies: vec![],
post_loop_policies: vec![],
event_counter: None,
post_turn_tracker: None,
pre_turn_order: None,
stopping_pre_turn: false,
init_called: Arc::new(AtomicBool::new(false)),
init_count: Arc::new(AtomicUsize::new(0)),
init_order: None,
}
}
#[must_use]
pub const fn with_priority(mut self, priority: i32) -> Self {
self.priority = priority;
self
}
#[must_use]
pub fn with_tools(mut self, names: &[&str]) -> Self {
self.tool_names = names.iter().copied().map(ToString::to_string).collect();
self
}
#[must_use]
pub fn with_tool(mut self, tool: Arc<dyn AgentTool>) -> Self {
self.tools.push(tool);
self
}
#[must_use]
pub fn with_event_counter(mut self, counter: Arc<AtomicUsize>) -> Self {
self.event_counter = Some(counter);
self
}
#[must_use]
pub fn with_post_turn_tracker(mut self, fired: Arc<AtomicBool>) -> Self {
self.post_turn_tracker = Some(fired);
self
}
#[must_use]
pub fn with_pre_turn_order(mut self, order: Arc<AtomicUsize>) -> Self {
self.pre_turn_order = Some(order);
self
}
#[must_use]
pub fn with_pre_turn_policy(mut self, policy: Arc<dyn PreTurnPolicy>) -> Self {
self.pre_turn_policies.push(policy);
self
}
#[must_use]
pub fn with_pre_dispatch_policy<P>(mut self, policy: P) -> Self
where
P: PreDispatchPolicy + 'static,
{
self.pre_dispatch_policies.push(Arc::new(policy));
self
}
#[must_use]
pub fn with_post_turn_policy(mut self, policy: Arc<dyn PostTurnPolicy>) -> Self {
self.post_turn_policies.push(policy);
self
}
#[must_use]
pub fn with_post_loop_policy(mut self, policy: Arc<dyn PostLoopPolicy>) -> Self {
self.post_loop_policies.push(policy);
self
}
#[must_use]
pub const fn with_stopping_pre_turn(mut self) -> Self {
self.stopping_pre_turn = true;
self
}
pub fn init_called(&self) -> Arc<AtomicBool> {
Arc::clone(&self.init_called)
}
pub fn init_count(&self) -> usize {
self.init_count.load(Ordering::SeqCst)
}
#[must_use]
pub fn with_init_order(mut self, order: Arc<AtomicUsize>) -> Self {
self.init_order = Some(order);
self
}
}
#[cfg(feature = "plugins")]
impl Plugin for MockPlugin {
fn name(&self) -> &str {
&self.plugin_name
}
fn priority(&self) -> i32 {
self.priority
}
fn on_init(&self, _agent: &crate::Agent) {
self.init_called.store(true, Ordering::SeqCst);
self.init_count.fetch_add(1, Ordering::SeqCst);
if let Some(order) = &self.init_order {
order.fetch_add(1, Ordering::SeqCst);
}
}
fn pre_turn_policies(&self) -> Vec<Arc<dyn PreTurnPolicy>> {
let mut policies = self.pre_turn_policies.clone();
if self.stopping_pre_turn {
policies.push(Arc::new(StoppingPreTurnPolicy {
label: format!("{}-stopping", self.plugin_name),
}));
}
if let Some(order) = &self.pre_turn_order {
policies.push(Arc::new(OrderRecordingPreTurnPolicy {
label: format!("{}-pre-turn", self.plugin_name),
order: Arc::clone(order),
}));
}
policies
}
fn pre_dispatch_policies(&self) -> Vec<Arc<dyn PreDispatchPolicy>> {
self.pre_dispatch_policies.clone()
}
fn post_turn_policies(&self) -> Vec<Arc<dyn PostTurnPolicy>> {
let mut policies = self.post_turn_policies.clone();
if let Some(fired) = &self.post_turn_tracker {
policies.push(Arc::new(RecordingPostTurnPolicy {
fired: Arc::clone(fired),
}));
}
policies
}
fn post_loop_policies(&self) -> Vec<Arc<dyn PostLoopPolicy>> {
self.post_loop_policies.clone()
}
fn on_event(&self, _event: &crate::AgentEvent) {
if let Some(counter) = &self.event_counter {
counter.fetch_add(1, Ordering::SeqCst);
}
}
fn tools(&self) -> Vec<Arc<dyn crate::tool::AgentTool>> {
let mut tools = self.tools.clone();
tools.extend(
self.tool_names
.iter()
.map(|n| Arc::new(MockTool::new(n)) as Arc<dyn crate::tool::AgentTool>),
);
tools
}
}
#[cfg(feature = "plugins")]
pub struct RecordingPostTurnPolicy {
pub fired: Arc<AtomicBool>,
}
#[cfg(feature = "plugins")]
impl PostTurnPolicy for RecordingPostTurnPolicy {
fn name(&self) -> &'static str {
"recording-post-turn"
}
fn evaluate(&self, _ctx: &PolicyContext<'_>, _turn: &TurnPolicyContext<'_>) -> PolicyVerdict {
self.fired.store(true, Ordering::SeqCst);
PolicyVerdict::Continue
}
}
#[cfg(feature = "plugins")]
pub static MOCK_PLUGIN_GLOBAL_ORDER: AtomicUsize = AtomicUsize::new(0);
#[cfg(feature = "plugins")]
pub struct OrderRecordingPreTurnPolicy {
pub label: String,
pub order: Arc<AtomicUsize>,
}
#[cfg(feature = "plugins")]
impl PreTurnPolicy for OrderRecordingPreTurnPolicy {
fn name(&self) -> &str {
&self.label
}
fn evaluate(&self, _ctx: &PolicyContext<'_>) -> PolicyVerdict {
let seq = MOCK_PLUGIN_GLOBAL_ORDER.fetch_add(1, Ordering::SeqCst);
self.order.store(seq, Ordering::SeqCst);
PolicyVerdict::Continue
}
}
#[cfg(feature = "plugins")]
pub struct StoppingPreTurnPolicy {
pub label: String,
}
#[cfg(feature = "plugins")]
impl PreTurnPolicy for StoppingPreTurnPolicy {
fn name(&self) -> &str {
&self.label
}
fn evaluate(&self, _ctx: &PolicyContext<'_>) -> PolicyVerdict {
PolicyVerdict::Stop("stopped by policy".into())
}
}
#[allow(dead_code)]
pub fn event_variant_name(event: &AgentEvent) -> String {
match event {
AgentEvent::AgentStart => "AgentStart".into(),
AgentEvent::AgentEnd { .. } => "AgentEnd".into(),
AgentEvent::TurnStart => "TurnStart".into(),
AgentEvent::TurnEnd { .. } => "TurnEnd".into(),
AgentEvent::MessageStart => "MessageStart".into(),
AgentEvent::MessageUpdate { .. } => "MessageUpdate".into(),
AgentEvent::MessageEnd { .. } => "MessageEnd".into(),
AgentEvent::ToolExecutionStart { .. } => "ToolExecutionStart".into(),
AgentEvent::ToolExecutionUpdate { .. } => "ToolExecutionUpdate".into(),
AgentEvent::ToolExecutionEnd { .. } => "ToolExecutionEnd".into(),
AgentEvent::ToolApprovalRequested { .. } => "ToolApprovalRequested".into(),
AgentEvent::ToolApprovalResolved { .. } => "ToolApprovalResolved".into(),
AgentEvent::BeforeLlmCall { .. } => "BeforeLlmCall".into(),
AgentEvent::ContextCompacted { .. } => "ContextCompacted".into(),
AgentEvent::Custom(emission) => format!("Custom({})", emission.name),
AgentEvent::ModelFallback { .. } => "ModelFallback".into(),
AgentEvent::ModelCycled { .. } => "ModelCycled".into(),
AgentEvent::StateChanged { .. } => "StateChanged".into(),
AgentEvent::CacheAction { .. } => "CacheAction".into(),
AgentEvent::McpServerConnected { .. } => "McpServerConnected".into(),
AgentEvent::McpServerDisconnected { .. } => "McpServerDisconnected".into(),
AgentEvent::McpToolsDiscovered { .. } => "McpToolsDiscovered".into(),
AgentEvent::McpToolCallStarted { .. } => "McpToolCallStarted".into(),
AgentEvent::McpToolCallCompleted { .. } => "McpToolCallCompleted".into(),
#[cfg(feature = "artifact-store")]
AgentEvent::ArtifactSaved { .. } => "ArtifactSaved".into(),
AgentEvent::TransferInitiated { .. } => "TransferInitiated".into(),
}
}
#[cfg(test)]
mod runtime_tests {
use super::{TestGpu, TestOs, TestRuntime, TestRuntimeRequirements, evaluate_test_runtime};
#[test]
fn runtime_rejects_os_mismatch() {
let runtime = TestRuntime {
os: TestOs::Linux,
arch: "x86_64",
has_any_gpu: false,
has_nvidia_gpu: false,
has_apple_metal_gpu: false,
};
let reason = evaluate_test_runtime(
&runtime,
TestRuntimeRequirements::new().with_os(TestOs::MacOs),
)
.expect_err("linux host should not satisfy macOS-only requirement");
assert!(reason.contains("requires macOS"));
}
#[test]
fn runtime_rejects_missing_gpu() {
let runtime = TestRuntime {
os: TestOs::Linux,
arch: "x86_64",
has_any_gpu: false,
has_nvidia_gpu: false,
has_apple_metal_gpu: false,
};
let reason = evaluate_test_runtime(
&runtime,
TestRuntimeRequirements::new().with_gpu(TestGpu::Any),
)
.expect_err("gpu-less host should not satisfy gpu requirement");
assert!(reason.contains("requires a detected GPU"));
}
#[test]
fn runtime_accepts_nvidia_gpu_requirement() {
let runtime = TestRuntime {
os: TestOs::Linux,
arch: "x86_64",
has_any_gpu: true,
has_nvidia_gpu: true,
has_apple_metal_gpu: false,
};
let result = evaluate_test_runtime(
&runtime,
TestRuntimeRequirements::new().with_gpu(TestGpu::Nvidia),
);
assert!(result.is_ok());
}
}