1use std::any::{Any, TypeId};
2use std::collections::HashMap;
3use std::sync::Arc;
4
5use serde::{Deserialize, Serialize};
6
7use crate::budget::{IterationBudget, MoneyBudget, TimeBudget, TokenBudget};
8use crate::ids::{AgentId, OrgId, TeamId};
9use crate::value::Value;
10
11#[derive(Clone, Default)]
31pub struct Extensions {
32 map: HashMap<TypeId, Arc<dyn Any + Send + Sync>>,
33}
34
35impl Extensions {
36 pub fn new() -> Self {
38 Self::default()
39 }
40
41 pub fn insert<T: Any + Send + Sync>(&mut self, value: T) {
43 self.map.insert(TypeId::of::<T>(), Arc::new(value));
44 }
45
46 pub fn get<T: Any + Send + Sync>(&self) -> Option<&T> {
48 self.map
49 .get(&TypeId::of::<T>())
50 .and_then(|v| v.downcast_ref::<T>())
51 }
52
53 pub fn contains<T: Any + Send + Sync>(&self) -> bool {
55 self.map.contains_key(&TypeId::of::<T>())
56 }
57
58 pub fn len(&self) -> usize {
60 self.map.len()
61 }
62
63 pub fn is_empty(&self) -> bool {
65 self.map.is_empty()
66 }
67}
68
69impl std::fmt::Debug for Extensions {
70 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
71 f.debug_struct("Extensions")
74 .field("len", &self.map.len())
75 .finish()
76 }
77}
78
79#[derive(Debug, Clone, Serialize, Deserialize)]
83pub struct Message {
84 pub role: MessageRole,
85 pub content: String,
86}
87
88#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
89#[serde(rename_all = "snake_case")]
90pub enum MessageRole {
91 System,
92 User,
93 Assistant,
94 Tool,
95}
96
97#[derive(Debug, Clone, Serialize, Deserialize)]
99pub struct TurnInput {
100 pub user: String,
101 #[serde(default)]
102 pub history: Vec<Message>,
103}
104
105#[derive(Debug, Clone)]
109pub struct AgentContext {
110 pub agent_id: AgentId,
111 pub team_id: Option<TeamId>,
112 pub org_id: Option<OrgId>,
113 pub turn: TurnInput,
114}
115
116impl AgentContext {
117 pub fn for_agent(agent_id: AgentId, turn: TurnInput) -> Self {
118 Self {
119 agent_id,
120 team_id: None,
121 org_id: None,
122 turn,
123 }
124 }
125}
126
127#[derive(Debug, Clone)]
132pub struct CallCtx {
133 pub agent_id: Option<AgentId>,
134 pub tokens: TokenBudget,
135 pub time: TimeBudget,
136 pub money: MoneyBudget,
137 pub iterations: IterationBudget,
138 pub trace: Vec<String>,
139 pub extensions: Extensions,
142}
143
144impl CallCtx {
145 pub fn new(
147 agent_id: Option<AgentId>,
148 tokens: TokenBudget,
149 time: TimeBudget,
150 money: MoneyBudget,
151 iterations: IterationBudget,
152 trace: Vec<String>,
153 ) -> Self {
154 Self {
155 agent_id,
156 tokens,
157 time,
158 money,
159 iterations,
160 trace,
161 extensions: Extensions::new(),
162 }
163 }
164
165 pub fn with_ext<T: Any + Send + Sync>(mut self, value: T) -> Self {
167 self.extensions.insert(value);
168 self
169 }
170
171 pub fn insert_ext<T: Any + Send + Sync>(&mut self, value: T) {
173 self.extensions.insert(value);
174 }
175
176 pub fn ext<T: Any + Send + Sync>(&self) -> Option<&T> {
178 self.extensions.get::<T>()
179 }
180}
181
182#[derive(Debug, Clone)]
184pub struct InvokeCtx {
185 pub call: CallCtx,
186 pub tool_call_id: String,
187 pub raw_args: Value,
188}
189
190impl InvokeCtx {
191 pub fn ext<T: Any + Send + Sync>(&self) -> Option<&T> {
194 self.call.ext::<T>()
195 }
196}
197
198#[cfg(test)]
199mod tests {
200 use super::*;
201 use crate::budget::{IterationBudget, MoneyBudget, TimeBudget, TokenBudget};
202
203 #[derive(Debug, PartialEq)]
204 struct Clearance(&'static str);
205
206 fn ctx() -> CallCtx {
207 CallCtx::new(
208 None,
209 TokenBudget::new(100),
210 TimeBudget::new(std::time::Duration::from_secs(1)),
211 MoneyBudget::from_usd(1.0),
212 IterationBudget::new(1),
213 vec![],
214 )
215 }
216
217 #[test]
218 fn extensions_roundtrip_by_type() {
219 let mut c = ctx();
220 assert!(c.ext::<Clearance>().is_none());
221 c.insert_ext(Clearance("mnpi"));
222 assert_eq!(c.ext::<Clearance>(), Some(&Clearance("mnpi")));
223 assert!(c.extensions.contains::<Clearance>());
224 assert_eq!(c.extensions.len(), 1);
225 }
226
227 #[test]
228 fn extensions_survive_clone_and_propagate_to_invoke_ctx() {
229 let c = ctx().with_ext(Clearance("internal"));
230 let cloned = c.clone();
231 assert_eq!(cloned.ext::<Clearance>(), Some(&Clearance("internal")));
232
233 let ictx = InvokeCtx {
234 call: c,
235 tool_call_id: "t".into(),
236 raw_args: Value::Null,
237 };
238 assert_eq!(ictx.ext::<Clearance>(), Some(&Clearance("internal")));
239 }
240}