Skip to main content

serdes_ai_agent/
context.rs

1//! Run context and state management.
2//!
3//! The context contains all information about the current agent run,
4//! including dependencies, settings, and execution state.
5
6use chrono::{DateTime, Utc};
7use serde_json::Value as JsonValue;
8use serdes_ai_core::ModelSettings;
9use std::sync::Arc;
10
11/// Context for an agent run.
12///
13/// The context is passed to tools, instruction functions, and validators.
14/// It provides access to dependencies and run metadata.
15#[derive(Debug)]
16pub struct RunContext<Deps> {
17    /// Shared dependencies.
18    pub deps: Arc<Deps>,
19    /// Unique run identifier.
20    pub run_id: String,
21    /// Run start time.
22    pub start_time: DateTime<Utc>,
23    /// Model being used.
24    pub model_name: String,
25    /// Model settings for this run.
26    pub model_settings: ModelSettings,
27    /// Current tool being executed (if any).
28    pub tool_name: Option<String>,
29    /// Current tool call ID (if any).
30    pub tool_call_id: Option<String>,
31    /// Current retry count.
32    pub retry_count: u32,
33    /// Custom metadata.
34    pub metadata: Option<JsonValue>,
35}
36
37impl<Deps> RunContext<Deps> {
38    /// Create a new run context.
39    pub fn new(deps: Deps, model_name: impl Into<String>) -> Self {
40        Self {
41            deps: Arc::new(deps),
42            run_id: generate_run_id(),
43            start_time: Utc::now(),
44            model_name: model_name.into(),
45            model_settings: ModelSettings::default(),
46            tool_name: None,
47            tool_call_id: None,
48            retry_count: 0,
49            metadata: None,
50        }
51    }
52
53    /// Create with shared dependencies.
54    pub fn with_shared_deps(deps: Arc<Deps>, model_name: impl Into<String>) -> Self {
55        Self {
56            deps,
57            run_id: generate_run_id(),
58            start_time: Utc::now(),
59            model_name: model_name.into(),
60            model_settings: ModelSettings::default(),
61            tool_name: None,
62            tool_call_id: None,
63            retry_count: 0,
64            metadata: None,
65        }
66    }
67
68    /// Get a reference to the dependencies.
69    pub fn deps(&self) -> &Deps {
70        &self.deps
71    }
72
73    /// Get elapsed time since run started.
74    pub fn elapsed(&self) -> chrono::Duration {
75        Utc::now() - self.start_time
76    }
77
78    /// Get elapsed time in seconds.
79    pub fn elapsed_seconds(&self) -> i64 {
80        self.elapsed().num_seconds()
81    }
82
83    /// Check if this is a retry.
84    pub fn is_retry(&self) -> bool {
85        self.retry_count > 0
86    }
87
88    /// Check if we're currently in a tool execution.
89    pub fn in_tool(&self) -> bool {
90        self.tool_name.is_some()
91    }
92
93    /// Set metadata value.
94    pub fn set_metadata(&mut self, key: &str, value: impl serde::Serialize) {
95        let meta = self
96            .metadata
97            .get_or_insert_with(|| JsonValue::Object(Default::default()));
98        if let JsonValue::Object(ref mut map) = meta {
99            if let Ok(v) = serde_json::to_value(value) {
100                map.insert(key.to_string(), v);
101            }
102        }
103    }
104
105    /// Get metadata value.
106    pub fn get_metadata<T: serde::de::DeserializeOwned>(&self, key: &str) -> Option<T> {
107        self.metadata
108            .as_ref()
109            .and_then(|m| m.get(key))
110            .and_then(|v| serde_json::from_value(v.clone()).ok())
111    }
112
113    /// Clone with a new tool context.
114    pub fn for_tool(&self, tool_name: impl Into<String>, tool_call_id: Option<String>) -> Self {
115        Self {
116            deps: self.deps.clone(),
117            run_id: self.run_id.clone(),
118            start_time: self.start_time,
119            model_name: self.model_name.clone(),
120            model_settings: self.model_settings.clone(),
121            tool_name: Some(tool_name.into()),
122            tool_call_id,
123            retry_count: 0,
124            metadata: self.metadata.clone(),
125        }
126    }
127
128    /// Clone for a retry.
129    pub fn for_retry(&self) -> Self {
130        Self {
131            deps: self.deps.clone(),
132            run_id: self.run_id.clone(),
133            start_time: self.start_time,
134            model_name: self.model_name.clone(),
135            model_settings: self.model_settings.clone(),
136            tool_name: self.tool_name.clone(),
137            tool_call_id: self.tool_call_id.clone(),
138            retry_count: self.retry_count + 1,
139            metadata: self.metadata.clone(),
140        }
141    }
142}
143
144impl<Deps: Default> Default for RunContext<Deps> {
145    fn default() -> Self {
146        Self::new(Deps::default(), "unknown")
147    }
148}
149
150impl<Deps> Clone for RunContext<Deps> {
151    fn clone(&self) -> Self {
152        Self {
153            deps: self.deps.clone(),
154            run_id: self.run_id.clone(),
155            start_time: self.start_time,
156            model_name: self.model_name.clone(),
157            model_settings: self.model_settings.clone(),
158            tool_name: self.tool_name.clone(),
159            tool_call_id: self.tool_call_id.clone(),
160            retry_count: self.retry_count,
161            metadata: self.metadata.clone(),
162        }
163    }
164}
165
166/// Generate a unique run ID.
167pub fn generate_run_id() -> String {
168    uuid::Uuid::new_v4().to_string()
169}
170
171/// Usage tracking for a run.
172#[derive(Debug, Clone, Default)]
173pub struct RunUsage {
174    /// Total request tokens.
175    pub request_tokens: u64,
176    /// Total response tokens.
177    pub response_tokens: u64,
178    /// Total tokens.
179    pub total_tokens: u64,
180    /// Number of model requests.
181    pub request_count: u32,
182    /// Number of tool calls.
183    pub tool_call_count: u32,
184    /// Cache creation tokens.
185    pub cache_creation_tokens: Option<u64>,
186    /// Cache read tokens.
187    pub cache_read_tokens: Option<u64>,
188}
189
190impl RunUsage {
191    /// Create new empty usage.
192    pub fn new() -> Self {
193        Self::default()
194    }
195
196    /// Add usage from a model request.
197    pub fn add_request(&mut self, usage: serdes_ai_core::RequestUsage) {
198        if let Some(req) = usage.request_tokens {
199            self.request_tokens += req;
200        }
201        if let Some(resp) = usage.response_tokens {
202            self.response_tokens += resp;
203        }
204        if let Some(total) = usage.total_tokens {
205            self.total_tokens += total;
206        } else {
207            self.total_tokens = self.request_tokens + self.response_tokens;
208        }
209        if let Some(cache) = usage.cache_creation_tokens {
210            *self.cache_creation_tokens.get_or_insert(0) += cache;
211        }
212        if let Some(cache) = usage.cache_read_tokens {
213            *self.cache_read_tokens.get_or_insert(0) += cache;
214        }
215        self.request_count += 1;
216    }
217
218    /// Record a tool call.
219    pub fn record_tool_call(&mut self) {
220        self.tool_call_count += 1;
221    }
222}
223
224/// Usage limits for a run.
225#[derive(Debug, Clone, Default)]
226pub struct UsageLimits {
227    /// Maximum request tokens.
228    pub max_request_tokens: Option<u64>,
229    /// Maximum response tokens.
230    pub max_response_tokens: Option<u64>,
231    /// Maximum total tokens.
232    pub max_total_tokens: Option<u64>,
233    /// Maximum number of requests.
234    pub max_requests: Option<u32>,
235    /// Maximum number of tool calls.
236    pub max_tool_calls: Option<u32>,
237    /// Maximum run time in seconds.
238    pub max_time_seconds: Option<u64>,
239}
240
241impl UsageLimits {
242    /// Create new empty limits.
243    pub fn new() -> Self {
244        Self::default()
245    }
246
247    /// Set max request tokens.
248    pub fn request_tokens(mut self, limit: u64) -> Self {
249        self.max_request_tokens = Some(limit);
250        self
251    }
252
253    /// Set max response tokens.
254    pub fn response_tokens(mut self, limit: u64) -> Self {
255        self.max_response_tokens = Some(limit);
256        self
257    }
258
259    /// Set max total tokens.
260    pub fn total_tokens(mut self, limit: u64) -> Self {
261        self.max_total_tokens = Some(limit);
262        self
263    }
264
265    /// Set max requests.
266    pub fn requests(mut self, limit: u32) -> Self {
267        self.max_requests = Some(limit);
268        self
269    }
270
271    /// Set max tool calls.
272    pub fn tool_calls(mut self, limit: u32) -> Self {
273        self.max_tool_calls = Some(limit);
274        self
275    }
276
277    /// Set max time in seconds.
278    pub fn time_seconds(mut self, limit: u64) -> Self {
279        self.max_time_seconds = Some(limit);
280        self
281    }
282
283    /// Check usage against limits.
284    pub fn check(&self, usage: &RunUsage) -> Result<(), crate::errors::UsageLimitError> {
285        use crate::errors::UsageLimitError;
286
287        if let Some(limit) = self.max_request_tokens {
288            if usage.request_tokens > limit {
289                return Err(UsageLimitError::RequestTokens {
290                    used: usage.request_tokens,
291                    limit,
292                });
293            }
294        }
295
296        if let Some(limit) = self.max_response_tokens {
297            if usage.response_tokens > limit {
298                return Err(UsageLimitError::ResponseTokens {
299                    used: usage.response_tokens,
300                    limit,
301                });
302            }
303        }
304
305        if let Some(limit) = self.max_total_tokens {
306            if usage.total_tokens > limit {
307                return Err(UsageLimitError::TotalTokens {
308                    used: usage.total_tokens,
309                    limit,
310                });
311            }
312        }
313
314        if let Some(limit) = self.max_requests {
315            if usage.request_count > limit {
316                return Err(UsageLimitError::RequestCount {
317                    count: usage.request_count,
318                    limit,
319                });
320            }
321        }
322
323        if let Some(limit) = self.max_tool_calls {
324            if usage.tool_call_count > limit {
325                return Err(UsageLimitError::ToolCalls {
326                    count: usage.tool_call_count,
327                    limit,
328                });
329            }
330        }
331
332        Ok(())
333    }
334
335    /// Check time limit.
336    pub fn check_time(&self, elapsed_seconds: u64) -> Result<(), crate::errors::UsageLimitError> {
337        if let Some(limit) = self.max_time_seconds {
338            if elapsed_seconds > limit {
339                return Err(crate::errors::UsageLimitError::TimeLimit {
340                    elapsed_seconds,
341                    limit_seconds: limit,
342                });
343            }
344        }
345        Ok(())
346    }
347}
348
349#[cfg(test)]
350mod tests {
351    use super::*;
352
353    #[test]
354    fn test_run_context_new() {
355        let ctx = RunContext::new((), "gpt-4o");
356        assert_eq!(ctx.model_name, "gpt-4o");
357        assert!(!ctx.run_id.is_empty());
358    }
359
360    #[test]
361    fn test_run_context_metadata() {
362        let mut ctx = RunContext::new((), "gpt-4o");
363        ctx.set_metadata("user_id", "12345");
364
365        let user_id: Option<String> = ctx.get_metadata("user_id");
366        assert_eq!(user_id, Some("12345".to_string()));
367    }
368
369    #[test]
370    fn test_run_context_for_tool() {
371        let ctx = RunContext::new((), "gpt-4o");
372        let tool_ctx = ctx.for_tool("search", Some("call-123".to_string()));
373
374        assert_eq!(tool_ctx.tool_name, Some("search".to_string()));
375        assert_eq!(tool_ctx.tool_call_id, Some("call-123".to_string()));
376        assert!(tool_ctx.in_tool());
377    }
378
379    #[test]
380    fn test_run_usage() {
381        let mut usage = RunUsage::new();
382        usage.add_request(serdes_ai_core::RequestUsage {
383            request_tokens: Some(100),
384            response_tokens: Some(50),
385            total_tokens: Some(150),
386            cache_creation_tokens: None,
387            cache_read_tokens: None,
388            details: None,
389        });
390
391        assert_eq!(usage.request_tokens, 100);
392        assert_eq!(usage.response_tokens, 50);
393        assert_eq!(usage.request_count, 1);
394    }
395
396    #[test]
397    fn test_usage_limits() {
398        let limits = UsageLimits::new().total_tokens(1000).requests(10);
399
400        let mut usage = RunUsage::new();
401        usage.total_tokens = 500;
402        usage.request_count = 5;
403
404        assert!(limits.check(&usage).is_ok());
405
406        usage.total_tokens = 1500;
407        assert!(limits.check(&usage).is_err());
408    }
409}