Skip to main content

awaken_runtime/hooks/
context.rs

1use std::sync::Arc;
2
3use serde_json::Value;
4
5use crate::cancellation::CancellationToken;
6use crate::state::{Snapshot, StateKey};
7use awaken_contract::StateError;
8use awaken_contract::contract::identity::RunIdentity;
9use awaken_contract::contract::inference::LLMResponse;
10use awaken_contract::contract::message::Message;
11use awaken_contract::contract::suspension::ToolCallResume;
12use awaken_contract::contract::tool::ToolResult;
13use awaken_contract::contract::tool_intercept::{
14    AdapterKind, RunMode, ToolKind, ToolPolicyContext,
15};
16use awaken_contract::model::Phase;
17use awaken_contract::registry_spec::{AgentSpec, PluginConfigKey};
18
19/// Execution context passed to phase hooks and action handlers.
20///
21/// Three input sources per ADR-0009:
22/// - `agent_spec`: immutable agent configuration (model, active_hook_filter, sections)
23/// - `snapshot`: shared runtime state (StateKeys)
24/// - `run_identity`: per-run identity (thread_id, run_id, etc.)
25#[derive(Clone)]
26pub struct PhaseContext {
27    pub phase: Phase,
28    pub snapshot: Snapshot,
29
30    /// Active agent spec (resolved from registry at each phase boundary).
31    pub agent_spec: Arc<AgentSpec>,
32
33    /// Per-run identity (thread_id, run_id, etc.). Immutable for the run.
34    pub run_identity: RunIdentity,
35
36    /// Messages accumulated in the current run.
37    pub messages: Arc<[Arc<Message>]>,
38
39    // Tool-call context (set during BeforeToolExecute / AfterToolExecute)
40    pub tool_name: Option<String>,
41    pub tool_call_id: Option<String>,
42    pub tool_args: Option<Value>,
43    pub tool_result: Option<ToolResult>,
44    pub run_mode: RunMode,
45    pub adapter: AdapterKind,
46    pub tool_kind: ToolKind,
47
48    // LLM response (set during AfterInference)
49    pub llm_response: Option<LLMResponse>,
50
51    // Resume decision (set during BeforeToolExecute when resuming a suspended tool call)
52    pub resume_input: Option<ToolCallResume>,
53    pub suspension_id: Option<String>,
54    pub suspension_reason: Option<String>,
55
56    /// Optional cancellation token for cooperative cancellation at phase boundaries.
57    pub cancellation_token: Option<CancellationToken>,
58
59    /// Optional profile access for cross-run persistence.
60    pub profile_access: Option<Arc<crate::profile::ProfileAccess>>,
61}
62
63impl PhaseContext {
64    /// Create a minimal context (for testing or phases without extra data).
65    pub fn new(phase: Phase, snapshot: Snapshot) -> Self {
66        Self {
67            phase,
68            snapshot,
69            agent_spec: Arc::new(AgentSpec::default()),
70            run_identity: RunIdentity::default(),
71            messages: Arc::from([]),
72            tool_name: None,
73            tool_call_id: None,
74            tool_args: None,
75            tool_result: None,
76            run_mode: RunMode::Foreground,
77            adapter: AdapterKind::Internal,
78            tool_kind: ToolKind::Other,
79            llm_response: None,
80            resume_input: None,
81            suspension_id: None,
82            suspension_reason: None,
83            cancellation_token: None,
84            profile_access: None,
85        }
86    }
87
88    /// Read a state key from the snapshot.
89    pub fn state<K: StateKey>(&self) -> Option<&K::Value> {
90        self.snapshot.get::<K>()
91    }
92
93    /// Read a typed plugin config from the active agent spec.
94    /// Returns `Config::default()` if the section is missing.
95    pub fn config<K: PluginConfigKey>(&self) -> Result<K::Config, StateError> {
96        self.agent_spec.config::<K>()
97    }
98
99    // -- Builder methods --
100
101    #[must_use]
102    pub fn with_snapshot(mut self, snapshot: Snapshot) -> Self {
103        self.snapshot = snapshot;
104        self
105    }
106
107    #[must_use]
108    pub fn with_agent_spec(mut self, spec: Arc<AgentSpec>) -> Self {
109        self.agent_spec = spec;
110        self
111    }
112
113    #[must_use]
114    pub fn with_run_identity(mut self, identity: RunIdentity) -> Self {
115        self.run_mode = identity.run_mode();
116        self.adapter = identity.adapter();
117        self.run_identity = identity;
118        self
119    }
120
121    #[must_use]
122    pub fn with_messages(mut self, messages: Vec<Arc<Message>>) -> Self {
123        self.messages = Arc::from(messages);
124        self
125    }
126
127    #[must_use]
128    pub fn with_tool_info(
129        mut self,
130        name: impl Into<String>,
131        call_id: impl Into<String>,
132        args: Option<Value>,
133    ) -> Self {
134        let name = name.into();
135        self.tool_kind = ToolKind::infer_name(&name);
136        self.tool_name = Some(name);
137        self.tool_call_id = Some(call_id.into());
138        self.tool_args = args;
139        self
140    }
141
142    #[must_use]
143    pub fn with_run_mode(mut self, mode: RunMode) -> Self {
144        self.run_mode = mode;
145        self
146    }
147
148    #[must_use]
149    pub fn with_adapter(mut self, adapter: AdapterKind) -> Self {
150        self.adapter = adapter;
151        self
152    }
153
154    #[must_use]
155    pub fn with_tool_kind(mut self, kind: ToolKind) -> Self {
156        self.tool_kind = kind;
157        self
158    }
159
160    /// Build typed policy context for ToolGate/ToolPolicy hooks.
161    pub fn tool_policy_context(&self) -> Option<ToolPolicyContext> {
162        Some(ToolPolicyContext {
163            thread_id: self.run_identity.thread_id.clone(),
164            run_id: self.run_identity.run_id.clone(),
165            dispatch_id: self.run_identity.trace.dispatch_id.clone(),
166            run_mode: self.run_mode,
167            adapter: self.adapter,
168            tool_name: self.tool_name.clone()?,
169            tool_kind: self.tool_kind,
170            arguments: self.tool_args.clone().unwrap_or(Value::Null),
171        })
172    }
173
174    #[must_use]
175    pub fn with_tool_result(mut self, result: ToolResult) -> Self {
176        self.tool_result = Some(result);
177        self
178    }
179
180    #[must_use]
181    pub fn with_llm_response(mut self, response: LLMResponse) -> Self {
182        self.llm_response = Some(response);
183        self
184    }
185
186    #[must_use]
187    pub fn with_resume_input(mut self, resume: ToolCallResume) -> Self {
188        self.resume_input = Some(resume);
189        self
190    }
191
192    #[must_use]
193    pub fn with_suspension(
194        mut self,
195        suspension_id: Option<String>,
196        suspension_reason: Option<String>,
197    ) -> Self {
198        self.suspension_id = suspension_id;
199        self.suspension_reason = suspension_reason;
200        self
201    }
202
203    #[must_use]
204    pub fn with_cancellation_token(mut self, token: CancellationToken) -> Self {
205        self.cancellation_token = Some(token);
206        self
207    }
208
209    /// Get profile access, if configured.
210    pub fn profile(&self) -> Option<&crate::profile::ProfileAccess> {
211        self.profile_access.as_deref()
212    }
213
214    #[must_use]
215    pub fn with_profile_access(mut self, access: Arc<crate::profile::ProfileAccess>) -> Self {
216        self.profile_access = Some(access);
217        self
218    }
219}
220
221#[cfg(test)]
222mod tests {
223    use super::*;
224    use crate::state::StateMap;
225    use awaken_contract::contract::content::ContentBlock;
226    use awaken_contract::contract::identity::RunOrigin;
227    use awaken_contract::contract::inference::{StopReason, StreamResult, TokenUsage};
228    use awaken_contract::contract::tool::ToolResult;
229
230    fn empty_snapshot() -> Snapshot {
231        Snapshot::new(0, std::sync::Arc::new(StateMap::default()))
232    }
233
234    #[test]
235    fn phase_context_new_has_defaults() {
236        let ctx = PhaseContext::new(Phase::BeforeInference, empty_snapshot());
237        assert_eq!(ctx.phase, Phase::BeforeInference);
238        assert!(ctx.messages.is_empty());
239        assert!(ctx.tool_name.is_none());
240        assert!(ctx.llm_response.is_none());
241        assert_eq!(ctx.agent_spec.id, "");
242    }
243
244    #[test]
245    fn phase_context_with_agent_spec() {
246        let spec = Arc::new(
247            AgentSpec::new("reviewer")
248                .with_model_id("opus")
249                .with_hook_filter("perm"),
250        );
251        let ctx = PhaseContext::new(Phase::RunStart, empty_snapshot()).with_agent_spec(spec);
252        assert_eq!(ctx.agent_spec.id, "reviewer");
253        assert_eq!(ctx.agent_spec.model_id, "opus");
254        assert!(ctx.agent_spec.active_hook_filter.contains("perm"));
255    }
256
257    #[test]
258    fn phase_context_with_run_identity() {
259        let ctx = PhaseContext::new(Phase::RunStart, empty_snapshot()).with_run_identity(
260            RunIdentity::new(
261                "t1".into(),
262                None,
263                "r1".into(),
264                None,
265                "a".into(),
266                RunOrigin::User,
267            ),
268        );
269        assert_eq!(ctx.run_identity.thread_id, "t1");
270    }
271
272    #[test]
273    fn phase_context_with_messages() {
274        let msgs = vec![
275            Arc::new(Message::user("hello")),
276            Arc::new(Message::assistant("hi")),
277        ];
278        let ctx = PhaseContext::new(Phase::BeforeInference, empty_snapshot()).with_messages(msgs);
279        assert_eq!(ctx.messages.len(), 2);
280    }
281
282    #[test]
283    fn phase_context_with_tool_info() {
284        let ctx = PhaseContext::new(Phase::BeforeToolExecute, empty_snapshot()).with_tool_info(
285            "search",
286            "c1",
287            Some(serde_json::json!({"q": "rust"})),
288        );
289        assert_eq!(ctx.tool_name.as_deref(), Some("search"));
290        assert_eq!(ctx.tool_call_id.as_deref(), Some("c1"));
291        assert_eq!(ctx.tool_kind, ToolKind::Search);
292        let policy = ctx.tool_policy_context().expect("policy context");
293        assert_eq!(policy.tool_name, "search");
294        assert_eq!(policy.tool_kind, ToolKind::Search);
295        assert_eq!(policy.arguments["q"], "rust");
296    }
297
298    #[test]
299    fn phase_context_tool_policy_context_carries_trace_and_mode() {
300        let identity = RunIdentity::new(
301            "t1".into(),
302            None,
303            "r1".into(),
304            None,
305            "agent".into(),
306            RunOrigin::User,
307        )
308        .with_dispatch_id("dispatch-1")
309        .with_run_mode(RunMode::Scheduled)
310        .with_adapter(AdapterKind::Acp);
311        let ctx = PhaseContext::new(Phase::ToolGate, empty_snapshot())
312            .with_run_identity(identity)
313            .with_tool_info(
314                "bash",
315                "call-1",
316                Some(serde_json::json!({"cmd": "echo ok"})),
317            );
318
319        let policy = ctx.tool_policy_context().expect("policy context");
320        assert_eq!(policy.thread_id, "t1");
321        assert_eq!(policy.run_id, "r1");
322        assert_eq!(policy.dispatch_id.as_deref(), Some("dispatch-1"));
323        assert_eq!(policy.run_mode, RunMode::Scheduled);
324        assert_eq!(policy.adapter, AdapterKind::Acp);
325        assert_eq!(policy.tool_kind, ToolKind::Execute);
326    }
327
328    #[test]
329    fn phase_context_with_tool_result() {
330        let ctx = PhaseContext::new(Phase::AfterToolExecute, empty_snapshot()).with_tool_result(
331            ToolResult::success("search", serde_json::json!({"hits": 5})),
332        );
333        assert!(ctx.tool_result.as_ref().unwrap().is_success());
334    }
335
336    #[test]
337    fn phase_context_with_llm_response() {
338        let response = LLMResponse::success(StreamResult {
339            content: vec![ContentBlock::text("hello")],
340            tool_calls: vec![],
341            usage: Some(TokenUsage::default()),
342            stop_reason: Some(StopReason::EndTurn),
343            has_incomplete_tool_calls: false,
344        });
345        let ctx =
346            PhaseContext::new(Phase::AfterInference, empty_snapshot()).with_llm_response(response);
347        assert!(ctx.llm_response.as_ref().unwrap().outcome.is_ok());
348    }
349
350    #[test]
351    fn phase_context_builder_chains() {
352        let ctx = PhaseContext::new(Phase::AfterToolExecute, empty_snapshot())
353            .with_run_identity(RunIdentity::for_thread("t1"))
354            .with_messages(vec![Arc::new(Message::user("hi"))])
355            .with_tool_info("calc", "c1", None)
356            .with_tool_result(ToolResult::success("calc", serde_json::json!(42)));
357
358        assert_eq!(ctx.run_identity.thread_id, "t1");
359        assert_eq!(ctx.messages.len(), 1);
360        assert_eq!(ctx.tool_name.as_deref(), Some("calc"));
361        assert!(ctx.tool_result.is_some());
362    }
363
364    #[test]
365    fn phase_context_profile_none_by_default() {
366        let ctx = PhaseContext::new(Phase::RunStart, empty_snapshot());
367        assert!(ctx.profile().is_none());
368    }
369}