use super::*;
use parking_lot::RwLock;
use std::collections::{HashMap, VecDeque};
use std::future::Future;
use std::sync::Arc;
pub trait TaskBackend: Send + Sync {
fn enqueue(&self, envelope: TaskEnvelope) -> Result<TaskId, String>;
fn get_status(&self, id: &TaskId) -> Option<TaskStatus>;
fn get_result(&self, id: &TaskId) -> Option<TaskResult>;
}
fn block_on_sync<F>(future: F) -> F::Output
where
F: Future + Send + 'static,
F::Output: Send + 'static,
{
let runtime = || {
tokio::runtime::Builder::new_current_thread()
.enable_all()
.build()
.expect("task backend runtime should build")
};
if tokio::runtime::Handle::try_current().is_ok() {
match std::thread::spawn(move || runtime().block_on(future)).join() {
Ok(output) => output,
Err(payload) => std::panic::resume_unwind(payload),
}
} else {
runtime().block_on(future)
}
}
pub struct ImmediateBackend {
registry: Arc<TaskRegistry>,
results: RwLock<HashMap<TaskId, TaskResult>>,
statuses: RwLock<HashMap<TaskId, TaskStatus>>,
}
impl ImmediateBackend {
#[must_use]
pub fn new(registry: Arc<TaskRegistry>) -> Self {
Self {
registry,
results: RwLock::new(HashMap::new()),
statuses: RwLock::new(HashMap::new()),
}
}
}
impl TaskBackend for ImmediateBackend {
fn enqueue(&self, mut envelope: TaskEnvelope) -> Result<TaskId, String> {
envelope.status = TaskStatus::Pending;
let id = envelope.id.clone();
BEFORE_TASK_PUBLISH.send(envelope.clone());
self.statuses
.write()
.insert(id.clone(), TaskStatus::Pending);
AFTER_TASK_PUBLISH.send(envelope.clone());
let Some(task) = self.registry.get(&envelope.task_name) else {
let error = format!("task '{}' is not registered", envelope.task_name);
let status = TaskStatus::Failed(error.clone());
self.statuses.write().insert(id.clone(), status.clone());
self.results.write().insert(id.clone(), Err(error.clone()));
return Err(error);
};
let max_retries = task.max_retries();
let mut attempt = 0;
loop {
attempt += 1;
let running = TaskStatus::Running;
self.statuses.write().insert(id.clone(), running.clone());
let mut running_envelope = envelope.clone();
running_envelope.status = running;
TASK_PRERUN.send(running_envelope);
match block_on_sync(task.run(envelope.args.clone())) {
Ok(value) => {
let status = TaskStatus::Completed;
self.results.write().insert(id.clone(), Ok(value));
self.statuses.write().insert(id.clone(), status.clone());
let mut completed_envelope = envelope.clone();
completed_envelope.status = status;
TASK_POSTRUN.send(completed_envelope);
return Ok(id);
}
Err(_error) if attempt <= max_retries => {
let status = TaskStatus::Retrying {
attempt,
max_retries,
};
self.statuses.write().insert(id.clone(), status.clone());
let mut retrying_envelope = envelope.clone();
retrying_envelope.status = status;
TASK_POSTRUN.send(retrying_envelope);
}
Err(error) => {
let status = TaskStatus::Failed(error.clone());
self.results.write().insert(id.clone(), Err(error));
self.statuses.write().insert(id.clone(), status.clone());
let mut failed_envelope = envelope.clone();
failed_envelope.status = status;
TASK_POSTRUN.send(failed_envelope);
return Ok(id);
}
}
}
}
fn get_status(&self, id: &TaskId) -> Option<TaskStatus> {
self.statuses.read().get(id).cloned()
}
fn get_result(&self, id: &TaskId) -> Option<TaskResult> {
self.results.read().get(id).cloned()
}
}
pub struct DummyBackend {
queue: RwLock<VecDeque<TaskEnvelope>>,
}
impl DummyBackend {
#[must_use]
pub fn new() -> Self {
Self {
queue: RwLock::new(VecDeque::new()),
}
}
#[must_use]
pub fn queued_tasks(&self) -> Vec<TaskEnvelope> {
self.queue.read().iter().cloned().collect()
}
#[must_use]
pub fn drain(&self) -> Vec<TaskEnvelope> {
self.queue.write().drain(..).collect()
}
}
impl Default for DummyBackend {
fn default() -> Self {
Self::new()
}
}
impl TaskBackend for DummyBackend {
fn enqueue(&self, mut envelope: TaskEnvelope) -> Result<TaskId, String> {
envelope.status = TaskStatus::Pending;
let id = envelope.id.clone();
BEFORE_TASK_PUBLISH.send(envelope.clone());
self.queue.write().push_back(envelope.clone());
AFTER_TASK_PUBLISH.send(envelope);
Ok(id)
}
fn get_status(&self, id: &TaskId) -> Option<TaskStatus> {
self.queue
.read()
.iter()
.find(|envelope| &envelope.id == id)
.map(|envelope| envelope.status.clone())
}
fn get_result(&self, _id: &TaskId) -> Option<TaskResult> {
None
}
}
#[cfg(test)]
mod tests {
use super::*;
use chrono::Utc;
use std::future::Future;
use std::pin::Pin;
use std::sync::Arc;
use std::sync::atomic::{AtomicUsize, Ordering};
struct EchoTask;
impl Task for EchoTask {
fn name(&self) -> &str {
"echo"
}
fn run(&self, args: TaskArgs) -> Pin<Box<dyn Future<Output = TaskResult> + Send>> {
Box::pin(async move {
Ok(args
.data
.get("message")
.cloned()
.unwrap_or_else(|| "missing".to_string()))
})
}
}
struct FlakyTask {
attempts: Arc<AtomicUsize>,
}
impl Task for FlakyTask {
fn name(&self) -> &str {
"flaky"
}
fn run(&self, _args: TaskArgs) -> Pin<Box<dyn Future<Output = TaskResult> + Send>> {
let attempts = Arc::clone(&self.attempts);
Box::pin(async move {
if attempts.fetch_add(1, Ordering::SeqCst) == 0 {
Err("try again".to_string())
} else {
Ok("eventually".to_string())
}
})
}
fn max_retries(&self) -> u32 {
1
}
}
struct FailingTask;
impl Task for FailingTask {
fn name(&self) -> &str {
"fail"
}
fn run(&self, _args: TaskArgs) -> Pin<Box<dyn Future<Output = TaskResult> + Send>> {
Box::pin(async { Err("boom".to_string()) })
}
}
fn envelope(task_name: &str) -> TaskEnvelope {
TaskEnvelope {
id: format!("{task_name}-1"),
task_name: task_name.to_string(),
args: TaskArgs::default(),
status: TaskStatus::Pending,
priority: TaskPriority::Normal,
created_at: Utc::now(),
}
}
#[test]
fn test_dummy_backend_enqueues() {
let backend = DummyBackend::new();
let queued_id = backend
.enqueue(envelope("noop"))
.expect("dummy backend should enqueue tasks");
assert_eq!(queued_id, "noop-1");
assert_eq!(backend.queued_tasks().len(), 1);
assert_eq!(backend.get_status(&queued_id), Some(TaskStatus::Pending));
assert_eq!(backend.get_result(&queued_id), None);
}
#[test]
fn test_dummy_backend_drain() {
let backend = DummyBackend::new();
backend
.enqueue(envelope("first"))
.expect("first task should enqueue");
backend
.enqueue(envelope("second"))
.expect("second task should enqueue");
let drained = backend.drain();
assert_eq!(drained.len(), 2);
assert!(backend.queued_tasks().is_empty());
assert_eq!(backend.get_status(&"first-1".to_string()), None);
}
#[test]
fn test_immediate_backend_executes_task() {
let mut registry = TaskRegistry::new();
registry.register(EchoTask);
let backend = ImmediateBackend::new(Arc::new(registry));
let mut args = TaskArgs::default();
args.data
.insert("message".to_string(), "hello tasks".to_string());
let id = backend
.enqueue(TaskEnvelope {
args,
..envelope("echo")
})
.expect("immediate backend should execute tasks");
assert_eq!(backend.get_status(&id), Some(TaskStatus::Completed));
assert_eq!(backend.get_result(&id), Some(Ok("hello tasks".to_string())));
}
#[test]
fn test_task_status_transitions() {
let mut registry = TaskRegistry::new();
registry.register(FailingTask);
let backend = ImmediateBackend::new(Arc::new(registry));
let dummy = DummyBackend::new();
let pending_id = dummy
.enqueue(envelope("queued"))
.expect("dummy backend should accept queued tasks");
let failed_id = backend
.enqueue(envelope("fail"))
.expect("failed executions should still return task ids");
assert_eq!(dummy.get_status(&pending_id), Some(TaskStatus::Pending));
assert_eq!(
backend.get_status(&failed_id),
Some(TaskStatus::Failed("boom".to_string()))
);
assert_eq!(
backend.get_result(&failed_id),
Some(Err("boom".to_string()))
);
}
#[test]
fn test_immediate_backend_retries_before_completion() {
let attempts = Arc::new(AtomicUsize::new(0));
let mut registry = TaskRegistry::new();
registry.register(FlakyTask {
attempts: Arc::clone(&attempts),
});
let backend = ImmediateBackend::new(Arc::new(registry));
let id = backend
.enqueue(envelope("flaky"))
.expect("retrying task should complete");
assert_eq!(attempts.load(Ordering::SeqCst), 2);
assert_eq!(backend.get_status(&id), Some(TaskStatus::Completed));
assert_eq!(backend.get_result(&id), Some(Ok("eventually".to_string())));
}
}