apalis_core/backend/impls/json/
util.rs

1use std::{cmp::Ordering, collections::BTreeMap, fmt::Debug};
2
3use futures_util::FutureExt;
4use serde::{Deserialize, Serialize};
5use serde_json::Value;
6
7use crate::{
8    backend::impls::json::{meta::JsonMapMetadata, JsonStorage},
9    error::BoxDynError,
10    task::{
11        status::Status,
12        task_id::{RandomId, TaskId},
13    },
14    worker::ext::ack::Acknowledge,
15};
16
17#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
18#[derive(Debug, Clone)]
19pub struct TaskKey {
20    pub(super) task_id: TaskId,
21    pub(super) namespace: String,
22    pub(super) status: Status,
23}
24
25impl PartialEq for TaskKey {
26    fn eq(&self, other: &Self) -> bool {
27        self.task_id == other.task_id && self.namespace == other.namespace
28    }
29}
30
31impl Eq for TaskKey {}
32
33impl PartialOrd for TaskKey {
34    fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
35        Some(self.cmp(other))
36    }
37}
38
39impl Ord for TaskKey {
40    fn cmp(&self, other: &Self) -> Ordering {
41        match self.task_id.cmp(&other.task_id) {
42            Ordering::Equal => self.namespace.cmp(&other.namespace),
43            ord => ord,
44        }
45    }
46}
47
48#[derive(Debug, Clone, Serialize, Deserialize)]
49pub struct TaskWithMeta {
50    pub(super) args: Value,
51    pub(super) ctx: JsonMapMetadata,
52    pub(super) result: Option<Value>,
53}
54
55#[derive(Debug)]
56pub struct JsonAck<Args> {
57    pub(crate) inner: JsonStorage<Args>,
58}
59
60impl<Args> Clone for JsonAck<Args> {
61    fn clone(&self) -> Self {
62        Self {
63            inner: self.inner.clone(),
64        }
65    }
66}
67
68impl<Args: Send + 'static + Debug, Res: Serialize, Ctx: Sync> Acknowledge<Res, Ctx, RandomId>
69    for JsonAck<Args>
70{
71    type Error = serde_json::Error;
72
73    type Future = futures_core::future::BoxFuture<'static, Result<(), Self::Error>>;
74
75    fn ack(
76        &mut self,
77        res: &Result<Res, BoxDynError>,
78        ctx: &crate::task::Parts<Ctx, RandomId>,
79    ) -> Self::Future {
80        let store = self.inner.clone();
81        let val = serde_json::to_value(res.as_ref().map_err(|e| e.to_string())).unwrap();
82        let task_id = ctx.task_id.clone().unwrap();
83        async move {
84            let key = TaskKey {
85                task_id: task_id.clone(),
86                namespace: std::any::type_name::<Args>().to_owned(),
87                status: Status::Running,
88            };
89
90            let _ = store.update_result(&key, Status::Done, val).unwrap();
91
92            store.persist_to_disk().unwrap();
93
94            Ok(())
95        }
96        .boxed()
97    }
98}
99
100#[cfg(feature = "sleep")]
101impl<Res: 'static + serde::de::DeserializeOwned + Send, Compact: 'static + Sync>
102    crate::backend::WaitForCompletion<Res, Compact> for JsonStorage<Compact>
103where
104    Compact: Send + serde::de::DeserializeOwned + 'static + Unpin,
105{
106    type ResultStream = futures_core::stream::BoxStream<
107        'static,
108        Result<crate::backend::TaskResult<Res>, futures_channel::mpsc::SendError>,
109    >;
110    fn wait_for(
111        &self,
112        task_ids: impl IntoIterator<Item = TaskId<Self::IdType>>,
113    ) -> Self::ResultStream {
114        use std::{collections::HashSet, time::Duration};
115        use futures_util::StreamExt;
116
117        let task_ids: HashSet<_> = task_ids.into_iter().collect();
118        struct PollState<T, Compact> {
119            vault: JsonStorage<Compact>,
120            pending_tasks: HashSet<TaskId>,
121            namespace: String,
122            poll_interval: Duration,
123            _phantom: std::marker::PhantomData<T>,
124        }
125        let state = PollState {
126            vault: self.clone(),
127            pending_tasks: task_ids,
128            namespace: std::any::type_name::<Compact>().to_owned(),
129            poll_interval: Duration::from_millis(100),
130            _phantom: std::marker::PhantomData,
131        };
132        futures_util::stream::unfold(state, |mut state: PollState<Res, Compact>| {
133            async move {
134                // panic!( "{}", state.pending_tasks.len());
135                // If no pending tasks, we're done
136                if state.pending_tasks.is_empty() {
137                    return None;
138                }
139
140                loop {
141                    // Check for completed tasks
142                    let vault = &state.vault;
143                    let completed_task = state.pending_tasks.iter().find_map(|task_id| {
144                        let key = TaskKey {
145                            task_id: task_id.clone(),
146                            namespace: state.namespace.clone(),
147                            status: Status::Pending,
148                        };
149
150                        vault
151                            .get(&key)
152                            .map(|value| (task_id.clone(), value.result.unwrap()))
153                    });
154
155                    if let Some((task_id, result)) = completed_task {
156                        state.pending_tasks.remove(&task_id);
157                        let result: Result<Res, String> = serde_json::from_value(result).unwrap();
158                        return Some((
159                            Ok(crate::backend::TaskResult {
160                                task_id: task_id,
161                                status: Status::Done,
162                                result,
163                            }),
164                            state,
165                        ));
166                    }
167
168                    // No completed tasks, wait and try again
169                    crate::timer::sleep(state.poll_interval).await;
170                }
171            }
172        })
173        .boxed()
174    }
175
176    async fn check_status(
177        &self,
178        task_ids: impl IntoIterator<Item = TaskId<Self::IdType>> + Send,
179    ) -> Result<Vec<crate::backend::TaskResult<Res>>, Self::Error> {
180        use crate::task::status::Status;
181        use std::collections::HashSet;
182        let task_ids: HashSet<_> = task_ids.into_iter().collect();
183        let mut results = Vec::new();
184        for task_id in task_ids {
185            let key = TaskKey {
186                task_id: task_id.clone(),
187                namespace: std::any::type_name::<Compact>().to_owned(),
188                status: Status::Pending,
189            };
190            if let Some(value) = self.get(&key) {
191                let result =
192                    match serde_json::from_value::<Result<Res, String>>(value.result.unwrap()) {
193                        Ok(result) => crate::backend::TaskResult {
194                            task_id: task_id.clone(),
195                            status: Status::Done,
196                            result,
197                        },
198                        Err(e) => crate::backend::TaskResult {
199                            task_id: task_id.clone(),
200                            status: Status::Failed,
201                            result: Err(format!("Deserialization error: {}", e)),
202                        },
203                    };
204                results.push(result);
205            }
206        }
207        Ok(results)
208    }
209}
210
211/// Find the first item that meets the requirements
212pub(super) trait FindFirstWith<K, V> {
213    fn find_first_with<F>(&self, predicate: F) -> Option<(&K, &V)>
214    where
215        F: FnMut(&K, &V) -> bool;
216}
217
218impl<K, V> FindFirstWith<K, V> for BTreeMap<K, V>
219where
220    K: Ord + Clone,
221{
222    fn find_first_with<F>(&self, mut predicate: F) -> Option<(&K, &V)>
223    where
224        F: FnMut(&K, &V) -> bool,
225    {
226        if let Some(key) = self.iter().find(|(k, v)| predicate(k, v)).map(|(k, _)| k) {
227            self.get_key_value(key)
228        } else {
229            None
230        }
231    }
232}