awaken_runtime/agent/state/
loop_actions.rs1use std::collections::HashMap;
2
3use serde::{Deserialize, Serialize};
4
5use crate::state::{MergeStrategy, StateKey};
6use awaken_contract::contract::context_message::ContextMessage;
7use awaken_contract::contract::inference::InferenceOverride;
8
9pub struct AddContextMessage;
19
20impl awaken_contract::model::ScheduledActionSpec for AddContextMessage {
21 const KEY: &'static str = "runtime.add_context_message";
22 const PHASE: awaken_contract::model::Phase = awaken_contract::model::Phase::BeforeInference;
23 type Payload = ContextMessage;
24}
25
26pub struct SetInferenceOverride;
31
32impl awaken_contract::model::ScheduledActionSpec for SetInferenceOverride {
33 const KEY: &'static str = "runtime.set_inference_override";
34 const PHASE: awaken_contract::model::Phase = awaken_contract::model::Phase::BeforeInference;
35 type Payload = InferenceOverride;
36}
37
38pub struct ExcludeTool;
43
44impl awaken_contract::model::ScheduledActionSpec for ExcludeTool {
45 const KEY: &'static str = "runtime.exclude_tool";
46 const PHASE: awaken_contract::model::Phase = awaken_contract::model::Phase::BeforeInference;
47 type Payload = String;
48}
49
50pub struct IncludeOnlyTools;
55
56impl awaken_contract::model::ScheduledActionSpec for IncludeOnlyTools {
57 const KEY: &'static str = "runtime.include_only_tools";
58 const PHASE: awaken_contract::model::Phase = awaken_contract::model::Phase::BeforeInference;
59 type Payload = Vec<String>;
60}
61
62pub struct ContextMessageStore;
73
74#[derive(Debug, Clone, Default, PartialEq, Serialize, Deserialize)]
76pub struct ContextMessageStoreValue {
77 pub messages: HashMap<String, ContextMessage>,
78}
79
80impl ContextMessageStoreValue {
81 pub fn sorted_messages(&self) -> Vec<&ContextMessage> {
83 let mut sorted: Vec<&ContextMessage> = self.messages.values().collect();
84 sorted.sort_by(|a, b| {
85 a.target
86 .cmp(&b.target)
87 .then(a.priority.cmp(&b.priority))
88 .then(a.key.cmp(&b.key))
89 });
90 sorted
91 }
92}
93
94#[derive(Debug, Clone, Serialize, Deserialize)]
96pub enum ContextMessageAction {
97 Upsert(ContextMessage),
99 Remove(String),
101 RemoveByPrefix(String),
103 RemoveEphemeral,
105 ConsumeAfterEmit,
107 Clear,
109}
110
111impl StateKey for ContextMessageStore {
112 const KEY: &'static str = "__runtime.context_message_store";
113 const MERGE: MergeStrategy = MergeStrategy::Commutative;
114
115 type Value = ContextMessageStoreValue;
116 type Update = ContextMessageAction;
117
118 fn apply(value: &mut Self::Value, update: Self::Update) {
119 match update {
120 ContextMessageAction::Upsert(msg) => {
121 value.messages.insert(msg.key.clone(), msg);
122 }
123 ContextMessageAction::Remove(key) => {
124 value.messages.remove(&key);
125 }
126 ContextMessageAction::RemoveByPrefix(prefix) => {
127 value.messages.retain(|k, _| !k.starts_with(&prefix));
128 }
129 ContextMessageAction::RemoveEphemeral => {
130 value.messages.retain(|_, m| m.persistent);
131 }
132 ContextMessageAction::ConsumeAfterEmit => {
133 value.messages.retain(|_, m| !m.consume_after_emit);
134 }
135 ContextMessageAction::Clear => {
136 value.messages.clear();
137 }
138 }
139 }
140}
141
142pub struct ToolFilterState;
150
151#[derive(Debug, Clone, Default, Serialize, Deserialize)]
152pub struct ToolFilterStateValue {
153 pub excluded: Vec<String>,
154 pub include_only: Vec<Vec<String>>,
155}
156
157#[derive(Debug, Clone, Serialize, Deserialize)]
158pub enum ToolFilterStateAction {
159 Exclude(String),
160 IncludeOnly(Vec<String>),
161 Clear,
162}
163
164impl StateKey for ToolFilterState {
165 const KEY: &'static str = "__runtime.tool_filter_state";
166 const MERGE: MergeStrategy = MergeStrategy::Commutative;
167 type Value = ToolFilterStateValue;
168 type Update = ToolFilterStateAction;
169
170 fn apply(value: &mut Self::Value, update: Self::Update) {
171 match update {
172 ToolFilterStateAction::Exclude(id) => value.excluded.push(id),
173 ToolFilterStateAction::IncludeOnly(ids) => value.include_only.push(ids),
174 ToolFilterStateAction::Clear => {
175 value.excluded.clear();
176 value.include_only.clear();
177 }
178 }
179 }
180}
181
182pub struct InferenceOverrideState;
185
186#[derive(Debug, Clone, Default, Serialize, Deserialize)]
187pub struct InferenceOverrideStateValue {
188 pub overrides: Option<InferenceOverride>,
189}
190
191#[derive(Debug, Clone, Serialize, Deserialize)]
192pub enum InferenceOverrideStateAction {
193 Merge(InferenceOverride),
194 Clear,
195}
196
197impl StateKey for InferenceOverrideState {
198 const KEY: &'static str = "__runtime.inference_override_state";
199 const MERGE: MergeStrategy = MergeStrategy::Commutative;
200 type Value = InferenceOverrideStateValue;
201 type Update = InferenceOverrideStateAction;
202
203 fn apply(value: &mut Self::Value, update: Self::Update) {
204 match update {
205 InferenceOverrideStateAction::Merge(ovr) => {
206 if let Some(existing) = value.overrides.as_mut() {
207 existing.merge(ovr);
208 } else {
209 value.overrides = Some(ovr);
210 }
211 }
212 InferenceOverrideStateAction::Clear => {
213 value.overrides = None;
214 }
215 }
216 }
217}
218
219#[cfg(test)]
220mod tests {
221 use super::*;
222 use awaken_contract::contract::context_message::ContextMessage as ContractContextMessage;
223
224 #[test]
229 fn context_message_store_upsert() {
230 let mut val = ContextMessageStoreValue::default();
231 ContextMessageStore::apply(
232 &mut val,
233 ContextMessageAction::Upsert(ContractContextMessage::system("k1", "msg1")),
234 );
235 assert_eq!(val.messages.len(), 1);
236 assert!(val.messages.contains_key("k1"));
237 }
238
239 #[test]
240 fn context_message_store_upsert_replaces() {
241 let mut val = ContextMessageStoreValue::default();
242 ContextMessageStore::apply(
243 &mut val,
244 ContextMessageAction::Upsert(ContractContextMessage::system("k1", "msg1")),
245 );
246 ContextMessageStore::apply(
247 &mut val,
248 ContextMessageAction::Upsert(ContractContextMessage::system("k1", "updated")),
249 );
250 assert_eq!(val.messages.len(), 1);
251 assert_eq!(
252 val.messages["k1"].content[0],
253 awaken_contract::contract::content::ContentBlock::text("updated")
254 );
255 }
256
257 #[test]
258 fn context_message_store_upsert_multiple() {
259 let mut val = ContextMessageStoreValue::default();
260 for i in 0..5 {
261 ContextMessageStore::apply(
262 &mut val,
263 ContextMessageAction::Upsert(ContractContextMessage::system(
264 format!("k{i}"),
265 format!("msg{i}"),
266 )),
267 );
268 }
269 assert_eq!(val.messages.len(), 5);
270 }
271
272 #[test]
273 fn context_message_store_remove() {
274 let mut val = ContextMessageStoreValue::default();
275 ContextMessageStore::apply(
276 &mut val,
277 ContextMessageAction::Upsert(ContractContextMessage::system("k1", "msg1")),
278 );
279 ContextMessageStore::apply(
280 &mut val,
281 ContextMessageAction::Upsert(ContractContextMessage::system("k2", "msg2")),
282 );
283 ContextMessageStore::apply(&mut val, ContextMessageAction::Remove("k1".into()));
284 assert_eq!(val.messages.len(), 1);
285 assert!(val.messages.contains_key("k2"));
286 }
287
288 #[test]
289 fn context_message_store_remove_by_prefix() {
290 let mut val = ContextMessageStoreValue::default();
291 ContextMessageStore::apply(
292 &mut val,
293 ContextMessageAction::Upsert(ContractContextMessage::system("mcp:tool1", "t1")),
294 );
295 ContextMessageStore::apply(
296 &mut val,
297 ContextMessageAction::Upsert(ContractContextMessage::system("mcp:tool2", "t2")),
298 );
299 ContextMessageStore::apply(
300 &mut val,
301 ContextMessageAction::Upsert(ContractContextMessage::system("skill:a", "s1")),
302 );
303 ContextMessageStore::apply(
304 &mut val,
305 ContextMessageAction::RemoveByPrefix("mcp:".into()),
306 );
307 assert_eq!(val.messages.len(), 1);
308 assert!(val.messages.contains_key("skill:a"));
309 }
310
311 #[test]
312 fn context_message_store_remove_ephemeral() {
313 let mut val = ContextMessageStoreValue::default();
314 ContextMessageStore::apply(
315 &mut val,
316 ContextMessageAction::Upsert(ContractContextMessage::system("eph", "ephemeral")),
317 );
318 ContextMessageStore::apply(
319 &mut val,
320 ContextMessageAction::Upsert(ContractContextMessage::system_persistent(
321 "pers",
322 "persistent",
323 )),
324 );
325 ContextMessageStore::apply(&mut val, ContextMessageAction::RemoveEphemeral);
326 assert_eq!(val.messages.len(), 1);
327 assert!(val.messages.contains_key("pers"));
328 }
329
330 #[test]
331 fn context_message_store_consume_after_emit() {
332 let mut val = ContextMessageStoreValue::default();
333 ContextMessageStore::apply(
334 &mut val,
335 ContextMessageAction::Upsert(ContractContextMessage::emit_once(
336 "once",
337 "once",
338 awaken_contract::contract::context_message::ContextMessageTarget::System,
339 )),
340 );
341 ContextMessageStore::apply(
342 &mut val,
343 ContextMessageAction::Upsert(ContractContextMessage::system_persistent("keep", "keep")),
344 );
345 ContextMessageStore::apply(&mut val, ContextMessageAction::ConsumeAfterEmit);
346 assert_eq!(val.messages.len(), 1);
347 assert!(val.messages.contains_key("keep"));
348 }
349
350 #[test]
351 fn context_message_store_clear() {
352 let mut val = ContextMessageStoreValue::default();
353 ContextMessageStore::apply(
354 &mut val,
355 ContextMessageAction::Upsert(ContractContextMessage::system("k1", "msg1")),
356 );
357 ContextMessageStore::apply(&mut val, ContextMessageAction::Clear);
358 assert!(val.messages.is_empty());
359 }
360
361 #[test]
362 fn context_message_store_sorted_messages() {
363 let mut val = ContextMessageStoreValue::default();
364 ContextMessageStore::apply(
365 &mut val,
366 ContextMessageAction::Upsert(
367 ContractContextMessage::suffix_system("z_suffix", "last").with_priority(0),
368 ),
369 );
370 ContextMessageStore::apply(
371 &mut val,
372 ContextMessageAction::Upsert(
373 ContractContextMessage::system("a_sys", "first").with_priority(0),
374 ),
375 );
376 ContextMessageStore::apply(
377 &mut val,
378 ContextMessageAction::Upsert(
379 ContractContextMessage::system("b_sys", "second").with_priority(10),
380 ),
381 );
382 let sorted = val.sorted_messages();
383 assert_eq!(sorted[0].key, "a_sys");
384 assert_eq!(sorted[1].key, "b_sys");
385 assert_eq!(sorted[2].key, "z_suffix");
386 }
387}