use std::sync::Arc;
use serde_json::Value;
use crate::cancellation::CancellationToken;
use crate::state::{Snapshot, StateKey};
use awaken_contract::StateError;
use awaken_contract::contract::identity::RunIdentity;
use awaken_contract::contract::inference::LLMResponse;
use awaken_contract::contract::message::Message;
use awaken_contract::contract::suspension::ToolCallResume;
use awaken_contract::contract::tool::ToolResult;
use awaken_contract::contract::tool_intercept::{
AdapterKind, RunMode, ToolKind, ToolPolicyContext,
};
use awaken_contract::model::Phase;
use awaken_contract::registry_spec::{AgentSpec, PluginConfigKey};
#[derive(Clone)]
pub struct PhaseContext {
pub phase: Phase,
pub snapshot: Snapshot,
pub agent_spec: Arc<AgentSpec>,
pub run_identity: RunIdentity,
pub messages: Arc<[Arc<Message>]>,
pub tool_name: Option<String>,
pub tool_call_id: Option<String>,
pub tool_args: Option<Value>,
pub tool_result: Option<ToolResult>,
pub run_mode: RunMode,
pub adapter: AdapterKind,
pub tool_kind: ToolKind,
pub llm_response: Option<LLMResponse>,
pub resume_input: Option<ToolCallResume>,
pub suspension_id: Option<String>,
pub suspension_reason: Option<String>,
pub cancellation_token: Option<CancellationToken>,
pub profile_access: Option<Arc<crate::profile::ProfileAccess>>,
}
impl PhaseContext {
pub fn new(phase: Phase, snapshot: Snapshot) -> Self {
Self {
phase,
snapshot,
agent_spec: Arc::new(AgentSpec::default()),
run_identity: RunIdentity::default(),
messages: Arc::from([]),
tool_name: None,
tool_call_id: None,
tool_args: None,
tool_result: None,
run_mode: RunMode::Foreground,
adapter: AdapterKind::Internal,
tool_kind: ToolKind::Other,
llm_response: None,
resume_input: None,
suspension_id: None,
suspension_reason: None,
cancellation_token: None,
profile_access: None,
}
}
pub fn state<K: StateKey>(&self) -> Option<&K::Value> {
self.snapshot.get::<K>()
}
pub fn config<K: PluginConfigKey>(&self) -> Result<K::Config, StateError> {
self.agent_spec.config::<K>()
}
#[must_use]
pub fn with_snapshot(mut self, snapshot: Snapshot) -> Self {
self.snapshot = snapshot;
self
}
#[must_use]
pub fn with_agent_spec(mut self, spec: Arc<AgentSpec>) -> Self {
self.agent_spec = spec;
self
}
#[must_use]
pub fn with_run_identity(mut self, identity: RunIdentity) -> Self {
self.run_mode = identity.run_mode();
self.adapter = identity.adapter();
self.run_identity = identity;
self
}
#[must_use]
pub fn with_messages(mut self, messages: Vec<Arc<Message>>) -> Self {
self.messages = Arc::from(messages);
self
}
#[must_use]
pub fn with_tool_info(
mut self,
name: impl Into<String>,
call_id: impl Into<String>,
args: Option<Value>,
) -> Self {
let name = name.into();
self.tool_kind = ToolKind::infer_name(&name);
self.tool_name = Some(name);
self.tool_call_id = Some(call_id.into());
self.tool_args = args;
self
}
#[must_use]
pub fn with_run_mode(mut self, mode: RunMode) -> Self {
self.run_mode = mode;
self
}
#[must_use]
pub fn with_adapter(mut self, adapter: AdapterKind) -> Self {
self.adapter = adapter;
self
}
#[must_use]
pub fn with_tool_kind(mut self, kind: ToolKind) -> Self {
self.tool_kind = kind;
self
}
pub fn tool_policy_context(&self) -> Option<ToolPolicyContext> {
Some(ToolPolicyContext {
thread_id: self.run_identity.thread_id.clone(),
run_id: self.run_identity.run_id.clone(),
dispatch_id: self.run_identity.trace.dispatch_id.clone(),
run_mode: self.run_mode,
adapter: self.adapter,
tool_name: self.tool_name.clone()?,
tool_kind: self.tool_kind,
arguments: self.tool_args.clone().unwrap_or(Value::Null),
})
}
#[must_use]
pub fn with_tool_result(mut self, result: ToolResult) -> Self {
self.tool_result = Some(result);
self
}
#[must_use]
pub fn with_llm_response(mut self, response: LLMResponse) -> Self {
self.llm_response = Some(response);
self
}
#[must_use]
pub fn with_resume_input(mut self, resume: ToolCallResume) -> Self {
self.resume_input = Some(resume);
self
}
#[must_use]
pub fn with_suspension(
mut self,
suspension_id: Option<String>,
suspension_reason: Option<String>,
) -> Self {
self.suspension_id = suspension_id;
self.suspension_reason = suspension_reason;
self
}
#[must_use]
pub fn with_cancellation_token(mut self, token: CancellationToken) -> Self {
self.cancellation_token = Some(token);
self
}
pub fn profile(&self) -> Option<&crate::profile::ProfileAccess> {
self.profile_access.as_deref()
}
#[must_use]
pub fn with_profile_access(mut self, access: Arc<crate::profile::ProfileAccess>) -> Self {
self.profile_access = Some(access);
self
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::state::StateMap;
use awaken_contract::contract::content::ContentBlock;
use awaken_contract::contract::identity::RunOrigin;
use awaken_contract::contract::inference::{StopReason, StreamResult, TokenUsage};
use awaken_contract::contract::tool::ToolResult;
fn empty_snapshot() -> Snapshot {
Snapshot::new(0, std::sync::Arc::new(StateMap::default()))
}
#[test]
fn phase_context_new_has_defaults() {
let ctx = PhaseContext::new(Phase::BeforeInference, empty_snapshot());
assert_eq!(ctx.phase, Phase::BeforeInference);
assert!(ctx.messages.is_empty());
assert!(ctx.tool_name.is_none());
assert!(ctx.llm_response.is_none());
assert_eq!(ctx.agent_spec.id, "");
}
#[test]
fn phase_context_with_agent_spec() {
let spec = Arc::new(
AgentSpec::new("reviewer")
.with_model_id("opus")
.with_hook_filter("perm"),
);
let ctx = PhaseContext::new(Phase::RunStart, empty_snapshot()).with_agent_spec(spec);
assert_eq!(ctx.agent_spec.id, "reviewer");
assert_eq!(ctx.agent_spec.model_id, "opus");
assert!(ctx.agent_spec.active_hook_filter.contains("perm"));
}
#[test]
fn phase_context_with_run_identity() {
let ctx = PhaseContext::new(Phase::RunStart, empty_snapshot()).with_run_identity(
RunIdentity::new(
"t1".into(),
None,
"r1".into(),
None,
"a".into(),
RunOrigin::User,
),
);
assert_eq!(ctx.run_identity.thread_id, "t1");
}
#[test]
fn phase_context_with_messages() {
let msgs = vec![
Arc::new(Message::user("hello")),
Arc::new(Message::assistant("hi")),
];
let ctx = PhaseContext::new(Phase::BeforeInference, empty_snapshot()).with_messages(msgs);
assert_eq!(ctx.messages.len(), 2);
}
#[test]
fn phase_context_with_tool_info() {
let ctx = PhaseContext::new(Phase::BeforeToolExecute, empty_snapshot()).with_tool_info(
"search",
"c1",
Some(serde_json::json!({"q": "rust"})),
);
assert_eq!(ctx.tool_name.as_deref(), Some("search"));
assert_eq!(ctx.tool_call_id.as_deref(), Some("c1"));
assert_eq!(ctx.tool_kind, ToolKind::Search);
let policy = ctx.tool_policy_context().expect("policy context");
assert_eq!(policy.tool_name, "search");
assert_eq!(policy.tool_kind, ToolKind::Search);
assert_eq!(policy.arguments["q"], "rust");
}
#[test]
fn phase_context_tool_policy_context_carries_trace_and_mode() {
let identity = RunIdentity::new(
"t1".into(),
None,
"r1".into(),
None,
"agent".into(),
RunOrigin::User,
)
.with_dispatch_id("dispatch-1")
.with_run_mode(RunMode::Scheduled)
.with_adapter(AdapterKind::Acp);
let ctx = PhaseContext::new(Phase::ToolGate, empty_snapshot())
.with_run_identity(identity)
.with_tool_info(
"bash",
"call-1",
Some(serde_json::json!({"cmd": "echo ok"})),
);
let policy = ctx.tool_policy_context().expect("policy context");
assert_eq!(policy.thread_id, "t1");
assert_eq!(policy.run_id, "r1");
assert_eq!(policy.dispatch_id.as_deref(), Some("dispatch-1"));
assert_eq!(policy.run_mode, RunMode::Scheduled);
assert_eq!(policy.adapter, AdapterKind::Acp);
assert_eq!(policy.tool_kind, ToolKind::Execute);
}
#[test]
fn phase_context_with_tool_result() {
let ctx = PhaseContext::new(Phase::AfterToolExecute, empty_snapshot()).with_tool_result(
ToolResult::success("search", serde_json::json!({"hits": 5})),
);
assert!(ctx.tool_result.as_ref().unwrap().is_success());
}
#[test]
fn phase_context_with_llm_response() {
let response = LLMResponse::success(StreamResult {
content: vec![ContentBlock::text("hello")],
tool_calls: vec![],
usage: Some(TokenUsage::default()),
stop_reason: Some(StopReason::EndTurn),
has_incomplete_tool_calls: false,
});
let ctx =
PhaseContext::new(Phase::AfterInference, empty_snapshot()).with_llm_response(response);
assert!(ctx.llm_response.as_ref().unwrap().outcome.is_ok());
}
#[test]
fn phase_context_builder_chains() {
let ctx = PhaseContext::new(Phase::AfterToolExecute, empty_snapshot())
.with_run_identity(RunIdentity::for_thread("t1"))
.with_messages(vec![Arc::new(Message::user("hi"))])
.with_tool_info("calc", "c1", None)
.with_tool_result(ToolResult::success("calc", serde_json::json!(42)));
assert_eq!(ctx.run_identity.thread_id, "t1");
assert_eq!(ctx.messages.len(), 1);
assert_eq!(ctx.tool_name.as_deref(), Some("calc"));
assert!(ctx.tool_result.is_some());
}
#[test]
fn phase_context_profile_none_by_default() {
let ctx = PhaseContext::new(Phase::RunStart, empty_snapshot());
assert!(ctx.profile().is_none());
}
}