Skip to main content

atomr_agents_core/
context.rs

1use std::any::{Any, TypeId};
2use std::collections::HashMap;
3use std::sync::Arc;
4
5use serde::{Deserialize, Serialize};
6
7use crate::budget::{IterationBudget, MoneyBudget, TimeBudget, TokenBudget};
8use crate::ids::{AgentId, OrgId, TeamId};
9use crate::value::Value;
10
11/// Typed, tamper-evident extension map carried on [`CallCtx`].
12///
13/// This is the substrate-level channel through which a host attaches
14/// caller identity / clearance / mandate (e.g. a `ClearanceContext`
15/// from `atomr-agents-security`) so it flows unmodified to
16/// `Tool::invoke` via [`InvokeCtx`]. Two deliberate guarantees:
17///
18/// * **Never persisted.** [`CallCtx`] is not `Serialize`, and the
19///   extension map holds `dyn Any` values that cannot be serialized,
20///   so secrets/clearance attached here never land in a checkpoint,
21///   telemetry record, or prompt.
22/// * **Not LLM-writable.** Values are inserted only by Rust host /
23///   runtime code via [`Extensions::insert`] / [`CallCtx::insert_ext`];
24///   there is no path from an LLM tool argument (`raw_args`) into this
25///   map, so it is tamper-evident with respect to model output.
26///
27/// Stored values are `Arc`-wrapped so cloning a [`CallCtx`] (which the
28/// runtime does per tool dispatch) shares — rather than deep-copies —
29/// the extensions.
30#[derive(Clone, Default)]
31pub struct Extensions {
32    map: HashMap<TypeId, Arc<dyn Any + Send + Sync>>,
33}
34
35impl Extensions {
36    /// An empty extension map.
37    pub fn new() -> Self {
38        Self::default()
39    }
40
41    /// Insert (or replace) the extension of type `T`.
42    pub fn insert<T: Any + Send + Sync>(&mut self, value: T) {
43        self.map.insert(TypeId::of::<T>(), Arc::new(value));
44    }
45
46    /// Borrow the extension of type `T`, if present.
47    pub fn get<T: Any + Send + Sync>(&self) -> Option<&T> {
48        self.map
49            .get(&TypeId::of::<T>())
50            .and_then(|v| v.downcast_ref::<T>())
51    }
52
53    /// Whether an extension of type `T` is present.
54    pub fn contains<T: Any + Send + Sync>(&self) -> bool {
55        self.map.contains_key(&TypeId::of::<T>())
56    }
57
58    /// Number of distinct extension types stored.
59    pub fn len(&self) -> usize {
60        self.map.len()
61    }
62
63    /// Whether the map is empty.
64    pub fn is_empty(&self) -> bool {
65        self.map.is_empty()
66    }
67}
68
69impl std::fmt::Debug for Extensions {
70    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
71        // Values are `dyn Any` and not `Debug`; report only the count so
72        // CallCtx can still derive `Debug` without leaking contents.
73        f.debug_struct("Extensions")
74            .field("len", &self.map.len())
75            .finish()
76    }
77}
78
79/// Conversation message — mirrors `atomr_infer_core::batch::Message`
80/// but lives at this layer so strategies can construct turns without
81/// pulling in the full inference crate.
82#[derive(Debug, Clone, Serialize, Deserialize)]
83pub struct Message {
84    pub role: MessageRole,
85    pub content: String,
86}
87
88#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
89#[serde(rename_all = "snake_case")]
90pub enum MessageRole {
91    System,
92    User,
93    Assistant,
94    Tool,
95}
96
97/// What a single agent turn consumes.
98#[derive(Debug, Clone, Serialize, Deserialize)]
99pub struct TurnInput {
100    pub user: String,
101    #[serde(default)]
102    pub history: Vec<Message>,
103}
104
105/// State the per-turn pipeline reads from. Strategies receive a
106/// reference to this; they do not mutate it directly (they return
107/// fragments which the `ContextAssembler` merges).
108#[derive(Debug, Clone)]
109pub struct AgentContext {
110    pub agent_id: AgentId,
111    pub team_id: Option<TeamId>,
112    pub org_id: Option<OrgId>,
113    pub turn: TurnInput,
114}
115
116impl AgentContext {
117    pub fn for_agent(agent_id: AgentId, turn: TurnInput) -> Self {
118        Self {
119            agent_id,
120            team_id: None,
121            org_id: None,
122            turn,
123        }
124    }
125}
126
127/// Context passed to `Callable::call`. Carries the budgets so a
128/// callable can refuse work it can't afford, plus a typed
129/// [`Extensions`] map for substrate-level caller context (clearance,
130/// mandate, decision keys) that flows through to `Tool::invoke`.
131#[derive(Debug, Clone)]
132pub struct CallCtx {
133    pub agent_id: Option<AgentId>,
134    pub tokens: TokenBudget,
135    pub time: TimeBudget,
136    pub money: MoneyBudget,
137    pub iterations: IterationBudget,
138    pub trace: Vec<String>,
139    /// Typed extension map (clearance, mandate, …). Never serialized,
140    /// never LLM-writable. See [`Extensions`].
141    pub extensions: Extensions,
142}
143
144impl CallCtx {
145    /// Construct a [`CallCtx`] with an empty extension map.
146    pub fn new(
147        agent_id: Option<AgentId>,
148        tokens: TokenBudget,
149        time: TimeBudget,
150        money: MoneyBudget,
151        iterations: IterationBudget,
152        trace: Vec<String>,
153    ) -> Self {
154        Self {
155            agent_id,
156            tokens,
157            time,
158            money,
159            iterations,
160            trace,
161            extensions: Extensions::new(),
162        }
163    }
164
165    /// Attach a typed extension, returning `self` for chaining.
166    pub fn with_ext<T: Any + Send + Sync>(mut self, value: T) -> Self {
167        self.extensions.insert(value);
168        self
169    }
170
171    /// Attach a typed extension in place.
172    pub fn insert_ext<T: Any + Send + Sync>(&mut self, value: T) {
173        self.extensions.insert(value);
174    }
175
176    /// Borrow a typed extension of type `T`, if present.
177    pub fn ext<T: Any + Send + Sync>(&self) -> Option<&T> {
178        self.extensions.get::<T>()
179    }
180}
181
182/// Context passed to `Tool::invoke`.
183#[derive(Debug, Clone)]
184pub struct InvokeCtx {
185    pub call: CallCtx,
186    pub tool_call_id: String,
187    pub raw_args: Value,
188}
189
190impl InvokeCtx {
191    /// Borrow a typed extension propagated from the enclosing
192    /// [`CallCtx`], if present.
193    pub fn ext<T: Any + Send + Sync>(&self) -> Option<&T> {
194        self.call.ext::<T>()
195    }
196}
197
198#[cfg(test)]
199mod tests {
200    use super::*;
201    use crate::budget::{IterationBudget, MoneyBudget, TimeBudget, TokenBudget};
202
203    #[derive(Debug, PartialEq)]
204    struct Clearance(&'static str);
205
206    fn ctx() -> CallCtx {
207        CallCtx::new(
208            None,
209            TokenBudget::new(100),
210            TimeBudget::new(std::time::Duration::from_secs(1)),
211            MoneyBudget::from_usd(1.0),
212            IterationBudget::new(1),
213            vec![],
214        )
215    }
216
217    #[test]
218    fn extensions_roundtrip_by_type() {
219        let mut c = ctx();
220        assert!(c.ext::<Clearance>().is_none());
221        c.insert_ext(Clearance("mnpi"));
222        assert_eq!(c.ext::<Clearance>(), Some(&Clearance("mnpi")));
223        assert!(c.extensions.contains::<Clearance>());
224        assert_eq!(c.extensions.len(), 1);
225    }
226
227    #[test]
228    fn extensions_survive_clone_and_propagate_to_invoke_ctx() {
229        let c = ctx().with_ext(Clearance("internal"));
230        let cloned = c.clone();
231        assert_eq!(cloned.ext::<Clearance>(), Some(&Clearance("internal")));
232
233        let ictx = InvokeCtx {
234            call: c,
235            tool_call_id: "t".into(),
236            raw_args: Value::Null,
237        };
238        assert_eq!(ictx.ext::<Clearance>(), Some(&Clearance("internal")));
239    }
240}