1use futures_core::stream::BoxStream;
92use futures_sink::Sink;
93use futures_util::SinkExt;
94use futures_util::{Stream, StreamExt};
95use std::pin::Pin;
96use std::sync::Arc;
97use std::task::{Context, Poll};
98use std::{fmt, marker::PhantomData};
99use thiserror::Error;
100use tower_layer::Identity;
101
102use crate::backend::TaskStream;
103use crate::error::BoxDynError;
104use crate::features_table;
105use crate::{backend::Backend, task::Task, worker::context::WorkerContext};
106
107type Fetcher<DB, Config, Fetch> =
108 Arc<Box<dyn Fn(&mut DB, &Config, &WorkerContext) -> Fetch + Send + Sync>>;
109
110type Sinker<DB, Config, Sink> = Arc<Box<dyn Fn(&mut DB, &Config) -> Sink + Send + Sync>>;
111
112#[doc = features_table! {
127 setup = "{ unreachable!() }",
128 TaskSink => supported("Ability to push new tasks", false),
129 Serialization => supported("Serialization support for arguments", false),
130 FetchById => not_supported("Allow fetching a task by its ID"),
131 RegisterWorker => not_implemented("Allow registering a worker with the backend"),
132 PipeExt => limited("Allow other backends to pipe to this backend", false), MakeShared => not_implemented("Share the same [`CustomBackend`] across multiple workers", false),
134 Workflow => not_implemented("Flexible enough to support workflows"),
135 WaitForCompletion => not_implemented("Wait for tasks to complete without blocking"), ResumeById => not_supported("Resume a task by its ID"),
137 ResumeAbandoned => not_supported("Resume abandoned tasks"),
138 ListWorkers => not_implemented("List all workers registered with the backend"),
139 ListTasks => not_implemented("List all tasks in the backend"),
140}]
141#[pin_project::pin_project]
142#[must_use = "Custom backends must be polled or used as a sink"]
143pub struct CustomBackend<Args, DB, Fetch, Sink, IdType, Config = ()> {
144 _marker: PhantomData<(Args, IdType)>,
145 db: DB,
146 fetcher: Fetcher<DB, Config, Fetch>,
147 sinker: Sinker<DB, Config, Sink>,
148 #[pin]
149 current_sink: Sink,
150 config: Config,
151}
152
153impl<Args, DB, Fetch, Sink, IdType, Config> Clone
154 for CustomBackend<Args, DB, Fetch, Sink, IdType, Config>
155where
156 DB: Clone,
157 Config: Clone,
158{
159 fn clone(&self) -> Self {
160 let mut db = self.db.clone();
161 let current_sink = (self.sinker)(&mut db, &self.config);
162 Self {
163 _marker: PhantomData,
164 db,
165 fetcher: Arc::clone(&self.fetcher),
166 sinker: Arc::clone(&self.sinker),
167 current_sink,
168 config: self.config.clone(),
169 }
170 }
171}
172
173impl<Args, DB, Fetch, Sink, IdType, Config> fmt::Debug
174 for CustomBackend<Args, DB, Fetch, Sink, IdType, Config>
175where
176 DB: fmt::Debug,
177 Config: fmt::Debug,
178{
179 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
180 f.debug_struct("CustomBackend")
181 .field(
182 "_marker",
183 &format_args!(
184 "PhantomData<({}, {})>",
185 std::any::type_name::<Args>(),
186 std::any::type_name::<IdType>()
187 ),
188 )
189 .field("db", &self.db)
190 .field("fetcher", &"Fn(&mut DB, &Config, &WorkerContext) -> Fetch")
191 .field("sink", &"Fn(&mut DB, &Config) -> Sink")
192 .field("config", &self.config)
193 .finish()
194 }
195}
196
197type FetcherBuilder<DB, Config, Fetch> =
198 Box<dyn Fn(&mut DB, &Config, &WorkerContext) -> Fetch + Send + Sync + 'static>;
199
200type SinkerBuilder<DB, Config, Sink> =
201 Box<dyn Fn(&mut DB, &Config) -> Sink + Send + Sync + 'static>;
202
203pub struct BackendBuilder<Args, DB, Fetch, Sink, IdType, Config = ()> {
207 _marker: PhantomData<(Args, IdType)>,
208 database: Option<DB>,
209 fetcher: Option<FetcherBuilder<DB, Config, Fetch>>,
210 sink: Option<SinkerBuilder<DB, Config, Sink>>,
211 config: Option<Config>,
212}
213
214impl<Args, DB, Fetch, Sink, IdType, Config> fmt::Debug
215 for BackendBuilder<Args, DB, Fetch, Sink, IdType, Config>
216where
217 DB: fmt::Debug,
218 Config: fmt::Debug,
219{
220 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
221 f.debug_struct("BackendBuilder")
222 .field(
223 "_marker",
224 &format_args!(
225 "PhantomData<({}, {})>",
226 std::any::type_name::<Args>(),
227 std::any::type_name::<IdType>()
228 ),
229 )
230 .field("database", &self.database)
231 .field("fetcher", &self.fetcher.as_ref().map(|_| "Some(fn)"))
232 .field("sink", &self.sink.as_ref().map(|_| "Some(fn)"))
233 .field("config", &self.config)
234 .finish()
235 }
236}
237
238impl<Args, DB, Fetch, Sink, IdType, Config> Default
239 for BackendBuilder<Args, DB, Fetch, Sink, IdType, Config>
240{
241 fn default() -> Self {
242 Self {
243 _marker: PhantomData,
244 database: None,
245 fetcher: None,
246 sink: None,
247 config: None,
248 }
249 }
250}
251
252impl<Args, DB, Fetch, Sink, IdType> BackendBuilder<Args, DB, Fetch, Sink, IdType, ()> {
253 #[must_use]
255 pub fn new() -> Self {
256 Self::new_with_cfg(())
257 }
258
259 pub fn new_with_cfg<Config>(
261 config: Config,
262 ) -> BackendBuilder<Args, DB, Fetch, Sink, IdType, Config> {
263 BackendBuilder {
264 config: Some(config),
265 ..Default::default()
266 }
267 }
268}
269
270impl<Args, DB, Fetch, Sink, IdType, Config> BackendBuilder<Args, DB, Fetch, Sink, IdType, Config> {
271 #[must_use]
273 pub fn database(mut self, db: DB) -> Self {
274 self.database = Some(db);
275 self
276 }
277
278 #[must_use]
280 pub fn fetcher<F: Fn(&mut DB, &Config, &WorkerContext) -> Fetch + Send + Sync + 'static>(
281 mut self,
282 fetcher: F,
283 ) -> Self {
284 self.fetcher = Some(Box::new(fetcher));
285 self
286 }
287
288 #[must_use]
290 pub fn sink<F: Fn(&mut DB, &Config) -> Sink + Send + Sync + 'static>(
291 mut self,
292 sink: F,
293 ) -> Self {
294 self.sink = Some(Box::new(sink));
295 self
296 }
297
298 #[allow(clippy::type_complexity)]
299 pub fn build(self) -> Result<CustomBackend<Args, DB, Fetch, Sink, IdType, Config>, BuildError> {
301 let mut db = self.database.ok_or(BuildError::MissingDb)?;
302 let config = self.config.ok_or(BuildError::MissingConfig)?;
303 let sink_fn = self.sink.ok_or(BuildError::MissingSink)?;
304 let sink = sink_fn(&mut db, &config);
305
306 Ok(CustomBackend {
307 _marker: PhantomData,
308 db,
309 fetcher: self
310 .fetcher
311 .map(Arc::new)
312 .ok_or(BuildError::MissingFetcher)?,
313 current_sink: sink,
314 sinker: Arc::new(sink_fn),
315 config,
316 })
317 }
318}
319
320#[derive(Debug, Error)]
322pub enum BuildError {
323 #[error("Database db is required")]
325 MissingDb,
326 #[error("Fetcher is required")]
328 MissingFetcher,
329 #[error("Sink is required")]
331 MissingSink,
332 #[error("Config is required")]
334 MissingConfig,
335}
336
337#[derive(Debug, Error)]
339pub enum CustomBackendError {
340 #[error("Inner error: {0}")]
342 Inner(#[from] BoxDynError),
343}
344
345impl<Args, DB, Fetch, Sink, IdType: Clone, E, Ctx: Default, Config> Backend
346 for CustomBackend<Args, DB, Fetch, Sink, IdType, Config>
347where
348 Fetch: Stream<Item = Result<Option<Task<Args, Ctx, IdType>>, E>> + Send + 'static,
349 E: Into<BoxDynError>,
350{
351 type Args = Args;
352 type IdType = IdType;
353
354 type Context = Ctx;
355
356 type Error = CustomBackendError;
357
358 type Stream = TaskStream<Task<Args, Ctx, IdType>, CustomBackendError>;
359
360 type Beat = BoxStream<'static, Result<(), Self::Error>>;
361
362 type Layer = Identity;
363
364 fn heartbeat(&self, _: &WorkerContext) -> Self::Beat {
365 futures_util::stream::once(async { Ok(()) }).boxed()
366 }
367
368 fn middleware(&self) -> Self::Layer {
369 Identity::new()
370 }
371
372 fn poll(mut self, worker: &WorkerContext) -> Self::Stream {
373 (self.fetcher)(&mut self.db, &self.config, worker)
374 .map(|task| match task {
375 Ok(Some(t)) => Ok(Some(t)),
376 Ok(None) => Ok(None),
377 Err(e) => Err(e.into().into()),
378 })
379 .boxed()
380 }
381}
382
383impl<Args, Ctx, IdType, DB, Fetch, S, Config> Sink<Task<Args, Ctx, IdType>>
384 for CustomBackend<Args, DB, Fetch, S, IdType, Config>
385where
386 S: Sink<Task<Args, Ctx, IdType>>,
387 S::Error: Into<BoxDynError>,
388{
389 type Error = CustomBackendError;
390
391 fn poll_ready(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
392 self.project()
393 .current_sink
394 .poll_ready_unpin(cx)
395 .map_err(|e| CustomBackendError::Inner(e.into()))
396 }
397
398 fn start_send(self: Pin<&mut Self>, item: Task<Args, Ctx, IdType>) -> Result<(), Self::Error> {
399 self.project()
400 .current_sink
401 .start_send_unpin(item)
402 .map_err(|e| CustomBackendError::Inner(e.into()))
403 }
404
405 fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
406 self.project()
407 .current_sink
408 .poll_flush_unpin(cx)
409 .map_err(|e| CustomBackendError::Inner(e.into()))
410 }
411
412 fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
413 self.project()
414 .current_sink
415 .poll_close_unpin(cx)
416 .map_err(|e| CustomBackendError::Inner(e.into()))
417 }
418}
419
420#[cfg(test)]
421mod tests {
422 use std::{collections::VecDeque, time::Duration};
423
424 use futures_util::{FutureExt, lock::Mutex, sink, stream};
425
426 use crate::{
427 error::BoxDynError,
428 task::task_id::RandomId,
429 worker::{builder::WorkerBuilder, ext::event_listener::EventListenerExt},
430 };
431
432 use super::*;
433
434 const ITEMS: u32 = 10;
435
436 #[tokio::test]
437 async fn basic_custom_backend() {
438 let memory: Arc<Mutex<VecDeque<Task<u32, (), RandomId>>>> =
439 Arc::new(Mutex::new(VecDeque::new()));
440
441 let mut backend = BackendBuilder::new()
442 .database(memory)
443 .fetcher(|db, _, _| {
444 stream::unfold(db.clone(), |p| async move {
445 tokio::time::sleep(Duration::from_millis(100)).await; let mut db = p.lock().await;
447 let item = db.pop_front();
448 drop(db);
449 match item {
450 Some(item) => Some((Ok::<_, CustomBackendError>(Some(item)), p)),
451 None => Some((Ok::<_, CustomBackendError>(None), p)),
452 }
453 })
454 .boxed()
455 })
456 .sink(|db, _| {
457 sink::unfold(db.clone(), move |p, item| {
458 async move {
459 let mut db = p.lock().await;
460 db.push_back(item);
461 drop(db);
462 Ok::<_, CustomBackendError>(p)
463 }
464 .boxed()
465 })
466 })
467 .build()
468 .unwrap();
469
470 for i in 0..ITEMS {
471 backend.send(Task::new(i)).await.unwrap();
472 }
473
474 async fn task(task: u32, ctx: WorkerContext) -> Result<(), BoxDynError> {
475 tokio::time::sleep(Duration::from_secs(1)).await;
476 if task == ITEMS - 1 {
477 ctx.stop().unwrap();
478 return Err("Worker stopped!")?;
479 }
480 Ok(())
481 }
482
483 let worker = WorkerBuilder::new("rango-tango")
484 .backend(backend)
485 .on_event(|ctx, ev| {
486 println!("On Event = {ev:?} from {}", ctx.name());
487 })
488 .build(task);
489 worker.run().await.unwrap();
490 }
491}