1use futures_core::stream::BoxStream;
92use futures_sink::Sink;
93use futures_util::SinkExt;
94use futures_util::{Stream, StreamExt};
95use std::pin::Pin;
96use std::sync::Arc;
97use std::task::{Context, Poll};
98use std::{fmt, marker::PhantomData};
99use thiserror::Error;
100use tower_layer::Identity;
101
102use crate::backend::codec::Codec;
103use crate::backend::codec::IdentityCodec;
104use crate::backend::TaskStream;
105use crate::error::BoxDynError;
106use crate::features_table;
107use crate::{backend::Backend, task::Task, worker::context::WorkerContext};
108
109#[doc = features_table! {
124 setup = unreachable!();,
125 TaskSink => supported("Ability to push new tasks", false),
126 Serialization => supported("Serialization support for arguments", false),
127 FetchById => not_supported("Allow fetching a task by its ID"),
128 RegisterWorker => not_implemented("Allow registering a worker with the backend"),
129 PipeExt => limited("Allow other backends to pipe to this backend", false), MakeShared => not_implemented("Share the same [`CustomBackend`] across multiple workers", false),
131 Workflow => not_implemented("Flexible enough to support workflows"),
132 WaitForCompletion => not_implemented("Wait for tasks to complete without blocking"), ResumeById => not_supported("Resume a task by its ID"),
134 ResumeAbandoned => not_supported("Resume abandoned tasks"),
135 ListWorkers => not_implemented("List all workers registered with the backend"),
136 ListTasks => not_implemented("List all tasks in the backend"),
137}]
138#[pin_project::pin_project]
139#[must_use = "Custom backends must be polled or used as a sink"]
140pub struct CustomBackend<Args, DB, Fetch, Sink, IdType, Codec = IdentityCodec, Config = ()> {
141 _marker: PhantomData<(Args, IdType, Codec)>,
142 db: DB,
143 fetcher: Arc<Box<dyn Fn(&mut DB, &Config, &WorkerContext) -> Fetch + Send + Sync>>,
144 sinker: Arc<Box<dyn Fn(&mut DB, &Config) -> Sink + Send + Sync>>,
145 #[pin]
146 current_sink: Sink,
147 config: Config,
148}
149
150impl<Args, DB, Fetch, Sink, IdType, Codec, Config> Clone
151 for CustomBackend<Args, DB, Fetch, Sink, IdType, Codec, Config>
152where
153 DB: Clone,
154 Config: Clone,
155{
156 fn clone(&self) -> Self {
157 let mut db = self.db.clone();
158 let current_sink = (self.sinker)(&mut db, &self.config);
159 Self {
160 _marker: PhantomData,
161 db: db,
162 fetcher: Arc::clone(&self.fetcher),
163 sinker: Arc::clone(&self.sinker),
164 current_sink,
165 config: self.config.clone(),
166 }
167 }
168}
169
170impl<Args, DB, Fetch, Sink, IdType, Codec, Config> fmt::Debug
171 for CustomBackend<Args, DB, Fetch, Sink, IdType, Codec, Config>
172where
173 DB: fmt::Debug,
174 Config: fmt::Debug,
175{
176 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
177 f.debug_struct("CustomBackend")
178 .field(
179 "_marker",
180 &format_args!(
181 "PhantomData<({}, {})>",
182 std::any::type_name::<Args>(),
183 std::any::type_name::<IdType>()
184 ),
185 )
186 .field("db", &self.db)
187 .field("fetcher", &"Fn(&mut DB, &Config, &WorkerContext) -> Fetch")
188 .field("sink", &"Fn(&mut DB, &Config) -> Sink")
189 .field("config", &self.config)
190 .finish()
191 }
192}
193
194pub struct BackendBuilder<Args, DB, Fetch, Sink, IdType, Codec = IdentityCodec, Config = ()> {
198 _marker: PhantomData<(Args, IdType, Codec)>,
199 database: Option<DB>,
200 fetcher: Option<Box<dyn Fn(&mut DB, &Config, &WorkerContext) -> Fetch + Send + Sync + 'static>>,
201 sink: Option<Box<dyn Fn(&mut DB, &Config) -> Sink + Send + Sync + 'static>>,
202 config: Option<Config>,
203}
204
205impl<Args, DB, Fetch, Sink, IdType, Codec, Config> fmt::Debug
206 for BackendBuilder<Args, DB, Fetch, Sink, IdType, Codec, Config>
207where
208 DB: fmt::Debug,
209 Config: fmt::Debug,
210{
211 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
212 f.debug_struct("BackendBuilder")
213 .field(
214 "_marker",
215 &format_args!(
216 "PhantomData<({}, {})>",
217 std::any::type_name::<Args>(),
218 std::any::type_name::<IdType>()
219 ),
220 )
221 .field("database", &self.database)
222 .field("fetcher", &self.fetcher.as_ref().map(|_| "Some(fn)"))
223 .field("sink", &self.sink.as_ref().map(|_| "Some(fn)"))
224 .field("config", &self.config)
225 .finish()
226 }
227}
228
229impl<Args, DB, Fetch, Sink, IdType, Codec, Config> Default
230 for BackendBuilder<Args, DB, Fetch, Sink, IdType, Codec, Config>
231{
232 fn default() -> Self {
233 Self {
234 _marker: PhantomData,
235 database: None,
236 fetcher: None,
237 sink: None,
238 config: None,
239 }
240 }
241}
242
243impl<Args, DB, Fetch, Sink, IdType>
244 BackendBuilder<Args, DB, Fetch, Sink, IdType, IdentityCodec, ()>
245{
246 pub fn new() -> Self {
248 Self::new_with_cfg(())
249 }
250}
251
252impl<Args, DB, Fetch, Sink, IdType, Codec, Config>
253 BackendBuilder<Args, DB, Fetch, Sink, IdType, Codec, Config>
254{
255 pub fn new_with_cfg(config: Config) -> Self {
257 Self {
258 config: Some(config),
259 ..Default::default()
260 }
261 }
262
263 pub fn with_codec<NewCodec>(
265 self,
266 ) -> BackendBuilder<Args, DB, Fetch, Sink, IdType, NewCodec, Config> {
267 BackendBuilder {
268 _marker: PhantomData,
269 database: self.database,
270 fetcher: self.fetcher,
271 sink: self.sink,
272 config: self.config,
273 }
274 }
275
276 pub fn database(mut self, db: DB) -> Self {
278 self.database = Some(db);
279 self
280 }
281
282 pub fn fetcher<F: Fn(&mut DB, &Config, &WorkerContext) -> Fetch + Send + Sync + 'static>(
284 mut self,
285 fetcher: F,
286 ) -> Self {
287 self.fetcher = Some(Box::new(fetcher));
288 self
289 }
290
291 pub fn sink<F: Fn(&mut DB, &Config) -> Sink + Send + Sync + 'static>(
293 mut self,
294 sink: F,
295 ) -> Self {
296 self.sink = Some(Box::new(sink));
297 self
298 }
299
300 pub fn build(
302 self,
303 ) -> Result<CustomBackend<Args, DB, Fetch, Sink, IdType, Codec, Config>, BuildError> {
304 let mut db = self.database.ok_or(BuildError::MissingPool)?;
305 let config = self.config.ok_or(BuildError::MissingConfig)?;
306 let sink_fn = self.sink.ok_or(BuildError::MissingSink)?;
307 let sink = sink_fn(&mut db, &config);
308
309 Ok(CustomBackend {
310 _marker: PhantomData,
311 db: db,
312 fetcher: self
313 .fetcher
314 .map(Arc::new)
315 .ok_or(BuildError::MissingFetcher)?,
316 current_sink: sink,
317 sinker: Arc::new(sink_fn),
318 config,
319 })
320 }
321}
322
323#[derive(Debug, Error)]
325pub enum BuildError {
326 #[error("Database db is required")]
328 MissingPool,
329 #[error("Fetcher is required")]
331 MissingFetcher,
332 #[error("Sink is required")]
334 MissingSink,
335 #[error("Config is required")]
337 MissingConfig,
338}
339
340impl<Args, DB, Fetch, Sink, IdType: Clone, E, Ctx: Default, Encode, Config> Backend<Args>
341 for CustomBackend<Args, DB, Fetch, Sink, IdType, Encode, Config>
342where
343 Fetch: Stream<Item = Result<Option<Task<Encode::Compact, Ctx, IdType>>, E>> + Send + 'static,
344 Encode: Codec<Args> + Send + 'static,
345 Encode::Error: Into<BoxDynError>,
346 E: Into<BoxDynError>,
347{
348 type IdType = IdType;
349
350 type Context = Ctx;
351
352 type Error = BoxDynError;
353
354 type Stream = TaskStream<Task<Args, Ctx, IdType>, BoxDynError>;
355
356 type Codec = Encode;
357
358 type Beat = BoxStream<'static, Result<(), Self::Error>>;
359
360 type Layer = Identity;
361
362 fn heartbeat(&self, _: &WorkerContext) -> Self::Beat {
363 futures_util::stream::once(async { Ok(()) }).boxed()
364 }
365
366 fn middleware(&self) -> Self::Layer {
367 Identity::new()
368 }
369
370 fn poll(mut self, worker: &WorkerContext) -> Self::Stream {
371 (self.fetcher)(&mut self.db, &self.config, worker)
372 .map(|task| match task {
373 Ok(Some(t)) => Ok(Some(
374 t.try_map(|args| Encode::decode(&args))
375 .map_err(|e| e.into())?,
376 )),
377 Ok(None) => Ok(None),
378 Err(e) => Err(e.into()),
379 })
380 .boxed()
381 }
382}
383
384impl<Args, Ctx, IdType, DB, Fetch, S, Codec, Config> Sink<Task<Args, Ctx, IdType>>
385 for CustomBackend<Args, DB, Fetch, S, IdType, Codec, Config>
386where
387 S: Sink<Task<Args, Ctx, IdType>>,
388{
389 type Error = S::Error;
390
391 fn poll_ready(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
392 self.project().current_sink.poll_ready_unpin(cx)
393 }
394
395 fn start_send(self: Pin<&mut Self>, item: Task<Args, Ctx, IdType>) -> Result<(), Self::Error> {
396 self.project().current_sink.start_send_unpin(item)
397 }
398
399 fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
400 self.project().current_sink.poll_flush_unpin(cx)
401 }
402
403 fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
404 self.project().current_sink.poll_close_unpin(cx)
405 }
406}
407
408#[cfg(test)]
409mod tests {
410 use std::{collections::VecDeque, time::Duration};
411
412 use futures_util::{lock::Mutex, sink, stream, FutureExt};
413
414 use crate::{
415 backend::TaskSink,
416 error::BoxDynError,
417 worker::{builder::WorkerBuilder, ext::event_listener::EventListenerExt},
418 };
419
420 use super::*;
421
422 const ITEMS: u32 = 10;
423
424 #[tokio::test]
425 async fn basic_custom_backend() {
426 let memory: Arc<Mutex<VecDeque<Task<u32, ()>>>> = Arc::new(Mutex::new(VecDeque::new()));
427
428 let mut backend = BackendBuilder::new()
429 .database(memory)
430 .fetcher(|db, _, _| {
431 stream::unfold(db.clone(), |p| async move {
432 tokio::time::sleep(Duration::from_millis(100)).await; let mut db = p.lock().await;
434 let item = db.pop_front();
435 drop(db);
436 match item {
437 Some(item) => Some((Ok::<_, BoxDynError>(Some(item)), p)),
438 None => Some((Ok::<_, BoxDynError>(None), p)),
439 }
440 })
441 .boxed()
442 })
443 .sink(|db, _| {
444 sink::unfold(db.clone(), move |p, item| {
445 async move {
446 let mut db = p.lock().await;
447 db.push_back(item);
448 drop(db);
449 Ok::<_, BoxDynError>(p)
450 }
451 .boxed()
452 })
453 })
454 .build()
455 .unwrap();
456
457 for i in 0..ITEMS {
458 TaskSink::push(&mut backend, i).await.unwrap();
459 }
460
461 async fn task(task: u32, ctx: WorkerContext) -> Result<(), BoxDynError> {
462 tokio::time::sleep(Duration::from_secs(1)).await;
463 if task == ITEMS - 1 {
464 ctx.stop().unwrap();
465 return Err("Worker stopped!")?;
466 }
467 Ok(())
468 }
469
470 let worker = WorkerBuilder::new("rango-tango")
471 .backend(backend)
472 .on_event(|ctx, ev| {
473 println!("On Event = {:?} from {}", ev, ctx.name());
474 })
475 .build(task);
476 worker.run().await.unwrap();
477 }
478}