1use serde::{Serialize, de::DeserializeOwned};
6
7use crate::delta::Reducer;
8use crate::state::{State, StateError};
9
10#[derive(Debug)]
12pub struct StateKey<T> {
13 name: &'static str,
14 reducer: Reducer,
15 _marker: std::marker::PhantomData<T>,
16}
17
18impl<T> StateKey<T> {
19 pub const fn new(name: &'static str, reducer: Reducer) -> Self {
20 Self {
21 name,
22 reducer,
23 _marker: std::marker::PhantomData,
24 }
25 }
26
27 pub const fn append(name: &'static str) -> Self {
28 Self::new(name, Reducer::Append)
29 }
30
31 pub const fn sum(name: &'static str) -> Self {
32 Self::new(name, Reducer::Sum)
33 }
34
35 pub const fn replace(name: &'static str) -> Self {
36 Self::new(name, Reducer::Replace)
37 }
38
39 pub const fn merge_object(name: &'static str) -> Self {
40 Self::new(name, Reducer::MergeObject)
41 }
42
43 pub const fn max(name: &'static str) -> Self {
44 Self::new(name, Reducer::Max)
45 }
46
47 pub const fn min(name: &'static str) -> Self {
48 Self::new(name, Reducer::Min)
49 }
50
51 pub const fn error(name: &'static str) -> Self {
52 Self::new(name, Reducer::Error)
53 }
54
55 pub fn name(&self) -> &str {
56 self.name
57 }
58
59 pub fn reducer(&self) -> &Reducer {
60 &self.reducer
61 }
62}
63
64pub trait StateKeyExt {
66 fn set_sk<T>(&mut self, key: &StateKey<T>, value: T)
67 where
68 T: Serialize;
69
70 fn get_sk<T>(&self, key: &StateKey<T>) -> Option<T>
71 where
72 T: DeserializeOwned;
73
74 fn require_sk<T>(&self, key: &StateKey<T>) -> Result<T, StateError>
75 where
76 T: DeserializeOwned;
77
78 fn contains_sk<T>(&self, key: &StateKey<T>) -> bool;
79
80 fn remove_sk<T>(&mut self, key: &StateKey<T>) -> Option<serde_json::Value>;
81}
82
83impl StateKeyExt for State {
84 fn set_sk<T>(&mut self, key: &StateKey<T>, value: T)
85 where
86 T: Serialize,
87 {
88 let key_str = key.name().to_string();
89 let json = match serde_json::to_value(value) {
90 Ok(v) => v,
91 Err(e) => {
92 tracing::warn!(key = %key_str, error = %e, "failed to serialize state value, storing null");
93 serde_json::Value::Null
94 }
95 };
96 self.insert(key_str, json);
97 }
98
99 fn get_sk<T>(&self, key: &StateKey<T>) -> Option<T>
100 where
101 T: DeserializeOwned,
102 {
103 self.get(key.name())
104 .and_then(|v| serde_json::from_value(v.clone()).ok())
105 }
106
107 fn require_sk<T>(&self, key: &StateKey<T>) -> Result<T, StateError>
108 where
109 T: DeserializeOwned,
110 {
111 let value = self
112 .get(key.name())
113 .ok_or_else(|| StateError::MissingKey(key.name().to_string()))?;
114 serde_json::from_value(value.clone())
115 .map_err(|e| StateError::Deserialize(key.name().to_string(), e.to_string()))
116 }
117
118 fn contains_sk<T>(&self, key: &StateKey<T>) -> bool {
119 self.contains_key(key.name())
120 }
121
122 fn remove_sk<T>(&mut self, key: &StateKey<T>) -> Option<serde_json::Value> {
123 self.remove(key.name())
124 }
125}
126
127pub static SK_MESSAGES: StateKey<Vec<serde_json::Value>> =
131 StateKey::new("messages", Reducer::Append);
132
133pub static SK_COUNT: StateKey<u64> = StateKey::new("count", Reducer::Sum);
135
136pub static SK_STEPS: StateKey<Vec<String>> = StateKey::new("steps", Reducer::Append);
138
139pub static SK_ITERATIONS: StateKey<u32> = StateKey::replace("iterations");
143
144pub static SK_PENDING_TOOL_CALLS: StateKey<Vec<serde_json::Value>> =
146 StateKey::replace("pending_tool_calls");
147
148pub static SK_TOTAL_TOOL_CALLS: StateKey<usize> = StateKey::sum("total_tool_calls");
150
151pub static SK_OUTPUT_TOKENS: StateKey<usize> = StateKey::sum("output_tokens");
153
154pub static SK_REASONING_TOKENS: StateKey<usize> = StateKey::sum("reasoning_tokens");