use async_trait::async_trait;
use log::{debug, error, info, warn};
use serde_json::Value;
use std::sync::Arc;
use tokio::sync::mpsc::UnboundedSender;
use tokio::time::{self, Duration, Instant};
use crate::backend::{ResultBackend, TaskMeta};
use crate::error::{ProtocolError, TaskError, TraceError};
use crate::protocol::Message;
use crate::task::{Request, ResultValue, Task, TaskEvent, TaskOptions, TaskStatus};
use crate::Celery;
pub(super) struct Tracer<T>
where
T: Task,
{
task: T,
event_tx: UnboundedSender<TaskEvent>,
backend: Option<Arc<dyn ResultBackend>>,
}
impl<T> Tracer<T>
where
T: Task,
{
fn new(
task: T,
event_tx: UnboundedSender<TaskEvent>,
backend: Option<Arc<dyn ResultBackend>>,
) -> Self {
if let Some(eta) = task.request().eta {
info!(
"Task {}[{}] received, ETA: {}",
task.name(),
task.request().id,
eta
);
} else {
info!("Task {}[{}] received", task.name(), task.request().id);
}
Self {
task,
event_tx,
backend,
}
}
fn task_id(&self) -> &str {
&self.task.request().id
}
async fn persist_meta(&self, meta: TaskMeta) {
if let Some(backend) = &self.backend {
if let Err(err) = backend.store_task_meta(meta).await {
error!("Failed storing result for task {}: {}", self.task_id(), err);
}
}
}
fn build_success_meta(&self, returned: &T::Returns) -> TaskMeta
where
T::Returns: ResultValue,
{
match TaskMeta::success(self.task_id(), returned) {
Ok(meta) => meta,
Err(err) => {
error!(
"Failed to serialize result for task {}: {}",
self.task_id(),
err
);
let mut fallback = TaskMeta::success(self.task_id(), &()).unwrap();
fallback.result = Some(Value::String(format!("{:?}", returned)));
fallback
}
}
}
}
#[async_trait]
impl<T> TracerTrait for Tracer<T>
where
T: Task,
{
async fn trace(&mut self) -> Result<(), TraceError> {
if self.is_expired() {
warn!(
"Task {}[{}] expired, discarding",
self.task.name(),
&self.task.request().id,
);
return Err(TraceError::ExpirationError);
}
self.event_tx
.send(TaskEvent::StatusChange(TaskStatus::Pending))
.unwrap_or_else(|_| {
error!("Failed sending task event");
});
self.persist_meta(TaskMeta::started(self.task_id())).await;
let start = Instant::now();
let result = match self.task.time_limit() {
Some(secs) => {
debug!("Executing task with {} second time limit", secs);
let duration = Duration::from_secs(secs as u64);
time::timeout(duration, self.task.run(self.task.request().params.clone()))
.await
.unwrap_or(Err(TaskError::TimeoutError))
}
None => self.task.run(self.task.request().params.clone()).await,
};
let duration = start.elapsed();
match result {
Ok(returned) => {
info!(
"Task {}[{}] succeeded in {}s: {:?}",
self.task.name(),
&self.task.request().id,
duration.as_secs_f32(),
returned
);
self.task.on_success(&returned).await;
self.event_tx
.send(TaskEvent::StatusChange(TaskStatus::Finished))
.unwrap_or_else(|_| {
error!("Failed sending task event");
});
let meta = self.build_success_meta(&returned);
self.persist_meta(meta).await;
Ok(())
}
Err(e) => {
let (should_retry, retry_eta) = match e {
TaskError::ExpectedError(ref reason) => {
warn!(
"Task {}[{}] failed with expected error: {}",
self.task.name(),
&self.task.request().id,
reason
);
(true, None)
}
TaskError::UnexpectedError(ref reason) => {
error!(
"Task {}[{}] failed with unexpected error: {}",
self.task.name(),
&self.task.request().id,
reason
);
(self.task.retry_for_unexpected(), None)
}
TaskError::TimeoutError => {
error!(
"Task {}[{}] timed out after {}s",
self.task.name(),
&self.task.request().id,
duration.as_secs_f32(),
);
(true, None)
}
TaskError::Retry(eta) => {
error!(
"Task {}[{}] triggered retry",
self.task.name(),
&self.task.request().id,
);
(true, eta)
}
};
self.task.on_failure(&e).await;
self.event_tx
.send(TaskEvent::StatusChange(TaskStatus::Finished))
.unwrap_or_else(|_| {
error!("Failed sending task event");
});
let retries_count = self.task.request().retries + 1;
if matches!(e, TaskError::Retry(_)) {
self.persist_meta(TaskMeta::retry(
self.task_id(),
&e,
retry_eta,
retries_count,
))
.await;
} else {
self.persist_meta(TaskMeta::failure(self.task_id(), &e))
.await;
}
if !should_retry {
return Err(TraceError::TaskError(e));
}
let retries = self.task.request().retries;
if let Some(max_retries) = self.task.max_retries() {
if retries >= max_retries {
warn!(
"Task {}[{}] retries exceeded",
self.task.name(),
&self.task.request().id,
);
return Err(TraceError::TaskError(e));
}
info!(
"Task {}[{}] retrying ({} / {})",
self.task.name(),
&self.task.request().id,
retries + 1,
max_retries,
);
} else {
info!(
"Task {}[{}] retrying ({} / inf)",
self.task.name(),
&self.task.request().id,
retries + 1,
);
}
Err(TraceError::Retry(
retry_eta.or_else(|| self.task.retry_eta()),
))
}
}
}
async fn wait(&self) {
if let Some(countdown) = self.task.request().countdown() {
time::sleep(countdown).await;
}
}
fn is_delayed(&self) -> bool {
self.task.request().is_delayed()
}
fn is_expired(&self) -> bool {
self.task.request().is_expired()
}
fn acks_late(&self) -> bool {
self.task.acks_late()
}
}
#[async_trait]
pub(super) trait TracerTrait: Send + Sync {
async fn trace(&mut self) -> Result<(), TraceError>;
async fn wait(&self);
fn is_delayed(&self) -> bool;
fn is_expired(&self) -> bool;
fn acks_late(&self) -> bool;
}
pub(super) type TraceBuilderResult = Result<Box<dyn TracerTrait>, ProtocolError>;
pub(super) type TraceBuilder = Box<
dyn Fn(
Arc<Celery>,
Message,
TaskOptions,
UnboundedSender<TaskEvent>,
String,
) -> TraceBuilderResult
+ Send
+ Sync
+ 'static,
>;
pub(super) fn build_tracer<T: Task + Send + 'static>(
app: Arc<Celery>,
message: Message,
mut options: TaskOptions,
event_tx: UnboundedSender<TaskEvent>,
hostname: String,
) -> TraceBuilderResult {
let mut request = Request::<T>::try_from_message(app.clone(), message)?;
request.hostname = Some(hostname);
T::DEFAULTS.override_other(&mut options);
let task = T::from_request(request, options);
Ok(Box::new(Tracer::<T>::new(
task,
event_tx,
app.result_backend(),
)))
}