Skip to main content

kojin_core/middleware/
metrics.rs

1use async_trait::async_trait;
2use std::sync::Arc;
3use std::sync::atomic::{AtomicU64, Ordering};
4
5use super::Middleware;
6use crate::error::KojinError;
7use crate::message::TaskMessage;
8
9/// Simple metrics middleware that tracks task counts.
10#[derive(Debug, Clone)]
11pub struct MetricsMiddleware {
12    inner: Arc<MetricsInner>,
13}
14
15#[derive(Debug)]
16struct MetricsInner {
17    tasks_started: AtomicU64,
18    tasks_succeeded: AtomicU64,
19    tasks_failed: AtomicU64,
20}
21
22impl MetricsMiddleware {
23    pub fn new() -> Self {
24        Self {
25            inner: Arc::new(MetricsInner {
26                tasks_started: AtomicU64::new(0),
27                tasks_succeeded: AtomicU64::new(0),
28                tasks_failed: AtomicU64::new(0),
29            }),
30        }
31    }
32
33    pub fn tasks_started(&self) -> u64 {
34        self.inner.tasks_started.load(Ordering::Relaxed)
35    }
36
37    pub fn tasks_succeeded(&self) -> u64 {
38        self.inner.tasks_succeeded.load(Ordering::Relaxed)
39    }
40
41    pub fn tasks_failed(&self) -> u64 {
42        self.inner.tasks_failed.load(Ordering::Relaxed)
43    }
44}
45
46impl Default for MetricsMiddleware {
47    fn default() -> Self {
48        Self::new()
49    }
50}
51
52#[async_trait]
53impl Middleware for MetricsMiddleware {
54    async fn before(&self, _message: &TaskMessage) -> Result<(), KojinError> {
55        self.inner.tasks_started.fetch_add(1, Ordering::Relaxed);
56        Ok(())
57    }
58
59    async fn after(
60        &self,
61        _message: &TaskMessage,
62        _result: &serde_json::Value,
63    ) -> Result<(), KojinError> {
64        self.inner.tasks_succeeded.fetch_add(1, Ordering::Relaxed);
65        Ok(())
66    }
67
68    async fn on_error(
69        &self,
70        _message: &TaskMessage,
71        _error: &KojinError,
72    ) -> Result<(), KojinError> {
73        self.inner.tasks_failed.fetch_add(1, Ordering::Relaxed);
74        Ok(())
75    }
76}
77
78#[cfg(test)]
79mod tests {
80    use super::*;
81
82    #[tokio::test]
83    async fn metrics_increments() {
84        let mw = MetricsMiddleware::new();
85        let msg = TaskMessage::new("test", "default", serde_json::json!({}));
86
87        mw.before(&msg).await.unwrap();
88        mw.before(&msg).await.unwrap();
89        mw.after(&msg, &serde_json::json!("ok")).await.unwrap();
90        mw.on_error(&msg, &KojinError::TaskFailed("err".into()))
91            .await
92            .unwrap();
93
94        assert_eq!(mw.tasks_started(), 2);
95        assert_eq!(mw.tasks_succeeded(), 1);
96        assert_eq!(mw.tasks_failed(), 1);
97    }
98}