apalis_mysql/
lib.rs

1#![doc = include_str!("../README.md")]
2//!
3use apalis_codec::json::JsonCodec;
4use apalis_core::{
5    backend::{Backend, BackendExt, TaskStream, codec::Codec, queue::Queue},
6    features_table,
7    layers::Stack,
8    task::Task,
9    worker::{context::WorkerContext, ext::ack::AcknowledgeLayer},
10};
11use apalis_sql::context::SqlContext;
12use futures::{
13    FutureExt, Stream, StreamExt, TryStreamExt,
14    stream::{self, BoxStream},
15};
16pub use sqlx::{
17    Connection, MySql, MySqlConnection, MySqlPool, Pool,
18    error::Error as SqlxError,
19    mysql::MySqlConnectOptions,
20    pool::{PoolConnection, PoolOptions},
21};
22use std::{fmt, marker::PhantomData};
23use ulid::Ulid;
24
25use crate::{
26    ack::{LockTaskLayer, MySqlAck},
27    fetcher::{MySqlFetcher, MySqlPollFetcher},
28    queries::{
29        keep_alive::{initial_heartbeat, keep_alive, keep_alive_stream},
30        reenqueue_orphaned::reenqueue_orphaned_stream,
31    },
32    sink::MySqlSink,
33};
34
35mod ack;
36/// Fetcher module for retrieving tasks from mysql backend
37pub mod fetcher;
38mod from_row;
39/// Queries module for mysql backend
40pub mod queries;
41mod shared;
42/// Sink module for pushing tasks to mysql backend
43pub mod sink;
44
45/// Type alias for a task stored in mysql backend
46pub type MySqlTask<Args> = Task<Args, MySqlContext, Ulid>;
47pub use apalis_sql::config::Config;
48pub use apalis_sql::ext::TaskBuilderExt;
49pub use shared::{SharedMySqlError, SharedMySqlStorage};
50
51pub type MySqlTaskId = apalis_core::task::task_id::TaskId<Ulid>;
52pub type MySqlContext = SqlContext<MySqlPool>;
53
54/// CompactType is the type used for compact serialization in mysql backend
55pub type CompactType = Vec<u8>;
56
57/// MySqlStorage is a storage backend for apalis using mysql as the database.
58///
59/// It supports both standard polling and event-driven (hooked) storage mechanisms.
60///
61#[doc = features_table! {
62    setup = r#"
63        # {
64        #   use apalis_mysql::MySqlStorage;
65        #   use sqlx::MySqlPool;
66        #   let pool = MySqlPool::connect(&std::env::var("DATABASE_URL").unwrap()).await.unwrap();
67        #   MySqlStorage::setup(&pool).await.unwrap();
68        #   MySqlStorage::new(&pool)
69        # };
70    "#,
71
72    Backend => supported("Supports storage and retrieval of tasks", true),
73    TaskSink => supported("Ability to push new tasks", true),
74    Serialization => supported("Serialization support for arguments", true),
75    Workflow => supported("Flexible enough to support workflows", true),
76    WebUI => supported("Expose a web interface for monitoring tasks", true),
77    FetchById => supported("Allow fetching a task by its ID", false),
78    RegisterWorker => supported("Allow registering a worker with the backend", false),
79    MakeShared => supported("Share one connection across multiple workers via [`SharedMySqlStorage`]", false),
80    WaitForCompletion => supported("Wait for tasks to complete without blocking", true),
81    ResumeById => supported("Resume a task by its ID", false),
82    ResumeAbandoned => supported("Resume abandoned tasks", false),
83    ListWorkers => supported("List all workers registered with the backend", false),
84    ListTasks => supported("List all tasks in the backend", false),
85}]
86#[pin_project::pin_project]
87pub struct MySqlStorage<T, C, Fetcher> {
88    pool: Pool<MySql>,
89    job_type: PhantomData<T>,
90    config: Config,
91    codec: PhantomData<C>,
92    #[pin]
93    sink: MySqlSink<T, CompactType, C>,
94    #[pin]
95    fetcher: Fetcher,
96}
97
98impl<T, C, F> fmt::Debug for MySqlStorage<T, C, F> {
99    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
100        f.debug_struct("MySqlStorage")
101            .field("pool", &self.pool)
102            .field("job_type", &"PhantomData<T>")
103            .field("config", &self.config)
104            .field("codec", &std::any::type_name::<C>())
105            .finish()
106    }
107}
108
109impl<T, C, F: Clone> Clone for MySqlStorage<T, C, F> {
110    fn clone(&self) -> Self {
111        Self {
112            sink: self.sink.clone(),
113            pool: self.pool.clone(),
114            job_type: PhantomData,
115            config: self.config.clone(),
116            codec: self.codec,
117            fetcher: self.fetcher.clone(),
118        }
119    }
120}
121
122impl MySqlStorage<(), (), ()> {
123    /// Get mysql migrations without running them
124    #[cfg(feature = "migrate")]
125    #[must_use]
126    pub fn migrations() -> sqlx::migrate::Migrator {
127        sqlx::migrate!("./migrations")
128    }
129
130    /// Do migrations for mysql
131    #[cfg(feature = "migrate")]
132    pub async fn setup(pool: &Pool<MySql>) -> Result<(), sqlx::Error> {
133        Self::migrations().run(pool).await?;
134        Ok(())
135    }
136}
137
138impl<T> MySqlStorage<T, (), ()> {
139    /// Create a new MySqlStorage
140    #[must_use]
141    pub fn new(
142        pool: &Pool<MySql>,
143    ) -> MySqlStorage<T, JsonCodec<CompactType>, fetcher::MySqlFetcher> {
144        let config = Config::new(std::any::type_name::<T>());
145        MySqlStorage {
146            pool: pool.clone(),
147            job_type: PhantomData,
148            sink: MySqlSink::new(pool, &config),
149            config,
150            codec: PhantomData,
151            fetcher: fetcher::MySqlFetcher,
152        }
153    }
154
155    /// Create a new MySqlStorage for a specific queue
156    #[must_use]
157    pub fn new_in_queue(
158        pool: &Pool<MySql>,
159        queue: &str,
160    ) -> MySqlStorage<T, JsonCodec<CompactType>, fetcher::MySqlFetcher> {
161        let config = Config::new(queue);
162        MySqlStorage {
163            pool: pool.clone(),
164            job_type: PhantomData,
165            sink: MySqlSink::new(pool, &config),
166            config,
167            codec: PhantomData,
168            fetcher: fetcher::MySqlFetcher,
169        }
170    }
171
172    /// Create a new MySqlStorage with config
173    #[must_use]
174    pub fn new_with_config(
175        pool: &Pool<MySql>,
176        config: &Config,
177    ) -> MySqlStorage<T, JsonCodec<CompactType>, fetcher::MySqlFetcher> {
178        MySqlStorage {
179            pool: pool.clone(),
180            job_type: PhantomData,
181            config: config.clone(),
182            codec: PhantomData,
183            sink: MySqlSink::new(pool, config),
184            fetcher: fetcher::MySqlFetcher,
185        }
186    }
187}
188
189impl<T, C, F> MySqlStorage<T, C, F> {
190    /// Change the codec used for serialization/deserialization
191    pub fn with_codec<D>(self) -> MySqlStorage<T, D, F> {
192        MySqlStorage {
193            sink: MySqlSink::new(&self.pool, &self.config),
194            pool: self.pool,
195            job_type: PhantomData,
196            config: self.config,
197            codec: PhantomData,
198            fetcher: self.fetcher,
199        }
200    }
201
202    /// Get the config used by the storage
203    pub fn config(&self) -> &Config {
204        &self.config
205    }
206
207    /// Get the connection pool used by the storage
208    pub fn pool(&self) -> &Pool<MySql> {
209        &self.pool
210    }
211}
212
213impl<Args, Decode> Backend for MySqlStorage<Args, Decode, MySqlFetcher>
214where
215    Args: Send + 'static + Unpin,
216    Decode: Codec<Args, Compact = CompactType> + 'static + Send,
217    Decode::Error: std::error::Error + Send + Sync + 'static,
218{
219    type Args = Args;
220    type IdType = Ulid;
221
222    type Context = MySqlContext;
223
224    type Error = sqlx::Error;
225
226    type Stream = TaskStream<MySqlTask<Args>, sqlx::Error>;
227
228    type Beat = BoxStream<'static, Result<(), sqlx::Error>>;
229
230    type Layer = Stack<LockTaskLayer, AcknowledgeLayer<MySqlAck>>;
231
232    fn heartbeat(&self, worker: &WorkerContext) -> Self::Beat {
233        let pool = self.pool.clone();
234        let config = self.config.clone();
235        let worker = worker.clone();
236        let keep_alive = keep_alive_stream(pool, config, worker);
237        let reenqueue = reenqueue_orphaned_stream(
238            self.pool.clone(),
239            self.config.clone(),
240            *self.config.keep_alive(),
241        )
242        .map_ok(|_| ());
243        futures::stream::select(keep_alive, reenqueue).boxed()
244    }
245
246    fn middleware(&self) -> Self::Layer {
247        let lock = LockTaskLayer::new(self.pool.clone());
248        let ack = AcknowledgeLayer::new(MySqlAck::new(self.pool.clone()));
249        Stack::new(lock, ack)
250    }
251
252    fn poll(self, worker: &WorkerContext) -> Self::Stream {
253        self.poll_default(worker)
254            .map(|a| match a {
255                Ok(Some(task)) => Ok(Some(
256                    task.try_map(|t| Decode::decode(&t))
257                        .map_err(|e| sqlx::Error::Decode(e.into()))?,
258                )),
259                Ok(None) => Ok(None),
260                Err(e) => Err(e),
261            })
262            .boxed()
263    }
264}
265
266impl<Args, Decode: Send + 'static> BackendExt for MySqlStorage<Args, Decode, MySqlFetcher>
267where
268    Self: Backend<Args = Args, IdType = Ulid, Context = MySqlContext, Error = sqlx::Error>,
269    Decode: Codec<Args, Compact = CompactType> + Send + 'static,
270    Decode::Error: std::error::Error + Send + Sync + 'static,
271    Args: Send + 'static + Unpin,
272{
273    type Codec = Decode;
274    type Compact = CompactType;
275    type CompactStream = TaskStream<MySqlTask<Self::Compact>, sqlx::Error>;
276
277    fn get_queue(&self) -> Queue {
278        self.config.queue().clone()
279    }
280
281    fn poll_compact(self, worker: &WorkerContext) -> Self::CompactStream {
282        self.poll_default(worker).boxed()
283    }
284}
285
286impl<Args, Decode: Send + 'static, F> MySqlStorage<Args, Decode, F> {
287    fn poll_default(
288        self,
289        worker: &WorkerContext,
290    ) -> impl Stream<Item = Result<Option<MySqlTask<CompactType>>, sqlx::Error>> + Send + 'static
291    {
292        let fut = initial_heartbeat(
293            self.pool.clone(),
294            self.config().clone(),
295            worker.clone(),
296            "MySqlStorage",
297        );
298        let register = stream::once(fut.map(|_| Ok(None)));
299        register.chain(MySqlPollFetcher::<CompactType, Decode>::new(
300            &self.pool,
301            &self.config,
302            worker,
303        ))
304    }
305}
306
307#[cfg(test)]
308mod tests {
309    use std::time::Duration;
310
311    use apalis::prelude::*;
312    use apalis_workflow::*;
313    use serde::{Deserialize, Serialize};
314    use sqlx::MySqlPool;
315
316    use super::*;
317
318    #[tokio::test]
319    async fn basic_worker() {
320        const ITEMS: usize = 10;
321        let pool = MySqlPool::connect(&std::env::var("DATABASE_URL").unwrap())
322            .await
323            .unwrap();
324        MySqlStorage::setup(&pool).await.unwrap();
325
326        let mut backend = MySqlStorage::new(&pool);
327
328        let mut start = 0;
329
330        let mut items = stream::repeat_with(move || {
331            start += 1;
332            start
333        })
334        .take(ITEMS);
335        backend.push_stream(&mut items).await.unwrap();
336
337        async fn send_reminder(item: usize, wrk: WorkerContext) -> Result<(), BoxDynError> {
338            if ITEMS == item {
339                wrk.stop().unwrap();
340            }
341            Ok(())
342        }
343
344        let worker = WorkerBuilder::new("rango-tango-1")
345            .backend(backend)
346            .build(send_reminder);
347        worker.run().await.unwrap();
348    }
349
350    #[tokio::test]
351    async fn test_workflow() {
352        let workflow = Workflow::new("odd-numbers-workflow")
353            .and_then(|a: usize| async move { Ok::<_, BoxDynError>((0..=a).collect::<Vec<_>>()) })
354            .filter_map(|x| async move { if x % 2 != 0 { Some(x) } else { None } })
355            .filter_map(|x| async move { if x % 3 != 0 { Some(x) } else { None } })
356            .filter_map(|x| async move { if x % 5 != 0 { Some(x) } else { None } })
357            .delay_for(Duration::from_millis(1000))
358            .and_then(|a: Vec<usize>| async move {
359                println!("Sum: {}", a.iter().sum::<usize>());
360                Err::<(), BoxDynError>("Intentional Error".into())
361            });
362
363        let pool = MySqlPool::connect(&std::env::var("DATABASE_URL").unwrap())
364            .await
365            .unwrap();
366        let mut mysql = MySqlStorage::new_with_config(
367            &pool,
368            &Config::new("workflow-queue").with_poll_interval(
369                StrategyBuilder::new()
370                    .apply(IntervalStrategy::new(Duration::from_millis(100)))
371                    .build(),
372            ),
373        );
374
375        MySqlStorage::setup(&pool).await.unwrap();
376
377        mysql.push_start(100usize).await.unwrap();
378
379        let worker = WorkerBuilder::new("rango-tango")
380            .backend(mysql)
381            .on_event(|ctx, ev| {
382                println!("On Event = {:?}", ev);
383                if matches!(ev, Event::Error(_)) {
384                    ctx.stop().unwrap();
385                }
386            })
387            .build(workflow);
388        worker.run().await.unwrap();
389    }
390
391    #[tokio::test]
392    async fn test_workflow_complete() {
393        #[derive(Debug, Serialize, Deserialize, Clone)]
394        struct PipelineConfig {
395            min_confidence: f32,
396            enable_sentiment: bool,
397        }
398
399        #[derive(Debug, Serialize, Deserialize)]
400        struct UserInput {
401            text: String,
402        }
403
404        #[derive(Debug, Serialize, Deserialize)]
405        struct Classified {
406            text: String,
407            label: String,
408            confidence: f32,
409        }
410
411        #[derive(Debug, Serialize, Deserialize)]
412        struct Summary {
413            text: String,
414            sentiment: Option<String>,
415        }
416
417        let workflow = Workflow::new("text-pipeline")
418            // Step 1: Preprocess input (e.g., tokenize, lowercase)
419            .and_then(|input: UserInput, mut worker: WorkerContext| async move {
420                worker.emit(&Event::custom(format!(
421                    "Preprocessing input: {}",
422                    input.text
423                )));
424                let processed = input.text.to_lowercase();
425                Ok::<_, BoxDynError>(processed)
426            })
427            // Step 2: Classify text
428            .and_then(|text: String| async move {
429                let confidence = 0.85; // pretend model confidence
430                let items = text.split_whitespace().collect::<Vec<_>>();
431                let results = items
432                    .into_iter()
433                    .map(|x| Classified {
434                        text: x.to_string(),
435                        label: if x.contains("rust") {
436                            "Tech"
437                        } else {
438                            "General"
439                        }
440                        .to_string(),
441                        confidence,
442                    })
443                    .collect::<Vec<_>>();
444                Ok::<_, BoxDynError>(results)
445            })
446            // Step 3: Filter out low-confidence predictions
447            .filter_map(
448                |c: Classified| async move { if c.confidence >= 0.6 { Some(c) } else { None } },
449            )
450            .filter_map(move |c: Classified, config: Data<PipelineConfig>| {
451                let cfg = config.enable_sentiment;
452                async move {
453                    if !cfg {
454                        return Some(Summary {
455                            text: c.text,
456                            sentiment: None,
457                        });
458                    }
459
460                    // pretend we run a sentiment model
461                    let sentiment = if c.text.contains("delightful") {
462                        "positive"
463                    } else {
464                        "neutral"
465                    };
466                    Some(Summary {
467                        text: c.text,
468                        sentiment: Some(sentiment.to_string()),
469                    })
470                }
471            })
472            .and_then(|a: Vec<Summary>, mut worker: WorkerContext| async move {
473                worker.emit(&Event::Custom(Box::new(format!(
474                    "Generated {} summaries",
475                    a.len()
476                ))));
477                worker.stop()
478            });
479
480        let pool = MySqlPool::connect(&std::env::var("DATABASE_URL").unwrap())
481            .await
482            .unwrap();
483        let mut mysql = MySqlStorage::new_with_config(&pool, &Config::new("text-pipeline"));
484
485        MySqlStorage::setup(&pool).await.unwrap();
486
487        let input = UserInput {
488            text: "Rust makes systems programming delightful!".to_string(),
489        };
490        mysql.push_start(input).await.unwrap();
491
492        let worker = WorkerBuilder::new("rango-tango")
493            .backend(mysql)
494            .data(PipelineConfig {
495                min_confidence: 0.8,
496                enable_sentiment: true,
497            })
498            .on_event(|ctx, ev| match ev {
499                Event::Custom(msg) => {
500                    if let Some(m) = msg.downcast_ref::<String>() {
501                        println!("Custom Message: {}", m);
502                    }
503                }
504                Event::Error(_) => {
505                    println!("On Error = {:?}", ev);
506                    ctx.stop().unwrap();
507                }
508                _ => {
509                    println!("On Event = {:?}", ev);
510                }
511            })
512            .build(workflow);
513        worker.run().await.unwrap();
514    }
515}