Skip to main content

awaken_runtime/loop_runner/
actions.rs

1//! Action handlers and helpers for loop-consumed actions.
2//!
3//! All scheduled actions flow through the handler queue via
4//! `TypedScheduledActionHandler`. Accumulator actions (`ExcludeTool`,
5//! `IncludeOnlyTools`, `SetInferenceOverride`) write to
6//! transient state keys. The orchestrator reads from state after the EXECUTE
7//! loop and clears the accumulators for the next step/call.
8
9use async_trait::async_trait;
10use std::collections::HashSet;
11use std::collections::hash_map::DefaultHasher;
12use std::hash::{Hash, Hasher};
13
14use crate::hooks::{PhaseContext, TypedScheduledActionHandler};
15use crate::state::StateCommand;
16use awaken_contract::StateError;
17use awaken_contract::contract::context_message::ContextMessage;
18use awaken_contract::contract::inference::InferenceOverride;
19use awaken_contract::contract::message::{Message, Role};
20
21use crate::agent::state::{
22    AddContextMessage, ContextMessageAction, ContextMessageStore, ContextThrottleState,
23    ContextThrottleUpdate, ExcludeTool, IncludeOnlyTools, InferenceOverrideState,
24    InferenceOverrideStateAction, RunLifecycle, SetInferenceOverride, ToolFilterState,
25    ToolFilterStateAction,
26};
27
28/// Merge multiple inference override payloads with last-wins-per-field semantics.
29pub(super) fn merge_override_payloads(
30    base: &mut Option<awaken_contract::contract::inference::InferenceOverride>,
31    payloads: Vec<awaken_contract::contract::inference::InferenceOverride>,
32) {
33    for ovr in payloads {
34        if let Some(existing) = base.as_mut() {
35            existing.merge(ovr);
36        } else {
37            *base = Some(ovr);
38        }
39    }
40}
41
42/// Apply tool filter payloads (exclusions and inclusions) to a tool list.
43///
44/// - If any `IncludeOnlyTools` payloads exist, only tools in the combined
45///   allow-list are kept.
46/// - Then any `ExcludeTool` tool IDs are removed.
47pub(super) fn apply_tool_filter_payloads(
48    tools: &mut Vec<awaken_contract::contract::tool::ToolDescriptor>,
49    exclusion_payloads: Vec<String>,
50    inclusion_payloads: Vec<Vec<String>>,
51) {
52    // Build combined allow-list from inclusion payloads
53    if !inclusion_payloads.is_empty() {
54        let allowed: HashSet<String> = inclusion_payloads.into_iter().flatten().collect();
55        tools.retain(|t| allowed.contains(&t.id));
56    }
57
58    // Apply exclusions
59    if !exclusion_payloads.is_empty() {
60        let excluded: HashSet<String> = exclusion_payloads.into_iter().collect();
61        tools.retain(|t| !excluded.contains(&t.id));
62    }
63}
64
65/// Resolve the winning intercept decision from multiple payloads using
66/// priority: Block > Suspend > SetResult.
67pub(super) fn resolve_intercept_payloads(
68    payloads: Vec<awaken_contract::contract::tool_intercept::ToolInterceptPayload>,
69) -> Option<awaken_contract::contract::tool_intercept::ToolInterceptPayload> {
70    use awaken_contract::contract::tool_intercept::ToolInterceptPayload;
71
72    fn priority(p: &ToolInterceptPayload) -> u8 {
73        match p {
74            ToolInterceptPayload::Block { .. } => 3,
75            ToolInterceptPayload::Suspend(_) => 2,
76            ToolInterceptPayload::SetResult(_) => 1,
77        }
78    }
79
80    let mut winner: Option<ToolInterceptPayload> = None;
81    for payload in payloads {
82        match winner.as_ref() {
83            None => {
84                winner = Some(payload);
85            }
86            Some(existing) if priority(&payload) > priority(existing) => {
87                winner = Some(payload);
88            }
89            Some(existing) if priority(&payload) == priority(existing) => {
90                tracing::error!(
91                    existing = ?existing,
92                    incoming = ?payload,
93                    "tool intercept conflict: two plugins scheduled same-priority intercepts"
94                );
95                // Keep first
96            }
97            _ => {
98                // Lower priority — ignore
99            }
100        }
101    }
102    winner
103}
104
105// ---------------------------------------------------------------------------
106// Action handlers
107// ---------------------------------------------------------------------------
108
109/// Handler for `ExcludeTool` — writes to [`ToolFilterState`].
110pub(super) struct ExcludeToolHandler;
111
112#[async_trait]
113impl TypedScheduledActionHandler<ExcludeTool> for ExcludeToolHandler {
114    async fn handle_typed(
115        &self,
116        _ctx: &PhaseContext,
117        payload: String,
118    ) -> Result<StateCommand, StateError> {
119        let mut cmd = StateCommand::new();
120        cmd.update::<ToolFilterState>(ToolFilterStateAction::Exclude(payload));
121        Ok(cmd)
122    }
123}
124
125/// Handler for `IncludeOnlyTools` — writes to [`ToolFilterState`].
126pub(super) struct IncludeOnlyToolsHandler;
127
128#[async_trait]
129impl TypedScheduledActionHandler<IncludeOnlyTools> for IncludeOnlyToolsHandler {
130    async fn handle_typed(
131        &self,
132        _ctx: &PhaseContext,
133        payload: Vec<String>,
134    ) -> Result<StateCommand, StateError> {
135        let mut cmd = StateCommand::new();
136        cmd.update::<ToolFilterState>(ToolFilterStateAction::IncludeOnly(payload));
137        Ok(cmd)
138    }
139}
140
141/// Handler for `SetInferenceOverride` — writes to [`InferenceOverrideState`].
142pub(super) struct SetInferenceOverrideHandler;
143
144#[async_trait]
145impl TypedScheduledActionHandler<SetInferenceOverride> for SetInferenceOverrideHandler {
146    async fn handle_typed(
147        &self,
148        _ctx: &PhaseContext,
149        payload: InferenceOverride,
150    ) -> Result<StateCommand, StateError> {
151        let mut cmd = StateCommand::new();
152        cmd.update::<InferenceOverrideState>(InferenceOverrideStateAction::Merge(payload));
153        Ok(cmd)
154    }
155}
156
157/// Handler for `AddContextMessage` — applies throttle logic, upserts accepted
158/// messages into [`ContextMessageStore`], updates [`ContextThrottleState`].
159pub(super) struct ContextMessageHandler;
160
161#[async_trait]
162impl TypedScheduledActionHandler<AddContextMessage> for ContextMessageHandler {
163    async fn handle_typed(
164        &self,
165        ctx: &PhaseContext,
166        payload: ContextMessage,
167    ) -> Result<StateCommand, StateError> {
168        let mut cmd = StateCommand::new();
169
170        // Determine current step from RunLifecycle.step_count + 1
171        // (step_count records completed steps; current step is one ahead)
172        let current_step = ctx
173            .snapshot
174            .get::<RunLifecycle>()
175            .map(|s| s.step_count as usize + 1)
176            .unwrap_or(1);
177
178        let content_hash = {
179            let mut hasher = DefaultHasher::new();
180            if let Ok(json) = serde_json::to_string(&payload.content) {
181                json.hash(&mut hasher);
182            }
183            hasher.finish()
184        };
185
186        let should_inject = if payload.cooldown_turns == 0 {
187            true
188        } else {
189            let throttle_state = ctx
190                .snapshot
191                .get::<ContextThrottleState>()
192                .cloned()
193                .unwrap_or_default();
194            match throttle_state.entries.get(&payload.key) {
195                None => true,
196                Some(entry) => {
197                    entry.content_hash != content_hash
198                        || current_step.saturating_sub(entry.last_step)
199                            >= payload.cooldown_turns as usize
200                }
201            }
202        };
203
204        if should_inject {
205            cmd.update::<ContextThrottleState>(ContextThrottleUpdate::Injected {
206                key: payload.key.clone(),
207                step: current_step,
208                content_hash,
209            });
210            cmd.update::<ContextMessageStore>(ContextMessageAction::Upsert(payload));
211        }
212
213        Ok(cmd)
214    }
215}
216
217// ---------------------------------------------------------------------------
218// Plugin for registering action handlers
219// ---------------------------------------------------------------------------
220
221/// Internal plugin that registers all loop action handlers and their
222/// accumulator state keys.
223/// Plugin that registers action handlers and their accumulator state keys.
224///
225/// Installed automatically by `inject_default_plugins` for the main runtime.
226/// External crates that build sub-runtimes (e.g. generative-ui) should also
227/// install this plugin alongside [`super::LoopStatePlugin`].
228pub struct LoopActionHandlersPlugin;
229
230impl crate::plugins::Plugin for LoopActionHandlersPlugin {
231    fn descriptor(&self) -> crate::plugins::PluginDescriptor {
232        crate::plugins::PluginDescriptor {
233            name: "__loop_action_handlers",
234        }
235    }
236
237    fn register(
238        &self,
239        r: &mut crate::plugins::PluginRegistrar,
240    ) -> Result<(), awaken_contract::StateError> {
241        use crate::state::StateKeyOptions;
242
243        // State keys for action accumulators
244        r.register_key::<ToolFilterState>(StateKeyOptions::default())?;
245        r.register_key::<InferenceOverrideState>(StateKeyOptions::default())?;
246        // Handlers
247        r.register_scheduled_action::<AddContextMessage, _>(ContextMessageHandler)?;
248        r.register_scheduled_action::<ExcludeTool, _>(ExcludeToolHandler)?;
249        r.register_scheduled_action::<IncludeOnlyTools, _>(IncludeOnlyToolsHandler)?;
250        r.register_scheduled_action::<SetInferenceOverride, _>(SetInferenceOverrideHandler)?;
251        Ok(())
252    }
253}
254
255// ---------------------------------------------------------------------------
256// Orchestrator helpers
257// ---------------------------------------------------------------------------
258
259/// Read context messages from the store, return sorted list, then apply lifecycle cleanup.
260///
261/// Lifecycle rules applied after injection:
262/// - Non-persistent (ephemeral) messages are removed.
263/// - Messages with `consume_after_emit` are removed.
264/// - Persistent messages remain for subsequent steps.
265pub(super) fn take_context_messages(
266    store: &crate::state::StateStore,
267) -> Result<Vec<ContextMessage>, StateError> {
268    let store_value = store.read::<ContextMessageStore>().unwrap_or_default();
269
270    if store_value.messages.is_empty() {
271        return Ok(Vec::new());
272    }
273
274    // Collect all messages sorted by (target, priority, key)
275    let result: Vec<ContextMessage> = store_value.sorted_messages().into_iter().cloned().collect();
276
277    // Apply lifecycle: remove ephemeral + consume-after-emit
278    let mut patch = crate::state::MutationBatch::new();
279    patch.update::<ContextMessageStore>(ContextMessageAction::RemoveEphemeral);
280    patch.update::<ContextMessageStore>(ContextMessageAction::ConsumeAfterEmit);
281    store.commit(patch)?;
282
283    Ok(result)
284}
285
286// ---------------------------------------------------------------------------
287// Message placement (unchanged)
288// ---------------------------------------------------------------------------
289
290/// Insert context messages into the message list at their declared target positions.
291pub(super) fn apply_context_messages(
292    messages: &mut Vec<Message>,
293    context_messages: Vec<ContextMessage>,
294    has_system_prompt: bool,
295) {
296    use awaken_contract::contract::context_message::ContextMessageTarget;
297
298    let mut system = Vec::new();
299    let mut session = Vec::new();
300    let mut conversation = Vec::new();
301    let mut suffix = Vec::new();
302
303    for entry in context_messages {
304        let msg = Message {
305            id: Some(awaken_contract::contract::message::gen_message_id()),
306            role: entry.role,
307            content: entry.content,
308            tool_calls: None,
309            tool_call_id: None,
310            visibility: entry.visibility,
311            metadata: None,
312        };
313        match entry.target {
314            ContextMessageTarget::System => system.push(msg),
315            ContextMessageTarget::Session => session.push(msg),
316            ContextMessageTarget::Conversation => conversation.push(msg),
317            ContextMessageTarget::SuffixSystem => suffix.push(msg),
318        }
319    }
320
321    // System: insert after base system prompt
322    let system_insert_pos = usize::from(has_system_prompt);
323    for (offset, msg) in system.into_iter().enumerate() {
324        messages.insert(system_insert_pos + offset, msg);
325    }
326
327    // Session: insert after all system-role messages
328    let session_insert_pos = messages
329        .iter()
330        .take_while(|m| m.role == Role::System)
331        .count();
332    for (offset, msg) in session.into_iter().enumerate() {
333        messages.insert(session_insert_pos + offset, msg);
334    }
335
336    // Conversation: insert after system messages, before history
337    let conversation_insert_pos = messages
338        .iter()
339        .take_while(|m| m.role == Role::System)
340        .count();
341    for (offset, msg) in conversation.into_iter().enumerate() {
342        messages.insert(conversation_insert_pos + offset, msg);
343    }
344
345    // Suffix: append at end
346    messages.extend(suffix);
347}
348
349#[cfg(test)]
350mod tests {
351    use super::*;
352    use awaken_contract::contract::context_message::ContextMessage;
353
354    // ---- apply_context_messages ----
355
356    #[test]
357    fn apply_context_messages_empty_input() {
358        let mut messages = vec![Message::system("sys prompt"), Message::user("hello")];
359        apply_context_messages(&mut messages, vec![], true);
360        assert_eq!(messages.len(), 2);
361        assert_eq!(messages[0].text(), "sys prompt");
362        assert_eq!(messages[1].text(), "hello");
363    }
364
365    #[test]
366    fn apply_context_messages_system_target() {
367        let mut messages = vec![
368            Message::system("base system"),
369            Message::user("hello"),
370            Message::assistant("hi"),
371        ];
372        let ctx_msgs = vec![ContextMessage::system("test.key", "injected system")];
373        apply_context_messages(&mut messages, ctx_msgs, true);
374
375        // System context should be inserted after the base system prompt (index 1)
376        assert_eq!(messages.len(), 4);
377        assert_eq!(messages[0].text(), "base system");
378        assert_eq!(messages[1].text(), "injected system");
379        assert_eq!(messages[1].role, Role::System);
380        assert_eq!(messages[2].text(), "hello");
381    }
382
383    #[test]
384    fn apply_context_messages_system_target_no_system_prompt() {
385        let mut messages = vec![Message::user("hello"), Message::assistant("hi")];
386        let ctx_msgs = vec![ContextMessage::system("test.key", "injected")];
387        apply_context_messages(&mut messages, ctx_msgs, false);
388
389        // Without system prompt, insert at position 0
390        assert_eq!(messages.len(), 3);
391        assert_eq!(messages[0].text(), "injected");
392        assert_eq!(messages[1].text(), "hello");
393    }
394
395    #[test]
396    fn apply_context_messages_suffix_target() {
397        let mut messages = vec![
398            Message::system("sys"),
399            Message::user("hello"),
400            Message::assistant("hi"),
401        ];
402        let ctx_msgs = vec![ContextMessage::suffix_system(
403            "suffix.key",
404            "suffix content",
405        )];
406        apply_context_messages(&mut messages, ctx_msgs, true);
407
408        assert_eq!(messages.len(), 4);
409        assert_eq!(messages[3].text(), "suffix content");
410    }
411
412    #[test]
413    fn apply_context_messages_session_target() {
414        let mut messages = vec![Message::system("sys"), Message::user("hello")];
415        let ctx_msgs = vec![ContextMessage::session(
416            "session.key",
417            Role::System,
418            "session context",
419        )];
420        apply_context_messages(&mut messages, ctx_msgs, true);
421
422        // Session: after all system-role messages. After injecting a system context_msg,
423        // the system count changes. The session-target message goes after system messages.
424        assert_eq!(messages.len(), 3);
425        // The session msg is inserted after the system prompt
426        let system_count = messages.iter().filter(|m| m.role == Role::System).count();
427        assert!(system_count >= 2); // base system + session context
428    }
429
430    #[test]
431    fn apply_context_messages_conversation_target() {
432        let mut messages = vec![
433            Message::system("sys"),
434            Message::user("hello"),
435            Message::assistant("hi"),
436        ];
437        let ctx_msgs = vec![ContextMessage::conversation(
438            "conv.key",
439            Role::User,
440            "conversation context",
441        )];
442        apply_context_messages(&mut messages, ctx_msgs, true);
443
444        assert_eq!(messages.len(), 4);
445        // Conversation messages are inserted after system messages, before history
446        assert_eq!(messages[0].role, Role::System);
447    }
448
449    #[test]
450    fn apply_context_messages_multiple_targets() {
451        let mut messages = vec![
452            Message::system("sys"),
453            Message::user("hello"),
454            Message::assistant("hi"),
455        ];
456        let ctx_msgs = vec![
457            ContextMessage::system("sys.key", "system inject"),
458            ContextMessage::suffix_system("suffix.key", "suffix inject"),
459        ];
460        apply_context_messages(&mut messages, ctx_msgs, true);
461
462        assert_eq!(messages.len(), 5);
463        // System inject should be near the beginning
464        assert_eq!(messages[1].text(), "system inject");
465        // Suffix inject should be at the end
466        assert_eq!(messages[4].text(), "suffix inject");
467    }
468
469    #[test]
470    fn apply_context_messages_ordering_preserved_within_target() {
471        let mut messages = vec![Message::system("sys"), Message::user("hello")];
472        let ctx_msgs = vec![
473            ContextMessage::system("a", "first system"),
474            ContextMessage::system("b", "second system"),
475        ];
476        apply_context_messages(&mut messages, ctx_msgs, true);
477
478        assert_eq!(messages[1].text(), "first system");
479        assert_eq!(messages[2].text(), "second system");
480    }
481
482    #[test]
483    fn apply_context_messages_empty_messages_list() {
484        let mut messages: Vec<Message> = vec![];
485        let ctx_msgs = vec![ContextMessage::system("key", "inject")];
486        apply_context_messages(&mut messages, ctx_msgs, false);
487
488        assert_eq!(messages.len(), 1);
489        assert_eq!(messages[0].text(), "inject");
490    }
491
492    #[test]
493    fn apply_context_messages_suffix_with_empty_messages() {
494        let mut messages: Vec<Message> = vec![];
495        let ctx_msgs = vec![ContextMessage::suffix_system("key", "suffix")];
496        apply_context_messages(&mut messages, ctx_msgs, false);
497
498        assert_eq!(messages.len(), 1);
499        assert_eq!(messages[0].text(), "suffix");
500    }
501}