1#![doc = include_str!("../README.md")]
2use apalis_codec::json::JsonCodec;
4use apalis_core::{
5 backend::{Backend, BackendExt, TaskStream, codec::Codec, queue::Queue},
6 features_table,
7 layers::Stack,
8 task::Task,
9 worker::{context::WorkerContext, ext::ack::AcknowledgeLayer},
10};
11use apalis_sql::context::SqlContext;
12use futures::{
13 FutureExt, Stream, StreamExt, TryStreamExt,
14 stream::{self, BoxStream},
15};
16pub use sqlx::{
17 Connection, MySql, MySqlConnection, MySqlPool, Pool,
18 error::Error as SqlxError,
19 mysql::MySqlConnectOptions,
20 pool::{PoolConnection, PoolOptions},
21};
22use std::{fmt, marker::PhantomData};
23use ulid::Ulid;
24
25use crate::{
26 ack::{LockTaskLayer, MySqlAck},
27 fetcher::{MySqlFetcher, MySqlPollFetcher},
28 queries::{
29 keep_alive::{initial_heartbeat, keep_alive, keep_alive_stream},
30 reenqueue_orphaned::reenqueue_orphaned_stream,
31 },
32 sink::MySqlSink,
33};
34
35mod ack;
36pub mod fetcher;
38mod from_row;
39pub mod queries;
41mod shared;
42pub mod sink;
44
45pub type MySqlTask<Args> = Task<Args, MySqlContext, Ulid>;
47pub use apalis_sql::config::Config;
48pub use apalis_sql::ext::TaskBuilderExt;
49pub use shared::{SharedMySqlError, SharedMySqlStorage};
50
51pub type MySqlTaskId = apalis_core::task::task_id::TaskId<Ulid>;
52pub type MySqlContext = SqlContext<MySqlPool>;
53
54pub type CompactType = Vec<u8>;
56
57#[doc = features_table! {
62 setup = r#"
63 # {
64 # use apalis_mysql::MySqlStorage;
65 # use sqlx::MySqlPool;
66 # let pool = MySqlPool::connect(&std::env::var("DATABASE_URL").unwrap()).await.unwrap();
67 # MySqlStorage::setup(&pool).await.unwrap();
68 # MySqlStorage::new(&pool)
69 # };
70 "#,
71
72 Backend => supported("Supports storage and retrieval of tasks", true),
73 TaskSink => supported("Ability to push new tasks", true),
74 Serialization => supported("Serialization support for arguments", true),
75 Workflow => supported("Flexible enough to support workflows", true),
76 WebUI => supported("Expose a web interface for monitoring tasks", true),
77 FetchById => supported("Allow fetching a task by its ID", false),
78 RegisterWorker => supported("Allow registering a worker with the backend", false),
79 MakeShared => supported("Share one connection across multiple workers via [`SharedMySqlStorage`]", false),
80 WaitForCompletion => supported("Wait for tasks to complete without blocking", true),
81 ResumeById => supported("Resume a task by its ID", false),
82 ResumeAbandoned => supported("Resume abandoned tasks", false),
83 ListWorkers => supported("List all workers registered with the backend", false),
84 ListTasks => supported("List all tasks in the backend", false),
85}]
86#[pin_project::pin_project]
87pub struct MySqlStorage<T, C, Fetcher> {
88 pool: Pool<MySql>,
89 job_type: PhantomData<T>,
90 config: Config,
91 codec: PhantomData<C>,
92 #[pin]
93 sink: MySqlSink<T, CompactType, C>,
94 #[pin]
95 fetcher: Fetcher,
96}
97
98impl<T, C, F> fmt::Debug for MySqlStorage<T, C, F> {
99 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
100 f.debug_struct("MySqlStorage")
101 .field("pool", &self.pool)
102 .field("job_type", &"PhantomData<T>")
103 .field("config", &self.config)
104 .field("codec", &std::any::type_name::<C>())
105 .finish()
106 }
107}
108
109impl<T, C, F: Clone> Clone for MySqlStorage<T, C, F> {
110 fn clone(&self) -> Self {
111 Self {
112 sink: self.sink.clone(),
113 pool: self.pool.clone(),
114 job_type: PhantomData,
115 config: self.config.clone(),
116 codec: self.codec,
117 fetcher: self.fetcher.clone(),
118 }
119 }
120}
121
122impl MySqlStorage<(), (), ()> {
123 #[cfg(feature = "migrate")]
125 #[must_use]
126 pub fn migrations() -> sqlx::migrate::Migrator {
127 sqlx::migrate!("./migrations")
128 }
129
130 #[cfg(feature = "migrate")]
132 pub async fn setup(pool: &Pool<MySql>) -> Result<(), sqlx::Error> {
133 Self::migrations().run(pool).await?;
134 Ok(())
135 }
136}
137
138impl<T> MySqlStorage<T, (), ()> {
139 #[must_use]
141 pub fn new(
142 pool: &Pool<MySql>,
143 ) -> MySqlStorage<T, JsonCodec<CompactType>, fetcher::MySqlFetcher> {
144 let config = Config::new(std::any::type_name::<T>());
145 MySqlStorage {
146 pool: pool.clone(),
147 job_type: PhantomData,
148 sink: MySqlSink::new(pool, &config),
149 config,
150 codec: PhantomData,
151 fetcher: fetcher::MySqlFetcher,
152 }
153 }
154
155 #[must_use]
157 pub fn new_in_queue(
158 pool: &Pool<MySql>,
159 queue: &str,
160 ) -> MySqlStorage<T, JsonCodec<CompactType>, fetcher::MySqlFetcher> {
161 let config = Config::new(queue);
162 MySqlStorage {
163 pool: pool.clone(),
164 job_type: PhantomData,
165 sink: MySqlSink::new(pool, &config),
166 config,
167 codec: PhantomData,
168 fetcher: fetcher::MySqlFetcher,
169 }
170 }
171
172 #[must_use]
174 pub fn new_with_config(
175 pool: &Pool<MySql>,
176 config: &Config,
177 ) -> MySqlStorage<T, JsonCodec<CompactType>, fetcher::MySqlFetcher> {
178 MySqlStorage {
179 pool: pool.clone(),
180 job_type: PhantomData,
181 config: config.clone(),
182 codec: PhantomData,
183 sink: MySqlSink::new(pool, config),
184 fetcher: fetcher::MySqlFetcher,
185 }
186 }
187}
188
189impl<T, C, F> MySqlStorage<T, C, F> {
190 pub fn with_codec<D>(self) -> MySqlStorage<T, D, F> {
192 MySqlStorage {
193 sink: MySqlSink::new(&self.pool, &self.config),
194 pool: self.pool,
195 job_type: PhantomData,
196 config: self.config,
197 codec: PhantomData,
198 fetcher: self.fetcher,
199 }
200 }
201
202 pub fn config(&self) -> &Config {
204 &self.config
205 }
206
207 pub fn pool(&self) -> &Pool<MySql> {
209 &self.pool
210 }
211}
212
213impl<Args, Decode> Backend for MySqlStorage<Args, Decode, MySqlFetcher>
214where
215 Args: Send + 'static + Unpin,
216 Decode: Codec<Args, Compact = CompactType> + 'static + Send,
217 Decode::Error: std::error::Error + Send + Sync + 'static,
218{
219 type Args = Args;
220 type IdType = Ulid;
221
222 type Context = MySqlContext;
223
224 type Error = sqlx::Error;
225
226 type Stream = TaskStream<MySqlTask<Args>, sqlx::Error>;
227
228 type Beat = BoxStream<'static, Result<(), sqlx::Error>>;
229
230 type Layer = Stack<LockTaskLayer, AcknowledgeLayer<MySqlAck>>;
231
232 fn heartbeat(&self, worker: &WorkerContext) -> Self::Beat {
233 let pool = self.pool.clone();
234 let config = self.config.clone();
235 let worker = worker.clone();
236 let keep_alive = keep_alive_stream(pool, config, worker);
237 let reenqueue = reenqueue_orphaned_stream(
238 self.pool.clone(),
239 self.config.clone(),
240 *self.config.keep_alive(),
241 )
242 .map_ok(|_| ());
243 futures::stream::select(keep_alive, reenqueue).boxed()
244 }
245
246 fn middleware(&self) -> Self::Layer {
247 let lock = LockTaskLayer::new(self.pool.clone());
248 let ack = AcknowledgeLayer::new(MySqlAck::new(self.pool.clone()));
249 Stack::new(lock, ack)
250 }
251
252 fn poll(self, worker: &WorkerContext) -> Self::Stream {
253 self.poll_default(worker)
254 .map(|a| match a {
255 Ok(Some(task)) => Ok(Some(
256 task.try_map(|t| Decode::decode(&t))
257 .map_err(|e| sqlx::Error::Decode(e.into()))?,
258 )),
259 Ok(None) => Ok(None),
260 Err(e) => Err(e),
261 })
262 .boxed()
263 }
264}
265
266impl<Args, Decode: Send + 'static> BackendExt for MySqlStorage<Args, Decode, MySqlFetcher>
267where
268 Self: Backend<Args = Args, IdType = Ulid, Context = MySqlContext, Error = sqlx::Error>,
269 Decode: Codec<Args, Compact = CompactType> + Send + 'static,
270 Decode::Error: std::error::Error + Send + Sync + 'static,
271 Args: Send + 'static + Unpin,
272{
273 type Codec = Decode;
274 type Compact = CompactType;
275 type CompactStream = TaskStream<MySqlTask<Self::Compact>, sqlx::Error>;
276
277 fn get_queue(&self) -> Queue {
278 self.config.queue().clone()
279 }
280
281 fn poll_compact(self, worker: &WorkerContext) -> Self::CompactStream {
282 self.poll_default(worker).boxed()
283 }
284}
285
286impl<Args, Decode: Send + 'static, F> MySqlStorage<Args, Decode, F> {
287 fn poll_default(
288 self,
289 worker: &WorkerContext,
290 ) -> impl Stream<Item = Result<Option<MySqlTask<CompactType>>, sqlx::Error>> + Send + 'static
291 {
292 let fut = initial_heartbeat(
293 self.pool.clone(),
294 self.config().clone(),
295 worker.clone(),
296 "MySqlStorage",
297 );
298 let register = stream::once(fut.map(|_| Ok(None)));
299 register.chain(MySqlPollFetcher::<CompactType, Decode>::new(
300 &self.pool,
301 &self.config,
302 worker,
303 ))
304 }
305}
306
307#[cfg(test)]
308mod tests {
309 use std::time::Duration;
310
311 use apalis::prelude::*;
312 use apalis_workflow::*;
313 use serde::{Deserialize, Serialize};
314 use sqlx::MySqlPool;
315
316 use super::*;
317
318 #[tokio::test]
319 async fn basic_worker() {
320 const ITEMS: usize = 10;
321 let pool = MySqlPool::connect(&std::env::var("DATABASE_URL").unwrap())
322 .await
323 .unwrap();
324 MySqlStorage::setup(&pool).await.unwrap();
325
326 let mut backend = MySqlStorage::new(&pool);
327
328 let mut start = 0;
329
330 let mut items = stream::repeat_with(move || {
331 start += 1;
332 start
333 })
334 .take(ITEMS);
335 backend.push_stream(&mut items).await.unwrap();
336
337 async fn send_reminder(item: usize, wrk: WorkerContext) -> Result<(), BoxDynError> {
338 if ITEMS == item {
339 wrk.stop().unwrap();
340 }
341 Ok(())
342 }
343
344 let worker = WorkerBuilder::new("rango-tango-1")
345 .backend(backend)
346 .build(send_reminder);
347 worker.run().await.unwrap();
348 }
349
350 #[tokio::test]
351 async fn test_workflow() {
352 let workflow = Workflow::new("odd-numbers-workflow")
353 .and_then(|a: usize| async move { Ok::<_, BoxDynError>((0..=a).collect::<Vec<_>>()) })
354 .filter_map(|x| async move { if x % 2 != 0 { Some(x) } else { None } })
355 .filter_map(|x| async move { if x % 3 != 0 { Some(x) } else { None } })
356 .filter_map(|x| async move { if x % 5 != 0 { Some(x) } else { None } })
357 .delay_for(Duration::from_millis(1000))
358 .and_then(|a: Vec<usize>| async move {
359 println!("Sum: {}", a.iter().sum::<usize>());
360 Err::<(), BoxDynError>("Intentional Error".into())
361 });
362
363 let pool = MySqlPool::connect(&std::env::var("DATABASE_URL").unwrap())
364 .await
365 .unwrap();
366 let mut mysql = MySqlStorage::new_with_config(
367 &pool,
368 &Config::new("workflow-queue").with_poll_interval(
369 StrategyBuilder::new()
370 .apply(IntervalStrategy::new(Duration::from_millis(100)))
371 .build(),
372 ),
373 );
374
375 MySqlStorage::setup(&pool).await.unwrap();
376
377 mysql.push_start(100usize).await.unwrap();
378
379 let worker = WorkerBuilder::new("rango-tango")
380 .backend(mysql)
381 .on_event(|ctx, ev| {
382 println!("On Event = {:?}", ev);
383 if matches!(ev, Event::Error(_)) {
384 ctx.stop().unwrap();
385 }
386 })
387 .build(workflow);
388 worker.run().await.unwrap();
389 }
390
391 #[tokio::test]
392 async fn test_workflow_complete() {
393 #[derive(Debug, Serialize, Deserialize, Clone)]
394 struct PipelineConfig {
395 min_confidence: f32,
396 enable_sentiment: bool,
397 }
398
399 #[derive(Debug, Serialize, Deserialize)]
400 struct UserInput {
401 text: String,
402 }
403
404 #[derive(Debug, Serialize, Deserialize)]
405 struct Classified {
406 text: String,
407 label: String,
408 confidence: f32,
409 }
410
411 #[derive(Debug, Serialize, Deserialize)]
412 struct Summary {
413 text: String,
414 sentiment: Option<String>,
415 }
416
417 let workflow = Workflow::new("text-pipeline")
418 .and_then(|input: UserInput, mut worker: WorkerContext| async move {
420 worker.emit(&Event::custom(format!(
421 "Preprocessing input: {}",
422 input.text
423 )));
424 let processed = input.text.to_lowercase();
425 Ok::<_, BoxDynError>(processed)
426 })
427 .and_then(|text: String| async move {
429 let confidence = 0.85; let items = text.split_whitespace().collect::<Vec<_>>();
431 let results = items
432 .into_iter()
433 .map(|x| Classified {
434 text: x.to_string(),
435 label: if x.contains("rust") {
436 "Tech"
437 } else {
438 "General"
439 }
440 .to_string(),
441 confidence,
442 })
443 .collect::<Vec<_>>();
444 Ok::<_, BoxDynError>(results)
445 })
446 .filter_map(
448 |c: Classified| async move { if c.confidence >= 0.6 { Some(c) } else { None } },
449 )
450 .filter_map(move |c: Classified, config: Data<PipelineConfig>| {
451 let cfg = config.enable_sentiment;
452 async move {
453 if !cfg {
454 return Some(Summary {
455 text: c.text,
456 sentiment: None,
457 });
458 }
459
460 let sentiment = if c.text.contains("delightful") {
462 "positive"
463 } else {
464 "neutral"
465 };
466 Some(Summary {
467 text: c.text,
468 sentiment: Some(sentiment.to_string()),
469 })
470 }
471 })
472 .and_then(|a: Vec<Summary>, mut worker: WorkerContext| async move {
473 worker.emit(&Event::Custom(Box::new(format!(
474 "Generated {} summaries",
475 a.len()
476 ))));
477 worker.stop()
478 });
479
480 let pool = MySqlPool::connect(&std::env::var("DATABASE_URL").unwrap())
481 .await
482 .unwrap();
483 let mut mysql = MySqlStorage::new_with_config(&pool, &Config::new("text-pipeline"));
484
485 MySqlStorage::setup(&pool).await.unwrap();
486
487 let input = UserInput {
488 text: "Rust makes systems programming delightful!".to_string(),
489 };
490 mysql.push_start(input).await.unwrap();
491
492 let worker = WorkerBuilder::new("rango-tango")
493 .backend(mysql)
494 .data(PipelineConfig {
495 min_confidence: 0.8,
496 enable_sentiment: true,
497 })
498 .on_event(|ctx, ev| match ev {
499 Event::Custom(msg) => {
500 if let Some(m) = msg.downcast_ref::<String>() {
501 println!("Custom Message: {}", m);
502 }
503 }
504 Event::Error(_) => {
505 println!("On Error = {:?}", ev);
506 ctx.stop().unwrap();
507 }
508 _ => {
509 println!("On Event = {:?}", ev);
510 }
511 })
512 .build(workflow);
513 worker.run().await.unwrap();
514 }
515}