use crate::budget::BudgetManager;
use crate::types::AgentId;
use anyhow::Result;
use chrono::{DateTime, Utc};
use parking_lot::Mutex;
use serde::{Deserialize, Serialize};
use std::collections::{BinaryHeap, HashMap};
use std::sync::Arc;
use uuid::Uuid;
#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Serialize, Deserialize, Default)]
pub enum Priority {
Low = 0,
#[default]
Normal = 1,
High = 2,
Critical = 3,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
pub enum TaskStatus {
Queued,
Running,
Completed,
Failed,
Cancelled,
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
pub struct ScheduledTask {
pub id: Uuid,
pub agent_id: Option<AgentId>,
pub description: String,
pub priority: Priority,
pub created_at: DateTime<Utc>,
pub status: TaskStatus,
#[serde(skip_serializing_if = "Option::is_none")]
pub error: Option<String>,
}
impl PartialOrd for ScheduledTask {
fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
Some(self.cmp(other))
}
}
impl Ord for ScheduledTask {
fn cmp(&self, other: &Self) -> std::cmp::Ordering {
self.priority
.cmp(&other.priority)
.then_with(|| other.created_at.cmp(&self.created_at))
}
}
impl ScheduledTask {
pub fn new(description: String, priority: Priority) -> Self {
Self {
id: Uuid::new_v4(),
agent_id: None,
description,
priority,
created_at: Utc::now(),
status: TaskStatus::Queued,
error: None,
}
}
pub fn for_agent(agent_id: AgentId, description: String, priority: Priority) -> Self {
Self {
id: Uuid::new_v4(),
agent_id: Some(agent_id),
description,
priority,
created_at: Utc::now(),
status: TaskStatus::Queued,
error: None,
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct SchedulerStats {
pub queued: usize,
pub running: usize,
pub completed: usize,
pub failed: usize,
pub max_concurrent: usize,
pub rate_limit_per_minute: u32,
}
impl Default for SchedulerStats {
fn default() -> Self {
Self {
queued: 0,
running: 0,
completed: 0,
failed: 0,
max_concurrent: 5,
rate_limit_per_minute: 60,
}
}
}
#[derive(Debug, Clone)]
struct RateLimiter {
window: Vec<DateTime<Utc>>,
window_secs: u64,
max_requests: u32,
}
impl RateLimiter {
fn new(window_secs: u64, max_requests: u32) -> Self {
Self {
window: Vec::new(),
window_secs,
max_requests,
}
}
fn allow(&mut self) -> bool {
let now = Utc::now();
let cutoff = now - chrono::Duration::seconds(self.window_secs as i64);
self.window.retain(|t| *t > cutoff);
if self.window.len() >= self.max_requests as usize {
return false;
}
self.window.push(now);
true
}
fn remaining(&self) -> u32 {
let now = Utc::now();
let cutoff = now - chrono::Duration::seconds(self.window_secs as i64);
let active = self.window.iter().filter(|t| **t > cutoff).count();
self.max_requests.saturating_sub(active as u32)
}
}
pub struct AgentScheduler {
queue: Arc<Mutex<BinaryHeap<ScheduledTask>>>,
running: Arc<Mutex<HashMap<Uuid, ScheduledTask>>>,
max_concurrent: usize,
rate_limiter: Arc<Mutex<RateLimiter>>,
zombie_timeout_secs: u64,
task_start_times: Arc<Mutex<HashMap<Uuid, DateTime<Utc>>>>,
budget_manager: Option<Arc<BudgetManager>>,
}
impl AgentScheduler {
pub fn new(
max_concurrent: usize,
rate_limit_per_minute: u32,
zombie_timeout_secs: u64,
) -> Self {
Self {
queue: Arc::new(Mutex::new(BinaryHeap::new())),
running: Arc::new(Mutex::new(HashMap::new())),
max_concurrent,
rate_limiter: Arc::new(Mutex::new(RateLimiter::new(60, rate_limit_per_minute))),
zombie_timeout_secs,
task_start_times: Arc::new(Mutex::new(HashMap::new())),
budget_manager: None,
}
}
pub fn set_budget_manager(&mut self, bm: Arc<BudgetManager>) {
self.budget_manager = Some(bm);
}
pub fn submit(&self, mut task: ScheduledTask) -> Result<Uuid> {
task.status = TaskStatus::Queued;
let id = task.id;
let mut queue = self.queue.lock();
queue.push(task);
tracing::debug!(
task_id = %id,
queue_len = queue.len(),
"Task submitted to scheduler"
);
Ok(id)
}
pub fn next_task(&self) -> Option<ScheduledTask> {
{
let running = self.running.lock();
if running.len() >= self.max_concurrent {
tracing::debug!(
running = running.len(),
max = self.max_concurrent,
"Max concurrent limit reached"
);
return None;
}
}
{
let mut limiter = self.rate_limiter.lock();
if !limiter.allow() {
tracing::debug!(remaining = limiter.remaining(), "Rate limit exceeded");
return None;
}
}
let mut discarded: usize = 0;
let mut task = loop {
let task_opt = {
let mut queue = self.queue.lock();
queue.pop() };
match task_opt {
Some(t) => {
if let (Some(ref bm), Some(ref agent_id)) = (&self.budget_manager, &t.agent_id)
{
if !bm.can_schedule(agent_id) {
tracing::warn!(
agent_id = %agent_id,
"Agent budget exhausted, skipping task"
);
discarded += 1;
continue; }
}
break t;
}
None => {
if discarded > 0 {
tracing::info!(discarded, "All queued tasks had exhausted budgets");
}
return None;
}
}
};
if discarded > 0 {
tracing::info!(discarded, "Skipped tasks with exhausted budgets");
}
task.status = TaskStatus::Running;
{
let mut start_times = self.task_start_times.lock();
start_times.insert(task.id, Utc::now());
}
{
let mut running = self.running.lock();
running.insert(task.id, task.clone());
}
tracing::info!(
task_id = %task.id,
priority = ?task.priority,
running = self.running.lock().len(),
"Task started by scheduler"
);
if let (Some(ref bm), Some(ref agent_id)) = (&self.budget_manager, &task.agent_id) {
if let Err(e) = bm.track_call(agent_id) {
tracing::warn!(
agent_id = %agent_id,
error = %e,
"Budget exceeded during task track_call"
);
}
}
Some(task)
}
pub fn complete_task(&self, task_id: Uuid) -> Result<()> {
let task = {
let mut running = self.running.lock();
running.remove(&task_id)
};
match task {
Some(mut t) => {
t.status = TaskStatus::Completed;
{
let mut start_times = self.task_start_times.lock();
start_times.remove(&task_id);
}
tracing::info!(task_id = %task_id, "Task completed");
Ok(())
}
None => {
tracing::warn!(task_id = %task_id, "Attempted to complete unknown task");
Err(anyhow::anyhow!("task not found"))
}
}
}
pub fn fail_task(&self, task_id: Uuid, error: &str) -> Result<()> {
let task = {
let mut running = self.running.lock();
running.remove(&task_id)
};
match task {
Some(mut t) => {
t.status = TaskStatus::Failed;
t.error = Some(error.to_string());
{
let mut start_times = self.task_start_times.lock();
start_times.remove(&task_id);
}
tracing::warn!(task_id = %task_id, error = %error, "Task failed");
Ok(())
}
None => {
tracing::warn!(task_id = %task_id, "Attempted to fail unknown task");
Err(anyhow::anyhow!("task not found"))
}
}
}
pub fn reap_zombies(&self) -> Vec<Uuid> {
let now = Utc::now();
let timeout = chrono::Duration::seconds(self.zombie_timeout_secs as i64);
let mut start_times = self.task_start_times.lock();
let mut running = self.running.lock();
let mut reaped = Vec::new();
let zombie_ids: Vec<Uuid> = start_times
.iter()
.filter(|(_, start)| now - **start > timeout)
.map(|(id, _)| *id)
.collect();
for id in zombie_ids {
if let Some(mut task) = running.remove(&id) {
task.status = TaskStatus::Failed;
task.error = Some(format!(
"zombie: ran for >{} seconds",
self.zombie_timeout_secs
));
reaped.push(id);
tracing::warn!(
task_id = %id,
duration_secs = self.zombie_timeout_secs,
"Zombie task reaped"
);
}
start_times.remove(&id);
}
reaped
}
pub fn start_task(&self, task_id: Uuid) -> Result<()> {
let task = {
let mut queue = self.queue.lock();
let all: Vec<ScheduledTask> = queue.drain().collect();
let mut found: Option<ScheduledTask> = None;
let remaining: Vec<ScheduledTask> = all
.into_iter()
.filter(|t| {
if t.id == task_id {
found = Some(t.clone());
false
} else {
true
}
})
.collect();
*queue = remaining.into_iter().collect();
found
};
match task {
Some(mut task) => {
task.status = TaskStatus::Running;
let mut start_times = self.task_start_times.lock();
start_times.insert(task.id, Utc::now());
let mut running = self.running.lock();
running.insert(task.id, task);
Ok(())
}
None => Err(anyhow::anyhow!("task {} not found in queue", task_id)),
}
}
pub fn cancel_task(&self, task_id: Uuid) -> Result<()> {
let mut queue = self.queue.lock();
let all: Vec<ScheduledTask> = queue.drain().collect();
let mut found = false;
let remaining: Vec<ScheduledTask> = all
.into_iter()
.filter(|t| {
if t.id == task_id && t.status == TaskStatus::Queued {
found = true;
false
} else {
true
}
})
.collect();
*queue = remaining.into_iter().collect();
if found {
tracing::info!(task_id = %task_id, "Task cancelled from queue");
Ok(())
} else {
tracing::warn!(task_id = %task_id, "Task not found in queue for cancellation");
Err(anyhow::anyhow!("task not found in queue"))
}
}
pub fn stats(&self) -> SchedulerStats {
let queue = self.queue.lock();
let running = self.running.lock();
let rate_limiter = self.rate_limiter.lock();
let _completed = 0usize;
let _failed = 0usize;
SchedulerStats {
queued: queue.len(),
running: running.len(),
completed: _completed,
failed: _failed,
max_concurrent: self.max_concurrent,
rate_limit_per_minute: rate_limiter.max_requests,
}
}
pub fn rate_limit_remaining(&self) -> u32 {
self.rate_limiter.lock().remaining()
}
pub fn queued_tasks(&self) -> Vec<ScheduledTask> {
let heap = self.queue.lock();
let mut tasks: Vec<ScheduledTask> = heap.iter().cloned().collect();
tasks.sort_by_key(|a| a.priority);
tasks
}
pub fn running_tasks(&self) -> Vec<ScheduledTask> {
self.running.lock().values().cloned().collect()
}
}
impl Default for AgentScheduler {
fn default() -> Self {
Self::new(5, 60, 300)
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::thread;
use std::time::Duration;
#[test]
fn test_task_creation() {
let task = ScheduledTask::new("Test task".into(), Priority::Normal);
assert_eq!(task.status, TaskStatus::Queued);
assert!(task.agent_id.is_none());
assert!(!task.error.is_some());
}
#[test]
fn test_task_creation_for_agent() {
let agent_id = AgentId::new_v4();
let task = ScheduledTask::for_agent(agent_id, "Agent task".into(), Priority::High);
assert_eq!(task.agent_id, Some(agent_id));
assert_eq!(task.priority, Priority::High);
}
#[test]
fn test_priority_ordering() {
assert!(Priority::Critical > Priority::High);
assert!(Priority::High > Priority::Normal);
assert!(Priority::Normal > Priority::Low);
assert!(Priority::Critical > Priority::Normal);
assert!(Priority::Critical > Priority::Low);
assert!(Priority::High > Priority::Low);
}
#[test]
fn test_priority_ordering_eq() {
assert_eq!(Priority::Low, Priority::Low);
assert_eq!(Priority::Normal, Priority::Normal);
assert_eq!(Priority::High, Priority::High);
assert_eq!(Priority::Critical, Priority::Critical);
}
#[test]
fn test_submit_and_next_high_priority_first() {
let scheduler = AgentScheduler::new(10, 10_000, 60);
scheduler
.submit(ScheduledTask::new("Low priority".into(), Priority::Low))
.unwrap();
scheduler
.submit(ScheduledTask::new("High priority".into(), Priority::High))
.unwrap();
scheduler
.submit(ScheduledTask::new(
"Normal priority".into(),
Priority::Normal,
))
.unwrap();
let next = scheduler.next_task().unwrap();
assert_eq!(next.priority, Priority::High);
let next = scheduler.next_task().unwrap();
assert_eq!(next.priority, Priority::Normal);
let next = scheduler.next_task().unwrap();
assert_eq!(next.priority, Priority::Low);
}
#[test]
fn test_submit_and_next_critical_first() {
let scheduler = AgentScheduler::new(10, 10_000, 60);
scheduler
.submit(ScheduledTask::new("Low".into(), Priority::Low))
.unwrap();
scheduler
.submit(ScheduledTask::new("Normal".into(), Priority::Normal))
.unwrap();
scheduler
.submit(ScheduledTask::new("High".into(), Priority::High))
.unwrap();
scheduler
.submit(ScheduledTask::new("Critical".into(), Priority::Critical))
.unwrap();
let next = scheduler.next_task().unwrap();
assert_eq!(next.priority, Priority::Critical);
let next = scheduler.next_task().unwrap();
assert_eq!(next.priority, Priority::High);
let next = scheduler.next_task().unwrap();
assert_eq!(next.priority, Priority::Normal);
let next = scheduler.next_task().unwrap();
assert_eq!(next.priority, Priority::Low);
}
#[test]
fn test_submit_multiple_same_priority() {
let scheduler = AgentScheduler::new(10, 10_000, 60);
scheduler
.submit(ScheduledTask::new("First".into(), Priority::Normal))
.unwrap();
scheduler
.submit(ScheduledTask::new("Second".into(), Priority::Normal))
.unwrap();
scheduler
.submit(ScheduledTask::new("Third".into(), Priority::Normal))
.unwrap();
let mut descriptions = Vec::new();
for _ in 0..3 {
let next = scheduler.next_task().unwrap();
assert_eq!(next.priority, Priority::Normal);
descriptions.push(next.description);
}
descriptions.sort();
assert_eq!(descriptions, vec!["First", "Second", "Third"]);
}
#[test]
fn test_max_concurrent_blocks() {
let scheduler = AgentScheduler::new(2, 10_000, 60);
scheduler
.submit(ScheduledTask::new("Task 1".into(), Priority::Normal))
.unwrap();
scheduler
.submit(ScheduledTask::new("Task 2".into(), Priority::Normal))
.unwrap();
scheduler
.submit(ScheduledTask::new("Task 3".into(), Priority::Normal))
.unwrap();
assert!(scheduler.next_task().is_some());
assert!(scheduler.next_task().is_some());
assert!(scheduler.next_task().is_none());
}
#[test]
fn test_max_concurrent_allows_when_slot_frees() {
let scheduler = AgentScheduler::new(2, 10_000, 60);
let _ = scheduler
.submit(ScheduledTask::new("Task 1".into(), Priority::Normal))
.unwrap();
let _id2 = scheduler
.submit(ScheduledTask::new("Task 2".into(), Priority::Normal))
.unwrap();
let t1 = scheduler.next_task().unwrap(); let t2 = scheduler.next_task().unwrap(); assert!(scheduler.next_task().is_none());
scheduler.complete_task(t1.id).unwrap();
scheduler.complete_task(t2.id).unwrap();
let _id3 = scheduler
.submit(ScheduledTask::new("Task 3".into(), Priority::Normal))
.unwrap();
let task = scheduler.next_task().unwrap();
assert_eq!(task.description, "Task 3");
scheduler.complete_task(task.id).unwrap();
}
#[test]
fn test_complete_task_removes_from_running() {
let scheduler = AgentScheduler::new(2, 10_000, 60);
let task = ScheduledTask::new("Test".into(), Priority::Normal);
let id = scheduler.submit(task).unwrap();
let _ = scheduler.next_task();
scheduler.complete_task(id).unwrap();
let stats = scheduler.stats();
assert_eq!(stats.running, 0);
}
#[test]
fn test_complete_unknown_task_returns_error() {
let scheduler = AgentScheduler::new(2, 10_000, 60);
let result = scheduler.complete_task(Uuid::new_v4());
assert!(result.is_err());
}
#[test]
fn test_fail_task_sets_error() {
let scheduler = AgentScheduler::new(2, 10_000, 60);
let task = ScheduledTask::new("Test".into(), Priority::Normal);
let id = scheduler.submit(task).unwrap();
let _ = scheduler.next_task();
scheduler.fail_task(id, "Something went wrong").unwrap();
let running = scheduler.running.lock();
assert!(!running.contains_key(&id));
}
#[test]
fn test_cancel_queued_task() {
let scheduler = AgentScheduler::new(2, 10_000, 60);
let id = scheduler
.submit(ScheduledTask::new("To cancel".into(), Priority::Normal))
.unwrap();
scheduler.cancel_task(id).unwrap();
assert!(scheduler.next_task().is_none());
}
#[test]
fn test_cancel_running_task_fails() {
let scheduler = AgentScheduler::new(2, 10_000, 60);
let id = scheduler
.submit(ScheduledTask::new("Running".into(), Priority::Normal))
.unwrap();
let _ = scheduler.next_task();
let result = scheduler.cancel_task(id);
assert!(result.is_err());
}
#[test]
fn test_cancel_unknown_task_fails() {
let scheduler = AgentScheduler::new(2, 10_000, 60);
let result = scheduler.cancel_task(Uuid::new_v4());
assert!(result.is_err());
}
#[test]
fn test_stats_tracking() {
let scheduler = AgentScheduler::new(2, 60, 60);
let id1 = scheduler
.submit(ScheduledTask::new("Queued".into(), Priority::Normal))
.unwrap();
scheduler
.submit(ScheduledTask::new("Queued 2".into(), Priority::Low))
.unwrap();
let started = scheduler.next_task().unwrap();
assert_eq!(started.id, id1);
let stats = scheduler.stats();
assert_eq!(stats.queued, 1); assert_eq!(stats.running, 1);
assert_eq!(stats.max_concurrent, 2);
assert_eq!(stats.rate_limit_per_minute, 60);
}
#[test]
fn test_reap_zombies() {
let scheduler = AgentScheduler::new(2, 10_000, 1);
let id = scheduler
.submit(ScheduledTask::new("Zombie".into(), Priority::Normal))
.unwrap();
let _ = scheduler.next_task();
thread::sleep(Duration::from_secs(2));
let reaped = scheduler.reap_zombies();
assert!(reaped.contains(&id));
assert!(scheduler.running.lock().get(&id).is_none());
}
#[test]
fn test_reap_zombies_no_zombies() {
let scheduler = AgentScheduler::new(2, 10_000, 60);
let id = scheduler
.submit(ScheduledTask::new("Normal".into(), Priority::Normal))
.unwrap();
let _ = scheduler.next_task();
let reaped = scheduler.reap_zombies();
assert!(reaped.is_empty());
assert!(scheduler.running.lock().get(&id).is_some());
}
#[test]
fn test_rate_limiter_basic() {
let mut limiter = RateLimiter::new(60, 3);
assert!(limiter.allow());
assert!(limiter.allow());
assert!(limiter.allow());
assert!(!limiter.allow());
}
#[test]
fn test_rate_limiter_remaining() {
let limiter = RateLimiter::new(60, 3);
assert_eq!(limiter.remaining(), 3);
let mut limiter = RateLimiter::new(60, 3);
limiter.allow();
limiter.allow();
assert_eq!(limiter.remaining(), 1);
}
#[test]
fn test_rate_limiter_tracks_per_scheduler() {
let scheduler = AgentScheduler::new(10, 5, 60);
for i in 0..5 {
scheduler
.submit(ScheduledTask::new(format!("T{}", i), Priority::Normal))
.unwrap();
let _ = scheduler.next_task();
}
assert!(scheduler.next_task().is_none());
assert_eq!(scheduler.rate_limit_remaining(), 0);
}
#[test]
fn test_queued_tasks_inspection() {
let scheduler = AgentScheduler::new(2, 10_000, 60);
scheduler
.submit(ScheduledTask::new("A".into(), Priority::Low))
.unwrap();
scheduler
.submit(ScheduledTask::new("B".into(), Priority::High))
.unwrap();
scheduler
.submit(ScheduledTask::new("C".into(), Priority::Normal))
.unwrap();
let queued = scheduler.queued_tasks();
assert_eq!(queued.len(), 3);
assert_eq!(queued.last().unwrap().description, "B");
}
#[test]
fn test_running_tasks_inspection() {
let scheduler = AgentScheduler::new(2, 10_000, 60);
scheduler
.submit(ScheduledTask::new("R1".into(), Priority::Normal))
.unwrap();
scheduler
.submit(ScheduledTask::new("R2".into(), Priority::Normal))
.unwrap();
let _ = scheduler.next_task();
let _ = scheduler.next_task();
let running = scheduler.running_tasks();
assert_eq!(running.len(), 2);
}
#[test]
fn test_default_scheduler() {
let scheduler = AgentScheduler::default();
let stats = scheduler.stats();
assert_eq!(stats.max_concurrent, 5);
assert_eq!(stats.rate_limit_per_minute, 60);
}
#[test]
fn test_budget_manager_integration_skips_exhausted_agent() {
use crate::budget::{BudgetLimit, BudgetManager};
let scheduler = Arc::new(Mutex::new(AgentScheduler::new(2, 10_000, 60)));
let budget_manager = Arc::new(BudgetManager::new());
let agent_id = AgentId::new_v4();
budget_manager.set_budget(BudgetLimit {
agent_id,
token_budget: 1000,
calls_budget: 1,
window_secs: 60,
});
scheduler
.lock()
.set_budget_manager(Arc::clone(&budget_manager));
scheduler
.lock()
.submit(ScheduledTask::for_agent(
agent_id,
"Task 1".into(),
Priority::Normal,
))
.unwrap();
scheduler
.lock()
.submit(ScheduledTask::for_agent(
agent_id,
"Task 2".into(),
Priority::Normal,
))
.unwrap();
let task1 = scheduler.lock().next_task();
assert!(task1.is_some());
scheduler.lock().complete_task(task1.unwrap().id).unwrap();
let task2 = scheduler.lock().next_task();
assert!(task2.is_none());
}
#[test]
fn test_budget_manager_allows_different_agents() {
use crate::budget::{BudgetLimit, BudgetManager};
let scheduler = Arc::new(Mutex::new(AgentScheduler::new(2, 10_000, 60)));
let budget_manager = Arc::new(BudgetManager::new());
let agent1 = AgentId::new_v4();
let agent2 = AgentId::new_v4();
for agent_id in [&agent1, &agent2] {
budget_manager.set_budget(BudgetLimit {
agent_id: *agent_id,
token_budget: 1000,
calls_budget: 3,
window_secs: 60,
});
}
scheduler
.lock()
.set_budget_manager(Arc::clone(&budget_manager));
scheduler
.lock()
.submit(ScheduledTask::for_agent(
agent1,
"A1".into(),
Priority::Normal,
))
.unwrap();
scheduler
.lock()
.submit(ScheduledTask::for_agent(
agent2,
"B1".into(),
Priority::Normal,
))
.unwrap();
let t1 = scheduler.lock().next_task().unwrap();
let t2 = scheduler.lock().next_task().unwrap();
assert_ne!(t1.description, t2.description);
}
#[test]
fn test_budget_manager_task_without_agent_id() {
use crate::budget::{BudgetLimit, BudgetManager};
let scheduler = Arc::new(Mutex::new(AgentScheduler::new(2, 10_000, 60)));
let budget_manager = Arc::new(BudgetManager::new());
scheduler
.lock()
.set_budget_manager(Arc::clone(&budget_manager));
scheduler
.lock()
.submit(ScheduledTask::new("No agent".into(), Priority::Normal))
.unwrap();
let task = scheduler.lock().next_task();
assert!(task.is_some());
}
#[test]
fn test_budget_manager_not_set_skips_check() {
let scheduler = Arc::new(Mutex::new(AgentScheduler::new(2, 10_000, 60)));
scheduler
.lock()
.submit(ScheduledTask::new("Any task".into(), Priority::Normal))
.unwrap();
let task = scheduler.lock().next_task();
assert!(task.is_some());
}
}