Skip to main content

awaken_runtime/agent/state/
loop_actions.rs

1use std::collections::HashMap;
2
3use serde::{Deserialize, Serialize};
4
5use crate::state::{MergeStrategy, StateKey};
6use awaken_contract::contract::context_message::ContextMessage;
7use awaken_contract::contract::inference::InferenceOverride;
8
9// ---------------------------------------------------------------------------
10// Action specs
11// ---------------------------------------------------------------------------
12
13/// Action spec for injecting a context message into the prompt.
14///
15/// Scheduled by `BeforeInference` hooks via `cmd.schedule_action::<AddContextMessage>(...)`.
16/// Handled during EXECUTE by `ContextMessageHandler` which applies throttle logic
17/// and writes accepted messages to [`ContextMessageStore`].
18pub struct AddContextMessage;
19
20impl awaken_contract::model::ScheduledActionSpec for AddContextMessage {
21    const KEY: &'static str = "runtime.add_context_message";
22    const PHASE: awaken_contract::model::Phase = awaken_contract::model::Phase::BeforeInference;
23    type Payload = ContextMessage;
24}
25
26/// Action spec for per-inference parameter overrides.
27///
28/// Scheduled by `BeforeInference` hooks via `cmd.schedule_action::<SetInferenceOverride>(...)`.
29/// Handled during EXECUTE by `SetInferenceOverrideHandler` which writes to [`InferenceOverrideState`].
30pub struct SetInferenceOverride;
31
32impl awaken_contract::model::ScheduledActionSpec for SetInferenceOverride {
33    const KEY: &'static str = "runtime.set_inference_override";
34    const PHASE: awaken_contract::model::Phase = awaken_contract::model::Phase::BeforeInference;
35    type Payload = InferenceOverride;
36}
37
38/// Action spec for excluding a specific tool from the current inference step.
39///
40/// Scheduled by `BeforeInference` hooks via `cmd.schedule_action::<ExcludeTool>(...)`.
41/// Handled during EXECUTE by `ExcludeToolHandler` which writes to [`ToolFilterState`].
42pub struct ExcludeTool;
43
44impl awaken_contract::model::ScheduledActionSpec for ExcludeTool {
45    const KEY: &'static str = "runtime.exclude_tool";
46    const PHASE: awaken_contract::model::Phase = awaken_contract::model::Phase::BeforeInference;
47    type Payload = String;
48}
49
50/// Action spec for restricting tools to an explicit allow-list for the current inference step.
51///
52/// Scheduled by `BeforeInference` hooks via `cmd.schedule_action::<IncludeOnlyTools>(...)`.
53/// Handled during EXECUTE by `IncludeOnlyToolsHandler` which writes to [`ToolFilterState`].
54pub struct IncludeOnlyTools;
55
56impl awaken_contract::model::ScheduledActionSpec for IncludeOnlyTools {
57    const KEY: &'static str = "runtime.include_only_tools";
58    const PHASE: awaken_contract::model::Phase = awaken_contract::model::Phase::BeforeInference;
59    type Payload = Vec<String>;
60}
61
62// ---------------------------------------------------------------------------
63// Persistent state keys (not accumulators)
64// ---------------------------------------------------------------------------
65
66/// Persistent store for context messages across the agent loop.
67///
68/// Messages are keyed by their `key` field for upsert semantics.
69/// The `AddContextMessage` handler applies throttle logic and upserts accepted
70/// messages here. The orchestrator reads messages, injects them, then applies
71/// lifecycle rules (removing ephemeral and consume-after-emit messages).
72pub struct ContextMessageStore;
73
74/// Durable map of context messages keyed by message key.
75#[derive(Debug, Clone, Default, PartialEq, Serialize, Deserialize)]
76pub struct ContextMessageStoreValue {
77    pub messages: HashMap<String, ContextMessage>,
78}
79
80impl ContextMessageStoreValue {
81    /// Return all messages sorted by (target, priority, key) for deterministic ordering.
82    pub fn sorted_messages(&self) -> Vec<&ContextMessage> {
83        let mut sorted: Vec<&ContextMessage> = self.messages.values().collect();
84        sorted.sort_by(|a, b| {
85            a.target
86                .cmp(&b.target)
87                .then(a.priority.cmp(&b.priority))
88                .then(a.key.cmp(&b.key))
89        });
90        sorted
91    }
92}
93
94/// Update actions for [`ContextMessageStore`].
95#[derive(Debug, Clone, Serialize, Deserialize)]
96pub enum ContextMessageAction {
97    /// Add or update a context message by key.
98    Upsert(ContextMessage),
99    /// Remove a context message by key.
100    Remove(String),
101    /// Remove all messages with keys starting with prefix.
102    RemoveByPrefix(String),
103    /// Remove all non-persistent messages (ephemeral lifecycle cleanup).
104    RemoveEphemeral,
105    /// Remove all messages flagged `consume_after_emit`.
106    ConsumeAfterEmit,
107    /// Clear all messages.
108    Clear,
109}
110
111impl StateKey for ContextMessageStore {
112    const KEY: &'static str = "__runtime.context_message_store";
113    const MERGE: MergeStrategy = MergeStrategy::Commutative;
114
115    type Value = ContextMessageStoreValue;
116    type Update = ContextMessageAction;
117
118    fn apply(value: &mut Self::Value, update: Self::Update) {
119        match update {
120            ContextMessageAction::Upsert(msg) => {
121                value.messages.insert(msg.key.clone(), msg);
122            }
123            ContextMessageAction::Remove(key) => {
124                value.messages.remove(&key);
125            }
126            ContextMessageAction::RemoveByPrefix(prefix) => {
127                value.messages.retain(|k, _| !k.starts_with(&prefix));
128            }
129            ContextMessageAction::RemoveEphemeral => {
130                value.messages.retain(|_, m| m.persistent);
131            }
132            ContextMessageAction::ConsumeAfterEmit => {
133                value.messages.retain(|_, m| !m.consume_after_emit);
134            }
135            ContextMessageAction::Clear => {
136                value.messages.clear();
137            }
138        }
139    }
140}
141
142// ---------------------------------------------------------------------------
143// Accumulator state keys (written by handlers, read/cleared by orchestrator)
144// ---------------------------------------------------------------------------
145
146/// Accumulated tool filter state for the current inference step.
147/// Written by `ExcludeTool` and `IncludeOnlyTools` handlers.
148/// Read and cleared by the orchestrator after the EXECUTE loop.
149pub struct ToolFilterState;
150
151#[derive(Debug, Clone, Default, Serialize, Deserialize)]
152pub struct ToolFilterStateValue {
153    pub excluded: Vec<String>,
154    pub include_only: Vec<Vec<String>>,
155}
156
157#[derive(Debug, Clone, Serialize, Deserialize)]
158pub enum ToolFilterStateAction {
159    Exclude(String),
160    IncludeOnly(Vec<String>),
161    Clear,
162}
163
164impl StateKey for ToolFilterState {
165    const KEY: &'static str = "__runtime.tool_filter_state";
166    const MERGE: MergeStrategy = MergeStrategy::Commutative;
167    type Value = ToolFilterStateValue;
168    type Update = ToolFilterStateAction;
169
170    fn apply(value: &mut Self::Value, update: Self::Update) {
171        match update {
172            ToolFilterStateAction::Exclude(id) => value.excluded.push(id),
173            ToolFilterStateAction::IncludeOnly(ids) => value.include_only.push(ids),
174            ToolFilterStateAction::Clear => {
175                value.excluded.clear();
176                value.include_only.clear();
177            }
178        }
179    }
180}
181
182/// Accumulated inference override for the current step.
183/// Written by `SetInferenceOverride` handler. Read and cleared by orchestrator.
184pub struct InferenceOverrideState;
185
186#[derive(Debug, Clone, Default, Serialize, Deserialize)]
187pub struct InferenceOverrideStateValue {
188    pub overrides: Option<InferenceOverride>,
189}
190
191#[derive(Debug, Clone, Serialize, Deserialize)]
192pub enum InferenceOverrideStateAction {
193    Merge(InferenceOverride),
194    Clear,
195}
196
197impl StateKey for InferenceOverrideState {
198    const KEY: &'static str = "__runtime.inference_override_state";
199    const MERGE: MergeStrategy = MergeStrategy::Commutative;
200    type Value = InferenceOverrideStateValue;
201    type Update = InferenceOverrideStateAction;
202
203    fn apply(value: &mut Self::Value, update: Self::Update) {
204        match update {
205            InferenceOverrideStateAction::Merge(ovr) => {
206                if let Some(existing) = value.overrides.as_mut() {
207                    existing.merge(ovr);
208                } else {
209                    value.overrides = Some(ovr);
210                }
211            }
212            InferenceOverrideStateAction::Clear => {
213                value.overrides = None;
214            }
215        }
216    }
217}
218
219#[cfg(test)]
220mod tests {
221    use super::*;
222    use awaken_contract::contract::context_message::ContextMessage as ContractContextMessage;
223
224    // -----------------------------------------------------------------------
225    // ContextMessageStore tests
226    // -----------------------------------------------------------------------
227
228    #[test]
229    fn context_message_store_upsert() {
230        let mut val = ContextMessageStoreValue::default();
231        ContextMessageStore::apply(
232            &mut val,
233            ContextMessageAction::Upsert(ContractContextMessage::system("k1", "msg1")),
234        );
235        assert_eq!(val.messages.len(), 1);
236        assert!(val.messages.contains_key("k1"));
237    }
238
239    #[test]
240    fn context_message_store_upsert_replaces() {
241        let mut val = ContextMessageStoreValue::default();
242        ContextMessageStore::apply(
243            &mut val,
244            ContextMessageAction::Upsert(ContractContextMessage::system("k1", "msg1")),
245        );
246        ContextMessageStore::apply(
247            &mut val,
248            ContextMessageAction::Upsert(ContractContextMessage::system("k1", "updated")),
249        );
250        assert_eq!(val.messages.len(), 1);
251        assert_eq!(
252            val.messages["k1"].content[0],
253            awaken_contract::contract::content::ContentBlock::text("updated")
254        );
255    }
256
257    #[test]
258    fn context_message_store_upsert_multiple() {
259        let mut val = ContextMessageStoreValue::default();
260        for i in 0..5 {
261            ContextMessageStore::apply(
262                &mut val,
263                ContextMessageAction::Upsert(ContractContextMessage::system(
264                    format!("k{i}"),
265                    format!("msg{i}"),
266                )),
267            );
268        }
269        assert_eq!(val.messages.len(), 5);
270    }
271
272    #[test]
273    fn context_message_store_remove() {
274        let mut val = ContextMessageStoreValue::default();
275        ContextMessageStore::apply(
276            &mut val,
277            ContextMessageAction::Upsert(ContractContextMessage::system("k1", "msg1")),
278        );
279        ContextMessageStore::apply(
280            &mut val,
281            ContextMessageAction::Upsert(ContractContextMessage::system("k2", "msg2")),
282        );
283        ContextMessageStore::apply(&mut val, ContextMessageAction::Remove("k1".into()));
284        assert_eq!(val.messages.len(), 1);
285        assert!(val.messages.contains_key("k2"));
286    }
287
288    #[test]
289    fn context_message_store_remove_by_prefix() {
290        let mut val = ContextMessageStoreValue::default();
291        ContextMessageStore::apply(
292            &mut val,
293            ContextMessageAction::Upsert(ContractContextMessage::system("mcp:tool1", "t1")),
294        );
295        ContextMessageStore::apply(
296            &mut val,
297            ContextMessageAction::Upsert(ContractContextMessage::system("mcp:tool2", "t2")),
298        );
299        ContextMessageStore::apply(
300            &mut val,
301            ContextMessageAction::Upsert(ContractContextMessage::system("skill:a", "s1")),
302        );
303        ContextMessageStore::apply(
304            &mut val,
305            ContextMessageAction::RemoveByPrefix("mcp:".into()),
306        );
307        assert_eq!(val.messages.len(), 1);
308        assert!(val.messages.contains_key("skill:a"));
309    }
310
311    #[test]
312    fn context_message_store_remove_ephemeral() {
313        let mut val = ContextMessageStoreValue::default();
314        ContextMessageStore::apply(
315            &mut val,
316            ContextMessageAction::Upsert(ContractContextMessage::system("eph", "ephemeral")),
317        );
318        ContextMessageStore::apply(
319            &mut val,
320            ContextMessageAction::Upsert(ContractContextMessage::system_persistent(
321                "pers",
322                "persistent",
323            )),
324        );
325        ContextMessageStore::apply(&mut val, ContextMessageAction::RemoveEphemeral);
326        assert_eq!(val.messages.len(), 1);
327        assert!(val.messages.contains_key("pers"));
328    }
329
330    #[test]
331    fn context_message_store_consume_after_emit() {
332        let mut val = ContextMessageStoreValue::default();
333        ContextMessageStore::apply(
334            &mut val,
335            ContextMessageAction::Upsert(ContractContextMessage::emit_once(
336                "once",
337                "once",
338                awaken_contract::contract::context_message::ContextMessageTarget::System,
339            )),
340        );
341        ContextMessageStore::apply(
342            &mut val,
343            ContextMessageAction::Upsert(ContractContextMessage::system_persistent("keep", "keep")),
344        );
345        ContextMessageStore::apply(&mut val, ContextMessageAction::ConsumeAfterEmit);
346        assert_eq!(val.messages.len(), 1);
347        assert!(val.messages.contains_key("keep"));
348    }
349
350    #[test]
351    fn context_message_store_clear() {
352        let mut val = ContextMessageStoreValue::default();
353        ContextMessageStore::apply(
354            &mut val,
355            ContextMessageAction::Upsert(ContractContextMessage::system("k1", "msg1")),
356        );
357        ContextMessageStore::apply(&mut val, ContextMessageAction::Clear);
358        assert!(val.messages.is_empty());
359    }
360
361    #[test]
362    fn context_message_store_sorted_messages() {
363        let mut val = ContextMessageStoreValue::default();
364        ContextMessageStore::apply(
365            &mut val,
366            ContextMessageAction::Upsert(
367                ContractContextMessage::suffix_system("z_suffix", "last").with_priority(0),
368            ),
369        );
370        ContextMessageStore::apply(
371            &mut val,
372            ContextMessageAction::Upsert(
373                ContractContextMessage::system("a_sys", "first").with_priority(0),
374            ),
375        );
376        ContextMessageStore::apply(
377            &mut val,
378            ContextMessageAction::Upsert(
379                ContractContextMessage::system("b_sys", "second").with_priority(10),
380            ),
381        );
382        let sorted = val.sorted_messages();
383        assert_eq!(sorted[0].key, "a_sys");
384        assert_eq!(sorted[1].key, "b_sys");
385        assert_eq!(sorted[2].key, "z_suffix");
386    }
387}