1use futures_core::stream::BoxStream;
92use futures_sink::Sink;
93use futures_util::SinkExt;
94use futures_util::{Stream, StreamExt};
95use std::pin::Pin;
96use std::sync::Arc;
97use std::task::{Context, Poll};
98use std::{fmt, marker::PhantomData};
99use thiserror::Error;
100use tower_layer::Identity;
101
102use crate::backend::TaskStream;
103use crate::backend::codec::Codec;
104use crate::backend::codec::IdentityCodec;
105use crate::error::BoxDynError;
106use crate::features_table;
107use crate::{backend::Backend, task::Task, worker::context::WorkerContext};
108
109#[doc = features_table! {
124 setup = "{ unreachable!() }",
125 TaskSink => supported("Ability to push new tasks", false),
126 Serialization => supported("Serialization support for arguments", false),
127 FetchById => not_supported("Allow fetching a task by its ID"),
128 RegisterWorker => not_implemented("Allow registering a worker with the backend"),
129 PipeExt => limited("Allow other backends to pipe to this backend", false), MakeShared => not_implemented("Share the same [`CustomBackend`] across multiple workers", false),
131 Workflow => not_implemented("Flexible enough to support workflows"),
132 WaitForCompletion => not_implemented("Wait for tasks to complete without blocking"), ResumeById => not_supported("Resume a task by its ID"),
134 ResumeAbandoned => not_supported("Resume abandoned tasks"),
135 ListWorkers => not_implemented("List all workers registered with the backend"),
136 ListTasks => not_implemented("List all tasks in the backend"),
137}]
138#[pin_project::pin_project]
139#[must_use = "Custom backends must be polled or used as a sink"]
140pub struct CustomBackend<Args, DB, Fetch, Sink, IdType, Codec = IdentityCodec, Config = ()> {
141 _marker: PhantomData<(Args, IdType, Codec)>,
142 db: DB,
143 fetcher: Arc<Box<dyn Fn(&mut DB, &Config, &WorkerContext) -> Fetch + Send + Sync>>,
144 sinker: Arc<Box<dyn Fn(&mut DB, &Config) -> Sink + Send + Sync>>,
145 #[pin]
146 current_sink: Sink,
147 config: Config,
148}
149
150impl<Args, DB, Fetch, Sink, IdType, Codec, Config> Clone
151 for CustomBackend<Args, DB, Fetch, Sink, IdType, Codec, Config>
152where
153 DB: Clone,
154 Config: Clone,
155{
156 fn clone(&self) -> Self {
157 let mut db = self.db.clone();
158 let current_sink = (self.sinker)(&mut db, &self.config);
159 Self {
160 _marker: PhantomData,
161 db: db,
162 fetcher: Arc::clone(&self.fetcher),
163 sinker: Arc::clone(&self.sinker),
164 current_sink,
165 config: self.config.clone(),
166 }
167 }
168}
169
170impl<Args, DB, Fetch, Sink, IdType, Codec, Config> fmt::Debug
171 for CustomBackend<Args, DB, Fetch, Sink, IdType, Codec, Config>
172where
173 DB: fmt::Debug,
174 Config: fmt::Debug,
175{
176 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
177 f.debug_struct("CustomBackend")
178 .field(
179 "_marker",
180 &format_args!(
181 "PhantomData<({}, {})>",
182 std::any::type_name::<Args>(),
183 std::any::type_name::<IdType>()
184 ),
185 )
186 .field("db", &self.db)
187 .field("fetcher", &"Fn(&mut DB, &Config, &WorkerContext) -> Fetch")
188 .field("sink", &"Fn(&mut DB, &Config) -> Sink")
189 .field("config", &self.config)
190 .finish()
191 }
192}
193
194pub struct BackendBuilder<Args, DB, Fetch, Sink, IdType, Codec = IdentityCodec, Config = ()> {
198 _marker: PhantomData<(Args, IdType, Codec)>,
199 database: Option<DB>,
200 fetcher: Option<Box<dyn Fn(&mut DB, &Config, &WorkerContext) -> Fetch + Send + Sync + 'static>>,
201 sink: Option<Box<dyn Fn(&mut DB, &Config) -> Sink + Send + Sync + 'static>>,
202 config: Option<Config>,
203}
204
205impl<Args, DB, Fetch, Sink, IdType, Codec, Config> fmt::Debug
206 for BackendBuilder<Args, DB, Fetch, Sink, IdType, Codec, Config>
207where
208 DB: fmt::Debug,
209 Config: fmt::Debug,
210{
211 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
212 f.debug_struct("BackendBuilder")
213 .field(
214 "_marker",
215 &format_args!(
216 "PhantomData<({}, {})>",
217 std::any::type_name::<Args>(),
218 std::any::type_name::<IdType>()
219 ),
220 )
221 .field("database", &self.database)
222 .field("fetcher", &self.fetcher.as_ref().map(|_| "Some(fn)"))
223 .field("sink", &self.sink.as_ref().map(|_| "Some(fn)"))
224 .field("config", &self.config)
225 .finish()
226 }
227}
228
229impl<Args, DB, Fetch, Sink, IdType, Codec, Config> Default
230 for BackendBuilder<Args, DB, Fetch, Sink, IdType, Codec, Config>
231{
232 fn default() -> Self {
233 Self {
234 _marker: PhantomData,
235 database: None,
236 fetcher: None,
237 sink: None,
238 config: None,
239 }
240 }
241}
242
243impl<Args, DB, Fetch, Sink, IdType>
244 BackendBuilder<Args, DB, Fetch, Sink, IdType, IdentityCodec, ()>
245{
246 pub fn new() -> Self {
248 Self::new_with_cfg(())
249 }
250
251 pub fn new_with_cfg<Config>(
253 config: Config,
254 ) -> BackendBuilder<Args, DB, Fetch, Sink, IdType, IdentityCodec, Config> {
255 BackendBuilder {
256 config: Some(config),
257 ..Default::default()
258 }
259 }
260}
261
262impl<Args, DB, Fetch, Sink, IdType, Codec, Config>
263 BackendBuilder<Args, DB, Fetch, Sink, IdType, Codec, Config>
264{
265 pub fn with_codec<NewCodec>(
267 self,
268 ) -> BackendBuilder<Args, DB, Fetch, Sink, IdType, NewCodec, Config> {
269 BackendBuilder {
270 _marker: PhantomData,
271 database: self.database,
272 fetcher: self.fetcher,
273 sink: self.sink,
274 config: self.config,
275 }
276 }
277
278 pub fn database(mut self, db: DB) -> Self {
280 self.database = Some(db);
281 self
282 }
283
284 pub fn fetcher<F: Fn(&mut DB, &Config, &WorkerContext) -> Fetch + Send + Sync + 'static>(
286 mut self,
287 fetcher: F,
288 ) -> Self {
289 self.fetcher = Some(Box::new(fetcher));
290 self
291 }
292
293 pub fn sink<F: Fn(&mut DB, &Config) -> Sink + Send + Sync + 'static>(
295 mut self,
296 sink: F,
297 ) -> Self {
298 self.sink = Some(Box::new(sink));
299 self
300 }
301
302 pub fn build(
304 self,
305 ) -> Result<CustomBackend<Args, DB, Fetch, Sink, IdType, Codec, Config>, BuildError> {
306 let mut db = self.database.ok_or(BuildError::MissingPool)?;
307 let config = self.config.ok_or(BuildError::MissingConfig)?;
308 let sink_fn = self.sink.ok_or(BuildError::MissingSink)?;
309 let sink = sink_fn(&mut db, &config);
310
311 Ok(CustomBackend {
312 _marker: PhantomData,
313 db: db,
314 fetcher: self
315 .fetcher
316 .map(Arc::new)
317 .ok_or(BuildError::MissingFetcher)?,
318 current_sink: sink,
319 sinker: Arc::new(sink_fn),
320 config,
321 })
322 }
323}
324
325#[derive(Debug, Error)]
327pub enum BuildError {
328 #[error("Database db is required")]
330 MissingPool,
331 #[error("Fetcher is required")]
333 MissingFetcher,
334 #[error("Sink is required")]
336 MissingSink,
337 #[error("Config is required")]
339 MissingConfig,
340}
341
342impl<Args, DB, Fetch, Sink, IdType: Clone, E, Ctx: Default, Encode, Config> Backend
343 for CustomBackend<Args, DB, Fetch, Sink, IdType, Encode, Config>
344where
345 Fetch: Stream<Item = Result<Option<Task<Encode::Compact, Ctx, IdType>>, E>> + Send + 'static,
346 Encode: Codec<Args> + Send + 'static,
347 Encode::Error: Into<BoxDynError>,
348 E: Into<BoxDynError>,
349{
350 type Args = Args;
351 type IdType = IdType;
352
353 type Context = Ctx;
354
355 type Error = BoxDynError;
356
357 type Stream = TaskStream<Task<Args, Ctx, IdType>, BoxDynError>;
358
359 type Codec = Encode;
360
361 type Compact = Encode::Compact;
362
363 type Beat = BoxStream<'static, Result<(), Self::Error>>;
364
365 type Layer = Identity;
366
367 fn heartbeat(&self, _: &WorkerContext) -> Self::Beat {
368 futures_util::stream::once(async { Ok(()) }).boxed()
369 }
370
371 fn middleware(&self) -> Self::Layer {
372 Identity::new()
373 }
374
375 fn poll(mut self, worker: &WorkerContext) -> Self::Stream {
376 (self.fetcher)(&mut self.db, &self.config, worker)
377 .map(|task| match task {
378 Ok(Some(t)) => Ok(Some(
379 t.try_map(|args| Encode::decode(&args))
380 .map_err(|e| e.into())?,
381 )),
382 Ok(None) => Ok(None),
383 Err(e) => Err(e.into()),
384 })
385 .boxed()
386 }
387}
388
389impl<Args, Ctx, IdType, DB, Fetch, S, Cdc, Config> Sink<Task<Cdc::Compact, Ctx, IdType>>
390 for CustomBackend<Args, DB, Fetch, S, IdType, Cdc, Config>
391where
392 S: Sink<Task<Cdc::Compact, Ctx, IdType>>,
393 Cdc: Codec<Args> + Send + 'static,
394{
395 type Error = S::Error;
396
397 fn poll_ready(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
398 self.project().current_sink.poll_ready_unpin(cx)
399 }
400
401 fn start_send(
402 self: Pin<&mut Self>,
403 item: Task<Cdc::Compact, Ctx, IdType>,
404 ) -> Result<(), Self::Error> {
405 self.project().current_sink.start_send_unpin(item)
406 }
407
408 fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
409 self.project().current_sink.poll_flush_unpin(cx)
410 }
411
412 fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
413 self.project().current_sink.poll_close_unpin(cx)
414 }
415}
416
417#[cfg(test)]
418mod tests {
419 use std::{collections::VecDeque, time::Duration};
420
421 use futures_util::{FutureExt, lock::Mutex, sink, stream};
422
423 use crate::{
424 backend::TaskSink,
425 error::BoxDynError,
426 worker::{builder::WorkerBuilder, ext::event_listener::EventListenerExt},
427 };
428
429 use super::*;
430
431 const ITEMS: u32 = 10;
432
433 #[tokio::test]
434 async fn basic_custom_backend() {
435 let memory: Arc<Mutex<VecDeque<Task<u32, ()>>>> = Arc::new(Mutex::new(VecDeque::new()));
436
437 let mut backend = BackendBuilder::new()
438 .database(memory)
439 .fetcher(|db, _, _| {
440 stream::unfold(db.clone(), |p| async move {
441 tokio::time::sleep(Duration::from_millis(100)).await; let mut db = p.lock().await;
443 let item = db.pop_front();
444 drop(db);
445 match item {
446 Some(item) => Some((Ok::<_, BoxDynError>(Some(item)), p)),
447 None => Some((Ok::<_, BoxDynError>(None), p)),
448 }
449 })
450 .boxed()
451 })
452 .sink(|db, _| {
453 sink::unfold(db.clone(), move |p, item| {
454 async move {
455 let mut db = p.lock().await;
456 db.push_back(item);
457 drop(db);
458 Ok::<_, BoxDynError>(p)
459 }
460 .boxed()
461 })
462 })
463 .build()
464 .unwrap();
465
466 for i in 0..ITEMS {
467 TaskSink::push(&mut backend, i).await.unwrap();
468 }
469
470 async fn task(task: u32, ctx: WorkerContext) -> Result<(), BoxDynError> {
471 tokio::time::sleep(Duration::from_secs(1)).await;
472 if task == ITEMS - 1 {
473 ctx.stop().unwrap();
474 return Err("Worker stopped!")?;
475 }
476 Ok(())
477 }
478
479 let worker = WorkerBuilder::new("rango-tango")
480 .backend(backend)
481 .on_event(|ctx, ev| {
482 println!("On Event = {:?} from {}", ev, ctx.name());
483 })
484 .build(task);
485 worker.run().await.unwrap();
486 }
487}