use std::any::{Any, TypeId};
use std::collections::HashMap;
use std::sync::Arc;
use serde::{Deserialize, Serialize};
use crate::budget::{IterationBudget, MoneyBudget, TimeBudget, TokenBudget};
use crate::ids::{AgentId, OrgId, TeamId};
use crate::value::Value;
#[derive(Clone, Default)]
pub struct Extensions {
map: HashMap<TypeId, Arc<dyn Any + Send + Sync>>,
}
impl Extensions {
pub fn new() -> Self {
Self::default()
}
pub fn insert<T: Any + Send + Sync>(&mut self, value: T) {
self.map.insert(TypeId::of::<T>(), Arc::new(value));
}
pub fn get<T: Any + Send + Sync>(&self) -> Option<&T> {
self.map
.get(&TypeId::of::<T>())
.and_then(|v| v.downcast_ref::<T>())
}
pub fn contains<T: Any + Send + Sync>(&self) -> bool {
self.map.contains_key(&TypeId::of::<T>())
}
pub fn len(&self) -> usize {
self.map.len()
}
pub fn is_empty(&self) -> bool {
self.map.is_empty()
}
}
impl std::fmt::Debug for Extensions {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("Extensions")
.field("len", &self.map.len())
.finish()
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Message {
pub role: MessageRole,
pub content: String,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum MessageRole {
System,
User,
Assistant,
Tool,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct TurnInput {
pub user: String,
#[serde(default)]
pub history: Vec<Message>,
}
#[derive(Debug, Clone)]
pub struct AgentContext {
pub agent_id: AgentId,
pub team_id: Option<TeamId>,
pub org_id: Option<OrgId>,
pub turn: TurnInput,
}
impl AgentContext {
pub fn for_agent(agent_id: AgentId, turn: TurnInput) -> Self {
Self {
agent_id,
team_id: None,
org_id: None,
turn,
}
}
}
#[derive(Debug, Clone)]
pub struct CallCtx {
pub agent_id: Option<AgentId>,
pub tokens: TokenBudget,
pub time: TimeBudget,
pub money: MoneyBudget,
pub iterations: IterationBudget,
pub trace: Vec<String>,
pub extensions: Extensions,
}
impl CallCtx {
pub fn new(
agent_id: Option<AgentId>,
tokens: TokenBudget,
time: TimeBudget,
money: MoneyBudget,
iterations: IterationBudget,
trace: Vec<String>,
) -> Self {
Self {
agent_id,
tokens,
time,
money,
iterations,
trace,
extensions: Extensions::new(),
}
}
pub fn with_ext<T: Any + Send + Sync>(mut self, value: T) -> Self {
self.extensions.insert(value);
self
}
pub fn insert_ext<T: Any + Send + Sync>(&mut self, value: T) {
self.extensions.insert(value);
}
pub fn ext<T: Any + Send + Sync>(&self) -> Option<&T> {
self.extensions.get::<T>()
}
}
#[derive(Debug, Clone)]
pub struct InvokeCtx {
pub call: CallCtx,
pub tool_call_id: String,
pub raw_args: Value,
}
impl InvokeCtx {
pub fn ext<T: Any + Send + Sync>(&self) -> Option<&T> {
self.call.ext::<T>()
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::budget::{IterationBudget, MoneyBudget, TimeBudget, TokenBudget};
#[derive(Debug, PartialEq)]
struct Clearance(&'static str);
fn ctx() -> CallCtx {
CallCtx::new(
None,
TokenBudget::new(100),
TimeBudget::new(std::time::Duration::from_secs(1)),
MoneyBudget::from_usd(1.0),
IterationBudget::new(1),
vec![],
)
}
#[test]
fn extensions_roundtrip_by_type() {
let mut c = ctx();
assert!(c.ext::<Clearance>().is_none());
c.insert_ext(Clearance("mnpi"));
assert_eq!(c.ext::<Clearance>(), Some(&Clearance("mnpi")));
assert!(c.extensions.contains::<Clearance>());
assert_eq!(c.extensions.len(), 1);
}
#[test]
fn extensions_survive_clone_and_propagate_to_invoke_ctx() {
let c = ctx().with_ext(Clearance("internal"));
let cloned = c.clone();
assert_eq!(cloned.ext::<Clearance>(), Some(&Clearance("internal")));
let ictx = InvokeCtx {
call: c,
tool_call_id: "t".into(),
raw_args: Value::Null,
};
assert_eq!(ictx.ext::<Clearance>(), Some(&Clearance("internal")));
}
}