Skip to main content

kojin_core/
memory_result_backend.rs

1use async_trait::async_trait;
2use std::collections::HashMap;
3use std::sync::Mutex;
4use std::time::Duration;
5
6use crate::error::{KojinError, TaskResult};
7use crate::result_backend::ResultBackend;
8use crate::task_id::TaskId;
9
10/// In-memory result backend for development and testing.
11///
12/// Stores task results in a `HashMap` protected by a `std::sync::Mutex`.
13/// This is cheap to construct and requires no external services, but results
14/// are lost when the process exits and the `Mutex` may become a bottleneck
15/// under very high concurrency. For production use, prefer
16/// `RedisResultBackend` (from `kojin-redis`) or
17/// `PostgresResultBackend` (from `kojin-postgres`).
18#[derive(Debug, Default)]
19pub struct MemoryResultBackend {
20    results: Mutex<HashMap<String, serde_json::Value>>,
21    groups: Mutex<HashMap<String, GroupState>>,
22}
23
24#[derive(Debug, Clone)]
25struct GroupState {
26    #[allow(dead_code)]
27    total: u32,
28    completed: u32,
29    results: Vec<serde_json::Value>,
30}
31
32impl MemoryResultBackend {
33    /// Create a new, empty in-memory result backend.
34    pub fn new() -> Self {
35        Self::default()
36    }
37}
38
39#[async_trait]
40impl ResultBackend for MemoryResultBackend {
41    async fn store(&self, id: &TaskId, result: &serde_json::Value) -> TaskResult<()> {
42        self.results
43            .lock()
44            .unwrap()
45            .insert(id.to_string(), result.clone());
46        Ok(())
47    }
48
49    async fn get(&self, id: &TaskId) -> TaskResult<Option<serde_json::Value>> {
50        Ok(self.results.lock().unwrap().get(&id.to_string()).cloned())
51    }
52
53    async fn wait(&self, id: &TaskId, timeout: Duration) -> TaskResult<serde_json::Value> {
54        let deadline = tokio::time::Instant::now() + timeout;
55        loop {
56            if let Some(result) = self.get(id).await? {
57                return Ok(result);
58            }
59            if tokio::time::Instant::now() >= deadline {
60                return Err(KojinError::Timeout(timeout));
61            }
62            tokio::time::sleep(Duration::from_millis(50)).await;
63        }
64    }
65
66    async fn delete(&self, id: &TaskId) -> TaskResult<()> {
67        self.results.lock().unwrap().remove(&id.to_string());
68        Ok(())
69    }
70
71    async fn init_group(&self, group_id: &str, total: u32) -> TaskResult<()> {
72        self.groups.lock().unwrap().insert(
73            group_id.to_string(),
74            GroupState {
75                total,
76                completed: 0,
77                results: Vec::new(),
78            },
79        );
80        Ok(())
81    }
82
83    async fn complete_group_member(
84        &self,
85        group_id: &str,
86        _task_id: &TaskId,
87        result: &serde_json::Value,
88    ) -> TaskResult<u32> {
89        let mut groups = self.groups.lock().unwrap();
90        let state = groups
91            .get_mut(group_id)
92            .ok_or_else(|| KojinError::ResultBackend(format!("group not found: {group_id}")))?;
93        state.completed += 1;
94        state.results.push(result.clone());
95        Ok(state.completed)
96    }
97
98    async fn get_group_results(&self, group_id: &str) -> TaskResult<Vec<serde_json::Value>> {
99        let groups = self.groups.lock().unwrap();
100        let state = groups
101            .get(group_id)
102            .ok_or_else(|| KojinError::ResultBackend(format!("group not found: {group_id}")))?;
103        Ok(state.results.clone())
104    }
105}
106
107#[cfg(test)]
108mod tests {
109    use super::*;
110
111    #[tokio::test]
112    async fn store_and_get() {
113        let backend = MemoryResultBackend::new();
114        let id = TaskId::new();
115        let value = serde_json::json!({"result": 42});
116
117        backend.store(&id, &value).await.unwrap();
118        let got = backend.get(&id).await.unwrap();
119        assert_eq!(got, Some(value));
120    }
121
122    #[tokio::test]
123    async fn get_missing() {
124        let backend = MemoryResultBackend::new();
125        let id = TaskId::new();
126        assert_eq!(backend.get(&id).await.unwrap(), None);
127    }
128
129    #[tokio::test]
130    async fn delete_result() {
131        let backend = MemoryResultBackend::new();
132        let id = TaskId::new();
133        backend.store(&id, &serde_json::json!(1)).await.unwrap();
134        backend.delete(&id).await.unwrap();
135        assert_eq!(backend.get(&id).await.unwrap(), None);
136    }
137
138    #[tokio::test]
139    async fn wait_for_result() {
140        let backend = std::sync::Arc::new(MemoryResultBackend::new());
141        let id = TaskId::new();
142
143        let b = backend.clone();
144        let id2 = id;
145        tokio::spawn(async move {
146            tokio::time::sleep(Duration::from_millis(100)).await;
147            b.store(&id2, &serde_json::json!("done")).await.unwrap();
148        });
149
150        let result = backend.wait(&id, Duration::from_secs(2)).await.unwrap();
151        assert_eq!(result, serde_json::json!("done"));
152    }
153
154    #[tokio::test]
155    async fn wait_timeout() {
156        let backend = MemoryResultBackend::new();
157        let id = TaskId::new();
158        let result = backend.wait(&id, Duration::from_millis(100)).await;
159        assert!(matches!(result, Err(KojinError::Timeout(_))));
160    }
161
162    #[tokio::test]
163    async fn group_lifecycle() {
164        let backend = MemoryResultBackend::new();
165        backend.init_group("g1", 3).await.unwrap();
166
167        let id1 = TaskId::new();
168        let id2 = TaskId::new();
169        let id3 = TaskId::new();
170
171        let c1 = backend
172            .complete_group_member("g1", &id1, &serde_json::json!(1))
173            .await
174            .unwrap();
175        assert_eq!(c1, 1);
176
177        let c2 = backend
178            .complete_group_member("g1", &id2, &serde_json::json!(2))
179            .await
180            .unwrap();
181        assert_eq!(c2, 2);
182
183        let c3 = backend
184            .complete_group_member("g1", &id3, &serde_json::json!(3))
185            .await
186            .unwrap();
187        assert_eq!(c3, 3);
188
189        let results = backend.get_group_results("g1").await.unwrap();
190        assert_eq!(results.len(), 3);
191    }
192}