1use chrono::{DateTime, Utc};
7use serde_json::Value as JsonValue;
8use serdes_ai_core::ModelSettings;
9use std::sync::Arc;
10
11#[derive(Debug)]
16pub struct RunContext<Deps> {
17 pub deps: Arc<Deps>,
19 pub run_id: String,
21 pub start_time: DateTime<Utc>,
23 pub model_name: String,
25 pub model_settings: ModelSettings,
27 pub tool_name: Option<String>,
29 pub tool_call_id: Option<String>,
31 pub retry_count: u32,
33 pub metadata: Option<JsonValue>,
35}
36
37impl<Deps> RunContext<Deps> {
38 pub fn new(deps: Deps, model_name: impl Into<String>) -> Self {
40 Self {
41 deps: Arc::new(deps),
42 run_id: generate_run_id(),
43 start_time: Utc::now(),
44 model_name: model_name.into(),
45 model_settings: ModelSettings::default(),
46 tool_name: None,
47 tool_call_id: None,
48 retry_count: 0,
49 metadata: None,
50 }
51 }
52
53 pub fn with_shared_deps(deps: Arc<Deps>, model_name: impl Into<String>) -> Self {
55 Self {
56 deps,
57 run_id: generate_run_id(),
58 start_time: Utc::now(),
59 model_name: model_name.into(),
60 model_settings: ModelSettings::default(),
61 tool_name: None,
62 tool_call_id: None,
63 retry_count: 0,
64 metadata: None,
65 }
66 }
67
68 pub fn deps(&self) -> &Deps {
70 &self.deps
71 }
72
73 pub fn elapsed(&self) -> chrono::Duration {
75 Utc::now() - self.start_time
76 }
77
78 pub fn elapsed_seconds(&self) -> i64 {
80 self.elapsed().num_seconds()
81 }
82
83 pub fn is_retry(&self) -> bool {
85 self.retry_count > 0
86 }
87
88 pub fn in_tool(&self) -> bool {
90 self.tool_name.is_some()
91 }
92
93 pub fn set_metadata(&mut self, key: &str, value: impl serde::Serialize) {
95 let meta = self
96 .metadata
97 .get_or_insert_with(|| JsonValue::Object(Default::default()));
98 if let JsonValue::Object(ref mut map) = meta {
99 if let Ok(v) = serde_json::to_value(value) {
100 map.insert(key.to_string(), v);
101 }
102 }
103 }
104
105 pub fn get_metadata<T: serde::de::DeserializeOwned>(&self, key: &str) -> Option<T> {
107 self.metadata
108 .as_ref()
109 .and_then(|m| m.get(key))
110 .and_then(|v| serde_json::from_value(v.clone()).ok())
111 }
112
113 pub fn for_tool(&self, tool_name: impl Into<String>, tool_call_id: Option<String>) -> Self {
115 Self {
116 deps: self.deps.clone(),
117 run_id: self.run_id.clone(),
118 start_time: self.start_time,
119 model_name: self.model_name.clone(),
120 model_settings: self.model_settings.clone(),
121 tool_name: Some(tool_name.into()),
122 tool_call_id,
123 retry_count: 0,
124 metadata: self.metadata.clone(),
125 }
126 }
127
128 pub fn for_retry(&self) -> Self {
130 Self {
131 deps: self.deps.clone(),
132 run_id: self.run_id.clone(),
133 start_time: self.start_time,
134 model_name: self.model_name.clone(),
135 model_settings: self.model_settings.clone(),
136 tool_name: self.tool_name.clone(),
137 tool_call_id: self.tool_call_id.clone(),
138 retry_count: self.retry_count + 1,
139 metadata: self.metadata.clone(),
140 }
141 }
142}
143
144impl<Deps: Default> Default for RunContext<Deps> {
145 fn default() -> Self {
146 Self::new(Deps::default(), "unknown")
147 }
148}
149
150impl<Deps> Clone for RunContext<Deps> {
151 fn clone(&self) -> Self {
152 Self {
153 deps: self.deps.clone(),
154 run_id: self.run_id.clone(),
155 start_time: self.start_time,
156 model_name: self.model_name.clone(),
157 model_settings: self.model_settings.clone(),
158 tool_name: self.tool_name.clone(),
159 tool_call_id: self.tool_call_id.clone(),
160 retry_count: self.retry_count,
161 metadata: self.metadata.clone(),
162 }
163 }
164}
165
166pub fn generate_run_id() -> String {
168 uuid::Uuid::new_v4().to_string()
169}
170
171#[derive(Debug, Clone, Default)]
173pub struct RunUsage {
174 pub request_tokens: u64,
176 pub response_tokens: u64,
178 pub total_tokens: u64,
180 pub request_count: u32,
182 pub tool_call_count: u32,
184 pub cache_creation_tokens: Option<u64>,
186 pub cache_read_tokens: Option<u64>,
188}
189
190impl RunUsage {
191 pub fn new() -> Self {
193 Self::default()
194 }
195
196 pub fn add_request(&mut self, usage: serdes_ai_core::RequestUsage) {
198 if let Some(req) = usage.request_tokens {
199 self.request_tokens += req;
200 }
201 if let Some(resp) = usage.response_tokens {
202 self.response_tokens += resp;
203 }
204 if let Some(total) = usage.total_tokens {
205 self.total_tokens += total;
206 } else {
207 self.total_tokens = self.request_tokens + self.response_tokens;
208 }
209 if let Some(cache) = usage.cache_creation_tokens {
210 *self.cache_creation_tokens.get_or_insert(0) += cache;
211 }
212 if let Some(cache) = usage.cache_read_tokens {
213 *self.cache_read_tokens.get_or_insert(0) += cache;
214 }
215 self.request_count += 1;
216 }
217
218 pub fn record_tool_call(&mut self) {
220 self.tool_call_count += 1;
221 }
222}
223
224#[derive(Debug, Clone, Default)]
226pub struct UsageLimits {
227 pub max_request_tokens: Option<u64>,
229 pub max_response_tokens: Option<u64>,
231 pub max_total_tokens: Option<u64>,
233 pub max_requests: Option<u32>,
235 pub max_tool_calls: Option<u32>,
237 pub max_time_seconds: Option<u64>,
239}
240
241impl UsageLimits {
242 pub fn new() -> Self {
244 Self::default()
245 }
246
247 pub fn request_tokens(mut self, limit: u64) -> Self {
249 self.max_request_tokens = Some(limit);
250 self
251 }
252
253 pub fn response_tokens(mut self, limit: u64) -> Self {
255 self.max_response_tokens = Some(limit);
256 self
257 }
258
259 pub fn total_tokens(mut self, limit: u64) -> Self {
261 self.max_total_tokens = Some(limit);
262 self
263 }
264
265 pub fn requests(mut self, limit: u32) -> Self {
267 self.max_requests = Some(limit);
268 self
269 }
270
271 pub fn tool_calls(mut self, limit: u32) -> Self {
273 self.max_tool_calls = Some(limit);
274 self
275 }
276
277 pub fn time_seconds(mut self, limit: u64) -> Self {
279 self.max_time_seconds = Some(limit);
280 self
281 }
282
283 pub fn check(&self, usage: &RunUsage) -> Result<(), crate::errors::UsageLimitError> {
285 use crate::errors::UsageLimitError;
286
287 if let Some(limit) = self.max_request_tokens {
288 if usage.request_tokens > limit {
289 return Err(UsageLimitError::RequestTokens {
290 used: usage.request_tokens,
291 limit,
292 });
293 }
294 }
295
296 if let Some(limit) = self.max_response_tokens {
297 if usage.response_tokens > limit {
298 return Err(UsageLimitError::ResponseTokens {
299 used: usage.response_tokens,
300 limit,
301 });
302 }
303 }
304
305 if let Some(limit) = self.max_total_tokens {
306 if usage.total_tokens > limit {
307 return Err(UsageLimitError::TotalTokens {
308 used: usage.total_tokens,
309 limit,
310 });
311 }
312 }
313
314 if let Some(limit) = self.max_requests {
315 if usage.request_count > limit {
316 return Err(UsageLimitError::RequestCount {
317 count: usage.request_count,
318 limit,
319 });
320 }
321 }
322
323 if let Some(limit) = self.max_tool_calls {
324 if usage.tool_call_count > limit {
325 return Err(UsageLimitError::ToolCalls {
326 count: usage.tool_call_count,
327 limit,
328 });
329 }
330 }
331
332 Ok(())
333 }
334
335 pub fn check_time(&self, elapsed_seconds: u64) -> Result<(), crate::errors::UsageLimitError> {
337 if let Some(limit) = self.max_time_seconds {
338 if elapsed_seconds > limit {
339 return Err(crate::errors::UsageLimitError::TimeLimit {
340 elapsed_seconds,
341 limit_seconds: limit,
342 });
343 }
344 }
345 Ok(())
346 }
347}
348
349#[cfg(test)]
350mod tests {
351 use super::*;
352
353 #[test]
354 fn test_run_context_new() {
355 let ctx = RunContext::new((), "gpt-4o");
356 assert_eq!(ctx.model_name, "gpt-4o");
357 assert!(!ctx.run_id.is_empty());
358 }
359
360 #[test]
361 fn test_run_context_metadata() {
362 let mut ctx = RunContext::new((), "gpt-4o");
363 ctx.set_metadata("user_id", "12345");
364
365 let user_id: Option<String> = ctx.get_metadata("user_id");
366 assert_eq!(user_id, Some("12345".to_string()));
367 }
368
369 #[test]
370 fn test_run_context_for_tool() {
371 let ctx = RunContext::new((), "gpt-4o");
372 let tool_ctx = ctx.for_tool("search", Some("call-123".to_string()));
373
374 assert_eq!(tool_ctx.tool_name, Some("search".to_string()));
375 assert_eq!(tool_ctx.tool_call_id, Some("call-123".to_string()));
376 assert!(tool_ctx.in_tool());
377 }
378
379 #[test]
380 fn test_run_usage() {
381 let mut usage = RunUsage::new();
382 usage.add_request(serdes_ai_core::RequestUsage {
383 request_tokens: Some(100),
384 response_tokens: Some(50),
385 total_tokens: Some(150),
386 cache_creation_tokens: None,
387 cache_read_tokens: None,
388 details: None,
389 });
390
391 assert_eq!(usage.request_tokens, 100);
392 assert_eq!(usage.response_tokens, 50);
393 assert_eq!(usage.request_count, 1);
394 }
395
396 #[test]
397 fn test_usage_limits() {
398 let limits = UsageLimits::new().total_tokens(1000).requests(10);
399
400 let mut usage = RunUsage::new();
401 usage.total_tokens = 500;
402 usage.request_count = 5;
403
404 assert!(limits.check(&usage).is_ok());
405
406 usage.total_tokens = 1500;
407 assert!(limits.check(&usage).is_err());
408 }
409}