kojin_core/
memory_result_backend.rs1use async_trait::async_trait;
2use std::collections::HashMap;
3use std::sync::Mutex;
4use std::time::Duration;
5
6use crate::error::{KojinError, TaskResult};
7use crate::result_backend::ResultBackend;
8use crate::task_id::TaskId;
9
10#[derive(Debug, Default)]
19pub struct MemoryResultBackend {
20 results: Mutex<HashMap<String, serde_json::Value>>,
21 groups: Mutex<HashMap<String, GroupState>>,
22}
23
24#[derive(Debug, Clone)]
25struct GroupState {
26 #[allow(dead_code)]
27 total: u32,
28 completed: u32,
29 results: Vec<serde_json::Value>,
30}
31
32impl MemoryResultBackend {
33 pub fn new() -> Self {
35 Self::default()
36 }
37}
38
39#[async_trait]
40impl ResultBackend for MemoryResultBackend {
41 async fn store(&self, id: &TaskId, result: &serde_json::Value) -> TaskResult<()> {
42 self.results
43 .lock()
44 .unwrap()
45 .insert(id.to_string(), result.clone());
46 Ok(())
47 }
48
49 async fn get(&self, id: &TaskId) -> TaskResult<Option<serde_json::Value>> {
50 Ok(self.results.lock().unwrap().get(&id.to_string()).cloned())
51 }
52
53 async fn wait(&self, id: &TaskId, timeout: Duration) -> TaskResult<serde_json::Value> {
54 let deadline = tokio::time::Instant::now() + timeout;
55 loop {
56 if let Some(result) = self.get(id).await? {
57 return Ok(result);
58 }
59 if tokio::time::Instant::now() >= deadline {
60 return Err(KojinError::Timeout(timeout));
61 }
62 tokio::time::sleep(Duration::from_millis(50)).await;
63 }
64 }
65
66 async fn delete(&self, id: &TaskId) -> TaskResult<()> {
67 self.results.lock().unwrap().remove(&id.to_string());
68 Ok(())
69 }
70
71 async fn init_group(&self, group_id: &str, total: u32) -> TaskResult<()> {
72 self.groups.lock().unwrap().insert(
73 group_id.to_string(),
74 GroupState {
75 total,
76 completed: 0,
77 results: Vec::new(),
78 },
79 );
80 Ok(())
81 }
82
83 async fn complete_group_member(
84 &self,
85 group_id: &str,
86 _task_id: &TaskId,
87 result: &serde_json::Value,
88 ) -> TaskResult<u32> {
89 let mut groups = self.groups.lock().unwrap();
90 let state = groups
91 .get_mut(group_id)
92 .ok_or_else(|| KojinError::ResultBackend(format!("group not found: {group_id}")))?;
93 state.completed += 1;
94 state.results.push(result.clone());
95 Ok(state.completed)
96 }
97
98 async fn get_group_results(&self, group_id: &str) -> TaskResult<Vec<serde_json::Value>> {
99 let groups = self.groups.lock().unwrap();
100 let state = groups
101 .get(group_id)
102 .ok_or_else(|| KojinError::ResultBackend(format!("group not found: {group_id}")))?;
103 Ok(state.results.clone())
104 }
105}
106
107#[cfg(test)]
108mod tests {
109 use super::*;
110
111 #[tokio::test]
112 async fn store_and_get() {
113 let backend = MemoryResultBackend::new();
114 let id = TaskId::new();
115 let value = serde_json::json!({"result": 42});
116
117 backend.store(&id, &value).await.unwrap();
118 let got = backend.get(&id).await.unwrap();
119 assert_eq!(got, Some(value));
120 }
121
122 #[tokio::test]
123 async fn get_missing() {
124 let backend = MemoryResultBackend::new();
125 let id = TaskId::new();
126 assert_eq!(backend.get(&id).await.unwrap(), None);
127 }
128
129 #[tokio::test]
130 async fn delete_result() {
131 let backend = MemoryResultBackend::new();
132 let id = TaskId::new();
133 backend.store(&id, &serde_json::json!(1)).await.unwrap();
134 backend.delete(&id).await.unwrap();
135 assert_eq!(backend.get(&id).await.unwrap(), None);
136 }
137
138 #[tokio::test]
139 async fn wait_for_result() {
140 let backend = std::sync::Arc::new(MemoryResultBackend::new());
141 let id = TaskId::new();
142
143 let b = backend.clone();
144 let id2 = id;
145 tokio::spawn(async move {
146 tokio::time::sleep(Duration::from_millis(100)).await;
147 b.store(&id2, &serde_json::json!("done")).await.unwrap();
148 });
149
150 let result = backend.wait(&id, Duration::from_secs(2)).await.unwrap();
151 assert_eq!(result, serde_json::json!("done"));
152 }
153
154 #[tokio::test]
155 async fn wait_timeout() {
156 let backend = MemoryResultBackend::new();
157 let id = TaskId::new();
158 let result = backend.wait(&id, Duration::from_millis(100)).await;
159 assert!(matches!(result, Err(KojinError::Timeout(_))));
160 }
161
162 #[tokio::test]
163 async fn group_lifecycle() {
164 let backend = MemoryResultBackend::new();
165 backend.init_group("g1", 3).await.unwrap();
166
167 let id1 = TaskId::new();
168 let id2 = TaskId::new();
169 let id3 = TaskId::new();
170
171 let c1 = backend
172 .complete_group_member("g1", &id1, &serde_json::json!(1))
173 .await
174 .unwrap();
175 assert_eq!(c1, 1);
176
177 let c2 = backend
178 .complete_group_member("g1", &id2, &serde_json::json!(2))
179 .await
180 .unwrap();
181 assert_eq!(c2, 2);
182
183 let c3 = backend
184 .complete_group_member("g1", &id3, &serde_json::json!(3))
185 .await
186 .unwrap();
187 assert_eq!(c3, 3);
188
189 let results = backend.get_group_results("g1").await.unwrap();
190 assert_eq!(results.len(), 3);
191 }
192}