use std::sync::Arc;
use super::*;
use crate::hooks::PhaseContext;
use crate::phase::ExecutionEnv;
use crate::plugins::{Plugin, PluginDescriptor, PluginRegistrar};
use crate::state::StateStore;
use awaken_contract::contract::content::ContentBlock;
use awaken_contract::contract::context_message::ContextMessage;
use awaken_contract::contract::event::AgentEvent;
use awaken_contract::contract::event_sink::VecEventSink;
use awaken_contract::contract::message::{Message, Role};
use awaken_contract::contract::suspension::{SuspendTicket, ToolCallOutcome, ToolCallStatus};
use awaken_contract::contract::tool::{
Tool, ToolCallContext, ToolDescriptor, ToolError, ToolOutput, ToolResult,
};
use awaken_contract::model::{
PendingScheduledActions, Phase, ScheduledAction, ScheduledActionEnvelope,
ScheduledActionQueueUpdate, ScheduledActionSpec,
};
use super::actions::{
LoopActionHandlersPlugin, apply_context_messages, apply_tool_filter_payloads,
merge_override_payloads, resolve_intercept_payloads, take_context_messages,
};
use crate::agent::state::{
AddContextMessage, InferenceOverrideState, InferenceOverrideStateAction,
InferenceOverrideStateValue, RunLifecycle, RunLifecycleUpdate, ToolCallStates, ToolFilterState,
ToolFilterStateAction, ToolFilterStateValue,
};
use crate::phase::{PhaseRuntime, ToolGateHook};
use crate::state::{StateCommand, StateKey, StateKeyOptions};
fn test_runtime() -> (PhaseRuntime, ExecutionEnv) {
let store = StateStore::new();
store
.install_plugin(LoopStatePlugin)
.expect("install LoopStatePlugin");
store
.install_plugin(LoopActionHandlersPlugin)
.expect("install LoopActionHandlersPlugin");
let mut patch = crate::state::MutationBatch::new();
patch.update::<RunLifecycle>(RunLifecycleUpdate::Start {
run_id: "test".into(),
updated_at: 0,
});
store.commit(patch).expect("init lifecycle");
let runtime = PhaseRuntime::new(store).expect("create runtime");
let env =
ExecutionEnv::from_plugins(&[Arc::new(LoopActionHandlersPlugin)], &Default::default())
.expect("build env");
(runtime, env)
}
fn test_runtime_with_plugins(plugins: Vec<Arc<dyn Plugin>>) -> (PhaseRuntime, ExecutionEnv) {
let store = StateStore::new();
store
.install_plugin(LoopStatePlugin)
.expect("install LoopStatePlugin");
store
.install_plugin(LoopActionHandlersPlugin)
.expect("install LoopActionHandlersPlugin");
let mut patch = crate::state::MutationBatch::new();
patch.update::<RunLifecycle>(RunLifecycleUpdate::Start {
run_id: "test".into(),
updated_at: 0,
});
store.commit(patch).expect("init lifecycle");
let runtime = PhaseRuntime::new(store).expect("create runtime");
let mut all_plugins: Vec<Arc<dyn Plugin>> = vec![Arc::new(LoopActionHandlersPlugin)];
all_plugins.extend(plugins);
let env = ExecutionEnv::from_plugins(&all_plugins, &Default::default()).expect("build env");
(runtime, env)
}
fn enqueue_context_message(store: &StateStore, id: u64, msg: ContextMessage) {
let payload = AddContextMessage::encode_payload(&msg).expect("encode payload");
let mut batch = crate::state::MutationBatch::new();
batch.update::<PendingScheduledActions>(ScheduledActionQueueUpdate::Push(
ScheduledActionEnvelope {
id,
action: ScheduledAction::new(AddContextMessage::PHASE, AddContextMessage::KEY, payload),
},
));
store.commit(batch).expect("commit enqueue");
}
fn text_of(msg: &Message) -> String {
msg.content
.iter()
.filter_map(|b| match b {
ContentBlock::Text { text } => Some(text.as_str()),
_ => None,
})
.collect::<Vec<_>>()
.join("")
}
#[test]
fn context_message_injected_at_system_target() {
let mut messages = vec![
Message::system("base system prompt"),
Message::user("hello"),
];
let ctx = vec![ContextMessage::system("reminder", "remember the rules")];
apply_context_messages(&mut messages, ctx, true);
assert_eq!(messages.len(), 3);
assert_eq!(messages[0].role, Role::System);
assert_eq!(text_of(&messages[0]), "base system prompt");
assert_eq!(messages[1].role, Role::System);
assert_eq!(text_of(&messages[1]), "remember the rules");
assert_eq!(messages[2].role, Role::User);
}
#[test]
fn context_message_injected_at_suffix_target() {
let mut messages = vec![
Message::system("system"),
Message::user("hello"),
Message::system(""), ];
let original_len = messages.len();
let ctx = vec![ContextMessage::suffix_system(
"suffix.key",
"final reminder",
)];
apply_context_messages(&mut messages, ctx, true);
assert_eq!(messages.len(), original_len + 1);
let last = messages.last().unwrap();
assert_eq!(last.role, Role::System);
assert_eq!(text_of(last), "final reminder");
}
#[test]
fn multiple_context_messages_sorted_by_target() {
let mut messages = vec![Message::system("system prompt"), Message::user("user msg")];
let ctx = vec![
ContextMessage::suffix_system("s1", "suffix text"),
ContextMessage::system("sys1", "after-system text"),
ContextMessage::conversation("conv1", Role::User, "conversation text"),
];
apply_context_messages(&mut messages, ctx, true);
assert_eq!(messages.len(), 5);
assert_eq!(text_of(&messages[0]), "system prompt");
assert_eq!(text_of(&messages[1]), "after-system text");
assert_eq!(text_of(&messages[2]), "conversation text");
assert_eq!(text_of(&messages[3]), "user msg");
assert_eq!(text_of(&messages[4]), "suffix text");
}
#[tokio::test]
async fn throttle_zero_cooldown_always_injects() {
let (runtime, env) = test_runtime();
let store = runtime.store();
for step in 0..5u32 {
enqueue_context_message(
store,
step as u64,
ContextMessage::system("always", "inject me").with_cooldown(0),
);
if step > 0 {
let mut patch = crate::state::MutationBatch::new();
patch.update::<RunLifecycle>(RunLifecycleUpdate::StepCompleted { updated_at: 0 });
store.commit(patch).expect("step completed");
}
let ctx = PhaseContext::new(Phase::BeforeInference, store.snapshot());
runtime
.run_phase_with_context(&env, ctx)
.await
.expect("run phase");
let accepted = take_context_messages(store).expect("take");
assert_eq!(
accepted.len(),
1,
"cooldown=0 should inject at every step, failed at step {step}"
);
}
}
#[tokio::test]
async fn throttle_skips_within_cooldown() {
let (runtime, env) = test_runtime();
let store = runtime.store();
enqueue_context_message(
store,
1,
ContextMessage::system("throttled", "content").with_cooldown(3),
);
let ctx = PhaseContext::new(Phase::BeforeInference, store.snapshot());
runtime
.run_phase_with_context(&env, ctx)
.await
.expect("step 1");
let accepted = take_context_messages(store).expect("take step 1");
assert_eq!(accepted.len(), 1, "first injection should pass");
for step in 2..=3u32 {
let mut patch = crate::state::MutationBatch::new();
patch.update::<RunLifecycle>(RunLifecycleUpdate::StepCompleted { updated_at: 0 });
store.commit(patch).expect("step completed");
enqueue_context_message(
store,
10 + step as u64,
ContextMessage::system("throttled", "content").with_cooldown(3),
);
let ctx = PhaseContext::new(Phase::BeforeInference, store.snapshot());
runtime
.run_phase_with_context(&env, ctx)
.await
.unwrap_or_else(|_| panic!("step {step}"));
let accepted = take_context_messages(store).unwrap_or_else(|e| panic!("step {step}: {e}"));
assert_eq!(
accepted.len(),
0,
"should be throttled at step {step} (cooldown=3, last_step=1)"
);
}
let mut patch = crate::state::MutationBatch::new();
patch.update::<RunLifecycle>(RunLifecycleUpdate::StepCompleted { updated_at: 0 });
store.commit(patch).expect("step completed");
enqueue_context_message(
store,
20,
ContextMessage::system("throttled", "content").with_cooldown(3),
);
let ctx = PhaseContext::new(Phase::BeforeInference, store.snapshot());
runtime
.run_phase_with_context(&env, ctx)
.await
.expect("step 4");
let accepted = take_context_messages(store).expect("take step 4");
assert_eq!(
accepted.len(),
1,
"cooldown expired at step 4, should inject"
);
}
#[tokio::test]
async fn throttle_bypassed_on_content_change() {
let (runtime, env) = test_runtime();
let store = runtime.store();
enqueue_context_message(
store,
1,
ContextMessage::system("changing", "original content").with_cooldown(10),
);
let ctx = PhaseContext::new(Phase::BeforeInference, store.snapshot());
runtime
.run_phase_with_context(&env, ctx)
.await
.expect("step 1");
let accepted = take_context_messages(store).expect("take step 1");
assert_eq!(accepted.len(), 1);
let mut patch = crate::state::MutationBatch::new();
patch.update::<RunLifecycle>(RunLifecycleUpdate::StepCompleted { updated_at: 0 });
store.commit(patch).expect("step completed");
enqueue_context_message(
store,
2,
ContextMessage::system("changing", "original content").with_cooldown(10),
);
let ctx = PhaseContext::new(Phase::BeforeInference, store.snapshot());
runtime
.run_phase_with_context(&env, ctx)
.await
.expect("step 2");
let accepted = take_context_messages(store).expect("take step 2");
assert_eq!(
accepted.len(),
0,
"same content within cooldown should be throttled"
);
let mut patch = crate::state::MutationBatch::new();
patch.update::<RunLifecycle>(RunLifecycleUpdate::StepCompleted { updated_at: 0 });
store.commit(patch).expect("step completed");
enqueue_context_message(
store,
3,
ContextMessage::system("changing", "updated content").with_cooldown(10),
);
let ctx = PhaseContext::new(Phase::BeforeInference, store.snapshot());
runtime
.run_phase_with_context(&env, ctx)
.await
.expect("step 3");
let accepted = take_context_messages(store).expect("take step 3");
assert_eq!(
accepted.len(),
1,
"different content should bypass cooldown"
);
assert_eq!(text_of_ctx(&accepted[0]), "updated content");
}
#[test]
fn tracing_does_not_panic_without_subscriber() {
use crate::context::ContextTransform;
use awaken_contract::contract::inference::ContextWindowPolicy;
use awaken_contract::contract::transform::InferenceRequestTransform;
let policy = ContextWindowPolicy {
max_context_tokens: 40,
max_output_tokens: 0,
min_recent_messages: 1,
enable_prompt_cache: false,
autocompact_threshold: None,
compaction_mode: Default::default(),
compaction_raw_suffix_messages: 2,
};
let transform = ContextTransform::new(policy);
let mut msgs = vec![Message::system("sys")];
for i in 0..10 {
msgs.push(Message::user(format!("msg {i}")));
msgs.push(Message::assistant(format!("reply {i}")));
}
let _output = transform.transform(msgs, &[]);
tracing::info!(step = 1u64, "step_start");
tracing::info!(
model = "test-model",
input_tokens = 100u64,
output_tokens = 50u64,
duration_ms = 42u64,
"inference_complete"
);
tracing::info!(
tool_name = "calculator",
call_id = "c1",
outcome = "Succeeded",
"tool_call_done"
);
tracing::warn!(reason = "NaturalEnd", "run_terminated");
tracing::debug!(phase = "StepStart", hooks = 3usize, "gather_start");
tracing::debug!(phase = "StepStart", actions = 2usize, "execute_start");
tracing::warn!(phase = "StepStart", "exclusive_conflict_fallback");
tracing::info!(
pre_tokens = 2000usize,
post_tokens = 500usize,
boundary = 10usize,
"compaction_complete"
);
tracing::debug!(dropped = 5usize, kept = 8usize, "truncation_applied");
}
fn text_of_ctx(msg: &ContextMessage) -> String {
msg.content
.iter()
.filter_map(|b| match b {
ContentBlock::Text { text } => Some(text.as_str()),
_ => None,
})
.collect::<Vec<_>>()
.join("")
}
fn tool(id: &str) -> ToolDescriptor {
ToolDescriptor::new(id, id, format!("{id} tool"))
}
#[test]
fn tool_filter_state_accumulates_excludes() {
let mut val = ToolFilterStateValue::default();
ToolFilterState::apply(&mut val, ToolFilterStateAction::Exclude("search".into()));
ToolFilterState::apply(&mut val, ToolFilterStateAction::Exclude("browser".into()));
assert_eq!(val.excluded, vec!["search", "browser"]);
}
#[test]
fn tool_filter_state_accumulates_include_only() {
let mut val = ToolFilterStateValue::default();
ToolFilterState::apply(
&mut val,
ToolFilterStateAction::IncludeOnly(vec!["a".into(), "b".into()]),
);
ToolFilterState::apply(
&mut val,
ToolFilterStateAction::IncludeOnly(vec!["c".into()]),
);
assert_eq!(val.include_only.len(), 2);
}
#[test]
fn tool_filter_state_clear_resets() {
let mut val = ToolFilterStateValue::default();
ToolFilterState::apply(&mut val, ToolFilterStateAction::Exclude("x".into()));
ToolFilterState::apply(
&mut val,
ToolFilterStateAction::IncludeOnly(vec!["y".into()]),
);
ToolFilterState::apply(&mut val, ToolFilterStateAction::Clear);
assert!(val.excluded.is_empty());
assert!(val.include_only.is_empty());
}
#[test]
fn exclude_tool_removes_from_request() {
let mut tools = vec![tool("search"), tool("calculator"), tool("browser")];
apply_tool_filter_payloads(&mut tools, vec!["search".into()], vec![]);
let ids: Vec<_> = tools.iter().map(|t| t.id.as_str()).collect();
assert!(!ids.contains(&"search"), "search should be excluded");
assert!(ids.contains(&"calculator"));
assert!(ids.contains(&"browser"));
assert_eq!(tools.len(), 2);
}
#[test]
fn include_only_tools_filters_to_subset() {
let mut tools = vec![
tool("search"),
tool("calculator"),
tool("browser"),
tool("code_exec"),
];
apply_tool_filter_payloads(
&mut tools,
vec![],
vec![vec!["calculator".into(), "browser".into()]],
);
let ids: Vec<_> = tools.iter().map(|t| t.id.as_str()).collect();
assert_eq!(ids.len(), 2);
assert!(ids.contains(&"calculator"));
assert!(ids.contains(&"browser"));
}
#[test]
fn exclude_and_include_only_combined() {
let mut tools = vec![tool("search"), tool("calculator"), tool("browser")];
apply_tool_filter_payloads(
&mut tools,
vec!["search".into()],
vec![vec!["search".into(), "calculator".into()]],
);
let ids: Vec<_> = tools.iter().map(|t| t.id.as_str()).collect();
assert_eq!(ids, vec!["calculator"]);
}
#[test]
fn multiple_exclude_tool_actions() {
let mut tools = vec![tool("a"), tool("b"), tool("c"), tool("d")];
apply_tool_filter_payloads(&mut tools, vec!["a".into(), "c".into()], vec![]);
let ids: Vec<_> = tools.iter().map(|t| t.id.as_str()).collect();
assert_eq!(ids, vec!["b", "d"]);
}
#[test]
fn no_filter_actions_leaves_tools_unchanged() {
let mut tools = vec![tool("search"), tool("calculator")];
apply_tool_filter_payloads(&mut tools, vec![], vec![]);
assert_eq!(tools.len(), 2);
}
#[test]
fn multiple_include_only_actions_union() {
let mut tools = vec![tool("a"), tool("b"), tool("c"), tool("d")];
apply_tool_filter_payloads(&mut tools, vec![], vec![vec!["a".into()], vec!["c".into()]]);
let ids: Vec<_> = tools.iter().map(|t| t.id.as_str()).collect();
assert_eq!(ids, vec!["a", "c"]);
}
#[test]
fn duplicate_exclude_from_permission_and_deferred_is_idempotent() {
let mut tools = vec![
tool("allowed"),
tool("denied_and_deferred"),
tool("ToolSearch"),
];
apply_tool_filter_payloads(
&mut tools,
vec!["denied_and_deferred".into(), "denied_and_deferred".into()],
vec![],
);
let ids: Vec<_> = tools.iter().map(|t| t.id.as_str()).collect();
assert_eq!(ids, vec!["allowed", "ToolSearch"]);
}
#[test]
fn permission_exclude_wins_over_deferred_promote() {
let mut tools = vec![
tool("safe_tool"),
tool("promoted_but_denied"),
tool("ToolSearch"),
];
apply_tool_filter_payloads(&mut tools, vec!["promoted_but_denied".into()], vec![]);
let ids: Vec<_> = tools.iter().map(|t| t.id.as_str()).collect();
assert_eq!(ids, vec!["safe_tool", "ToolSearch"]);
}
#[test]
fn deferred_excludes_and_permission_excludes_are_additive() {
let mut tools = vec![
tool("rm"),
tool("calculator"),
tool("browser"),
tool("search"),
tool("ToolSearch"),
];
apply_tool_filter_payloads(
&mut tools,
vec!["rm".into(), "calculator".into(), "browser".into()],
vec![],
);
let ids: Vec<_> = tools.iter().map(|t| t.id.as_str()).collect();
assert_eq!(ids, vec!["search", "ToolSearch"]);
}
#[test]
fn exclude_nonexistent_tool_is_harmless() {
let mut tools = vec![tool("a"), tool("b")];
apply_tool_filter_payloads(
&mut tools,
vec!["nonexistent".into(), "also_missing".into()],
vec![],
);
let ids: Vec<_> = tools.iter().map(|t| t.id.as_str()).collect();
assert_eq!(ids, vec!["a", "b"]);
}
#[test]
fn inference_override_state_merges_correctly() {
let ovr1 = awaken_contract::contract::inference::InferenceOverride {
upstream_model: Some("gpt-4".into()),
temperature: Some(0.7),
..Default::default()
};
let ovr2 = awaken_contract::contract::inference::InferenceOverride {
temperature: Some(0.9),
max_tokens: Some(1000),
..Default::default()
};
let mut val = InferenceOverrideStateValue::default();
InferenceOverrideState::apply(&mut val, InferenceOverrideStateAction::Merge(ovr1));
InferenceOverrideState::apply(&mut val, InferenceOverrideStateAction::Merge(ovr2));
let ovr = val.overrides.expect("should have overrides");
assert_eq!(ovr.upstream_model.as_deref(), Some("gpt-4"));
assert_eq!(ovr.temperature, Some(0.9)); assert_eq!(ovr.max_tokens, Some(1000));
}
#[test]
fn inference_override_state_clear_resets() {
let mut val = InferenceOverrideStateValue::default();
InferenceOverrideState::apply(
&mut val,
InferenceOverrideStateAction::Merge(
awaken_contract::contract::inference::InferenceOverride {
upstream_model: Some("gpt-4".into()),
..Default::default()
},
),
);
InferenceOverrideState::apply(&mut val, InferenceOverrideStateAction::Clear);
assert!(val.overrides.is_none());
}
#[test]
fn inference_override_merge_helper_works() {
let ovr1 = awaken_contract::contract::inference::InferenceOverride {
upstream_model: Some("gpt-4".into()),
temperature: Some(0.7),
..Default::default()
};
let ovr2 = awaken_contract::contract::inference::InferenceOverride {
temperature: Some(0.9),
max_tokens: Some(1000),
..Default::default()
};
let mut overrides = None;
merge_override_payloads(&mut overrides, vec![ovr1, ovr2]);
let ovr = overrides.expect("should have overrides");
assert_eq!(ovr.upstream_model.as_deref(), Some("gpt-4"));
assert_eq!(ovr.temperature, Some(0.9)); assert_eq!(ovr.max_tokens, Some(1000));
}
#[test]
fn intercept_block_wins_over_suspend() {
use awaken_contract::contract::suspension::{
PendingToolCall, SuspendTicket, Suspension, ToolCallResumeMode,
};
use awaken_contract::contract::tool_intercept::ToolInterceptPayload;
let payloads = vec![
ToolInterceptPayload::Suspend(SuspendTicket {
suspension: Suspension::default(),
pending: PendingToolCall::default(),
resume_mode: ToolCallResumeMode::default(),
}),
ToolInterceptPayload::Block {
reason: "blocked".into(),
},
];
let winner = resolve_intercept_payloads(payloads);
assert!(matches!(winner, Some(ToolInterceptPayload::Block { .. })));
}
#[test]
fn intercept_same_priority_keeps_first() {
use awaken_contract::contract::tool::ToolResult;
use awaken_contract::contract::tool_intercept::ToolInterceptPayload;
let payloads = vec![
ToolInterceptPayload::SetResult(ToolResult::success("first", serde_json::json!({}))),
ToolInterceptPayload::SetResult(ToolResult::success("second", serde_json::json!({}))),
];
let winner = resolve_intercept_payloads(payloads);
match winner {
Some(ToolInterceptPayload::SetResult(r)) => assert_eq!(r.tool_name, "first"),
other => panic!("expected SetResult, got {other:?}"),
}
}
#[test]
fn intercept_empty_returns_none() {
let winner = resolve_intercept_payloads(vec![]);
assert!(winner.is_none());
}
#[derive(Default)]
struct DummyLlm;
#[async_trait::async_trait]
impl awaken_contract::contract::executor::LlmExecutor for DummyLlm {
async fn execute(
&self,
_request: awaken_contract::contract::executor::InferenceRequest,
) -> Result<
awaken_contract::contract::inference::StreamResult,
awaken_contract::contract::executor::InferenceExecutionError,
> {
panic!("dummy llm should not execute in tool-only tests");
}
fn name(&self) -> &str {
"dummy"
}
}
#[derive(Default)]
struct RecordingExecutor {
batches: std::sync::Mutex<Vec<Vec<String>>>,
}
impl RecordingExecutor {
fn take(&self) -> Vec<Vec<String>> {
std::mem::take(&mut *self.batches.lock().expect("lock poisoned"))
}
}
#[async_trait::async_trait]
impl crate::execution::ToolExecutor for RecordingExecutor {
async fn execute(
&self,
_tools: &std::collections::HashMap<
String,
std::sync::Arc<dyn awaken_contract::contract::tool::Tool>,
>,
calls: &[awaken_contract::contract::message::ToolCall],
_base_ctx: &awaken_contract::contract::tool::ToolCallContext,
) -> Result<Vec<crate::execution::ToolExecutionResult>, crate::execution::ToolExecutorError>
{
self.batches.lock().expect("lock poisoned").push(
calls
.iter()
.map(|call| call.id.clone())
.collect::<Vec<String>>(),
);
Ok(calls
.iter()
.map(|call| crate::execution::ToolExecutionResult {
call: call.clone(),
result: awaken_contract::contract::tool::ToolResult::success(
&call.name,
serde_json::json!({"id": call.id}),
),
outcome: awaken_contract::contract::suspension::ToolCallOutcome::Succeeded,
command: crate::state::StateCommand::new(),
})
.collect())
}
fn name(&self) -> &'static str {
"recording"
}
}
struct ToolGateUnlockKey;
impl StateKey for ToolGateUnlockKey {
const KEY: &'static str = "test.tool_gate.unlock";
type Value = bool;
type Update = bool;
fn apply(value: &mut Self::Value, update: Self::Update) {
*value = update;
}
}
struct UnlockTool;
#[async_trait::async_trait]
impl Tool for UnlockTool {
fn descriptor(&self) -> ToolDescriptor {
ToolDescriptor::new("unlock", "unlock", "unlock state")
}
async fn execute(
&self,
_args: serde_json::Value,
_ctx: &ToolCallContext,
) -> Result<ToolOutput, ToolError> {
let mut command = StateCommand::new();
command.update::<ToolGateUnlockKey>(true);
Ok(ToolOutput::with_command(
ToolResult::success("unlock", serde_json::json!({"unlocked": true})),
command,
))
}
}
struct SuspendedUnlockTool;
#[async_trait::async_trait]
impl Tool for SuspendedUnlockTool {
fn descriptor(&self) -> ToolDescriptor {
ToolDescriptor::new("unlock", "unlock", "suspending unlock state")
}
async fn execute(
&self,
_args: serde_json::Value,
_ctx: &ToolCallContext,
) -> Result<ToolOutput, ToolError> {
let ticket = SuspendTicket::new(
awaken_contract::contract::suspension::Suspension {
id: "suspend_unlock".into(),
action: "tool:UnlockConfirm".into(),
message: "unlock requires approval".into(),
parameters: serde_json::json!({}),
response_schema: None,
},
awaken_contract::contract::suspension::PendingToolCall::new(
"call_unlock",
"unlock",
serde_json::json!({}),
),
awaken_contract::contract::suspension::ToolCallResumeMode::ReplayToolCall,
);
Ok(ToolResult::suspended_with("unlock", "unlock suspended", ticket).into())
}
}
struct GuardedTool;
#[async_trait::async_trait]
impl Tool for GuardedTool {
fn descriptor(&self) -> ToolDescriptor {
ToolDescriptor::new("guarded", "guarded", "guarded")
}
async fn execute(
&self,
_args: serde_json::Value,
_ctx: &ToolCallContext,
) -> Result<ToolOutput, ToolError> {
Ok(ToolResult::success("guarded", serde_json::json!({"ok": true})).into())
}
}
struct UnlockingToolGateHook;
#[async_trait::async_trait]
impl ToolGateHook for UnlockingToolGateHook {
async fn run(
&self,
ctx: &PhaseContext,
) -> Result<
Option<awaken_contract::contract::tool_intercept::ToolInterceptPayload>,
awaken_contract::StateError,
> {
if ctx.tool_name.as_deref() != Some("guarded") {
return Ok(None);
}
if ctx.state::<ToolGateUnlockKey>().copied().unwrap_or(false) {
Ok(None)
} else {
Ok(Some(
awaken_contract::contract::tool_intercept::ToolInterceptPayload::Block {
reason: "guarded locked".into(),
},
))
}
}
}
struct GuardedBeforeHook {
calls: Arc<std::sync::Mutex<Vec<String>>>,
}
#[async_trait::async_trait]
impl crate::hooks::PhaseHook for GuardedBeforeHook {
async fn run(&self, ctx: &PhaseContext) -> Result<StateCommand, awaken_contract::StateError> {
if ctx.tool_name.as_deref() == Some("guarded")
&& let Some(call_id) = ctx.tool_call_id.clone()
{
self.calls.lock().expect("lock poisoned").push(call_id);
}
Ok(StateCommand::new())
}
}
struct ToolGateTestPlugin {
before_calls: Arc<std::sync::Mutex<Vec<String>>>,
}
impl Plugin for ToolGateTestPlugin {
fn descriptor(&self) -> PluginDescriptor {
PluginDescriptor {
name: "tool-gate-test",
}
}
fn register(&self, registrar: &mut PluginRegistrar) -> Result<(), awaken_contract::StateError> {
registrar.register_key::<ToolGateUnlockKey>(StateKeyOptions::default())?;
registrar.register_tool_gate_hook("tool-gate-test", UnlockingToolGateHook)?;
registrar.register_phase_hook(
"tool-gate-test",
Phase::BeforeToolExecute,
GuardedBeforeHook {
calls: Arc::clone(&self.before_calls),
},
)?;
Ok(())
}
}
#[tokio::test]
async fn resumed_calls_do_not_serialize_neighboring_fresh_batches() {
let (runtime, _env) = test_runtime();
let store = runtime.store();
let mut patch = crate::state::MutationBatch::new();
patch.update::<ToolCallStates>(crate::agent::state::ToolCallStatesUpdate::put(
crate::agent::state::ToolCallState::new(
"c3",
"gamma",
serde_json::json!({"resumed": true}),
awaken_contract::contract::suspension::ToolCallStatus::Resuming,
1,
)
.with_resume_mode(awaken_contract::contract::suspension::ToolCallResumeMode::ReplayToolCall)
.with_suspension(
Some("perm_c3".into()),
Some("tool:PermissionConfirm".into()),
)
.with_resume_input(Some(
awaken_contract::contract::suspension::ToolCallResume {
decision_id: "decision-1".into(),
action: awaken_contract::contract::suspension::ResumeDecisionAction::Resume,
result: serde_json::json!({"approved": true}),
reason: None,
updated_at: 2,
},
)),
));
store.commit(patch).expect("seed resume state");
let executor = std::sync::Arc::new(RecordingExecutor::default());
let mut agent = crate::registry::ResolvedAgent::new(
"agent",
"model",
"system",
std::sync::Arc::new(DummyLlm),
)
.with_tool_executor(executor.clone());
let sink: std::sync::Arc<dyn awaken_contract::contract::event_sink::EventSink> =
std::sync::Arc::new(awaken_contract::contract::event_sink::NullEventSink);
let mut messages = vec![std::sync::Arc::new(Message::user("go"))];
let run_identity = awaken_contract::contract::identity::RunIdentity::default();
let run_overrides = None;
let mut total_input_tokens = 0;
let mut total_output_tokens = 0;
let mut truncation_state = crate::context::TruncationState::new();
let input_message_count = messages.len();
let mut ctx = super::step::StepContext {
agent: &mut agent,
messages: &mut messages,
runtime: &runtime,
sink,
checkpoint_store: None,
run_identity: &run_identity,
input_message_count,
cancellation_token: None,
run_overrides: &run_overrides,
total_input_tokens: &mut total_input_tokens,
total_output_tokens: &mut total_output_tokens,
truncation_state: &mut truncation_state,
run_created_at: 0,
thread_ctx: None,
};
let calls = vec![
awaken_contract::contract::message::ToolCall::new("c1", "alpha", serde_json::json!({})),
awaken_contract::contract::message::ToolCall::new("c2", "beta", serde_json::json!({})),
awaken_contract::contract::message::ToolCall::new("c3", "gamma", serde_json::json!({})),
awaken_contract::contract::message::ToolCall::new("c4", "delta", serde_json::json!({})),
awaken_contract::contract::message::ToolCall::new("c5", "epsilon", serde_json::json!({})),
];
let mut transcript = super::step::ToolBatchTranscript::for_resume();
let (_blocked, suspended) =
super::step::execute_tools_with_interception(&mut ctx, &mut transcript, &calls)
.await
.expect("tool execution should succeed");
drop(ctx);
transcript.commit_into(&mut messages);
assert!(!suspended, "recording executor never suspends");
assert_eq!(
executor.take(),
vec![
vec![String::from("c1"), String::from("c2")],
vec![String::from("c3")],
vec![String::from("c4"), String::from("c5")],
]
);
assert!(
messages.iter().all(|message| message.tool_calls.is_none()),
"resume replay should not append a fresh assistant tool-call message"
);
let states = store.read::<ToolCallStates>().expect("tool call states");
let resumed = states.calls.get("c3").expect("resumed call state");
assert_eq!(
resumed.status,
awaken_contract::contract::suspension::ToolCallStatus::Succeeded
);
assert!(
resumed.resume_input.is_none(),
"terminal tool state should clear consumed resume input"
);
assert!(
resumed.suspension_id.is_none() && resumed.suspension_reason.is_none(),
"terminal tool state should not retain an active suspension context"
);
}
#[tokio::test]
async fn cancelled_resume_is_emitted_once_even_when_other_calls_replay() {
let (runtime, _env) = test_runtime();
let store = runtime.store();
let mut patch = crate::state::MutationBatch::new();
patch.update::<ToolCallStates>(crate::agent::state::ToolCallStatesUpdate::put(
crate::agent::state::ToolCallState::new(
"cancel_a",
"alpha",
serde_json::json!({"cancelled": true}),
awaken_contract::contract::suspension::ToolCallStatus::Cancelled,
1,
)
.with_resume_mode(awaken_contract::contract::suspension::ToolCallResumeMode::ReplayToolCall)
.with_suspension(
Some("perm_cancel_a".into()),
Some("tool:PermissionConfirm".into()),
)
.with_resume_input(Some(
awaken_contract::contract::suspension::ToolCallResume {
decision_id: "decision-cancel".into(),
action: awaken_contract::contract::suspension::ResumeDecisionAction::Cancel,
result: serde_json::json!({"kind": "permission_decision", "approved": false}),
reason: None,
updated_at: 2,
},
)),
));
patch.update::<ToolCallStates>(crate::agent::state::ToolCallStatesUpdate::put(
crate::agent::state::ToolCallState::new(
"resume_b",
"beta",
serde_json::json!({"resumed": true}),
awaken_contract::contract::suspension::ToolCallStatus::Resuming,
1,
)
.with_resume_mode(awaken_contract::contract::suspension::ToolCallResumeMode::ReplayToolCall)
.with_suspension(
Some("perm_resume_b".into()),
Some("tool:PermissionConfirm".into()),
)
.with_resume_input(Some(
awaken_contract::contract::suspension::ToolCallResume {
decision_id: "decision-resume".into(),
action: awaken_contract::contract::suspension::ResumeDecisionAction::Resume,
result: serde_json::json!({"kind": "permission_decision", "approved": true}),
reason: None,
updated_at: 2,
},
)),
));
store
.commit(patch)
.expect("seed cancelled and resumed states");
let executor = std::sync::Arc::new(RecordingExecutor::default());
let agent = crate::registry::ResolvedAgent::new(
"agent",
"model",
"system",
std::sync::Arc::new(DummyLlm),
)
.with_tool_executor(executor);
let run_identity = awaken_contract::contract::identity::RunIdentity::default();
let sink = std::sync::Arc::new(VecEventSink::new());
let mut messages = vec![std::sync::Arc::new(Message::user("go"))];
super::resume::detect_and_replay_resume(
&agent,
&runtime,
&run_identity,
&mut messages,
sink.clone(),
)
.await
.expect("first replay should succeed");
super::resume::detect_and_replay_resume(
&agent,
&runtime,
&run_identity,
&mut messages,
sink.clone(),
)
.await
.expect("second replay should not re-emit cancelled resume");
let cancelled_events: Vec<_> = sink
.events()
.into_iter()
.filter(|event| {
matches!(
event,
AgentEvent::ToolCallResumed { target_id, .. } if target_id == "cancel_a"
)
})
.collect();
assert_eq!(
cancelled_events.len(),
1,
"cancelled resume should only emit one ToolCallResumed event"
);
let cancelled_messages: Vec<_> = messages
.iter()
.filter(|message| {
message.role == Role::Tool && message.tool_call_id.as_deref() == Some("cancel_a")
})
.collect();
assert_eq!(
cancelled_messages.len(),
1,
"cancelled resume should only append one tool message"
);
let states = store.read::<ToolCallStates>().expect("tool call states");
let cancelled = states.calls.get("cancel_a").expect("cancelled call state");
assert_eq!(
cancelled.status,
awaken_contract::contract::suspension::ToolCallStatus::Cancelled
);
assert!(
cancelled.resume_input.is_none(),
"cancelled terminal state should clear consumed resume input"
);
assert!(
cancelled.suspension_id.is_none() && cancelled.suspension_reason.is_none(),
"cancelled terminal state should not retain an active suspension context"
);
}
#[tokio::test]
async fn tool_gate_recheck_executes_before_tool_hook_once() {
let before_calls = Arc::new(std::sync::Mutex::new(Vec::new()));
let (runtime, env) = test_runtime_with_plugins(vec![Arc::new(ToolGateTestPlugin {
before_calls: Arc::clone(&before_calls),
})]);
runtime
.store()
.install_plugin(ToolGateTestPlugin {
before_calls: Arc::clone(&before_calls),
})
.expect("install tool gate test plugin keys");
let store = runtime.store();
let sink = Arc::new(VecEventSink::new());
let sink_dyn: Arc<dyn awaken_contract::contract::event_sink::EventSink> = sink.clone();
let mut messages = vec![Arc::new(Message::user("go"))];
let run_identity = awaken_contract::contract::identity::RunIdentity::default();
let run_overrides = None;
let mut total_input_tokens = 0;
let mut total_output_tokens = 0;
let mut truncation_state = crate::context::TruncationState::new();
let mut agent =
crate::registry::ResolvedAgent::new("agent", "model", "system", Arc::new(DummyLlm))
.with_tools(vec![Arc::new(UnlockTool), Arc::new(GuardedTool)]);
agent.env = env;
let input_message_count = messages.len();
let mut ctx = super::step::StepContext {
agent: &mut agent,
messages: &mut messages,
runtime: &runtime,
sink: sink_dyn,
checkpoint_store: None,
run_identity: &run_identity,
input_message_count,
cancellation_token: None,
run_overrides: &run_overrides,
total_input_tokens: &mut total_input_tokens,
total_output_tokens: &mut total_output_tokens,
truncation_state: &mut truncation_state,
run_created_at: 0,
thread_ctx: None,
};
let calls = vec![
awaken_contract::contract::message::ToolCall::new("c1", "unlock", serde_json::json!({})),
awaken_contract::contract::message::ToolCall::new("c2", "guarded", serde_json::json!({})),
];
let mut transcript =
super::step::ToolBatchTranscript::for_inference(String::new(), calls.clone());
let (blocked, suspended) =
super::step::execute_tools_with_interception(&mut ctx, &mut transcript, &calls)
.await
.expect("tool execution should succeed");
drop(ctx);
transcript.commit_into(&mut messages);
assert!(blocked.is_none(), "tool gate should re-open after unlock");
assert!(!suspended, "unlock + guarded should complete");
assert_eq!(
before_calls.lock().expect("lock poisoned").as_slice(),
&["c2".to_string()],
"BeforeToolExecute should run once for the guarded call after ToolGate recheck",
);
assert_eq!(store.read::<ToolGateUnlockKey>(), Some(true));
let states = store.read::<ToolCallStates>().expect("tool call states");
assert_eq!(
states.calls.get("c1").expect("unlock state").status,
ToolCallStatus::Succeeded
);
assert_eq!(
states.calls.get("c2").expect("guarded state").status,
ToolCallStatus::Succeeded
);
}
#[tokio::test]
async fn tool_gate_flush_suspension_backfills_rechecked_and_later_calls() {
let before_calls = Arc::new(std::sync::Mutex::new(Vec::new()));
let (runtime, env) = test_runtime_with_plugins(vec![Arc::new(ToolGateTestPlugin {
before_calls: Arc::clone(&before_calls),
})]);
runtime
.store()
.install_plugin(ToolGateTestPlugin {
before_calls: Arc::clone(&before_calls),
})
.expect("install tool gate test plugin keys");
let store = runtime.store();
let sink = Arc::new(VecEventSink::new());
let sink_dyn: Arc<dyn awaken_contract::contract::event_sink::EventSink> = sink.clone();
let mut messages = vec![Arc::new(Message::user("go"))];
let run_identity = awaken_contract::contract::identity::RunIdentity::default();
let run_overrides = None;
let mut total_input_tokens = 0;
let mut total_output_tokens = 0;
let mut truncation_state = crate::context::TruncationState::new();
let mut agent =
crate::registry::ResolvedAgent::new("agent", "model", "system", Arc::new(DummyLlm))
.with_tools(vec![Arc::new(SuspendedUnlockTool), Arc::new(GuardedTool)]);
agent.env = env;
let input_message_count = messages.len();
let mut ctx = super::step::StepContext {
agent: &mut agent,
messages: &mut messages,
runtime: &runtime,
sink: sink_dyn,
checkpoint_store: None,
run_identity: &run_identity,
input_message_count,
cancellation_token: None,
run_overrides: &run_overrides,
total_input_tokens: &mut total_input_tokens,
total_output_tokens: &mut total_output_tokens,
truncation_state: &mut truncation_state,
run_created_at: 0,
thread_ctx: None,
};
let calls = vec![
awaken_contract::contract::message::ToolCall::new("c1", "unlock", serde_json::json!({})),
awaken_contract::contract::message::ToolCall::new("c2", "guarded", serde_json::json!({})),
awaken_contract::contract::message::ToolCall::new("c3", "guarded", serde_json::json!({})),
];
let mut transcript =
super::step::ToolBatchTranscript::for_inference(String::new(), calls.clone());
let (blocked, suspended) =
super::step::execute_tools_with_interception(&mut ctx, &mut transcript, &calls)
.await
.expect("tool execution should succeed");
drop(ctx);
transcript.commit_into(&mut messages);
assert!(blocked.is_none());
assert!(suspended, "suspending prefix call should suspend the batch");
assert!(
before_calls.lock().expect("lock poisoned").is_empty(),
"guarded calls should never reach BeforeToolExecute after prefix suspension",
);
let states = store.read::<ToolCallStates>().expect("tool call states");
assert_eq!(
states.calls.get("c1").expect("unlock state").status,
ToolCallStatus::Suspended
);
assert_eq!(
states.calls.get("c2").expect("guarded state").status,
ToolCallStatus::Failed
);
assert_eq!(
states.calls.get("c3").expect("later guarded state").status,
ToolCallStatus::Failed
);
let interrupted: Vec<_> = messages
.iter()
.filter(|message| {
message.role == Role::Tool
&& matches!(message.tool_call_id.as_deref(), Some("c2" | "c3"))
})
.map(|message| text_of(message))
.collect();
assert_eq!(
interrupted,
vec![
"[Tool execution was interrupted]".to_string(),
"[Tool execution was interrupted]".to_string(),
]
);
let failed_events: Vec<_> = sink
.events()
.into_iter()
.filter_map(|event| match event {
AgentEvent::ToolCallDone { id, outcome, .. }
if matches!(
(id.as_str(), outcome),
("c2", ToolCallOutcome::Failed) | ("c3", ToolCallOutcome::Failed)
) =>
{
Some(id)
}
_ => None,
})
.collect();
assert_eq!(failed_events, vec!["c2".to_string(), "c3".to_string()]);
}