Skip to main content

kojin_core/middleware/
tracing_mw.rs

1use async_trait::async_trait;
2
3use super::Middleware;
4use crate::error::KojinError;
5use crate::message::TaskMessage;
6
7/// Middleware that emits tracing spans for task execution.
8#[derive(Debug, Default)]
9pub struct TracingMiddleware;
10
11#[async_trait]
12impl Middleware for TracingMiddleware {
13    async fn before(&self, message: &TaskMessage) -> Result<(), KojinError> {
14        tracing::info!(
15            task_id = %message.id,
16            task_name = %message.task_name,
17            queue = %message.queue,
18            retries = message.retries,
19            "Task starting"
20        );
21        Ok(())
22    }
23
24    async fn after(
25        &self,
26        message: &TaskMessage,
27        _result: &serde_json::Value,
28    ) -> Result<(), KojinError> {
29        tracing::info!(
30            task_id = %message.id,
31            task_name = %message.task_name,
32            "Task completed"
33        );
34        Ok(())
35    }
36
37    async fn on_error(&self, message: &TaskMessage, error: &KojinError) -> Result<(), KojinError> {
38        tracing::error!(
39            task_id = %message.id,
40            task_name = %message.task_name,
41            error = %error,
42            retries = message.retries,
43            max_retries = message.max_retries,
44            "Task failed"
45        );
46        Ok(())
47    }
48}
49
50#[cfg(test)]
51mod tests {
52    use super::*;
53
54    #[tokio::test]
55    async fn tracing_middleware_does_not_error() {
56        let mw = TracingMiddleware;
57        let msg = TaskMessage::new("test_task", "default", serde_json::json!({}));
58
59        assert!(mw.before(&msg).await.is_ok());
60        assert!(mw.after(&msg, &serde_json::json!("ok")).await.is_ok());
61        assert!(
62            mw.on_error(&msg, &KojinError::TaskFailed("err".into()))
63                .await
64                .is_ok()
65        );
66    }
67}