use serde::{Deserialize, Serialize};
use std::fmt;
use std::sync::atomic::{AtomicU64, Ordering};
use std::time::{Duration, Instant};
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
pub struct TaskId(u64);
impl TaskId {
pub fn new() -> Self {
static COUNTER: AtomicU64 = AtomicU64::new(1);
Self(COUNTER.fetch_add(1, Ordering::Relaxed))
}
#[must_use]
pub fn as_u64(&self) -> u64 {
self.0
}
#[must_use]
pub fn from_u64(id: u64) -> Self {
Self(id)
}
}
impl Default for TaskId {
fn default() -> Self {
Self::new()
}
}
impl fmt::Display for TaskId {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "#{}", self.0)
}
}
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
pub enum TaskState {
Pending,
Running,
Blocked {
await_point: String,
},
Completed,
Failed,
}
impl fmt::Display for TaskState {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Self::Pending => write!(f, "PENDING"),
Self::Running => write!(f, "RUNNING"),
Self::Blocked { await_point } => write!(f, "BLOCKED({await_point})"),
Self::Completed => write!(f, "COMPLETED"),
Self::Failed => write!(f, "FAILED"),
}
}
}
#[derive(Debug, Clone)]
pub struct TaskInfo {
pub id: TaskId,
pub name: String,
pub state: TaskState,
pub created_at: Instant,
pub last_updated: Instant,
pub poll_count: u64,
pub total_run_time: Duration,
pub parent: Option<TaskId>,
pub location: Option<String>,
}
impl TaskInfo {
#[must_use]
pub fn new(name: String) -> Self {
let now = Instant::now();
Self {
id: TaskId::new(),
name,
state: TaskState::Pending,
created_at: now,
last_updated: now,
poll_count: 0,
total_run_time: Duration::ZERO,
parent: None,
location: None,
}
}
pub fn update_state(&mut self, new_state: TaskState) {
self.state = new_state;
self.last_updated = Instant::now();
}
pub fn record_poll(&mut self, duration: Duration) {
self.poll_count += 1;
self.total_run_time += duration;
self.last_updated = Instant::now();
}
#[must_use]
pub fn age(&self) -> Duration {
self.created_at.elapsed()
}
#[must_use]
pub fn time_since_update(&self) -> Duration {
self.last_updated.elapsed()
}
#[must_use]
pub fn with_parent(mut self, parent: TaskId) -> Self {
self.parent = Some(parent);
self
}
#[must_use]
pub fn with_location(mut self, location: String) -> Self {
self.location = Some(location);
self
}
}
impl fmt::Display for TaskInfo {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(
f,
"Task {} [{}]: {} (polls: {}, runtime: {:.2}s, age: {:.2}s)",
self.id,
self.name,
self.state,
self.poll_count,
self.total_run_time.as_secs_f64(),
self.age().as_secs_f64()
)
}
}
#[derive(Debug, Clone, Default)]
pub struct TaskFilter {
pub state: Option<TaskState>,
pub name_pattern: Option<String>,
pub min_duration: Option<Duration>,
pub max_duration: Option<Duration>,
pub min_polls: Option<u64>,
pub max_polls: Option<u64>,
pub parent: Option<TaskId>,
pub root_only: bool,
}
impl TaskFilter {
#[must_use]
pub fn new() -> Self {
Self::default()
}
#[must_use]
pub fn with_state(mut self, state: TaskState) -> Self {
self.state = Some(state);
self
}
#[must_use]
pub fn with_name_pattern(mut self, pattern: impl Into<String>) -> Self {
self.name_pattern = Some(pattern.into());
self
}
#[must_use]
pub fn with_min_duration(mut self, duration: Duration) -> Self {
self.min_duration = Some(duration);
self
}
#[must_use]
pub fn with_max_duration(mut self, duration: Duration) -> Self {
self.max_duration = Some(duration);
self
}
#[must_use]
pub fn with_min_polls(mut self, count: u64) -> Self {
self.min_polls = Some(count);
self
}
#[must_use]
pub fn with_max_polls(mut self, count: u64) -> Self {
self.max_polls = Some(count);
self
}
#[must_use]
pub fn with_parent(mut self, parent: TaskId) -> Self {
self.parent = Some(parent);
self
}
#[must_use]
pub fn root_only(mut self) -> Self {
self.root_only = true;
self
}
#[must_use]
pub fn matches(&self, task: &TaskInfo) -> bool {
if let Some(ref state) = self.state {
if !self.state_matches(&task.state, state) {
return false;
}
}
if let Some(ref pattern) = self.name_pattern {
if !task.name.to_lowercase().contains(&pattern.to_lowercase()) {
return false;
}
}
if let Some(min) = self.min_duration {
if task.age() < min {
return false;
}
}
if let Some(max) = self.max_duration {
if task.age() > max {
return false;
}
}
if let Some(min) = self.min_polls {
if task.poll_count < min {
return false;
}
}
if let Some(max) = self.max_polls {
if task.poll_count > max {
return false;
}
}
if let Some(parent) = self.parent {
if task.parent != Some(parent) {
return false;
}
}
if self.root_only && task.parent.is_some() {
return false;
}
true
}
fn state_matches(&self, task_state: &TaskState, filter_state: &TaskState) -> bool {
match (task_state, filter_state) {
(TaskState::Pending, TaskState::Pending) => true,
(TaskState::Running, TaskState::Running) => true,
(TaskState::Blocked { .. }, TaskState::Blocked { .. }) => true,
(TaskState::Completed, TaskState::Completed) => true,
(TaskState::Failed, TaskState::Failed) => true,
_ => false,
}
}
pub fn filter<'a>(&self, tasks: impl IntoIterator<Item = &'a TaskInfo>) -> Vec<&'a TaskInfo> {
tasks.into_iter().filter(|t| self.matches(t)).collect()
}
pub fn filter_cloned(&self, tasks: impl IntoIterator<Item = TaskInfo>) -> Vec<TaskInfo> {
tasks.into_iter().filter(|t| self.matches(t)).collect()
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
pub enum TaskSortBy {
#[default]
Id,
Name,
Age,
Polls,
RunTime,
State,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
pub enum SortDirection {
#[default]
Ascending,
Descending,
}
pub fn sort_tasks(tasks: &mut [TaskInfo], sort_by: TaskSortBy, direction: SortDirection) {
tasks.sort_by(|a, b| {
let cmp = match sort_by {
TaskSortBy::Id => a.id.as_u64().cmp(&b.id.as_u64()),
TaskSortBy::Name => a.name.cmp(&b.name),
TaskSortBy::Age => a.created_at.cmp(&b.created_at),
TaskSortBy::Polls => a.poll_count.cmp(&b.poll_count),
TaskSortBy::RunTime => a.total_run_time.cmp(&b.total_run_time),
TaskSortBy::State => state_order(&a.state).cmp(&state_order(&b.state)),
};
match direction {
SortDirection::Ascending => cmp,
SortDirection::Descending => cmp.reverse(),
}
});
}
fn state_order(state: &TaskState) -> u8 {
match state {
TaskState::Running => 0,
TaskState::Blocked { .. } => 1,
TaskState::Pending => 2,
TaskState::Completed => 3,
TaskState::Failed => 4,
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_task_id_uniqueness() {
let id1 = TaskId::new();
let id2 = TaskId::new();
assert_ne!(id1, id2);
}
#[test]
fn test_task_info_creation() {
let task = TaskInfo::new("test_task".to_string());
assert_eq!(task.name, "test_task");
assert_eq!(task.state, TaskState::Pending);
assert_eq!(task.poll_count, 0);
}
#[test]
fn test_task_state_update() {
let mut task = TaskInfo::new("test".to_string());
task.update_state(TaskState::Running);
assert_eq!(task.state, TaskState::Running);
}
#[test]
fn test_task_poll_recording() {
let mut task = TaskInfo::new("test".to_string());
task.record_poll(Duration::from_millis(100));
assert_eq!(task.poll_count, 1);
assert_eq!(task.total_run_time, Duration::from_millis(100));
}
}