use chrono::{DateTime, Utc};
use serdes_ai_core::{identifier::generate_run_id, ModelSettings, RunUsage};
use std::sync::Arc;
#[derive(Debug, Clone)]
pub struct RunContext<Deps = ()> {
pub deps: Arc<Deps>,
pub run_id: String,
pub start_time: DateTime<Utc>,
pub retry_count: u32,
pub max_retries: u32,
pub tool_name: Option<String>,
pub tool_call_id: Option<String>,
pub model_name: String,
pub model_settings: ModelSettings,
pub usage: RunUsage,
pub metadata: Option<serde_json::Value>,
pub partial_output: bool,
}
impl<Deps> RunContext<Deps> {
#[must_use]
pub fn new(deps: Deps, model_name: impl Into<String>) -> Self {
Self {
deps: Arc::new(deps),
run_id: generate_run_id(),
start_time: Utc::now(),
retry_count: 0,
max_retries: 3,
tool_name: None,
tool_call_id: None,
model_name: model_name.into(),
model_settings: ModelSettings::default(),
usage: RunUsage::default(),
metadata: None,
partial_output: false,
}
}
#[must_use]
pub fn from_arc(deps: Arc<Deps>, model_name: impl Into<String>) -> Self {
Self {
deps,
run_id: generate_run_id(),
start_time: Utc::now(),
retry_count: 0,
max_retries: 3,
tool_name: None,
tool_call_id: None,
model_name: model_name.into(),
model_settings: ModelSettings::default(),
usage: RunUsage::default(),
metadata: None,
partial_output: false,
}
}
#[must_use]
pub fn with_run_id(mut self, run_id: impl Into<String>) -> Self {
self.run_id = run_id.into();
self
}
#[must_use]
pub fn with_max_retries(mut self, max_retries: u32) -> Self {
self.max_retries = max_retries;
self
}
#[must_use]
pub fn with_model_settings(mut self, settings: ModelSettings) -> Self {
self.model_settings = settings;
self
}
#[must_use]
pub fn with_metadata(mut self, metadata: serde_json::Value) -> Self {
self.metadata = Some(metadata);
self
}
#[must_use]
pub fn with_tool_context(
mut self,
tool_name: impl Into<String>,
tool_call_id: Option<String>,
) -> Self {
self.tool_name = Some(tool_name.into());
self.tool_call_id = tool_call_id;
self
}
#[must_use]
pub fn with_partial_output(mut self, partial: bool) -> Self {
self.partial_output = partial;
self
}
pub fn increment_retry(&mut self) {
self.retry_count += 1;
}
#[must_use]
pub fn can_retry(&self) -> bool {
self.retry_count < self.max_retries
}
#[must_use]
pub fn elapsed(&self) -> chrono::Duration {
Utc::now() - self.start_time
}
#[must_use]
pub fn elapsed_secs(&self) -> f64 {
self.elapsed().num_milliseconds() as f64 / 1000.0
}
#[must_use]
pub fn in_tool_call(&self) -> bool {
self.tool_name.is_some()
}
#[must_use]
pub fn for_tool(&self, tool_name: impl Into<String>, tool_call_id: Option<String>) -> Self {
Self {
deps: Arc::clone(&self.deps),
run_id: self.run_id.clone(),
start_time: self.start_time,
retry_count: 0,
max_retries: self.max_retries,
tool_name: Some(tool_name.into()),
tool_call_id,
model_name: self.model_name.clone(),
model_settings: self.model_settings.clone(),
usage: self.usage.clone(),
metadata: self.metadata.clone(),
partial_output: self.partial_output,
}
}
#[must_use]
pub fn with_usage(mut self, usage: RunUsage) -> Self {
self.usage = usage;
self
}
#[must_use]
pub fn with_deps<NewDeps>(self, new_deps: NewDeps) -> RunContext<NewDeps> {
RunContext {
deps: Arc::new(new_deps),
run_id: self.run_id,
start_time: self.start_time,
retry_count: self.retry_count,
max_retries: self.max_retries,
tool_name: self.tool_name,
tool_call_id: self.tool_call_id,
model_name: self.model_name,
model_settings: self.model_settings,
usage: self.usage,
metadata: self.metadata,
partial_output: self.partial_output,
}
}
}
impl<Deps: Default> Default for RunContext<Deps> {
fn default() -> Self {
Self::new(Deps::default(), "default")
}
}
impl RunContext<()> {
#[must_use]
pub fn minimal(model_name: impl Into<String>) -> Self {
Self::new((), model_name)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[derive(Debug, Clone, Default)]
struct TestDeps {
value: i32,
}
#[test]
fn test_run_context_new() {
let ctx = RunContext::new(TestDeps { value: 42 }, "gpt-4");
assert_eq!(ctx.deps.value, 42);
assert_eq!(ctx.model_name, "gpt-4");
assert!(ctx.run_id.starts_with("run_"));
assert_eq!(ctx.retry_count, 0);
}
#[test]
fn test_run_context_minimal() {
let ctx = RunContext::minimal("claude-3");
assert_eq!(ctx.model_name, "claude-3");
}
#[test]
fn test_run_context_with_tool_context() {
let ctx =
RunContext::minimal("gpt-4").with_tool_context("my_tool", Some("call_123".to_string()));
assert_eq!(ctx.tool_name, Some("my_tool".to_string()));
assert_eq!(ctx.tool_call_id, Some("call_123".to_string()));
assert!(ctx.in_tool_call());
}
#[test]
fn test_increment_retry() {
let mut ctx = RunContext::minimal("gpt-4").with_max_retries(3);
assert!(ctx.can_retry());
ctx.increment_retry();
ctx.increment_retry();
ctx.increment_retry();
assert!(!ctx.can_retry());
}
#[test]
fn test_for_tool() {
let ctx = RunContext::new(TestDeps { value: 10 }, "gpt-4");
let tool_ctx = ctx.for_tool("test_tool", Some("id1".to_string()));
assert_eq!(tool_ctx.deps.value, 10);
assert_eq!(tool_ctx.run_id, ctx.run_id);
assert_eq!(tool_ctx.tool_name, Some("test_tool".to_string()));
assert_eq!(tool_ctx.retry_count, 0);
}
#[test]
fn test_with_deps() {
let ctx = RunContext::minimal("gpt-4");
let new_ctx = ctx.with_deps(TestDeps { value: 99 });
assert_eq!(new_ctx.deps.value, 99);
}
#[test]
fn test_elapsed() {
let ctx = RunContext::minimal("gpt-4");
std::thread::sleep(std::time::Duration::from_millis(10));
let elapsed = ctx.elapsed_secs();
assert!(elapsed >= 0.01);
}
#[test]
fn test_default() {
let ctx: RunContext<TestDeps> = RunContext::default();
assert_eq!(ctx.deps.value, 0);
assert_eq!(ctx.model_name, "default");
}
}