1#![crate_name = "apalis_cron"]
2#![warn(
3    missing_debug_implementations,
4    missing_docs,
5    rust_2018_idioms,
6    unreachable_pub
7)]
8#![cfg_attr(docsrs, feature(doc_cfg))]
9
10use apalis_core::backend::Backend;
75use apalis_core::codec::NoopCodec;
76use apalis_core::error::BoxDynError;
77use apalis_core::layers::Identity;
78use apalis_core::mq::MessageQueue;
79use apalis_core::poller::Poller;
80use apalis_core::request::RequestStream;
81use apalis_core::storage::Storage;
82use apalis_core::task::namespace::Namespace;
83use apalis_core::worker::{Context, Worker};
84use apalis_core::{error::Error, request::Request, service_fn::FromRequest};
85use chrono::{DateTime, OutOfRangeError, TimeZone, Utc};
86pub use cron::Schedule;
87use futures::StreamExt;
88use pipe::CronPipe;
89use std::fmt::{self, Debug};
90use std::marker::PhantomData;
91use std::sync::Arc;
92
93pub mod pipe;
95
96#[derive(Clone, Debug)]
98pub struct CronStream<J, Tz> {
99    schedule: Schedule,
100    timezone: Tz,
101    _marker: PhantomData<J>,
102}
103
104impl<J> CronStream<J, Utc> {
105    pub fn new(schedule: Schedule) -> Self {
107        Self {
108            schedule,
109            timezone: Utc,
110            _marker: PhantomData,
111        }
112    }
113}
114
115impl<J, Tz> CronStream<J, Tz>
116where
117    Tz: TimeZone + Send + Sync + 'static,
118{
119    pub fn new_with_timezone(schedule: Schedule, timezone: Tz) -> Self {
121        Self {
122            schedule,
123            timezone,
124            _marker: PhantomData,
125        }
126    }
127}
128
129fn build_stream<Tz: TimeZone, Req>(
130    timezone: &Tz,
131    schedule: &Schedule,
132) -> RequestStream<Request<Req, CronContext<Tz>>>
133where
134    Req: Default + Send + Sync + 'static,
135    Tz: TimeZone + Send + Sync + 'static,
136    Tz::Offset: Send + Sync,
137{
138    let timezone = timezone.clone();
139    let schedule = schedule.clone();
140    let mut queue_schedule = schedule.upcoming_owned(timezone.clone());
141    let stream = async_stream::stream! {
142        loop {
143            let next = queue_schedule.next();
144            match next {
145                Some(tick) => {
146                    let to_sleep = tick.clone() - timezone.from_utc_datetime(&Utc::now().naive_utc());
147                    let to_sleep_res = to_sleep.to_std();
148                    match to_sleep_res {
149                        Ok(to_sleep) => {
150                            apalis_core::sleep(to_sleep).await;
151                            let timestamp = timezone.from_utc_datetime(&Utc::now().naive_utc());
152                            let namespace = Namespace(format!("{}:{timestamp:?}", schedule));
153                            let mut req = Request::new_with_ctx(Default::default(), CronContext { timestamp });
154                            req.parts.namespace = Some(namespace);
155                            yield Ok(Some(req));
156                        },
157                        Err(e) => {
158                            yield Err(Error::SourceError(Arc::new(Box::new(CronStreamError::OutOfRangeError { inner: e, tick }))))
159                        },
160                    }
161
162
163                },
164                None => {
165                    yield Ok(None);
166                }
167            }
168        }
169    };
170    stream.boxed()
171}
172impl<Req, Tz> CronStream<Req, Tz>
173where
174    Req: Default + Send + Sync + 'static,
175    Tz: TimeZone + Send + Sync + 'static,
176    Tz::Offset: Send + Sync,
177{
178    fn into_stream(self) -> RequestStream<Request<Req, CronContext<Tz>>> {
180        build_stream(&self.timezone, &self.schedule)
181    }
182
183    fn into_stream_worker(
184        self,
185        worker: &Worker<Context>,
186    ) -> RequestStream<Request<Req, CronContext<Tz>>> {
187        let worker = worker.clone();
188        let mut poller = build_stream(&self.timezone, &self.schedule);
189        let stream = async_stream::stream! {
190            loop {
191                if worker.is_shutting_down() {
192                    break;
193                }
194                match poller.next().await {
195                    Some(res) => yield res,
196                    None => break,
197                }
198            }
199        };
200        Box::pin(stream)
201    }
202
203    pub fn pipe_to_storage<S, Ctx>(self, storage: S) -> CronPipe<S>
205    where
206        S: Storage<Job = Req, Context = Ctx> + Clone + Send + Sync + 'static,
207        S::Error: std::error::Error + Send + Sync + 'static,
208    {
209        let stream = self
210            .into_stream()
211            .then({
212                let storage = storage.clone();
213                move |res| {
214                    let mut storage = storage.clone();
215                    async move {
216                        match res {
217                            Ok(Some(req)) => storage
218                                .push(req.args)
219                                .await
220                                .map(|_| ())
221                                .map_err(|e| Box::new(e) as BoxDynError),
222                            _ => Ok(()),
223                        }
224                    }
225                }
226            })
227            .boxed();
228
229        CronPipe {
230            stream,
231            inner: storage,
232        }
233    }
234    pub fn pipe_to_mq<Mq>(self, mq: Mq) -> CronPipe<Mq>
236    where
237        Mq: MessageQueue<Req> + Clone + Send + Sync + 'static,
238        Mq::Error: std::error::Error + Send + Sync + 'static,
239    {
240        let stream = self
241            .into_stream()
242            .then({
243                let mq = mq.clone();
244                move |res| {
245                    let mut mq = mq.clone();
246                    async move {
247                        match res {
248                            Ok(Some(req)) => mq
249                                .enqueue(req.args)
250                                .await
251                                .map(|_| ())
252                                .map_err(|e| Box::new(e) as BoxDynError),
253                            _ => Ok(()),
254                        }
255                    }
256                }
257            })
258            .boxed();
259
260        CronPipe { stream, inner: mq }
261    }
262}
263
264#[derive(Debug, Clone)]
266pub struct CronContext<Tz: TimeZone> {
267    timestamp: DateTime<Tz>,
268}
269
270impl<Tz: TimeZone> Default for CronContext<Tz>
271where
272    DateTime<Tz>: Default,
273{
274    fn default() -> Self {
275        Self {
276            timestamp: Default::default(),
277        }
278    }
279}
280
281impl<Tz: TimeZone> CronContext<Tz> {
282    pub fn new(timestamp: DateTime<Tz>) -> Self {
284        Self { timestamp }
285    }
286
287    pub fn get_timestamp(&self) -> &DateTime<Tz> {
289        &self.timestamp
290    }
291}
292
293impl<Req, Tz: TimeZone> FromRequest<Request<Req, CronContext<Tz>>> for CronContext<Tz> {
294    fn from_request(req: &Request<Req, CronContext<Tz>>) -> Result<Self, Error> {
295        Ok(req.parts.context.clone())
296    }
297}
298
299impl<Req, Tz> Backend<Request<Req, CronContext<Tz>>> for CronStream<Req, Tz>
300where
301    Req: Default + Send + Sync + 'static,
302    Tz: TimeZone + Send + Sync + 'static,
303    Tz::Offset: Send + Sync,
304{
305    type Stream = RequestStream<Request<Req, CronContext<Tz>>>;
306
307    type Layer = Identity;
308
309    type Codec = NoopCodec<Request<Req, CronContext<Tz>>>;
310
311    fn poll(self, worker: &Worker<Context>) -> Poller<Self::Stream, Self::Layer> {
312        let stream = self.into_stream_worker(worker);
313        Poller::new(stream, futures::future::pending())
314    }
315}
316
317pub enum CronStreamError<Tz: TimeZone> {
319    OutOfRangeError {
323        inner: OutOfRangeError,
325        tick: DateTime<Tz>,
327    },
328}
329
330impl<Tz: TimeZone> fmt::Display for CronStreamError<Tz> {
331    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
332        match self {
333            CronStreamError::OutOfRangeError { inner, tick } => {
334                write!(
335                    f,
336                    "Cron tick {} is out of range: {}",
337                    tick.timestamp(),
338                    inner
339                )
340            }
341        }
342    }
343}
344
345impl<Tz: TimeZone> std::error::Error for CronStreamError<Tz> {
346    fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
347        match self {
348            CronStreamError::OutOfRangeError { inner, .. } => Some(inner),
349        }
350    }
351}
352
353impl<Tz: TimeZone> fmt::Debug for CronStreamError<Tz> {
354    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
355        match self {
356            CronStreamError::OutOfRangeError { inner, tick } => f
357                .debug_struct("OutOfRangeError")
358                .field("tick", tick)
359                .field("inner", inner)
360                .finish(),
361        }
362    }
363}