use std::collections::HashMap;
use serde::{Deserialize, Serialize};
use crate::state::{MergeStrategy, StateKey};
use awaken_contract::contract::context_message::ContextMessage;
use awaken_contract::contract::inference::InferenceOverride;
pub struct AddContextMessage;
impl awaken_contract::model::ScheduledActionSpec for AddContextMessage {
const KEY: &'static str = "runtime.add_context_message";
const PHASE: awaken_contract::model::Phase = awaken_contract::model::Phase::BeforeInference;
type Payload = ContextMessage;
}
pub struct SetInferenceOverride;
impl awaken_contract::model::ScheduledActionSpec for SetInferenceOverride {
const KEY: &'static str = "runtime.set_inference_override";
const PHASE: awaken_contract::model::Phase = awaken_contract::model::Phase::BeforeInference;
type Payload = InferenceOverride;
}
pub struct ExcludeTool;
impl awaken_contract::model::ScheduledActionSpec for ExcludeTool {
const KEY: &'static str = "runtime.exclude_tool";
const PHASE: awaken_contract::model::Phase = awaken_contract::model::Phase::BeforeInference;
type Payload = String;
}
pub struct IncludeOnlyTools;
impl awaken_contract::model::ScheduledActionSpec for IncludeOnlyTools {
const KEY: &'static str = "runtime.include_only_tools";
const PHASE: awaken_contract::model::Phase = awaken_contract::model::Phase::BeforeInference;
type Payload = Vec<String>;
}
pub struct ContextMessageStore;
#[derive(Debug, Clone, Default, PartialEq, Serialize, Deserialize)]
pub struct ContextMessageStoreValue {
pub messages: HashMap<String, ContextMessage>,
}
impl ContextMessageStoreValue {
pub fn sorted_messages(&self) -> Vec<&ContextMessage> {
let mut sorted: Vec<&ContextMessage> = self.messages.values().collect();
sorted.sort_by(|a, b| {
a.target
.cmp(&b.target)
.then(a.priority.cmp(&b.priority))
.then(a.key.cmp(&b.key))
});
sorted
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum ContextMessageAction {
Upsert(ContextMessage),
Remove(String),
RemoveByPrefix(String),
RemoveEphemeral,
ConsumeAfterEmit,
Clear,
}
impl StateKey for ContextMessageStore {
const KEY: &'static str = "__runtime.context_message_store";
const MERGE: MergeStrategy = MergeStrategy::Commutative;
type Value = ContextMessageStoreValue;
type Update = ContextMessageAction;
fn apply(value: &mut Self::Value, update: Self::Update) {
match update {
ContextMessageAction::Upsert(msg) => {
value.messages.insert(msg.key.clone(), msg);
}
ContextMessageAction::Remove(key) => {
value.messages.remove(&key);
}
ContextMessageAction::RemoveByPrefix(prefix) => {
value.messages.retain(|k, _| !k.starts_with(&prefix));
}
ContextMessageAction::RemoveEphemeral => {
value.messages.retain(|_, m| m.persistent);
}
ContextMessageAction::ConsumeAfterEmit => {
value.messages.retain(|_, m| !m.consume_after_emit);
}
ContextMessageAction::Clear => {
value.messages.clear();
}
}
}
}
pub struct ToolFilterState;
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
pub struct ToolFilterStateValue {
pub excluded: Vec<String>,
pub include_only: Vec<Vec<String>>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum ToolFilterStateAction {
Exclude(String),
IncludeOnly(Vec<String>),
Clear,
}
impl StateKey for ToolFilterState {
const KEY: &'static str = "__runtime.tool_filter_state";
const MERGE: MergeStrategy = MergeStrategy::Commutative;
type Value = ToolFilterStateValue;
type Update = ToolFilterStateAction;
fn apply(value: &mut Self::Value, update: Self::Update) {
match update {
ToolFilterStateAction::Exclude(id) => value.excluded.push(id),
ToolFilterStateAction::IncludeOnly(ids) => value.include_only.push(ids),
ToolFilterStateAction::Clear => {
value.excluded.clear();
value.include_only.clear();
}
}
}
}
pub struct InferenceOverrideState;
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
pub struct InferenceOverrideStateValue {
pub overrides: Option<InferenceOverride>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum InferenceOverrideStateAction {
Merge(InferenceOverride),
Clear,
}
impl StateKey for InferenceOverrideState {
const KEY: &'static str = "__runtime.inference_override_state";
const MERGE: MergeStrategy = MergeStrategy::Commutative;
type Value = InferenceOverrideStateValue;
type Update = InferenceOverrideStateAction;
fn apply(value: &mut Self::Value, update: Self::Update) {
match update {
InferenceOverrideStateAction::Merge(ovr) => {
if let Some(existing) = value.overrides.as_mut() {
existing.merge(ovr);
} else {
value.overrides = Some(ovr);
}
}
InferenceOverrideStateAction::Clear => {
value.overrides = None;
}
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use awaken_contract::contract::context_message::ContextMessage as ContractContextMessage;
#[test]
fn context_message_store_upsert() {
let mut val = ContextMessageStoreValue::default();
ContextMessageStore::apply(
&mut val,
ContextMessageAction::Upsert(ContractContextMessage::system("k1", "msg1")),
);
assert_eq!(val.messages.len(), 1);
assert!(val.messages.contains_key("k1"));
}
#[test]
fn context_message_store_upsert_replaces() {
let mut val = ContextMessageStoreValue::default();
ContextMessageStore::apply(
&mut val,
ContextMessageAction::Upsert(ContractContextMessage::system("k1", "msg1")),
);
ContextMessageStore::apply(
&mut val,
ContextMessageAction::Upsert(ContractContextMessage::system("k1", "updated")),
);
assert_eq!(val.messages.len(), 1);
assert_eq!(
val.messages["k1"].content[0],
awaken_contract::contract::content::ContentBlock::text("updated")
);
}
#[test]
fn context_message_store_upsert_multiple() {
let mut val = ContextMessageStoreValue::default();
for i in 0..5 {
ContextMessageStore::apply(
&mut val,
ContextMessageAction::Upsert(ContractContextMessage::system(
format!("k{i}"),
format!("msg{i}"),
)),
);
}
assert_eq!(val.messages.len(), 5);
}
#[test]
fn context_message_store_remove() {
let mut val = ContextMessageStoreValue::default();
ContextMessageStore::apply(
&mut val,
ContextMessageAction::Upsert(ContractContextMessage::system("k1", "msg1")),
);
ContextMessageStore::apply(
&mut val,
ContextMessageAction::Upsert(ContractContextMessage::system("k2", "msg2")),
);
ContextMessageStore::apply(&mut val, ContextMessageAction::Remove("k1".into()));
assert_eq!(val.messages.len(), 1);
assert!(val.messages.contains_key("k2"));
}
#[test]
fn context_message_store_remove_by_prefix() {
let mut val = ContextMessageStoreValue::default();
ContextMessageStore::apply(
&mut val,
ContextMessageAction::Upsert(ContractContextMessage::system("mcp:tool1", "t1")),
);
ContextMessageStore::apply(
&mut val,
ContextMessageAction::Upsert(ContractContextMessage::system("mcp:tool2", "t2")),
);
ContextMessageStore::apply(
&mut val,
ContextMessageAction::Upsert(ContractContextMessage::system("skill:a", "s1")),
);
ContextMessageStore::apply(
&mut val,
ContextMessageAction::RemoveByPrefix("mcp:".into()),
);
assert_eq!(val.messages.len(), 1);
assert!(val.messages.contains_key("skill:a"));
}
#[test]
fn context_message_store_remove_ephemeral() {
let mut val = ContextMessageStoreValue::default();
ContextMessageStore::apply(
&mut val,
ContextMessageAction::Upsert(ContractContextMessage::system("eph", "ephemeral")),
);
ContextMessageStore::apply(
&mut val,
ContextMessageAction::Upsert(ContractContextMessage::system_persistent(
"pers",
"persistent",
)),
);
ContextMessageStore::apply(&mut val, ContextMessageAction::RemoveEphemeral);
assert_eq!(val.messages.len(), 1);
assert!(val.messages.contains_key("pers"));
}
#[test]
fn context_message_store_consume_after_emit() {
let mut val = ContextMessageStoreValue::default();
ContextMessageStore::apply(
&mut val,
ContextMessageAction::Upsert(ContractContextMessage::emit_once(
"once",
"once",
awaken_contract::contract::context_message::ContextMessageTarget::System,
)),
);
ContextMessageStore::apply(
&mut val,
ContextMessageAction::Upsert(ContractContextMessage::system_persistent("keep", "keep")),
);
ContextMessageStore::apply(&mut val, ContextMessageAction::ConsumeAfterEmit);
assert_eq!(val.messages.len(), 1);
assert!(val.messages.contains_key("keep"));
}
#[test]
fn context_message_store_clear() {
let mut val = ContextMessageStoreValue::default();
ContextMessageStore::apply(
&mut val,
ContextMessageAction::Upsert(ContractContextMessage::system("k1", "msg1")),
);
ContextMessageStore::apply(&mut val, ContextMessageAction::Clear);
assert!(val.messages.is_empty());
}
#[test]
fn context_message_store_sorted_messages() {
let mut val = ContextMessageStoreValue::default();
ContextMessageStore::apply(
&mut val,
ContextMessageAction::Upsert(
ContractContextMessage::suffix_system("z_suffix", "last").with_priority(0),
),
);
ContextMessageStore::apply(
&mut val,
ContextMessageAction::Upsert(
ContractContextMessage::system("a_sys", "first").with_priority(0),
),
);
ContextMessageStore::apply(
&mut val,
ContextMessageAction::Upsert(
ContractContextMessage::system("b_sys", "second").with_priority(10),
),
);
let sorted = val.sorted_messages();
assert_eq!(sorted[0].key, "a_sys");
assert_eq!(sorted[1].key, "b_sys");
assert_eq!(sorted[2].key, "z_suffix");
}
}