awaken_runtime/loop_runner/
actions.rs1use async_trait::async_trait;
10use std::collections::HashSet;
11use std::collections::hash_map::DefaultHasher;
12use std::hash::{Hash, Hasher};
13
14use crate::hooks::{PhaseContext, TypedScheduledActionHandler};
15use crate::state::StateCommand;
16use awaken_contract::StateError;
17use awaken_contract::contract::context_message::ContextMessage;
18use awaken_contract::contract::inference::InferenceOverride;
19use awaken_contract::contract::message::{Message, Role};
20
21use crate::agent::state::{
22 AddContextMessage, ContextMessageAction, ContextMessageStore, ContextThrottleState,
23 ContextThrottleUpdate, ExcludeTool, IncludeOnlyTools, InferenceOverrideState,
24 InferenceOverrideStateAction, RunLifecycle, SetInferenceOverride, ToolFilterState,
25 ToolFilterStateAction,
26};
27
28pub(super) fn merge_override_payloads(
30 base: &mut Option<awaken_contract::contract::inference::InferenceOverride>,
31 payloads: Vec<awaken_contract::contract::inference::InferenceOverride>,
32) {
33 for ovr in payloads {
34 if let Some(existing) = base.as_mut() {
35 existing.merge(ovr);
36 } else {
37 *base = Some(ovr);
38 }
39 }
40}
41
42pub(super) fn apply_tool_filter_payloads(
48 tools: &mut Vec<awaken_contract::contract::tool::ToolDescriptor>,
49 exclusion_payloads: Vec<String>,
50 inclusion_payloads: Vec<Vec<String>>,
51) {
52 if !inclusion_payloads.is_empty() {
54 let allowed: HashSet<String> = inclusion_payloads.into_iter().flatten().collect();
55 tools.retain(|t| allowed.contains(&t.id));
56 }
57
58 if !exclusion_payloads.is_empty() {
60 let excluded: HashSet<String> = exclusion_payloads.into_iter().collect();
61 tools.retain(|t| !excluded.contains(&t.id));
62 }
63}
64
65pub(super) fn resolve_intercept_payloads(
68 payloads: Vec<awaken_contract::contract::tool_intercept::ToolInterceptPayload>,
69) -> Option<awaken_contract::contract::tool_intercept::ToolInterceptPayload> {
70 use awaken_contract::contract::tool_intercept::ToolInterceptPayload;
71
72 fn priority(p: &ToolInterceptPayload) -> u8 {
73 match p {
74 ToolInterceptPayload::Block { .. } => 3,
75 ToolInterceptPayload::Suspend(_) => 2,
76 ToolInterceptPayload::SetResult(_) => 1,
77 }
78 }
79
80 let mut winner: Option<ToolInterceptPayload> = None;
81 for payload in payloads {
82 match winner.as_ref() {
83 None => {
84 winner = Some(payload);
85 }
86 Some(existing) if priority(&payload) > priority(existing) => {
87 winner = Some(payload);
88 }
89 Some(existing) if priority(&payload) == priority(existing) => {
90 tracing::error!(
91 existing = ?existing,
92 incoming = ?payload,
93 "tool intercept conflict: two plugins scheduled same-priority intercepts"
94 );
95 }
97 _ => {
98 }
100 }
101 }
102 winner
103}
104
105pub(super) struct ExcludeToolHandler;
111
112#[async_trait]
113impl TypedScheduledActionHandler<ExcludeTool> for ExcludeToolHandler {
114 async fn handle_typed(
115 &self,
116 _ctx: &PhaseContext,
117 payload: String,
118 ) -> Result<StateCommand, StateError> {
119 let mut cmd = StateCommand::new();
120 cmd.update::<ToolFilterState>(ToolFilterStateAction::Exclude(payload));
121 Ok(cmd)
122 }
123}
124
125pub(super) struct IncludeOnlyToolsHandler;
127
128#[async_trait]
129impl TypedScheduledActionHandler<IncludeOnlyTools> for IncludeOnlyToolsHandler {
130 async fn handle_typed(
131 &self,
132 _ctx: &PhaseContext,
133 payload: Vec<String>,
134 ) -> Result<StateCommand, StateError> {
135 let mut cmd = StateCommand::new();
136 cmd.update::<ToolFilterState>(ToolFilterStateAction::IncludeOnly(payload));
137 Ok(cmd)
138 }
139}
140
141pub(super) struct SetInferenceOverrideHandler;
143
144#[async_trait]
145impl TypedScheduledActionHandler<SetInferenceOverride> for SetInferenceOverrideHandler {
146 async fn handle_typed(
147 &self,
148 _ctx: &PhaseContext,
149 payload: InferenceOverride,
150 ) -> Result<StateCommand, StateError> {
151 let mut cmd = StateCommand::new();
152 cmd.update::<InferenceOverrideState>(InferenceOverrideStateAction::Merge(payload));
153 Ok(cmd)
154 }
155}
156
157pub(super) struct ContextMessageHandler;
160
161#[async_trait]
162impl TypedScheduledActionHandler<AddContextMessage> for ContextMessageHandler {
163 async fn handle_typed(
164 &self,
165 ctx: &PhaseContext,
166 payload: ContextMessage,
167 ) -> Result<StateCommand, StateError> {
168 let mut cmd = StateCommand::new();
169
170 let current_step = ctx
173 .snapshot
174 .get::<RunLifecycle>()
175 .map(|s| s.step_count as usize + 1)
176 .unwrap_or(1);
177
178 let content_hash = {
179 let mut hasher = DefaultHasher::new();
180 if let Ok(json) = serde_json::to_string(&payload.content) {
181 json.hash(&mut hasher);
182 }
183 hasher.finish()
184 };
185
186 let should_inject = if payload.cooldown_turns == 0 {
187 true
188 } else {
189 let throttle_state = ctx
190 .snapshot
191 .get::<ContextThrottleState>()
192 .cloned()
193 .unwrap_or_default();
194 match throttle_state.entries.get(&payload.key) {
195 None => true,
196 Some(entry) => {
197 entry.content_hash != content_hash
198 || current_step.saturating_sub(entry.last_step)
199 >= payload.cooldown_turns as usize
200 }
201 }
202 };
203
204 if should_inject {
205 cmd.update::<ContextThrottleState>(ContextThrottleUpdate::Injected {
206 key: payload.key.clone(),
207 step: current_step,
208 content_hash,
209 });
210 cmd.update::<ContextMessageStore>(ContextMessageAction::Upsert(payload));
211 }
212
213 Ok(cmd)
214 }
215}
216
217pub struct LoopActionHandlersPlugin;
229
230impl crate::plugins::Plugin for LoopActionHandlersPlugin {
231 fn descriptor(&self) -> crate::plugins::PluginDescriptor {
232 crate::plugins::PluginDescriptor {
233 name: "__loop_action_handlers",
234 }
235 }
236
237 fn register(
238 &self,
239 r: &mut crate::plugins::PluginRegistrar,
240 ) -> Result<(), awaken_contract::StateError> {
241 use crate::state::StateKeyOptions;
242
243 r.register_key::<ToolFilterState>(StateKeyOptions::default())?;
245 r.register_key::<InferenceOverrideState>(StateKeyOptions::default())?;
246 r.register_scheduled_action::<AddContextMessage, _>(ContextMessageHandler)?;
248 r.register_scheduled_action::<ExcludeTool, _>(ExcludeToolHandler)?;
249 r.register_scheduled_action::<IncludeOnlyTools, _>(IncludeOnlyToolsHandler)?;
250 r.register_scheduled_action::<SetInferenceOverride, _>(SetInferenceOverrideHandler)?;
251 Ok(())
252 }
253}
254
255pub(super) fn take_context_messages(
266 store: &crate::state::StateStore,
267) -> Result<Vec<ContextMessage>, StateError> {
268 let store_value = store.read::<ContextMessageStore>().unwrap_or_default();
269
270 if store_value.messages.is_empty() {
271 return Ok(Vec::new());
272 }
273
274 let result: Vec<ContextMessage> = store_value.sorted_messages().into_iter().cloned().collect();
276
277 let mut patch = crate::state::MutationBatch::new();
279 patch.update::<ContextMessageStore>(ContextMessageAction::RemoveEphemeral);
280 patch.update::<ContextMessageStore>(ContextMessageAction::ConsumeAfterEmit);
281 store.commit(patch)?;
282
283 Ok(result)
284}
285
286pub(super) fn apply_context_messages(
292 messages: &mut Vec<Message>,
293 context_messages: Vec<ContextMessage>,
294 has_system_prompt: bool,
295) {
296 use awaken_contract::contract::context_message::ContextMessageTarget;
297
298 let mut system = Vec::new();
299 let mut session = Vec::new();
300 let mut conversation = Vec::new();
301 let mut suffix = Vec::new();
302
303 for entry in context_messages {
304 let msg = Message {
305 id: Some(awaken_contract::contract::message::gen_message_id()),
306 role: entry.role,
307 content: entry.content,
308 tool_calls: None,
309 tool_call_id: None,
310 visibility: entry.visibility,
311 metadata: None,
312 };
313 match entry.target {
314 ContextMessageTarget::System => system.push(msg),
315 ContextMessageTarget::Session => session.push(msg),
316 ContextMessageTarget::Conversation => conversation.push(msg),
317 ContextMessageTarget::SuffixSystem => suffix.push(msg),
318 }
319 }
320
321 let system_insert_pos = usize::from(has_system_prompt);
323 for (offset, msg) in system.into_iter().enumerate() {
324 messages.insert(system_insert_pos + offset, msg);
325 }
326
327 let session_insert_pos = messages
329 .iter()
330 .take_while(|m| m.role == Role::System)
331 .count();
332 for (offset, msg) in session.into_iter().enumerate() {
333 messages.insert(session_insert_pos + offset, msg);
334 }
335
336 let conversation_insert_pos = messages
338 .iter()
339 .take_while(|m| m.role == Role::System)
340 .count();
341 for (offset, msg) in conversation.into_iter().enumerate() {
342 messages.insert(conversation_insert_pos + offset, msg);
343 }
344
345 messages.extend(suffix);
347}
348
349#[cfg(test)]
350mod tests {
351 use super::*;
352 use awaken_contract::contract::context_message::ContextMessage;
353
354 #[test]
357 fn apply_context_messages_empty_input() {
358 let mut messages = vec![Message::system("sys prompt"), Message::user("hello")];
359 apply_context_messages(&mut messages, vec![], true);
360 assert_eq!(messages.len(), 2);
361 assert_eq!(messages[0].text(), "sys prompt");
362 assert_eq!(messages[1].text(), "hello");
363 }
364
365 #[test]
366 fn apply_context_messages_system_target() {
367 let mut messages = vec![
368 Message::system("base system"),
369 Message::user("hello"),
370 Message::assistant("hi"),
371 ];
372 let ctx_msgs = vec![ContextMessage::system("test.key", "injected system")];
373 apply_context_messages(&mut messages, ctx_msgs, true);
374
375 assert_eq!(messages.len(), 4);
377 assert_eq!(messages[0].text(), "base system");
378 assert_eq!(messages[1].text(), "injected system");
379 assert_eq!(messages[1].role, Role::System);
380 assert_eq!(messages[2].text(), "hello");
381 }
382
383 #[test]
384 fn apply_context_messages_system_target_no_system_prompt() {
385 let mut messages = vec![Message::user("hello"), Message::assistant("hi")];
386 let ctx_msgs = vec![ContextMessage::system("test.key", "injected")];
387 apply_context_messages(&mut messages, ctx_msgs, false);
388
389 assert_eq!(messages.len(), 3);
391 assert_eq!(messages[0].text(), "injected");
392 assert_eq!(messages[1].text(), "hello");
393 }
394
395 #[test]
396 fn apply_context_messages_suffix_target() {
397 let mut messages = vec![
398 Message::system("sys"),
399 Message::user("hello"),
400 Message::assistant("hi"),
401 ];
402 let ctx_msgs = vec![ContextMessage::suffix_system(
403 "suffix.key",
404 "suffix content",
405 )];
406 apply_context_messages(&mut messages, ctx_msgs, true);
407
408 assert_eq!(messages.len(), 4);
409 assert_eq!(messages[3].text(), "suffix content");
410 }
411
412 #[test]
413 fn apply_context_messages_session_target() {
414 let mut messages = vec![Message::system("sys"), Message::user("hello")];
415 let ctx_msgs = vec![ContextMessage::session(
416 "session.key",
417 Role::System,
418 "session context",
419 )];
420 apply_context_messages(&mut messages, ctx_msgs, true);
421
422 assert_eq!(messages.len(), 3);
425 let system_count = messages.iter().filter(|m| m.role == Role::System).count();
427 assert!(system_count >= 2); }
429
430 #[test]
431 fn apply_context_messages_conversation_target() {
432 let mut messages = vec![
433 Message::system("sys"),
434 Message::user("hello"),
435 Message::assistant("hi"),
436 ];
437 let ctx_msgs = vec![ContextMessage::conversation(
438 "conv.key",
439 Role::User,
440 "conversation context",
441 )];
442 apply_context_messages(&mut messages, ctx_msgs, true);
443
444 assert_eq!(messages.len(), 4);
445 assert_eq!(messages[0].role, Role::System);
447 }
448
449 #[test]
450 fn apply_context_messages_multiple_targets() {
451 let mut messages = vec![
452 Message::system("sys"),
453 Message::user("hello"),
454 Message::assistant("hi"),
455 ];
456 let ctx_msgs = vec![
457 ContextMessage::system("sys.key", "system inject"),
458 ContextMessage::suffix_system("suffix.key", "suffix inject"),
459 ];
460 apply_context_messages(&mut messages, ctx_msgs, true);
461
462 assert_eq!(messages.len(), 5);
463 assert_eq!(messages[1].text(), "system inject");
465 assert_eq!(messages[4].text(), "suffix inject");
467 }
468
469 #[test]
470 fn apply_context_messages_ordering_preserved_within_target() {
471 let mut messages = vec![Message::system("sys"), Message::user("hello")];
472 let ctx_msgs = vec![
473 ContextMessage::system("a", "first system"),
474 ContextMessage::system("b", "second system"),
475 ];
476 apply_context_messages(&mut messages, ctx_msgs, true);
477
478 assert_eq!(messages[1].text(), "first system");
479 assert_eq!(messages[2].text(), "second system");
480 }
481
482 #[test]
483 fn apply_context_messages_empty_messages_list() {
484 let mut messages: Vec<Message> = vec![];
485 let ctx_msgs = vec![ContextMessage::system("key", "inject")];
486 apply_context_messages(&mut messages, ctx_msgs, false);
487
488 assert_eq!(messages.len(), 1);
489 assert_eq!(messages[0].text(), "inject");
490 }
491
492 #[test]
493 fn apply_context_messages_suffix_with_empty_messages() {
494 let mut messages: Vec<Message> = vec![];
495 let ctx_msgs = vec![ContextMessage::suffix_system("key", "suffix")];
496 apply_context_messages(&mut messages, ctx_msgs, false);
497
498 assert_eq!(messages.len(), 1);
499 assert_eq!(messages[0].text(), "suffix");
500 }
501}