1use std::sync::Arc;
2
3use serde_json::Value;
4
5use crate::cancellation::CancellationToken;
6use crate::state::{Snapshot, StateKey};
7use awaken_contract::StateError;
8use awaken_contract::contract::identity::RunIdentity;
9use awaken_contract::contract::inference::LLMResponse;
10use awaken_contract::contract::message::Message;
11use awaken_contract::contract::suspension::ToolCallResume;
12use awaken_contract::contract::tool::ToolResult;
13use awaken_contract::contract::tool_intercept::{
14 AdapterKind, RunMode, ToolKind, ToolPolicyContext,
15};
16use awaken_contract::model::Phase;
17use awaken_contract::registry_spec::{AgentSpec, PluginConfigKey};
18
19#[derive(Clone)]
26pub struct PhaseContext {
27 pub phase: Phase,
28 pub snapshot: Snapshot,
29
30 pub agent_spec: Arc<AgentSpec>,
32
33 pub run_identity: RunIdentity,
35
36 pub messages: Arc<[Arc<Message>]>,
38
39 pub tool_name: Option<String>,
41 pub tool_call_id: Option<String>,
42 pub tool_args: Option<Value>,
43 pub tool_result: Option<ToolResult>,
44 pub run_mode: RunMode,
45 pub adapter: AdapterKind,
46 pub tool_kind: ToolKind,
47
48 pub llm_response: Option<LLMResponse>,
50
51 pub resume_input: Option<ToolCallResume>,
53 pub suspension_id: Option<String>,
54 pub suspension_reason: Option<String>,
55
56 pub cancellation_token: Option<CancellationToken>,
58
59 pub profile_access: Option<Arc<crate::profile::ProfileAccess>>,
61}
62
63impl PhaseContext {
64 pub fn new(phase: Phase, snapshot: Snapshot) -> Self {
66 Self {
67 phase,
68 snapshot,
69 agent_spec: Arc::new(AgentSpec::default()),
70 run_identity: RunIdentity::default(),
71 messages: Arc::from([]),
72 tool_name: None,
73 tool_call_id: None,
74 tool_args: None,
75 tool_result: None,
76 run_mode: RunMode::Foreground,
77 adapter: AdapterKind::Internal,
78 tool_kind: ToolKind::Other,
79 llm_response: None,
80 resume_input: None,
81 suspension_id: None,
82 suspension_reason: None,
83 cancellation_token: None,
84 profile_access: None,
85 }
86 }
87
88 pub fn state<K: StateKey>(&self) -> Option<&K::Value> {
90 self.snapshot.get::<K>()
91 }
92
93 pub fn config<K: PluginConfigKey>(&self) -> Result<K::Config, StateError> {
96 self.agent_spec.config::<K>()
97 }
98
99 #[must_use]
102 pub fn with_snapshot(mut self, snapshot: Snapshot) -> Self {
103 self.snapshot = snapshot;
104 self
105 }
106
107 #[must_use]
108 pub fn with_agent_spec(mut self, spec: Arc<AgentSpec>) -> Self {
109 self.agent_spec = spec;
110 self
111 }
112
113 #[must_use]
114 pub fn with_run_identity(mut self, identity: RunIdentity) -> Self {
115 self.run_mode = identity.run_mode();
116 self.adapter = identity.adapter();
117 self.run_identity = identity;
118 self
119 }
120
121 #[must_use]
122 pub fn with_messages(mut self, messages: Vec<Arc<Message>>) -> Self {
123 self.messages = Arc::from(messages);
124 self
125 }
126
127 #[must_use]
128 pub fn with_tool_info(
129 mut self,
130 name: impl Into<String>,
131 call_id: impl Into<String>,
132 args: Option<Value>,
133 ) -> Self {
134 let name = name.into();
135 self.tool_kind = ToolKind::infer_name(&name);
136 self.tool_name = Some(name);
137 self.tool_call_id = Some(call_id.into());
138 self.tool_args = args;
139 self
140 }
141
142 #[must_use]
143 pub fn with_run_mode(mut self, mode: RunMode) -> Self {
144 self.run_mode = mode;
145 self
146 }
147
148 #[must_use]
149 pub fn with_adapter(mut self, adapter: AdapterKind) -> Self {
150 self.adapter = adapter;
151 self
152 }
153
154 #[must_use]
155 pub fn with_tool_kind(mut self, kind: ToolKind) -> Self {
156 self.tool_kind = kind;
157 self
158 }
159
160 pub fn tool_policy_context(&self) -> Option<ToolPolicyContext> {
162 Some(ToolPolicyContext {
163 thread_id: self.run_identity.thread_id.clone(),
164 run_id: self.run_identity.run_id.clone(),
165 dispatch_id: self.run_identity.trace.dispatch_id.clone(),
166 run_mode: self.run_mode,
167 adapter: self.adapter,
168 tool_name: self.tool_name.clone()?,
169 tool_kind: self.tool_kind,
170 arguments: self.tool_args.clone().unwrap_or(Value::Null),
171 })
172 }
173
174 #[must_use]
175 pub fn with_tool_result(mut self, result: ToolResult) -> Self {
176 self.tool_result = Some(result);
177 self
178 }
179
180 #[must_use]
181 pub fn with_llm_response(mut self, response: LLMResponse) -> Self {
182 self.llm_response = Some(response);
183 self
184 }
185
186 #[must_use]
187 pub fn with_resume_input(mut self, resume: ToolCallResume) -> Self {
188 self.resume_input = Some(resume);
189 self
190 }
191
192 #[must_use]
193 pub fn with_suspension(
194 mut self,
195 suspension_id: Option<String>,
196 suspension_reason: Option<String>,
197 ) -> Self {
198 self.suspension_id = suspension_id;
199 self.suspension_reason = suspension_reason;
200 self
201 }
202
203 #[must_use]
204 pub fn with_cancellation_token(mut self, token: CancellationToken) -> Self {
205 self.cancellation_token = Some(token);
206 self
207 }
208
209 pub fn profile(&self) -> Option<&crate::profile::ProfileAccess> {
211 self.profile_access.as_deref()
212 }
213
214 #[must_use]
215 pub fn with_profile_access(mut self, access: Arc<crate::profile::ProfileAccess>) -> Self {
216 self.profile_access = Some(access);
217 self
218 }
219}
220
221#[cfg(test)]
222mod tests {
223 use super::*;
224 use crate::state::StateMap;
225 use awaken_contract::contract::content::ContentBlock;
226 use awaken_contract::contract::identity::RunOrigin;
227 use awaken_contract::contract::inference::{StopReason, StreamResult, TokenUsage};
228 use awaken_contract::contract::tool::ToolResult;
229
230 fn empty_snapshot() -> Snapshot {
231 Snapshot::new(0, std::sync::Arc::new(StateMap::default()))
232 }
233
234 #[test]
235 fn phase_context_new_has_defaults() {
236 let ctx = PhaseContext::new(Phase::BeforeInference, empty_snapshot());
237 assert_eq!(ctx.phase, Phase::BeforeInference);
238 assert!(ctx.messages.is_empty());
239 assert!(ctx.tool_name.is_none());
240 assert!(ctx.llm_response.is_none());
241 assert_eq!(ctx.agent_spec.id, "");
242 }
243
244 #[test]
245 fn phase_context_with_agent_spec() {
246 let spec = Arc::new(
247 AgentSpec::new("reviewer")
248 .with_model_id("opus")
249 .with_hook_filter("perm"),
250 );
251 let ctx = PhaseContext::new(Phase::RunStart, empty_snapshot()).with_agent_spec(spec);
252 assert_eq!(ctx.agent_spec.id, "reviewer");
253 assert_eq!(ctx.agent_spec.model_id, "opus");
254 assert!(ctx.agent_spec.active_hook_filter.contains("perm"));
255 }
256
257 #[test]
258 fn phase_context_with_run_identity() {
259 let ctx = PhaseContext::new(Phase::RunStart, empty_snapshot()).with_run_identity(
260 RunIdentity::new(
261 "t1".into(),
262 None,
263 "r1".into(),
264 None,
265 "a".into(),
266 RunOrigin::User,
267 ),
268 );
269 assert_eq!(ctx.run_identity.thread_id, "t1");
270 }
271
272 #[test]
273 fn phase_context_with_messages() {
274 let msgs = vec![
275 Arc::new(Message::user("hello")),
276 Arc::new(Message::assistant("hi")),
277 ];
278 let ctx = PhaseContext::new(Phase::BeforeInference, empty_snapshot()).with_messages(msgs);
279 assert_eq!(ctx.messages.len(), 2);
280 }
281
282 #[test]
283 fn phase_context_with_tool_info() {
284 let ctx = PhaseContext::new(Phase::BeforeToolExecute, empty_snapshot()).with_tool_info(
285 "search",
286 "c1",
287 Some(serde_json::json!({"q": "rust"})),
288 );
289 assert_eq!(ctx.tool_name.as_deref(), Some("search"));
290 assert_eq!(ctx.tool_call_id.as_deref(), Some("c1"));
291 assert_eq!(ctx.tool_kind, ToolKind::Search);
292 let policy = ctx.tool_policy_context().expect("policy context");
293 assert_eq!(policy.tool_name, "search");
294 assert_eq!(policy.tool_kind, ToolKind::Search);
295 assert_eq!(policy.arguments["q"], "rust");
296 }
297
298 #[test]
299 fn phase_context_tool_policy_context_carries_trace_and_mode() {
300 let identity = RunIdentity::new(
301 "t1".into(),
302 None,
303 "r1".into(),
304 None,
305 "agent".into(),
306 RunOrigin::User,
307 )
308 .with_dispatch_id("dispatch-1")
309 .with_run_mode(RunMode::Scheduled)
310 .with_adapter(AdapterKind::Acp);
311 let ctx = PhaseContext::new(Phase::ToolGate, empty_snapshot())
312 .with_run_identity(identity)
313 .with_tool_info(
314 "bash",
315 "call-1",
316 Some(serde_json::json!({"cmd": "echo ok"})),
317 );
318
319 let policy = ctx.tool_policy_context().expect("policy context");
320 assert_eq!(policy.thread_id, "t1");
321 assert_eq!(policy.run_id, "r1");
322 assert_eq!(policy.dispatch_id.as_deref(), Some("dispatch-1"));
323 assert_eq!(policy.run_mode, RunMode::Scheduled);
324 assert_eq!(policy.adapter, AdapterKind::Acp);
325 assert_eq!(policy.tool_kind, ToolKind::Execute);
326 }
327
328 #[test]
329 fn phase_context_with_tool_result() {
330 let ctx = PhaseContext::new(Phase::AfterToolExecute, empty_snapshot()).with_tool_result(
331 ToolResult::success("search", serde_json::json!({"hits": 5})),
332 );
333 assert!(ctx.tool_result.as_ref().unwrap().is_success());
334 }
335
336 #[test]
337 fn phase_context_with_llm_response() {
338 let response = LLMResponse::success(StreamResult {
339 content: vec![ContentBlock::text("hello")],
340 tool_calls: vec![],
341 usage: Some(TokenUsage::default()),
342 stop_reason: Some(StopReason::EndTurn),
343 has_incomplete_tool_calls: false,
344 });
345 let ctx =
346 PhaseContext::new(Phase::AfterInference, empty_snapshot()).with_llm_response(response);
347 assert!(ctx.llm_response.as_ref().unwrap().outcome.is_ok());
348 }
349
350 #[test]
351 fn phase_context_builder_chains() {
352 let ctx = PhaseContext::new(Phase::AfterToolExecute, empty_snapshot())
353 .with_run_identity(RunIdentity::for_thread("t1"))
354 .with_messages(vec![Arc::new(Message::user("hi"))])
355 .with_tool_info("calc", "c1", None)
356 .with_tool_result(ToolResult::success("calc", serde_json::json!(42)));
357
358 assert_eq!(ctx.run_identity.thread_id, "t1");
359 assert_eq!(ctx.messages.len(), 1);
360 assert_eq!(ctx.tool_name.as_deref(), Some("calc"));
361 assert!(ctx.tool_result.is_some());
362 }
363
364 #[test]
365 fn phase_context_profile_none_by_default() {
366 let ctx = PhaseContext::new(Phase::RunStart, empty_snapshot());
367 assert!(ctx.profile().is_none());
368 }
369}