use std::collections::HashMap;
use std::sync::Arc;
use std::time::Duration;
use chrono::{DateTime, Utc};
use rust_decimal::Decimal;
use serde::{Deserialize, Serialize};
use uuid::Uuid;
use crate::llm::recording::HttpInterceptor;
#[derive(Debug, thiserror::Error)]
#[error("Token budget exceeded: used {used} of {limit} allowed tokens")]
pub struct TokenBudgetExceeded {
pub used: u64,
pub limit: u64,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum JobState {
Pending,
InProgress,
Completed,
Submitted,
Accepted,
Failed,
Stuck,
Cancelled,
}
impl JobState {
pub fn can_transition_to(&self, target: JobState) -> bool {
use JobState::*;
if matches!((self, target), (Completed, Completed)) {
return true;
}
matches!(
(self, target),
(Pending, InProgress) | (Pending, Cancelled) |
(InProgress, Completed) | (InProgress, Failed) |
(InProgress, Stuck) | (InProgress, Cancelled) |
(Completed, Submitted) | (Completed, Failed) |
(Submitted, Accepted) | (Submitted, Failed) |
(Stuck, InProgress) | (Stuck, Failed) | (Stuck, Cancelled)
)
}
pub fn is_terminal(&self) -> bool {
matches!(self, Self::Accepted | Self::Failed | Self::Cancelled)
}
pub fn is_active(&self) -> bool {
!self.is_terminal()
}
pub fn is_parallel_blocking(&self) -> bool {
matches!(self, Self::Pending | Self::InProgress | Self::Stuck)
}
}
impl std::fmt::Display for JobState {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
let s = match self {
Self::Pending => "pending",
Self::InProgress => "in_progress",
Self::Completed => "completed",
Self::Submitted => "submitted",
Self::Accepted => "accepted",
Self::Failed => "failed",
Self::Stuck => "stuck",
Self::Cancelled => "cancelled",
};
write!(f, "{}", s)
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct StateTransition {
pub from: JobState,
pub to: JobState,
pub timestamp: DateTime<Utc>,
pub reason: Option<String>,
}
#[derive(Debug, Clone, Serialize)]
pub struct JobContext {
pub job_id: Uuid,
pub state: JobState,
pub user_id: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub requester_id: Option<String>,
pub conversation_id: Option<Uuid>,
pub title: String,
pub description: String,
pub category: Option<String>,
pub budget: Option<Decimal>,
pub budget_token: Option<String>,
pub bid_amount: Option<Decimal>,
pub estimated_cost: Option<Decimal>,
pub estimated_duration: Option<Duration>,
pub actual_cost: Decimal,
pub total_tokens_used: u64,
pub max_tokens: u64,
pub created_at: DateTime<Utc>,
pub started_at: Option<DateTime<Utc>>,
pub completed_at: Option<DateTime<Utc>>,
pub repair_attempts: u32,
pub transitions: Vec<StateTransition>,
pub metadata: serde_json::Value,
#[serde(skip)]
pub extra_env: Arc<HashMap<String, String>>,
#[serde(skip)]
pub http_interceptor: Option<Arc<dyn HttpInterceptor>>,
#[serde(skip)]
pub tool_output_stash: Arc<tokio::sync::RwLock<HashMap<String, String>>>,
pub user_timezone: String,
}
impl JobContext {
pub fn new(title: impl Into<String>, description: impl Into<String>) -> Self {
Self::with_user("default", title, description)
}
pub fn with_user(
user_id: impl Into<String>,
title: impl Into<String>,
description: impl Into<String>,
) -> Self {
Self {
job_id: Uuid::new_v4(),
state: JobState::Pending,
user_id: user_id.into(),
requester_id: None,
conversation_id: None,
title: title.into(),
description: description.into(),
category: None,
budget: None,
budget_token: None,
bid_amount: None,
estimated_cost: None,
estimated_duration: None,
actual_cost: Decimal::ZERO,
total_tokens_used: 0,
max_tokens: 0,
created_at: Utc::now(),
started_at: None,
completed_at: None,
repair_attempts: 0,
transitions: Vec::new(),
extra_env: Arc::new(HashMap::new()),
http_interceptor: None,
metadata: serde_json::Value::Null,
tool_output_stash: Arc::new(tokio::sync::RwLock::new(HashMap::new())),
user_timezone: "UTC".to_string(),
}
}
pub fn with_timezone(mut self, tz: impl Into<String>) -> Self {
self.user_timezone = tz.into();
self
}
pub fn with_requester_id(mut self, requester_id: impl Into<String>) -> Self {
self.requester_id = Some(requester_id.into());
self
}
pub fn transition_to(
&mut self,
new_state: JobState,
reason: Option<String>,
) -> Result<(), String> {
if !self.state.can_transition_to(new_state) {
return Err(format!(
"Cannot transition from {} to {}",
self.state, new_state
));
}
if self.state == new_state {
tracing::debug!(
job_id = %self.job_id,
state = %self.state,
"idempotent state transition (already in target state), skipping"
);
return Ok(());
}
let transition = StateTransition {
from: self.state,
to: new_state,
timestamp: Utc::now(),
reason,
};
self.transitions.push(transition);
const MAX_TRANSITIONS: usize = 200;
if self.transitions.len() > MAX_TRANSITIONS {
let drain_count = self.transitions.len() - MAX_TRANSITIONS;
self.transitions.drain(..drain_count);
}
self.state = new_state;
match new_state {
JobState::InProgress if self.started_at.is_none() => {
self.started_at = Some(Utc::now());
}
JobState::Completed | JobState::Accepted | JobState::Failed | JobState::Cancelled => {
self.completed_at = Some(Utc::now());
}
_ => {}
}
Ok(())
}
pub fn add_cost(&mut self, cost: Decimal) {
self.actual_cost += cost;
}
pub fn add_tokens(&mut self, tokens: u64) -> Result<(), TokenBudgetExceeded> {
self.total_tokens_used += tokens;
if self.max_tokens > 0 && self.total_tokens_used > self.max_tokens {
Err(TokenBudgetExceeded {
used: self.total_tokens_used,
limit: self.max_tokens,
})
} else {
Ok(())
}
}
pub fn budget_exceeded(&self) -> bool {
if let Some(ref budget) = self.budget {
self.actual_cost > *budget
} else {
false
}
}
pub fn elapsed(&self) -> Option<Duration> {
self.started_at.map(|start| {
let end = self.completed_at.unwrap_or_else(Utc::now);
let duration = end.signed_duration_since(start);
Duration::from_secs(duration.num_seconds().max(0) as u64)
})
}
pub fn mark_stuck(&mut self, reason: impl Into<String>) -> Result<(), String> {
self.transition_to(JobState::Stuck, Some(reason.into()))
}
pub fn attempt_recovery(&mut self) -> Result<(), String> {
if self.state != JobState::Stuck {
return Err("Job is not stuck".to_string());
}
self.repair_attempts += 1;
self.transition_to(JobState::InProgress, Some("Recovery attempt".to_string()))
}
}
impl Default for JobContext {
fn default() -> Self {
Self::with_user("default", "Untitled", "No description")
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_state_transitions() {
assert!(JobState::Pending.can_transition_to(JobState::InProgress));
assert!(JobState::InProgress.can_transition_to(JobState::Completed));
assert!(!JobState::Completed.can_transition_to(JobState::Pending));
assert!(!JobState::Accepted.can_transition_to(JobState::InProgress));
}
#[test]
fn test_completed_to_completed_is_idempotent() {
let mut ctx = JobContext::new("Test", "Idempotent completion test");
ctx.transition_to(JobState::InProgress, None).unwrap();
ctx.transition_to(JobState::Completed, Some("first".into()))
.unwrap();
assert_eq!(ctx.state, JobState::Completed);
let transitions_before = ctx.transitions.len();
let result = ctx.transition_to(JobState::Completed, Some("duplicate".into()));
assert!(
result.is_ok(),
"Completed -> Completed should be idempotent"
);
assert_eq!(ctx.state, JobState::Completed);
assert_eq!(
ctx.transitions.len(),
transitions_before,
"idempotent transition should not record a new history entry"
);
}
#[test]
fn test_other_self_transitions_still_rejected() {
assert!(!JobState::Pending.can_transition_to(JobState::Pending));
assert!(!JobState::InProgress.can_transition_to(JobState::InProgress));
assert!(!JobState::Failed.can_transition_to(JobState::Failed));
assert!(!JobState::Stuck.can_transition_to(JobState::Stuck));
assert!(!JobState::Submitted.can_transition_to(JobState::Submitted));
assert!(!JobState::Accepted.can_transition_to(JobState::Accepted));
assert!(!JobState::Cancelled.can_transition_to(JobState::Cancelled));
}
#[test]
fn test_terminal_states() {
assert!(JobState::Accepted.is_terminal());
assert!(JobState::Failed.is_terminal());
assert!(JobState::Cancelled.is_terminal());
assert!(!JobState::InProgress.is_terminal());
}
#[test]
fn test_job_context_transitions() {
let mut ctx = JobContext::new("Test", "Test job");
assert_eq!(ctx.state, JobState::Pending);
ctx.transition_to(JobState::InProgress, None).unwrap();
assert_eq!(ctx.state, JobState::InProgress);
assert!(ctx.started_at.is_some());
ctx.transition_to(JobState::Completed, Some("Done".to_string()))
.unwrap();
assert_eq!(ctx.state, JobState::Completed);
}
#[test]
fn test_transition_history_capped() {
let mut ctx = JobContext::new("Test", "Transition cap test");
ctx.transition_to(JobState::InProgress, None).unwrap();
for i in 0..250 {
ctx.mark_stuck(format!("stuck {}", i)).unwrap();
ctx.attempt_recovery().unwrap();
}
assert!(
ctx.transitions.len() <= 200,
"transitions should be capped at 200, got {}",
ctx.transitions.len()
);
}
#[test]
fn test_add_tokens_enforces_budget() {
let mut ctx = JobContext::new("Test", "Budget test");
ctx.max_tokens = 1000;
assert!(ctx.add_tokens(500).is_ok());
assert_eq!(ctx.total_tokens_used, 500);
assert!(ctx.add_tokens(600).is_err());
assert_eq!(ctx.total_tokens_used, 1100); }
#[test]
fn test_add_tokens_unlimited() {
let mut ctx = JobContext::new("Test", "No budget");
assert!(ctx.add_tokens(1_000_000).is_ok());
}
#[test]
fn test_budget_exceeded() {
let mut ctx = JobContext::new("Test", "Money test");
ctx.budget = Some(Decimal::new(100, 0)); assert!(!ctx.budget_exceeded());
ctx.add_cost(Decimal::new(50, 0));
assert!(!ctx.budget_exceeded());
ctx.add_cost(Decimal::new(60, 0));
assert!(ctx.budget_exceeded());
}
#[test]
fn test_budget_exceeded_none() {
let ctx = JobContext::new("Test", "No budget");
assert!(!ctx.budget_exceeded()); }
#[test]
fn test_stuck_recovery() {
let mut ctx = JobContext::new("Test", "Test job");
ctx.transition_to(JobState::InProgress, None).unwrap();
ctx.mark_stuck("Timed out").unwrap();
assert_eq!(ctx.state, JobState::Stuck);
ctx.attempt_recovery().unwrap();
assert_eq!(ctx.state, JobState::InProgress);
assert_eq!(ctx.repair_attempts, 1);
}
}