1#![allow(clippy::expect_used, reason = "invariants are upheld by construction")]
4
5use std::any::Any;
6use std::sync::Arc;
7
8use crate::rate_limiter::RateLimiter;
9
10use futures::future::BoxFuture;
11use futures::stream::{BoxStream, Stream, StreamExt};
12
13use crate::task_context::TaskContext;
14pub trait Value: Any + Send + Sync + 'static {
23 fn as_any(&self) -> &dyn Any;
24 fn as_any_mut(&mut self) -> &mut dyn Any;
25 fn into_any(self: Box<Self>) -> Box<dyn Any>;
26}
27
28impl<T: Any + Send + Sync + 'static> Value for T {
29 fn as_any(&self) -> &dyn Any {
30 self
31 }
32
33 fn as_any_mut(&mut self) -> &mut dyn Any {
34 self
35 }
36
37 fn into_any(self: Box<Self>) -> Box<dyn Any> {
38 self
39 }
40}
41
42pub fn downcast_value<T: Any>(value: Box<dyn Value>) -> Result<Box<T>, Box<dyn Value>> {
47 if value.as_any().is::<T>() {
48 Ok(value
50 .into_any()
51 .downcast::<T>()
52 .expect("downcast can't fail after is::<T>() check"))
53 } else {
54 Err(value)
55 }
56}
57
58pub struct Tagged<T: Value> {
71 inner: T,
72 metadata: std::collections::HashMap<String, String>,
73}
74
75impl<T: Value> Tagged<T> {
76 pub fn new(inner: T) -> Self {
77 Self {
78 inner,
79 metadata: std::collections::HashMap::new(),
80 }
81 }
82
83 pub fn with_meta(mut self, key: impl Into<String>, value: impl Into<String>) -> Self {
84 self.metadata.insert(key.into(), value.into());
85 self
86 }
87
88 pub fn inner(&self) -> &T {
89 &self.inner
90 }
91
92 pub fn into_inner(self) -> T {
93 self.inner
94 }
95
96 pub fn meta(&self, key: &str) -> Option<&str> {
97 self.metadata.get(key).map(|s| s.as_str())
98 }
99
100 pub fn metadata(&self) -> &std::collections::HashMap<String, String> {
101 &self.metadata
102 }
103}
104
105pub fn extract_node_set(value: &dyn Value) -> Option<&str> {
117 value
118 .as_any()
119 .downcast_ref::<TaggedMeta>()
120 .and_then(|m| m.node_set.as_deref())
121}
122
123pub struct TaggedMeta {
128 pub value: Arc<dyn Value>,
130 pub node_set: Option<String>,
132}
133
134impl TaggedMeta {
135 pub fn new(value: Arc<dyn Value>) -> Self {
136 Self {
137 value,
138 node_set: None,
139 }
140 }
141
142 pub fn with_node_set(mut self, node_set: impl Into<String>) -> Self {
143 self.node_set = Some(node_set.into());
144 self
145 }
146}
147pub type TaskError = Box<dyn std::error::Error + Send + Sync + 'static>;
149
150pub type ValueIter = Box<dyn Iterator<Item = Box<dyn Value>> + Send + 'static>;
152
153pub type ValueStream = BoxStream<'static, Box<dyn Value>>;
155
156pub type SyncFn = Arc<
175 dyn Fn(Arc<dyn Value>, Arc<TaskContext>) -> Result<Arc<dyn Value>, TaskError> + Send + Sync,
176>;
177
178pub type AsyncFn = Arc<
180 dyn Fn(
181 Arc<dyn Value>,
182 Arc<TaskContext>,
183 ) -> BoxFuture<'static, Result<Arc<dyn Value>, TaskError>>
184 + Send
185 + Sync,
186>;
187
188pub type SyncIterFn =
190 Arc<dyn Fn(Arc<dyn Value>, Arc<TaskContext>) -> Result<ValueIter, TaskError> + Send + Sync>;
191
192pub type AsyncStreamFn =
194 Arc<dyn Fn(Arc<dyn Value>, Arc<TaskContext>) -> Result<ValueStream, TaskError> + Send + Sync>;
195
196pub type SyncBatchFn = Arc<
198 dyn for<'a> Fn(&'a [Box<dyn Value>], Arc<TaskContext>) -> Result<Arc<dyn Value>, TaskError>
199 + Send
200 + Sync,
201>;
202
203pub type AsyncBatchFn = Arc<
205 dyn for<'a> Fn(
206 &'a [Box<dyn Value>],
207 Arc<TaskContext>,
208 ) -> BoxFuture<'static, Result<Arc<dyn Value>, TaskError>>
209 + Send
210 + Sync,
211>;
212
213pub type SyncIterBatchFn = Arc<
215 dyn for<'a> Fn(&'a [Box<dyn Value>], Arc<TaskContext>) -> Result<ValueIter, TaskError>
216 + Send
217 + Sync,
218>;
219
220pub type AsyncStreamBatchFn = Arc<
222 dyn for<'a> Fn(&'a [Box<dyn Value>], Arc<TaskContext>) -> Result<ValueStream, TaskError>
223 + Send
224 + Sync,
225>;
226pub enum Task {
244 Sync(SyncFn),
245 Async(AsyncFn),
246 SyncIter(SyncIterFn),
247 AsyncStream(AsyncStreamFn),
248 SyncBatch(SyncBatchFn),
249 AsyncBatch(AsyncBatchFn),
250 SyncIterBatch(SyncIterBatchFn),
251 AsyncStreamBatch(AsyncStreamBatchFn),
252}
253
254impl Task {
255 pub fn is_batch(&self) -> bool {
257 matches!(
258 self,
259 Task::SyncBatch(_)
260 | Task::AsyncBatch(_)
261 | Task::SyncIterBatch(_)
262 | Task::AsyncStreamBatch(_)
263 )
264 }
265
266 pub fn python_task_type(&self) -> &'static str {
284 match self {
285 Task::Sync(_) | Task::SyncBatch(_) => "Function",
286 Task::Async(_) | Task::AsyncBatch(_) => "Coroutine",
287 Task::SyncIter(_) | Task::SyncIterBatch(_) => "Generator",
288 Task::AsyncStream(_) | Task::AsyncStreamBatch(_) => "Async Generator",
289 }
290 }
291}
292
293impl Task {
294 pub fn sync<F>(f: F) -> Self
298 where
299 F: Fn(Arc<dyn Value>, Arc<TaskContext>) -> Result<Arc<dyn Value>, TaskError>
300 + Send
301 + Sync
302 + 'static,
303 {
304 Task::Sync(Arc::new(f))
305 }
306
307 pub fn async_fn<F>(f: F) -> Self
309 where
310 F: Fn(
311 Arc<dyn Value>,
312 Arc<TaskContext>,
313 ) -> BoxFuture<'static, Result<Arc<dyn Value>, TaskError>>
314 + Send
315 + Sync
316 + 'static,
317 {
318 Task::Async(Arc::new(f))
319 }
320
321 pub fn sync_iter<F>(f: F) -> Self
323 where
324 F: Fn(Arc<dyn Value>, Arc<TaskContext>) -> Result<ValueIter, TaskError>
325 + Send
326 + Sync
327 + 'static,
328 {
329 Task::SyncIter(Arc::new(f))
330 }
331
332 pub fn async_stream<F>(f: F) -> Self
335 where
336 F: Fn(Arc<dyn Value>, Arc<TaskContext>) -> Result<ValueStream, TaskError>
337 + Send
338 + Sync
339 + 'static,
340 {
341 Task::AsyncStream(Arc::new(f))
342 }
343
344 pub fn sync_batch<F>(f: F) -> Self
348 where
349 F: for<'a> Fn(&'a [Box<dyn Value>], Arc<TaskContext>) -> Result<Arc<dyn Value>, TaskError>
350 + Send
351 + Sync
352 + 'static,
353 {
354 Task::SyncBatch(Arc::new(f))
355 }
356
357 pub fn async_batch<F>(f: F) -> Self
359 where
360 F: for<'a> Fn(
361 &'a [Box<dyn Value>],
362 Arc<TaskContext>,
363 ) -> BoxFuture<'static, Result<Arc<dyn Value>, TaskError>>
364 + Send
365 + Sync
366 + 'static,
367 {
368 Task::AsyncBatch(Arc::new(f))
369 }
370
371 pub fn sync_iter_batch<F>(f: F) -> Self
373 where
374 F: for<'a> Fn(&'a [Box<dyn Value>], Arc<TaskContext>) -> Result<ValueIter, TaskError>
375 + Send
376 + Sync
377 + 'static,
378 {
379 Task::SyncIterBatch(Arc::new(f))
380 }
381
382 pub fn async_stream_batch<F>(f: F) -> Self
384 where
385 F: for<'a> Fn(&'a [Box<dyn Value>], Arc<TaskContext>) -> Result<ValueStream, TaskError>
386 + Send
387 + Sync
388 + 'static,
389 {
390 Task::AsyncStreamBatch(Arc::new(f))
391 }
392
393 pub fn sync_typed<I, O, F>(f: F) -> Self
418 where
419 I: Value,
420 O: Value,
421 F: Fn(&I, Arc<TaskContext>) -> Result<Box<O>, TaskError> + Send + Sync + 'static,
422 {
423 Task::Sync(Arc::new(move |input: Arc<dyn Value>, ctx| {
424 let typed = Self::borrow_input::<I>(&input);
425 f(typed, ctx).map(|v| Arc::new(*v) as Arc<dyn Value>)
426 }))
427 }
428
429 pub fn async_fn_typed<I, O, F>(f: F) -> Self
443 where
444 I: Value,
445 O: Value,
446 F: Fn(&I, Arc<TaskContext>) -> BoxFuture<'static, Result<Box<O>, TaskError>>
447 + Send
448 + Sync
449 + 'static,
450 {
451 Task::Async(Arc::new(move |input: Arc<dyn Value>, ctx| {
452 let typed = Self::borrow_input::<I>(&input);
453 let fut = f(typed, ctx);
456 Box::pin(async move { fut.await.map(|v| Arc::new(*v) as Arc<dyn Value>) })
457 }))
458 }
459
460 pub fn sync_iter_typed<I, O, F, Iter>(f: F) -> Self
470 where
471 I: Value,
472 O: Value,
473 F: Fn(&I, Arc<TaskContext>) -> Result<Iter, TaskError> + Send + Sync + 'static,
474 Iter: Iterator<Item = Box<O>> + Send + 'static,
475 {
476 Task::SyncIter(Arc::new(move |input: Arc<dyn Value>, ctx| {
477 let typed = Self::borrow_input::<I>(&input);
478 f(typed, ctx).map(|iter| Box::new(iter.map(|v| v as Box<dyn Value>)) as ValueIter)
479 }))
480 }
481
482 pub fn async_stream_typed<I, O, F, S>(f: F) -> Self
492 where
493 I: Value,
494 O: Value,
495 F: Fn(&I, Arc<TaskContext>) -> Result<S, TaskError> + Send + Sync + 'static,
496 S: Stream<Item = Box<O>> + Send + 'static,
497 {
498 Task::AsyncStream(Arc::new(move |input: Arc<dyn Value>, ctx| {
499 let typed = Self::borrow_input::<I>(&input);
500 f(typed, ctx).map(|s| Box::pin(s.map(|v| v as Box<dyn Value>)) as ValueStream)
501 }))
502 }
503
504 pub fn sync_batch_typed<I, O, F>(f: F) -> Self
519 where
520 I: Value,
521 O: Value,
522 F: for<'a> Fn(&'a [&'a I], Arc<TaskContext>) -> Result<Box<O>, TaskError>
523 + Send
524 + Sync
525 + 'static,
526 {
527 Task::SyncBatch(Arc::new(move |items: &[Box<dyn Value>], ctx| {
528 let typed: Vec<&I> = items.iter().map(|v| Self::borrow_item::<I>(v)).collect();
529 f(&typed, ctx).map(|v| Arc::new(*v) as Arc<dyn Value>)
530 }))
531 }
532
533 pub fn async_batch_typed<I, O, F>(f: F) -> Self
546 where
547 I: Value,
548 O: Value,
549 F: for<'a> Fn(
550 &'a [&'a I],
551 Arc<TaskContext>,
552 ) -> BoxFuture<'static, Result<Box<O>, TaskError>>
553 + Send
554 + Sync
555 + 'static,
556 {
557 Task::AsyncBatch(Arc::new(move |items: &[Box<dyn Value>], ctx| {
558 let typed: Vec<&I> = items.iter().map(|v| Self::borrow_item::<I>(v)).collect();
559 let fut = f(&typed, ctx);
560 Box::pin(async move { fut.await.map(|v| Arc::new(*v) as Arc<dyn Value>) })
561 }))
562 }
563
564 pub fn sync_iter_batch_typed<I, O, F, Iter>(f: F) -> Self
566 where
567 I: Value,
568 O: Value,
569 F: for<'a> Fn(&'a [&'a I], Arc<TaskContext>) -> Result<Iter, TaskError>
570 + Send
571 + Sync
572 + 'static,
573 Iter: Iterator<Item = Box<O>> + Send + 'static,
574 {
575 Task::SyncIterBatch(Arc::new(move |items: &[Box<dyn Value>], ctx| {
576 let typed: Vec<&I> = items.iter().map(|v| Self::borrow_item::<I>(v)).collect();
577 f(&typed, ctx).map(|iter| Box::new(iter.map(|v| v as Box<dyn Value>)) as ValueIter)
578 }))
579 }
580
581 pub fn async_stream_batch_typed<I, O, F, S>(f: F) -> Self
583 where
584 I: Value,
585 O: Value,
586 F: for<'a> Fn(&'a [&'a I], Arc<TaskContext>) -> Result<S, TaskError>
587 + Send
588 + Sync
589 + 'static,
590 S: Stream<Item = Box<O>> + Send + 'static,
591 {
592 Task::AsyncStreamBatch(Arc::new(move |items: &[Box<dyn Value>], ctx| {
593 let typed: Vec<&I> = items.iter().map(|v| Self::borrow_item::<I>(v)).collect();
594 f(&typed, ctx).map(|s| Box::pin(s.map(|v| v as Box<dyn Value>)) as ValueStream)
595 }))
596 }
597
598 fn borrow_input<I: Value>(input: &Arc<dyn Value>) -> &I {
605 let type_name = std::any::type_name::<I>();
606 (**input)
611 .as_any()
612 .downcast_ref::<I>()
613 .unwrap_or_else(|| panic!("Task input type mismatch: expected {type_name}"))
614 }
615
616 fn borrow_item<I: Value>(item: &dyn Value) -> &I {
620 let type_name = std::any::type_name::<I>();
621 item.as_any()
622 .downcast_ref::<I>()
623 .unwrap_or_else(|| panic!("Batch item type mismatch: expected {type_name}"))
624 }
625
626 pub fn call(&self, input: Arc<dyn Value>, ctx: Arc<TaskContext>) -> TaskCall {
632 match self {
633 Task::Sync(f) => TaskCall::Sync(f(input, ctx)),
634 Task::Async(f) => TaskCall::Async(f(input, ctx)),
635 Task::SyncIter(f) => TaskCall::SyncIter(f(input, ctx)),
636 Task::AsyncStream(f) => TaskCall::AsyncStream(f(input, ctx)),
637 Task::SyncBatch(_)
638 | Task::AsyncBatch(_)
639 | Task::SyncIterBatch(_)
640 | Task::AsyncStreamBatch(_) => {
641 panic!("call() used on a batch task variant — use call_batch() instead")
642 }
643 }
644 }
645
646 pub fn parallel(tasks: Vec<Task>) -> Self {
658 let tasks = Arc::new(tasks);
659 Task::Async(Arc::new(move |input, ctx| {
660 let tasks = Arc::clone(&tasks);
661 Box::pin(async move {
662 if tasks.is_empty() {
663 return Ok(input);
664 }
665
666 let futs: Vec<_> = tasks
667 .iter()
668 .map(|t| {
669 let call = t.call(Arc::clone(&input), Arc::clone(&ctx));
670 async move {
671 match call {
672 TaskCall::Sync(result) => result,
673 TaskCall::Async(fut) => fut.await,
674 TaskCall::SyncIter(_) | TaskCall::AsyncStream(_) => {
675 Err("iter/stream tasks are not supported inside Task::parallel"
676 .into())
677 }
678 }
679 }
680 })
681 .collect();
682
683 let results = futures::future::join_all(futs).await;
684
685 let mut last_ok: Option<Arc<dyn Value>> = None;
688 for r in results {
689 match r {
690 Err(e) => return Err(e),
691 Ok(v) => last_ok = Some(v),
692 }
693 }
694
695 Ok(last_ok.expect("non-empty tasks guaranteed above"))
696 })
697 }))
698 }
699
700 pub fn call_batch(&self, items: &[Box<dyn Value>], ctx: Arc<TaskContext>) -> TaskCall {
704 match self {
705 Task::SyncBatch(f) => TaskCall::Sync(f(items, ctx)),
706 Task::AsyncBatch(f) => TaskCall::Async(f(items, ctx)),
707 Task::SyncIterBatch(f) => TaskCall::SyncIter(f(items, ctx)),
708 Task::AsyncStreamBatch(f) => TaskCall::AsyncStream(f(items, ctx)),
709 Task::Sync(_) | Task::Async(_) | Task::SyncIter(_) | Task::AsyncStream(_) => {
710 panic!("call_batch() used on a single-value task variant — use call() instead")
711 }
712 }
713 }
714}
715pub struct TaskInfo {
726 pub task: Task,
727 pub name: Option<String>,
729 pub batch_size: Option<usize>,
732 pub summary_template: Option<String>,
738 pub weight: u32,
742 pub enriches: bool,
750 pub rate_limiter: Option<Arc<dyn RateLimiter>>,
754}
755
756impl TaskInfo {
757 pub fn new(task: Task) -> Self {
758 Self {
759 task,
760 name: None,
761 batch_size: None,
762 summary_template: None,
763 weight: 1,
764 enriches: false,
765 rate_limiter: None,
766 }
767 }
768
769 pub fn with_name(mut self, name: impl Into<String>) -> Self {
770 self.name = Some(name.into());
771 self
772 }
773
774 pub fn with_batch_size(mut self, size: usize) -> Self {
775 assert!(size > 0, "batch_size must be > 0");
776 self.batch_size = Some(size);
777 self
778 }
779
780 pub fn with_summary(mut self, template: impl Into<String>) -> Self {
785 self.summary_template = Some(template.into());
786 self
787 }
788
789 pub fn with_weight(mut self, weight: u32) -> Self {
790 self.weight = weight;
791 self
792 }
793
794 pub fn with_enriches(mut self) -> Self {
804 self.enriches = true;
805 self
806 }
807
808 pub fn with_rate_limiter(mut self, rl: Arc<dyn RateLimiter>) -> Self {
814 self.rate_limiter = Some(rl);
815 self
816 }
817
818 pub fn parallel(infos: Vec<TaskInfo>) -> Self {
823 let names: Vec<String> = infos
824 .iter()
825 .enumerate()
826 .map(|(i, ti)| ti.name.clone().unwrap_or_else(|| format!("task_{i}")))
827 .collect();
828
829 let tasks: Vec<Task> = infos.into_iter().map(|ti| ti.task).collect();
830
831 TaskInfo {
832 task: Task::parallel(tasks),
833 name: Some(format!("parallel([{}])", names.join(", "))),
834 batch_size: None,
835 summary_template: None,
836 weight: 1,
837 enriches: false,
838 rate_limiter: None,
839 }
840 }
841}
842
843impl From<Task> for TaskInfo {
844 fn from(task: Task) -> Self {
845 TaskInfo::new(task)
846 }
847}
848
849type TypedSyncFn<I, O> = dyn Fn(&I, Arc<TaskContext>) -> Result<Box<O>, TaskError> + Send + Sync;
851type TypedAsyncFn<I, O> =
853 dyn Fn(&I, Arc<TaskContext>) -> BoxFuture<'static, Result<Box<O>, TaskError>> + Send + Sync;
854type TypedSyncIterFn<I, O> = dyn Fn(&I, Arc<TaskContext>) -> Result<Box<dyn Iterator<Item = Box<O>> + Send + 'static>, TaskError>
856 + Send
857 + Sync;
858type TypedAsyncStreamFn<I, O> =
860 dyn Fn(&I, Arc<TaskContext>) -> Result<BoxStream<'static, Box<O>>, TaskError> + Send + Sync;
861type TypedSyncBatchFn<I, O> =
863 dyn for<'a> Fn(&'a [&'a I], Arc<TaskContext>) -> Result<Box<O>, TaskError> + Send + Sync;
864type TypedAsyncBatchFn<I, O> = dyn for<'a> Fn(&'a [&'a I], Arc<TaskContext>) -> BoxFuture<'static, Result<Box<O>, TaskError>>
866 + Send
867 + Sync;
868type TypedSyncIterBatchFn<I, O> = dyn for<'a> Fn(
870 &'a [&'a I],
871 Arc<TaskContext>,
872 ) -> Result<Box<dyn Iterator<Item = Box<O>> + Send + 'static>, TaskError>
873 + Send
874 + Sync;
875type TypedAsyncStreamBatchFn<I, O> = dyn for<'a> Fn(&'a [&'a I], Arc<TaskContext>) -> Result<BoxStream<'static, Box<O>>, TaskError>
877 + Send
878 + Sync;
879
880pub enum TypedTask<I: Value, O: Value> {
904 Sync(Arc<TypedSyncFn<I, O>>),
906 Async(Arc<TypedAsyncFn<I, O>>),
908 SyncIter(Arc<TypedSyncIterFn<I, O>>),
910 AsyncStream(Arc<TypedAsyncStreamFn<I, O>>),
912 SyncBatch(Arc<TypedSyncBatchFn<I, O>>),
914 AsyncBatch(Arc<TypedAsyncBatchFn<I, O>>),
916 SyncIterBatch(Arc<TypedSyncIterBatchFn<I, O>>),
918 AsyncStreamBatch(Arc<TypedAsyncStreamBatchFn<I, O>>),
920}
921
922impl<I: Value, O: Value> TypedTask<I, O> {
923 pub fn sync<F>(f: F) -> Self
925 where
926 F: Fn(&I, Arc<TaskContext>) -> Result<Box<O>, TaskError> + Send + Sync + 'static,
927 {
928 TypedTask::Sync(Arc::new(f))
929 }
930
931 pub fn async_fn<F>(f: F) -> Self
936 where
937 F: Fn(&I, Arc<TaskContext>) -> BoxFuture<'static, Result<Box<O>, TaskError>>
938 + Send
939 + Sync
940 + 'static,
941 {
942 TypedTask::Async(Arc::new(f))
943 }
944
945 pub fn sync_iter<F, Iter>(f: F) -> Self
949 where
950 F: Fn(&I, Arc<TaskContext>) -> Result<Iter, TaskError> + Send + Sync + 'static,
951 Iter: Iterator<Item = Box<O>> + Send + 'static,
952 {
953 TypedTask::SyncIter(Arc::new(move |i, ctx| {
954 f(i, ctx)
955 .map(|iter| Box::new(iter) as Box<dyn Iterator<Item = Box<O>> + Send + 'static>)
956 }))
957 }
958
959 pub fn async_stream<F, S>(f: F) -> Self
963 where
964 F: Fn(&I, Arc<TaskContext>) -> Result<S, TaskError> + Send + Sync + 'static,
965 S: Stream<Item = Box<O>> + Send + 'static,
966 {
967 TypedTask::AsyncStream(Arc::new(move |i, ctx| {
968 f(i, ctx).map(|s| Box::pin(s) as BoxStream<'static, Box<O>>)
969 }))
970 }
971
972 pub fn sync_batch<F>(f: F) -> Self
974 where
975 F: for<'a> Fn(&'a [&'a I], Arc<TaskContext>) -> Result<Box<O>, TaskError>
976 + Send
977 + Sync
978 + 'static,
979 {
980 TypedTask::SyncBatch(Arc::new(f))
981 }
982
983 pub fn async_batch<F>(f: F) -> Self
985 where
986 F: for<'a> Fn(
987 &'a [&'a I],
988 Arc<TaskContext>,
989 ) -> BoxFuture<'static, Result<Box<O>, TaskError>>
990 + Send
991 + Sync
992 + 'static,
993 {
994 TypedTask::AsyncBatch(Arc::new(f))
995 }
996
997 pub fn sync_iter_batch<F, Iter>(f: F) -> Self
999 where
1000 F: for<'a> Fn(&'a [&'a I], Arc<TaskContext>) -> Result<Iter, TaskError>
1001 + Send
1002 + Sync
1003 + 'static,
1004 Iter: Iterator<Item = Box<O>> + Send + 'static,
1005 {
1006 TypedTask::SyncIterBatch(Arc::new(move |items, ctx| {
1007 f(items, ctx)
1008 .map(|iter| Box::new(iter) as Box<dyn Iterator<Item = Box<O>> + Send + 'static>)
1009 }))
1010 }
1011
1012 pub fn async_stream_batch<F, S>(f: F) -> Self
1014 where
1015 F: for<'a> Fn(&'a [&'a I], Arc<TaskContext>) -> Result<S, TaskError>
1016 + Send
1017 + Sync
1018 + 'static,
1019 S: Stream<Item = Box<O>> + Send + 'static,
1020 {
1021 TypedTask::AsyncStreamBatch(Arc::new(move |items, ctx| {
1022 f(items, ctx).map(|s| Box::pin(s) as BoxStream<'static, Box<O>>)
1023 }))
1024 }
1025}
1026
1027impl<I: Value, O: Value> From<TypedTask<I, O>> for Task {
1028 fn from(typed: TypedTask<I, O>) -> Self {
1033 match typed {
1034 TypedTask::Sync(f) => Task::sync_typed(move |i: &I, ctx| f(i, ctx)),
1035 TypedTask::Async(f) => Task::async_fn_typed(move |i: &I, ctx| f(i, ctx)),
1036 TypedTask::SyncIter(f) => Task::sync_iter_typed(move |i: &I, ctx| f(i, ctx)),
1037 TypedTask::AsyncStream(f) => Task::async_stream_typed(move |i: &I, ctx| f(i, ctx)),
1038 TypedTask::SyncBatch(f) => {
1039 Task::sync_batch_typed(move |items: &[&I], ctx| f(items, ctx))
1040 }
1041 TypedTask::AsyncBatch(f) => {
1042 Task::async_batch_typed(move |items: &[&I], ctx| f(items, ctx))
1043 }
1044 TypedTask::SyncIterBatch(f) => {
1045 Task::sync_iter_batch_typed(move |items: &[&I], ctx| f(items, ctx))
1046 }
1047 TypedTask::AsyncStreamBatch(f) => {
1048 Task::async_stream_batch_typed(move |items: &[&I], ctx| f(items, ctx))
1049 }
1050 }
1051 }
1052}
1053
1054impl<I: Value, O: Value> From<TypedTask<I, O>> for TaskInfo {
1055 fn from(t: TypedTask<I, O>) -> TaskInfo {
1056 TaskInfo::new(Task::from(t))
1057 }
1058}
1059
1060pub enum TaskCall {
1062 Sync(Result<Arc<dyn Value>, TaskError>),
1064
1065 Async(BoxFuture<'static, Result<Arc<dyn Value>, TaskError>>),
1067
1068 SyncIter(Result<ValueIter, TaskError>),
1070
1071 AsyncStream(Result<ValueStream, TaskError>),
1073}
1074
1075#[cfg(test)]
1076#[allow(
1077 clippy::unwrap_used,
1078 clippy::expect_used,
1079 reason = "test code — panics are acceptable failures"
1080)]
1081mod tests {
1082 use super::*;
1083 use std::future::Future;
1084 use std::pin::Pin;
1085
1086 use crate::cancellation::cancellation_pair;
1087 use crate::exec_status::NoopExecStatusManager;
1088 use crate::progress::ProgressToken;
1089 use crate::task_context::TaskContext;
1090 use crate::thread_pool::CpuPool;
1091
1092 struct StubPool;
1095 impl CpuPool for StubPool {
1096 fn spawn_raw(
1097 &self,
1098 _task: Box<dyn FnOnce() + Send + 'static>,
1099 ) -> Pin<Box<dyn Future<Output = Result<(), crate::error::CoreError>> + Send + 'static>>
1100 {
1101 Box::pin(async { Ok(()) })
1102 }
1103 }
1104
1105 async fn stub_ctx() -> Arc<TaskContext> {
1106 let db = cognee_database::connect("sqlite::memory:").await.unwrap();
1107 cognee_database::initialize(&db).await.unwrap();
1108 let (_handle, token) = cancellation_pair();
1109 Arc::new(TaskContext {
1110 thread_pool: Arc::new(StubPool),
1111 database: Arc::new(db),
1112 graph_db: Arc::new(cognee_graph::MockGraphDB::new()),
1113 vector_db: Arc::new(cognee_vector::MockVectorDB::new()),
1114 cancellation: token,
1115 progress: ProgressToken::new(),
1116 pipeline_ctx: None,
1117 exec_status: Arc::new(NoopExecStatusManager),
1118 pipeline_watcher: None,
1119 })
1120 }
1121
1122 #[tokio::test]
1125 async fn parallel_runs_sync_tasks_concurrently() {
1126 let double = Task::sync_typed(|x: &i32, _ctx| Ok(Box::new(*x * 2)));
1128 let triple = Task::sync_typed(|x: &i32, _ctx| Ok(Box::new(*x * 3)));
1129
1130 let par = Task::parallel(vec![double, triple]);
1131 let input: Arc<dyn Value> = Arc::new(5_i32);
1132 let ctx = stub_ctx().await;
1133
1134 let call = par.call(input, ctx);
1135 let result = match call {
1136 TaskCall::Async(fut) => fut.await.unwrap(),
1137 _ => panic!("parallel should produce Async variant"),
1138 };
1139
1140 assert_eq!(*(*result).as_any().downcast_ref::<i32>().unwrap(), 15);
1142 }
1143
1144 #[tokio::test]
1145 async fn parallel_runs_async_tasks() {
1146 let add_ten = Task::async_fn_typed(|x: &i32, _ctx| {
1147 let v = *x + 10;
1148 Box::pin(async move { Ok(Box::new(v)) })
1149 });
1150 let add_twenty = Task::async_fn_typed(|x: &i32, _ctx| {
1151 let v = *x + 20;
1152 Box::pin(async move { Ok(Box::new(v)) })
1153 });
1154
1155 let par = Task::parallel(vec![add_ten, add_twenty]);
1156 let input: Arc<dyn Value> = Arc::new(100_i32);
1157 let ctx = stub_ctx().await;
1158
1159 let result = match par.call(input, ctx) {
1160 TaskCall::Async(fut) => fut.await.unwrap(),
1161 _ => panic!("expected Async"),
1162 };
1163
1164 assert_eq!(*(*result).as_any().downcast_ref::<i32>().unwrap(), 120);
1166 }
1167
1168 #[tokio::test]
1169 async fn parallel_propagates_first_error() {
1170 let ok_task = Task::sync_typed(|x: &i32, _ctx| Ok(Box::new(*x)));
1171 let err_task = Task::Sync(Arc::new(|_input, _ctx| Err("boom".into())));
1172
1173 let par = Task::parallel(vec![ok_task, err_task]);
1174 let input: Arc<dyn Value> = Arc::new(42_i32);
1175 let ctx = stub_ctx().await;
1176
1177 let result = match par.call(input, ctx) {
1178 TaskCall::Async(fut) => fut.await,
1179 _ => panic!("expected Async"),
1180 };
1181
1182 let err = result.err().expect("should be an error");
1183 assert!(err.to_string().contains("boom"));
1184 }
1185
1186 #[tokio::test]
1187 async fn parallel_empty_returns_input() {
1188 let par = Task::parallel(vec![]);
1189 let input: Arc<dyn Value> = Arc::new(99_i32);
1190 let ctx = stub_ctx().await;
1191
1192 let result = match par.call(Arc::clone(&input), ctx) {
1193 TaskCall::Async(fut) => fut.await.unwrap(),
1194 _ => panic!("expected Async"),
1195 };
1196
1197 assert_eq!(*(*result).as_any().downcast_ref::<i32>().unwrap(), 99);
1198 }
1199
1200 #[tokio::test]
1201 async fn test_typed_task_panics_on_type_mismatch() {
1202 use std::panic::{AssertUnwindSafe, catch_unwind};
1203
1204 let task = Task::sync_typed(|_x: &String, _ctx| Ok(Box::new("ok".to_string())));
1205 let input: Arc<dyn Value> = Arc::new(42_i32); let ctx = stub_ctx().await;
1207
1208 let result = catch_unwind(AssertUnwindSafe(|| task.call(input, ctx)));
1209
1210 let err = match result {
1211 Err(e) => e,
1212 Ok(_) => panic!("should have panicked on type mismatch"),
1213 };
1214 let msg = err
1215 .downcast_ref::<String>()
1216 .map(|s| s.as_str())
1217 .or_else(|| err.downcast_ref::<&str>().copied())
1218 .expect("panic payload should be a string");
1219 assert!(
1220 msg.contains("type mismatch"),
1221 "expected 'type mismatch' in panic message, got: {msg}"
1222 );
1223 }
1224
1225 #[test]
1226 fn test_taskinfo_weight_default() {
1227 let info = TaskInfo::new(Task::sync_typed(|_: &i32, _| Ok(Box::new(0_i32))));
1228 assert_eq!(info.weight, 1);
1229 }
1230
1231 #[test]
1232 fn test_taskinfo_with_weight() {
1233 let info = TaskInfo::new(Task::sync_typed(|_: &i32, _| Ok(Box::new(0_i32)))).with_weight(5);
1234 assert_eq!(info.weight, 5);
1235 }
1236
1237 #[test]
1238 fn task_info_parallel_generates_name() {
1239 let t1 =
1240 TaskInfo::new(Task::sync_typed(|_: &i32, _| Ok(Box::new(0_i32)))).with_name("classify");
1241 let t2 =
1242 TaskInfo::new(Task::sync_typed(|_: &i32, _| Ok(Box::new(0_i32)))).with_name("embed");
1243 let t3 = TaskInfo::new(Task::sync_typed(|_: &i32, _| Ok(Box::new(0_i32))));
1244
1245 let par = TaskInfo::parallel(vec![t1, t2, t3]);
1246 assert_eq!(
1247 par.name.as_deref(),
1248 Some("parallel([classify, embed, task_2])")
1249 );
1250 }
1251
1252 mod python_task_type {
1256 use super::*;
1257 use futures::stream;
1258
1259 #[test]
1260 fn sync_variant_maps_to_function() {
1261 let t = Task::sync_typed(|_: &i32, _| Ok(Box::new(0_i32)));
1262 assert_eq!(t.python_task_type(), "Function");
1263 }
1264
1265 #[test]
1266 fn sync_batch_variant_maps_to_function() {
1267 let t = Task::sync_batch_typed(|_: &[&i32], _| Ok(Box::new(0_i32)));
1268 assert_eq!(t.python_task_type(), "Function");
1269 }
1270
1271 #[test]
1272 fn async_variant_maps_to_coroutine() {
1273 let t = Task::async_fn_typed(|_: &i32, _| Box::pin(async move { Ok(Box::new(0_i32)) }));
1274 assert_eq!(t.python_task_type(), "Coroutine");
1275 }
1276
1277 #[test]
1278 fn async_batch_variant_maps_to_coroutine() {
1279 let t = Task::async_batch_typed(|_: &[&i32], _| {
1280 Box::pin(async move { Ok(Box::new(0_i32)) })
1281 });
1282 assert_eq!(t.python_task_type(), "Coroutine");
1283 }
1284
1285 #[test]
1286 fn sync_iter_variant_maps_to_generator() {
1287 let t = Task::sync_iter_typed(|_: &i32, _| Ok(std::iter::empty::<Box<i32>>()));
1288 assert_eq!(t.python_task_type(), "Generator");
1289 }
1290
1291 #[test]
1292 fn sync_iter_batch_variant_maps_to_generator() {
1293 let t = Task::sync_iter_batch_typed(|_: &[&i32], _| Ok(std::iter::empty::<Box<i32>>()));
1294 assert_eq!(t.python_task_type(), "Generator");
1295 }
1296
1297 #[test]
1298 fn async_stream_variant_maps_to_async_generator() {
1299 let t = Task::async_stream_typed(|_: &i32, _| Ok(stream::empty::<Box<i32>>()));
1300 assert_eq!(t.python_task_type(), "Async Generator");
1301 }
1302
1303 #[test]
1304 fn async_stream_batch_variant_maps_to_async_generator() {
1305 let t = Task::async_stream_batch_typed(|_: &[&i32], _| Ok(stream::empty::<Box<i32>>()));
1306 assert_eq!(t.python_task_type(), "Async Generator");
1307 }
1308
1309 #[test]
1310 fn covers_all_eight_variants_with_four_distinct_labels() {
1311 let labels: std::collections::HashSet<&'static str> = [
1312 Task::sync_typed(|_: &i32, _| Ok(Box::new(0_i32))).python_task_type(),
1313 Task::sync_batch_typed(|_: &[&i32], _| Ok(Box::new(0_i32))).python_task_type(),
1314 Task::async_fn_typed(|_: &i32, _| Box::pin(async move { Ok(Box::new(0_i32)) }))
1315 .python_task_type(),
1316 Task::async_batch_typed(|_: &[&i32], _| {
1317 Box::pin(async move { Ok(Box::new(0_i32)) })
1318 })
1319 .python_task_type(),
1320 Task::sync_iter_typed(|_: &i32, _| Ok(std::iter::empty::<Box<i32>>()))
1321 .python_task_type(),
1322 Task::sync_iter_batch_typed(|_: &[&i32], _| Ok(std::iter::empty::<Box<i32>>()))
1323 .python_task_type(),
1324 Task::async_stream_typed(|_: &i32, _| Ok(stream::empty::<Box<i32>>()))
1325 .python_task_type(),
1326 Task::async_stream_batch_typed(|_: &[&i32], _| Ok(stream::empty::<Box<i32>>()))
1327 .python_task_type(),
1328 ]
1329 .into_iter()
1330 .collect();
1331
1332 assert_eq!(
1333 labels.len(),
1334 4,
1335 "expected exactly 4 distinct Python task-type labels, got {labels:?}"
1336 );
1337 assert!(labels.contains("Function"));
1338 assert!(labels.contains("Coroutine"));
1339 assert!(labels.contains("Generator"));
1340 assert!(labels.contains("Async Generator"));
1341 }
1342 }
1343}