apalis_core/backend/impls/json/
util.rs1use std::{cmp::Ordering, collections::BTreeMap, fmt::Debug};
2
3use futures_util::FutureExt;
4use serde::{Deserialize, Serialize};
5use serde_json::Value;
6
7use crate::{
8 backend::impls::json::{meta::JsonMapMetadata, JsonStorage},
9 error::BoxDynError,
10 task::{
11 status::Status,
12 task_id::{RandomId, TaskId},
13 },
14 worker::ext::ack::Acknowledge,
15};
16
17#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
18#[derive(Debug, Clone)]
19pub struct TaskKey {
20 pub(super) task_id: TaskId,
21 pub(super) namespace: String,
22 pub(super) status: Status,
23}
24
25impl PartialEq for TaskKey {
26 fn eq(&self, other: &Self) -> bool {
27 self.task_id == other.task_id && self.namespace == other.namespace
28 }
29}
30
31impl Eq for TaskKey {}
32
33impl PartialOrd for TaskKey {
34 fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
35 Some(self.cmp(other))
36 }
37}
38
39impl Ord for TaskKey {
40 fn cmp(&self, other: &Self) -> Ordering {
41 match self.task_id.cmp(&other.task_id) {
42 Ordering::Equal => self.namespace.cmp(&other.namespace),
43 ord => ord,
44 }
45 }
46}
47
48#[derive(Debug, Clone, Serialize, Deserialize)]
49pub struct TaskWithMeta {
50 pub(super) args: Value,
51 pub(super) ctx: JsonMapMetadata,
52 pub(super) result: Option<Value>,
53}
54
55#[derive(Debug)]
56pub struct JsonAck<Args> {
57 pub(crate) inner: JsonStorage<Args>,
58}
59
60impl<Args> Clone for JsonAck<Args> {
61 fn clone(&self) -> Self {
62 Self {
63 inner: self.inner.clone(),
64 }
65 }
66}
67
68impl<Args: Send + 'static + Debug, Res: Serialize, Ctx: Sync> Acknowledge<Res, Ctx, RandomId>
69 for JsonAck<Args>
70{
71 type Error = serde_json::Error;
72
73 type Future = futures_core::future::BoxFuture<'static, Result<(), Self::Error>>;
74
75 fn ack(
76 &mut self,
77 res: &Result<Res, BoxDynError>,
78 ctx: &crate::task::Parts<Ctx, RandomId>,
79 ) -> Self::Future {
80 let store = self.inner.clone();
81 let val = serde_json::to_value(res.as_ref().map_err(|e| e.to_string())).unwrap();
82 let task_id = ctx.task_id.clone().unwrap();
83 async move {
84 let key = TaskKey {
85 task_id: task_id.clone(),
86 namespace: std::any::type_name::<Args>().to_owned(),
87 status: Status::Running,
88 };
89
90 let _ = store.update_result(&key, Status::Done, val).unwrap();
91
92 store.persist_to_disk().unwrap();
93
94 Ok(())
95 }
96 .boxed()
97 }
98}
99
100#[cfg(feature = "sleep")]
101impl<Res: 'static + serde::de::DeserializeOwned + Send, Compact: 'static + Sync>
102 crate::backend::WaitForCompletion<Res, Compact> for JsonStorage<Compact>
103where
104 Compact: Send + serde::de::DeserializeOwned + 'static + Unpin,
105{
106 type ResultStream = futures_core::stream::BoxStream<
107 'static,
108 Result<crate::backend::TaskResult<Res>, futures_channel::mpsc::SendError>,
109 >;
110 fn wait_for(
111 &self,
112 task_ids: impl IntoIterator<Item = TaskId<Self::IdType>>,
113 ) -> Self::ResultStream {
114 use std::{collections::HashSet, time::Duration};
115 use futures_util::StreamExt;
116
117 let task_ids: HashSet<_> = task_ids.into_iter().collect();
118 struct PollState<T, Compact> {
119 vault: JsonStorage<Compact>,
120 pending_tasks: HashSet<TaskId>,
121 namespace: String,
122 poll_interval: Duration,
123 _phantom: std::marker::PhantomData<T>,
124 }
125 let state = PollState {
126 vault: self.clone(),
127 pending_tasks: task_ids,
128 namespace: std::any::type_name::<Compact>().to_owned(),
129 poll_interval: Duration::from_millis(100),
130 _phantom: std::marker::PhantomData,
131 };
132 futures_util::stream::unfold(state, |mut state: PollState<Res, Compact>| {
133 async move {
134 if state.pending_tasks.is_empty() {
137 return None;
138 }
139
140 loop {
141 let vault = &state.vault;
143 let completed_task = state.pending_tasks.iter().find_map(|task_id| {
144 let key = TaskKey {
145 task_id: task_id.clone(),
146 namespace: state.namespace.clone(),
147 status: Status::Pending,
148 };
149
150 vault
151 .get(&key)
152 .map(|value| (task_id.clone(), value.result.unwrap()))
153 });
154
155 if let Some((task_id, result)) = completed_task {
156 state.pending_tasks.remove(&task_id);
157 let result: Result<Res, String> = serde_json::from_value(result).unwrap();
158 return Some((
159 Ok(crate::backend::TaskResult {
160 task_id: task_id,
161 status: Status::Done,
162 result,
163 }),
164 state,
165 ));
166 }
167
168 crate::timer::sleep(state.poll_interval).await;
170 }
171 }
172 })
173 .boxed()
174 }
175
176 async fn check_status(
177 &self,
178 task_ids: impl IntoIterator<Item = TaskId<Self::IdType>> + Send,
179 ) -> Result<Vec<crate::backend::TaskResult<Res>>, Self::Error> {
180 use crate::task::status::Status;
181 use std::collections::HashSet;
182 let task_ids: HashSet<_> = task_ids.into_iter().collect();
183 let mut results = Vec::new();
184 for task_id in task_ids {
185 let key = TaskKey {
186 task_id: task_id.clone(),
187 namespace: std::any::type_name::<Compact>().to_owned(),
188 status: Status::Pending,
189 };
190 if let Some(value) = self.get(&key) {
191 let result =
192 match serde_json::from_value::<Result<Res, String>>(value.result.unwrap()) {
193 Ok(result) => crate::backend::TaskResult {
194 task_id: task_id.clone(),
195 status: Status::Done,
196 result,
197 },
198 Err(e) => crate::backend::TaskResult {
199 task_id: task_id.clone(),
200 status: Status::Failed,
201 result: Err(format!("Deserialization error: {}", e)),
202 },
203 };
204 results.push(result);
205 }
206 }
207 Ok(results)
208 }
209}
210
211pub(super) trait FindFirstWith<K, V> {
213 fn find_first_with<F>(&self, predicate: F) -> Option<(&K, &V)>
214 where
215 F: FnMut(&K, &V) -> bool;
216}
217
218impl<K, V> FindFirstWith<K, V> for BTreeMap<K, V>
219where
220 K: Ord + Clone,
221{
222 fn find_first_with<F>(&self, mut predicate: F) -> Option<(&K, &V)>
223 where
224 F: FnMut(&K, &V) -> bool,
225 {
226 if let Some(key) = self.iter().find(|(k, v)| predicate(k, v)).map(|(k, _)| k) {
227 self.get_key_value(key)
228 } else {
229 None
230 }
231 }
232}