1#![allow(
5 clippy::unwrap_used,
6 clippy::expect_used,
7 reason = "lock poisoning is unrecoverable; expect() calls guard construction-time invariants"
8)]
9
10use std::marker::PhantomData;
11use std::mem;
12use std::sync::Arc;
13use std::time::Duration;
14
15use async_trait::async_trait;
16use futures::StreamExt;
17use futures::future::BoxFuture;
18use thiserror::Error;
19use tokio::time::sleep;
20use uuid::Uuid;
21
22use crate::progress::ProgressToken;
23use crate::rate_limiter::RateLimiter;
24use crate::task::{
25 TaggedMeta, Task, TaskCall, TaskError, TaskInfo, TypedTask, Value, ValueIter, ValueStream,
26};
27use crate::task_context::TaskContext;
28
29#[derive(Debug, Clone)]
30pub enum RetryPolicy {
31 NoRetry,
33 Limited {
36 max_attempts: std::num::NonZeroU32,
37 delay: RetryDelay,
38 },
39}
40
41#[derive(Debug, Clone)]
43pub enum RetryDelay {
44 Constant(Duration),
46 Exponential { base: Duration, factor: u32 },
49}
50
51impl RetryDelay {
52 pub fn exponential(base: Duration) -> Self {
54 RetryDelay::Exponential { base, factor: 2 }
55 }
56}
57
58impl RetryPolicy {
59 fn max_attempts(&self) -> u32 {
60 match self {
61 RetryPolicy::NoRetry => 1,
62 RetryPolicy::Limited { max_attempts, .. } => max_attempts.get(),
63 }
64 }
65
66 fn delay(&self, retry_index: u32) -> Option<Duration> {
69 match self {
70 RetryPolicy::NoRetry => None,
71 RetryPolicy::Limited { delay, .. } => Some(delay.compute(retry_index)),
72 }
73 }
74}
75
76impl RetryDelay {
77 fn compute(&self, retry_index: u32) -> Duration {
78 match self {
79 RetryDelay::Constant(d) => *d,
80 RetryDelay::Exponential { base, factor } => {
81 let multiplier = factor.checked_pow(retry_index).unwrap_or(u32::MAX);
82 *base * multiplier
83 }
84 }
85 }
86}
87pub type DataIdFn = Arc<dyn Fn(Arc<dyn Value>) -> Option<String> + Send + Sync>;
101pub struct Pipeline {
102 pub id: Uuid,
103 pub name: Option<String>,
105 pub description: String,
106 pub tasks: Vec<TaskInfo>,
107 pub retry_policy: RetryPolicy,
108 pub batch_size: usize,
113 pub data_id_fn: Option<DataIdFn>,
117 pub concurrency: usize,
121 pub telemetry_settings: Option<serde_json::Map<String, serde_json::Value>>,
130 pub rate_limiter: Option<Arc<dyn RateLimiter>>,
135}
136
137impl Pipeline {
138 pub fn new(description: impl Into<String>) -> Self {
139 Self {
140 id: Uuid::new_v4(),
141 name: None,
142 description: description.into(),
143 tasks: Vec::new(),
144 retry_policy: RetryPolicy::NoRetry,
145 batch_size: 32,
146 data_id_fn: None,
147 concurrency: 1,
148 telemetry_settings: None,
149 rate_limiter: None,
150 }
151 }
152
153 pub fn with_name(mut self, name: impl Into<String>) -> Self {
154 self.name = Some(name.into());
155 self
156 }
157
158 pub fn with_task(mut self, task: impl Into<TaskInfo>) -> Self {
159 self.tasks.push(task.into());
160 self
161 }
162
163 pub fn with_retry(mut self, policy: RetryPolicy) -> Self {
164 self.retry_policy = policy;
165 self
166 }
167
168 pub fn with_batch_size(mut self, size: usize) -> Self {
169 assert!(size > 0, "batch_size must be > 0");
170 self.batch_size = size;
171 self
172 }
173
174 pub fn with_data_id(mut self, f: DataIdFn) -> Self {
177 self.data_id_fn = Some(f);
178 self
179 }
180
181 pub fn with_concurrency(mut self, n: usize) -> Self {
187 assert!(n > 0, "concurrency must be > 0");
188 self.concurrency = n;
189 self
190 }
191
192 pub fn with_rate_limiter(mut self, rl: Arc<dyn RateLimiter>) -> Self {
204 self.rate_limiter = Some(rl);
205 self
206 }
207
208 pub fn with_telemetry_settings(
212 mut self,
213 settings: serde_json::Map<String, serde_json::Value>,
214 ) -> Self {
215 self.telemetry_settings = Some(settings);
216 self
217 }
218}
219
220pub struct PipelineBuilder<I: Value, O: Value> {
242 description: String,
243 name: Option<String>,
244 tasks: Vec<TaskInfo>,
245 retry_policy: RetryPolicy,
246 batch_size: usize,
247 data_id_fn: Option<DataIdFn>,
248 concurrency: usize,
249 _marker: PhantomData<fn(I) -> O>,
250}
251
252impl<I: Value, O: Value> PipelineBuilder<I, O> {
253 pub fn new_with_task(
257 description: impl Into<String>,
258 first_task: TypedTask<I, O>,
259 ) -> PipelineBuilder<I, O> {
260 PipelineBuilder {
261 description: description.into(),
262 name: None,
263 tasks: vec![first_task.into()],
264 retry_policy: RetryPolicy::NoRetry,
265 batch_size: 32,
266 data_id_fn: None,
267 concurrency: 1,
268 _marker: PhantomData,
269 }
270 }
271
272 pub fn add_task<O2: Value>(mut self, task: TypedTask<O, O2>) -> PipelineBuilder<I, O2> {
278 self.tasks.push(task.into());
279 PipelineBuilder {
280 description: self.description,
281 name: self.name,
282 tasks: self.tasks,
283 retry_policy: self.retry_policy,
284 batch_size: self.batch_size,
285 data_id_fn: self.data_id_fn,
286 concurrency: self.concurrency,
287 _marker: PhantomData,
288 }
289 }
290
291 pub fn add_task_named<O2: Value>(
298 mut self,
299 task: TypedTask<O, O2>,
300 name: impl Into<String>,
301 ) -> PipelineBuilder<I, O2> {
302 self.tasks.push(TaskInfo::from(task).with_name(name));
303 PipelineBuilder {
304 description: self.description,
305 name: self.name,
306 tasks: self.tasks,
307 retry_policy: self.retry_policy,
308 batch_size: self.batch_size,
309 data_id_fn: self.data_id_fn,
310 concurrency: self.concurrency,
311 _marker: PhantomData,
312 }
313 }
314
315 pub fn with_first_task_name(mut self, name: impl Into<String>) -> Self {
327 let first = self
328 .tasks
329 .first_mut()
330 .expect("PipelineBuilder always has at least the seed task from new_with_task");
331 first.name = Some(name.into());
332 self
333 }
334
335 pub fn with_name(mut self, name: impl Into<String>) -> Self {
337 self.name = Some(name.into());
338 self
339 }
340
341 pub fn with_retry(mut self, policy: RetryPolicy) -> Self {
343 self.retry_policy = policy;
344 self
345 }
346
347 pub fn with_batch_size(mut self, size: usize) -> Self {
349 assert!(size > 0, "batch_size must be > 0");
350 self.batch_size = size;
351 self
352 }
353
354 pub fn with_concurrency(mut self, n: usize) -> Self {
356 assert!(n > 0, "concurrency must be > 0");
357 self.concurrency = n;
358 self
359 }
360
361 pub fn with_data_id(mut self, f: DataIdFn) -> Self {
363 self.data_id_fn = Some(f);
364 self
365 }
366
367 pub fn build(self) -> Pipeline {
372 Pipeline {
373 id: Uuid::new_v4(),
374 name: self.name,
375 description: self.description,
376 tasks: self.tasks,
377 retry_policy: self.retry_policy,
378 batch_size: self.batch_size,
379 data_id_fn: self.data_id_fn,
380 concurrency: self.concurrency,
381 telemetry_settings: None,
382 rate_limiter: None,
383 }
384 }
385}
386
387#[derive(Debug, Clone)]
390pub struct PipelineRunInfo {
391 pub run_id: Uuid,
393 pub pipeline_id: Uuid,
396 pub pipeline_name: String,
398 pub user_id: Option<Uuid>,
400 pub tenant_id: Option<Uuid>,
404 pub dataset_id: Option<Uuid>,
406 pub data_ids: Vec<Uuid>,
410 pub status: PipelineRunStatus,
412 pub started_at: chrono::DateTime<chrono::Utc>,
414 pub completed_at: Option<chrono::DateTime<chrono::Utc>>,
417}
418
419impl PipelineRunInfo {
420 pub fn elapsed_seconds(&self) -> Option<f64> {
424 let end = self.completed_at?;
425 let dur_ms = (end - self.started_at).num_milliseconds();
426 Some(dur_ms as f64 / 1000.0)
427 }
428}
429
430#[derive(Debug, Clone, PartialEq, Eq)]
432pub enum PipelineRunStatus {
433 Initiated,
434 Started,
435 Completed,
436 Errored,
437}
438
439impl std::fmt::Display for PipelineRunStatus {
440 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
441 match self {
442 Self::Initiated => write!(f, "INITIATED"),
443 Self::Started => write!(f, "STARTED"),
444 Self::Completed => write!(f, "COMPLETED"),
445 Self::Errored => write!(f, "ERRORED"),
446 }
447 }
448}
449
450fn deterministic_pipeline_id(
462 name: Option<&str>,
463 user_id: Option<Uuid>,
464 dataset_id: Option<Uuid>,
465) -> Option<Uuid> {
466 let name = name.filter(|n| !n.is_empty())?;
467 let key = format!(
468 "{}{}{}",
469 user_id.map(|u| u.to_string()).unwrap_or_default(),
470 name,
471 dataset_id.map(|d| d.to_string()).unwrap_or_default(),
472 );
473 Some(Uuid::new_v5(&Uuid::NAMESPACE_OID, key.as_bytes()))
474}
475#[derive(Debug)]
476pub enum TaskStatus {
477 Started,
478 Retrying { attempt: u32, error: String },
479 Succeeded,
480 Failed { attempts: u32, error: String },
481}
482
483#[derive(Debug)]
484pub enum PipelineStatus {
485 Started {
486 task_count: usize,
487 },
488 Succeeded {
489 output_count: usize,
490 },
491 Failed {
492 task_index: usize,
493 error: String,
494 },
495 Cancelled,
496 ItemSkipped {
499 data_id: String,
500 },
501}
502
503#[async_trait]
514pub trait PipelineWatcher: Send + Sync {
515 async fn on_pipeline(&self, pipeline_id: Uuid, status: PipelineStatus);
516 async fn on_task(
517 &self,
518 pipeline_id: Uuid,
519 task_index: usize,
520 task_name: Option<&str>,
521 total_tasks: usize,
522 status: TaskStatus,
523 );
524
525 async fn on_pipeline_run_initiated(&self, _run: &PipelineRunInfo) {}
536
537 async fn on_pipeline_run_started(&self, _run: &PipelineRunInfo) {}
539
540 async fn on_pipeline_run_completed(&self, _run: &PipelineRunInfo, _output_count: usize) {}
542
543 async fn on_pipeline_run_errored(&self, _run: &PipelineRunInfo, _error: &str) {}
545
546 async fn on_task_started(&self, _run: &PipelineRunInfo, _task_name: &str, _task_index: usize) {}
548
549 async fn on_task_completed(
551 &self,
552 _run: &PipelineRunInfo,
553 _task_name: &str,
554 _result_count: usize,
555 ) {
556 }
557
558 async fn on_task_errored(&self, _run: &PipelineRunInfo, _task_name: &str, _error: &str) {}
560
561 async fn on_payload_field(&self, _run_id: Uuid, _key: &str, _value: serde_json::Value) {}
570}
571
572pub struct NoopWatcher;
573
574#[async_trait]
575impl PipelineWatcher for NoopWatcher {
576 async fn on_pipeline(&self, _: Uuid, _: PipelineStatus) {}
577 async fn on_task(&self, _: Uuid, _: usize, _: Option<&str>, _: usize, _: TaskStatus) {}
578}
579#[derive(Debug, Error)]
580pub enum ExecutionError {
581 #[error("task {task_index} failed after {attempts} attempt(s): {source}")]
582 TaskFailed {
583 task_index: usize,
584 attempts: u32,
585 #[source]
586 source: TaskError,
587 },
588
589 #[error("pipeline was cancelled")]
590 Cancelled,
591
592 #[error("pipeline has no tasks")]
593 NoTasks,
594
595 #[error("invalid pipeline configuration: {reason}")]
596 InvalidConfig { reason: String },
597}
598#[cfg(feature = "telemetry")]
612fn emit_pipeline_event(
613 event_name: &str,
614 user_id: Option<Uuid>,
615 pipeline_name: &str,
616 tenant_id: Option<Uuid>,
617 settings: Option<&serde_json::Map<String, serde_json::Value>>,
618) {
619 use serde_json::{Map, Value};
620
621 let mut props: Map<String, Value> = settings.cloned().unwrap_or_default();
622 props.insert(
623 "pipeline_name".into(),
624 Value::String(pipeline_name.to_string()),
625 );
626 props.insert(
627 "cognee_version".into(),
628 Value::String(cognee_telemetry::cognee_version().to_string()),
629 );
630 props.insert(
631 "tenant_id".into(),
632 Value::String(cognee_telemetry::tenant_id_for_telemetry(tenant_id)),
633 );
634
635 cognee_telemetry::send_telemetry(event_name, user_id, Some(Value::Object(props)));
636}
637
638#[cfg(not(feature = "telemetry"))]
641#[inline]
642fn emit_pipeline_event(
643 _event_name: &str,
644 _user_id: Option<Uuid>,
645 _pipeline_name: &str,
646 _tenant_id: Option<Uuid>,
647 _settings: Option<&serde_json::Map<String, serde_json::Value>>,
648) {
649}
650
651#[cfg(feature = "telemetry")]
669fn emit_task_event(
670 stage: &'static str,
671 task: &Task,
672 task_name: Option<&str>,
673 user_id: Option<Uuid>,
674 tenant_id: Option<Uuid>,
675) {
676 let event_name = format!("{} Task {}", task.python_task_type(), stage);
677 let props = serde_json::json!({
678 "task_name": task_name.unwrap_or("unknown"),
679 "cognee_version": cognee_telemetry::cognee_version(),
680 "tenant_id": cognee_telemetry::tenant_id_for_telemetry(tenant_id),
681 });
682 cognee_telemetry::send_telemetry(&event_name, user_id, Some(props));
683}
684
685#[cfg(not(feature = "telemetry"))]
687#[inline]
688fn emit_task_event(
689 _stage: &'static str,
690 _task: &Task,
691 _task_name: Option<&str>,
692 _user_id: Option<Uuid>,
693 _tenant_id: Option<Uuid>,
694) {
695}
696
697pub async fn execute(
710 pipeline: &Pipeline,
711 inputs: Vec<Arc<dyn Value>>,
712 ctx: Arc<TaskContext>,
713 watcher: &dyn PipelineWatcher,
714) -> Result<Vec<Arc<dyn Value>>, ExecutionError> {
715 if pipeline.tasks.is_empty() {
716 return Err(ExecutionError::NoTasks);
717 }
718 if pipeline.batch_size == 0 {
719 return Err(ExecutionError::InvalidConfig {
720 reason: "batch_size must be > 0".into(),
721 });
722 }
723 if pipeline.concurrency == 0 {
724 return Err(ExecutionError::InvalidConfig {
725 reason: "concurrency must be > 0".into(),
726 });
727 }
728
729 let run_id = Uuid::new_v4();
730 let task_count = pipeline.tasks.len();
731
732 let user_id = ctx.pipeline_ctx.as_ref().and_then(|p| p.user_id);
733 let tenant_id = ctx.pipeline_ctx.as_ref().and_then(|p| p.tenant_id);
734 let dataset_id = ctx.pipeline_ctx.as_ref().and_then(|p| p.dataset_id);
735 let pipeline_id = deterministic_pipeline_id(pipeline.name.as_deref(), user_id, dataset_id)
736 .unwrap_or(pipeline.id);
737
738 let data_ids: Vec<Uuid> = if let Some(id_fn) = pipeline.data_id_fn.as_ref() {
748 inputs
749 .iter()
750 .filter_map(|x| id_fn(Arc::clone(x)))
751 .filter_map(|s| Uuid::parse_str(&s).ok())
752 .collect()
753 } else {
754 Vec::new()
755 };
756
757 let mut run_info = PipelineRunInfo {
758 run_id,
759 pipeline_id,
760 pipeline_name: pipeline.name.clone().unwrap_or_default(),
761 user_id,
762 tenant_id,
763 dataset_id,
764 data_ids,
765 status: PipelineRunStatus::Initiated,
766 started_at: chrono::Utc::now(),
767 completed_at: None,
768 };
769
770 let ctx = ctx.with_run_id(run_id);
773
774 if let Some(pctx) = ctx.pipeline_ctx.as_ref() {
779 pctx.provenance_visited.lock().unwrap().clear();
781 }
782
783 watcher.on_pipeline_run_initiated(&run_info).await;
790
791 run_info.status = PipelineRunStatus::Started;
793 watcher
794 .on_pipeline(pipeline_id, PipelineStatus::Started { task_count })
795 .await;
796 watcher.on_pipeline_run_started(&run_info).await;
797
798 emit_pipeline_event(
800 "Pipeline Run Started",
801 user_id,
802 &run_info.pipeline_name,
803 tenant_id,
804 pipeline.telemetry_settings.as_ref(),
805 );
806
807 let weights: Vec<u32> = pipeline.tasks.iter().map(|t| t.weight).collect();
808 let task_subtokens =
809 ctx.progress
810 .split(&weights)
811 .map_err(|e| ExecutionError::InvalidConfig {
812 reason: e.to_string(),
813 })?;
814
815 let env = ExecEnv {
816 policy: &pipeline.retry_policy,
817 default_batch_size: pipeline.batch_size,
818 pipeline_id,
819 pipeline_name: pipeline.name.as_deref(),
820 total_tasks: task_count,
821 ctx: &ctx,
822 watcher,
823 data_id_fn: &pipeline.data_id_fn,
824 run_info: &run_info,
825 task_subtokens: &task_subtokens,
826 rate_limiter: pipeline.rate_limiter.as_ref(),
827 };
828
829 let result = if pipeline.concurrency <= 1 {
830 execute_items_seq(inputs, pipeline, &ctx, &env).await
831 } else {
832 execute_items_par(inputs, pipeline, &ctx, &env).await
833 };
834
835 match &result {
836 Ok(outputs) => {
837 run_info.status = PipelineRunStatus::Completed;
838 run_info.completed_at = Some(chrono::Utc::now());
839 watcher
840 .on_pipeline(
841 pipeline_id,
842 PipelineStatus::Succeeded {
843 output_count: outputs.len(),
844 },
845 )
846 .await;
847 watcher
848 .on_pipeline_run_completed(&run_info, outputs.len())
849 .await;
850
851 emit_pipeline_event(
853 "Pipeline Run Completed",
854 user_id,
855 &run_info.pipeline_name,
856 tenant_id,
857 pipeline.telemetry_settings.as_ref(),
858 );
859 }
860 Err(ExecutionError::Cancelled) => {
861 run_info.status = PipelineRunStatus::Errored;
862 run_info.completed_at = Some(chrono::Utc::now());
863 watcher
864 .on_pipeline(pipeline_id, PipelineStatus::Cancelled)
865 .await;
866 watcher
867 .on_pipeline_run_errored(&run_info, "pipeline was cancelled")
868 .await;
869
870 emit_pipeline_event(
873 "Pipeline Run Errored",
874 user_id,
875 &run_info.pipeline_name,
876 tenant_id,
877 pipeline.telemetry_settings.as_ref(),
878 );
879 }
880 Err(e) => {
881 run_info.status = PipelineRunStatus::Errored;
882 run_info.completed_at = Some(chrono::Utc::now());
883 let task_index = match e {
884 ExecutionError::TaskFailed { task_index, .. } => *task_index,
885 _ => 0,
886 };
887 watcher
888 .on_pipeline(
889 pipeline_id,
890 PipelineStatus::Failed {
891 task_index,
892 error: e.to_string(),
893 },
894 )
895 .await;
896 watcher
897 .on_pipeline_run_errored(&run_info, &e.to_string())
898 .await;
899
900 emit_pipeline_event(
903 "Pipeline Run Errored",
904 user_id,
905 &run_info.pipeline_name,
906 tenant_id,
907 pipeline.telemetry_settings.as_ref(),
908 );
909 }
910 }
911
912 result
913}
914async fn execute_one_item<'a>(
917 input: Arc<dyn Value>,
918 pipeline: &'a Pipeline,
919 ctx: &'a Arc<TaskContext>,
920 env: &'a ExecEnv<'a>,
921) -> Result<Vec<Arc<dyn Value>>, ExecutionError> {
922 let data_id = pipeline
923 .data_id_fn
924 .as_ref()
925 .and_then(|f| f(Arc::clone(&input)));
926
927 let result = execute_from(&pipeline.tasks, input, 0, env).await;
928
929 if let Some(data_id) = &data_id {
931 let pipeline_name = pipeline.name.as_deref().unwrap_or("");
932 let dataset_id = ctx.pipeline_ctx.as_ref().and_then(|p| p.dataset_id);
933 match &result {
934 Ok(_) => {
935 let _ = ctx
936 .exec_status
937 .mark_completed(data_id, pipeline_name, dataset_id)
938 .await;
939 }
940 Err(ExecutionError::TaskFailed { source, .. }) => {
941 let _ = ctx
942 .exec_status
943 .mark_failed(data_id, pipeline_name, dataset_id, &source.to_string())
944 .await;
945 }
946 Err(_) => {}
947 }
948 }
949
950 result
951}
952
953async fn execute_items_seq<'a>(
955 inputs: Vec<Arc<dyn Value>>,
956 pipeline: &'a Pipeline,
957 ctx: &'a Arc<TaskContext>,
958 env: &'a ExecEnv<'a>,
959) -> Result<Vec<Arc<dyn Value>>, ExecutionError> {
960 let mut all_outputs = Vec::new();
961 for input in inputs {
962 all_outputs.append(&mut execute_one_item(input, pipeline, ctx, env).await?);
963 }
964 Ok(all_outputs)
965}
966
967async fn execute_items_par<'a>(
974 inputs: Vec<Arc<dyn Value>>,
975 pipeline: &'a Pipeline,
976 ctx: &'a Arc<TaskContext>,
977 env: &'a ExecEnv<'a>,
978) -> Result<Vec<Arc<dyn Value>>, ExecutionError> {
979 let mut all_outputs = Vec::new();
980 for chunk in inputs.chunks(pipeline.concurrency) {
981 let futures: Vec<_> = chunk
982 .iter()
983 .map(|input| execute_one_item(Arc::clone(input), pipeline, ctx, env))
984 .collect();
985 let results = futures::future::join_all(futures).await;
986 for result in results {
987 all_outputs.append(&mut result?);
988 }
989 }
990 Ok(all_outputs)
991}
992enum Resolved {
994 Single(Arc<dyn Value>),
995 Iter(ValueIter),
996 Stream(ValueStream),
997}
998
999#[derive(Clone)]
1007struct ProvenanceInputs<'a> {
1008 pipeline_name: &'a str,
1009 task_name: &'a str,
1010 user_label: Option<String>,
1011 input_node_set: Option<String>,
1012 input_content_hash: Option<String>,
1013}
1014
1015impl<'a> ProvenanceInputs<'a> {
1016 fn ctx(&'a self) -> crate::provenance::ProvenanceContext<'a> {
1017 crate::provenance::ProvenanceContext {
1018 pipeline_name: self.pipeline_name,
1019 task_name: self.task_name,
1020 user_label: self.user_label.as_deref(),
1021 node_set: self.input_node_set.as_deref(),
1022 content_hash: self.input_content_hash.as_deref(),
1023 }
1024 }
1025}
1026struct ExecEnv<'a> {
1029 policy: &'a RetryPolicy,
1030 default_batch_size: usize,
1032 pipeline_id: Uuid,
1033 pipeline_name: Option<&'a str>,
1034 total_tasks: usize,
1035 ctx: &'a Arc<TaskContext>,
1036 watcher: &'a dyn PipelineWatcher,
1037 data_id_fn: &'a Option<DataIdFn>,
1038 run_info: &'a PipelineRunInfo,
1040 task_subtokens: &'a [ProgressToken],
1042 rate_limiter: Option<&'a Arc<dyn RateLimiter>>,
1044}
1045fn execute_from<'a>(
1056 tasks: &'a [TaskInfo],
1057 input: Arc<dyn Value>,
1058 first_index: usize,
1059 env: &'a ExecEnv<'a>,
1060) -> BoxFuture<'a, Result<Vec<Arc<dyn Value>>, ExecutionError>> {
1061 Box::pin(async move {
1062 let Some((info, rest)) = tasks.split_first() else {
1063 return Ok(vec![input]);
1065 };
1066
1067 if env.ctx.cancellation.is_cancelled() {
1068 return Err(ExecutionError::Cancelled);
1069 }
1070
1071 if first_index == 0
1074 && let Some(id_fn) = env.data_id_fn
1075 && let Some(data_id) = id_fn(Arc::clone(&input))
1076 {
1077 let pipeline_name = env.pipeline_name.unwrap_or("");
1078 let dataset_id = env.ctx.pipeline_ctx.as_ref().and_then(|p| p.dataset_id);
1079 let completed = env
1080 .ctx
1081 .exec_status
1082 .is_completed(&data_id, pipeline_name, dataset_id)
1083 .await
1084 .map_err(|source| ExecutionError::TaskFailed {
1085 task_index: 0,
1086 attempts: 0,
1087 source,
1088 })?;
1089 if completed {
1090 env.watcher
1091 .on_pipeline(env.pipeline_id, PipelineStatus::ItemSkipped { data_id })
1092 .await;
1093 return Ok(vec![]);
1094 }
1095 }
1096
1097 let task_name = info.name.as_deref();
1098 let task_label = task_name.unwrap_or("");
1099
1100 env.watcher
1101 .on_task(
1102 env.pipeline_id,
1103 first_index,
1104 task_name,
1105 env.total_tasks,
1106 TaskStatus::Started,
1107 )
1108 .await;
1109 env.watcher
1110 .on_task_started(env.run_info, task_label, first_index)
1111 .await;
1112
1113 let data_id = env.data_id_fn.as_ref().and_then(|f| f(Arc::clone(&input)));
1116
1117 let user_label_owned = env.ctx.pipeline_ctx.as_ref().and_then(|p| p.user_label());
1124 let prov_inputs = ProvenanceInputs {
1125 pipeline_name: env.pipeline_name.unwrap_or(""),
1126 task_name: task_label,
1127 user_label: user_label_owned,
1128 input_node_set: crate::provenance::extract_node_set_from_value(input.as_ref()),
1129 input_content_hash: crate::provenance::extract_content_hash_from_value(input.as_ref()),
1130 };
1131
1132 let input_passthrough = info.enriches.then(|| Arc::clone(&input));
1136
1137 let effective_rl = info.rate_limiter.as_ref().or(env.rate_limiter);
1138
1139 let resolved = call_with_retry(
1140 &info.task,
1141 input,
1142 first_index,
1143 task_name,
1144 data_id.as_deref(),
1145 info.summary_template.as_deref(),
1146 &prov_inputs,
1147 effective_rl,
1148 env,
1149 )
1150 .await?;
1151
1152 env.watcher
1153 .on_task(
1154 env.pipeline_id,
1155 first_index,
1156 task_name,
1157 env.total_tasks,
1158 TaskStatus::Succeeded,
1159 )
1160 .await;
1161 env.watcher
1162 .on_task_completed(env.run_info, task_label, 1)
1163 .await;
1164
1165 env.task_subtokens[first_index].set(1.0);
1167
1168 let batch_size = rest
1173 .first()
1174 .and_then(|next| next.batch_size)
1175 .unwrap_or(env.default_batch_size);
1176
1177 match resolved {
1178 Resolved::Single(v) => {
1179 if crate::sentinels::is_passthrough(v.as_ref()) {
1181 match input_passthrough {
1182 Some(orig) => return execute_from(rest, orig, first_index + 1, env).await,
1183 None => {
1184 return Err(ExecutionError::TaskFailed {
1185 task_index: first_index,
1186 attempts: 1,
1187 source: "task returned PassthroughSentinel but enriches=false"
1188 .into(),
1189 });
1190 }
1191 }
1192 }
1193 if crate::sentinels::is_dropped(v.as_ref()) {
1195 return Ok(vec![]);
1196 }
1197 execute_from(rest, v, first_index + 1, env).await
1198 }
1199 Resolved::Iter(iter) => {
1200 process_iter(iter, rest, batch_size, first_index + 1, &prov_inputs, env).await
1201 }
1202 Resolved::Stream(stream) => {
1203 process_stream(stream, rest, batch_size, first_index + 1, &prov_inputs, env).await
1204 }
1205 }
1206 })
1207}
1208
1209fn dispatch_batch<'a>(
1221 batch: Vec<Box<dyn Value>>,
1222 tail: &'a [TaskInfo],
1223 first_index: usize,
1224 prov_inputs: &'a ProvenanceInputs<'a>,
1225 env: &'a ExecEnv<'a>,
1226) -> BoxFuture<'a, Result<Vec<Arc<dyn Value>>, ExecutionError>> {
1227 Box::pin(async move {
1228 let Some((next_info, _)) = tail.split_first() else {
1229 return Ok(batch
1231 .into_iter()
1232 .map(|item| Arc::from(item) as Arc<dyn Value>)
1233 .collect());
1234 };
1235
1236 if next_info.task.is_batch() {
1237 let call = next_info.task.call_batch(&batch, env.ctx.clone());
1247 let resolved =
1248 resolve_call(call)
1249 .await
1250 .map_err(|source| ExecutionError::TaskFailed {
1251 task_index: first_index,
1252 attempts: 1,
1253 source,
1254 })?;
1255 let rest = &tail[1..];
1257 match resolved {
1258 Resolved::Single(v) => execute_from(rest, v, first_index + 1, env).await,
1259 Resolved::Iter(iter) => {
1260 let batch_size = rest
1261 .first()
1262 .and_then(|t| t.batch_size)
1263 .unwrap_or(env.default_batch_size);
1264 process_iter(iter, rest, batch_size, first_index + 1, prov_inputs, env).await
1265 }
1266 Resolved::Stream(stream) => {
1267 let batch_size = rest
1268 .first()
1269 .and_then(|t| t.batch_size)
1270 .unwrap_or(env.default_batch_size);
1271 process_stream(stream, rest, batch_size, first_index + 1, prov_inputs, env)
1272 .await
1273 }
1274 }
1275 } else {
1276 let mut all_outputs = Vec::new();
1279 for item in batch {
1280 let input = Arc::from(item) as Arc<dyn Value>;
1281 all_outputs.append(&mut execute_from(tail, input, first_index, env).await?);
1282 }
1283 Ok(all_outputs)
1284 }
1285 })
1286}
1287
1288async fn process_iter(
1296 iter: ValueIter,
1297 tail: &[TaskInfo],
1298 batch_size: usize,
1299 first_index: usize,
1300 prov_inputs: &ProvenanceInputs<'_>,
1301 env: &ExecEnv<'_>,
1302) -> Result<Vec<Arc<dyn Value>>, ExecutionError> {
1303 let mut outputs = Vec::new();
1304 let mut batch: Vec<Box<dyn Value>> = Vec::with_capacity(batch_size);
1305
1306 for mut item in iter {
1307 if crate::sentinels::is_dropped(item.as_ref()) {
1309 continue;
1310 }
1311 stamp_boxed_item(&mut item, prov_inputs, env);
1312 batch.push(item);
1313 if batch.len() >= batch_size {
1314 outputs.append(
1315 &mut dispatch_batch(mem::take(&mut batch), tail, first_index, prov_inputs, env)
1316 .await?,
1317 );
1318 }
1319 }
1320
1321 if !batch.is_empty() {
1322 outputs.append(&mut dispatch_batch(batch, tail, first_index, prov_inputs, env).await?);
1323 }
1324
1325 Ok(outputs)
1326}
1327
1328async fn process_stream(
1336 mut stream: ValueStream,
1337 tail: &[TaskInfo],
1338 batch_size: usize,
1339 first_index: usize,
1340 prov_inputs: &ProvenanceInputs<'_>,
1341 env: &ExecEnv<'_>,
1342) -> Result<Vec<Arc<dyn Value>>, ExecutionError> {
1343 let mut outputs = Vec::new();
1344 let mut batch: Vec<Box<dyn Value>> = Vec::with_capacity(batch_size);
1345
1346 while let Some(mut item) = stream.next().await {
1347 if crate::sentinels::is_dropped(item.as_ref()) {
1349 continue;
1350 }
1351 stamp_boxed_item(&mut item, prov_inputs, env);
1352 batch.push(item);
1353 if batch.len() >= batch_size {
1354 outputs.append(
1355 &mut dispatch_batch(mem::take(&mut batch), tail, first_index, prov_inputs, env)
1356 .await?,
1357 );
1358 }
1359 }
1360
1361 if !batch.is_empty() {
1362 outputs.append(&mut dispatch_batch(batch, tail, first_index, prov_inputs, env).await?);
1363 }
1364
1365 Ok(outputs)
1366}
1367
1368fn stamp_boxed_item(
1372 item: &mut Box<dyn Value>,
1373 prov_inputs: &ProvenanceInputs<'_>,
1374 env: &ExecEnv<'_>,
1375) {
1376 if let Some(pctx) = env.ctx.pipeline_ctx.as_ref() {
1377 let mut visited = pctx.provenance_visited.lock().unwrap();
1379 let prov_ctx = prov_inputs.ctx();
1380 let _ = crate::provenance::stamp_tree_dyn(item.as_mut(), &prov_ctx, &mut visited);
1381 }
1382}
1383#[allow(clippy::too_many_arguments)]
1389async fn call_with_retry(
1390 task: &Task,
1391 input: Arc<dyn Value>,
1392 task_index: usize,
1393 task_name: Option<&str>,
1394 data_id: Option<&str>,
1395 #[allow(unused_variables)] summary_template: Option<&str>,
1396 prov_inputs: &ProvenanceInputs<'_>,
1397 rate_limiter: Option<&Arc<dyn RateLimiter>>,
1398 env: &ExecEnv<'_>,
1399) -> Result<Resolved, ExecutionError> {
1400 #[cfg(feature = "telemetry")]
1402 let span = tracing::info_span!(
1403 "cognee.pipeline.task",
1404 task.name = task_name.unwrap_or("unknown"),
1405 task.index = task_index,
1406 task.result_count = tracing::field::Empty,
1407 task.result_summary = tracing::field::Empty,
1408 task.error = tracing::field::Empty,
1409 );
1410
1411 let max_attempts = env.policy.max_attempts();
1412 let mut last_error: Option<TaskError> = None;
1413
1414 let subtoken = env.task_subtokens[task_index].clone();
1416 let scoped_ctx = env.ctx.with_progress(subtoken);
1417 let task_ctx = scoped_ctx.with_current_data(input.clone());
1418
1419 let user_id = env.ctx.pipeline_ctx.as_ref().and_then(|p| p.user_id);
1423 let tenant_id = env.ctx.pipeline_ctx.as_ref().and_then(|p| p.tenant_id);
1424
1425 emit_task_event("Started", task, task_name, user_id, tenant_id);
1427
1428 for attempt in 1..=max_attempts {
1429 if let Some(rl) = rate_limiter {
1431 rl.acquire().await;
1432 }
1433 let call = task.call(input.clone(), Arc::clone(&task_ctx));
1434 match resolve_call(call).await {
1435 Ok(mut resolved) => {
1436 #[cfg(feature = "telemetry")]
1438 {
1439 let result_count: usize = match &resolved {
1440 Resolved::Single(_) => 1,
1441 Resolved::Iter(_) | Resolved::Stream(_) => 1,
1442 };
1443 span.record("task.result_count", result_count);
1444 if let Some(template) = summary_template {
1445 let summary = template.replace("{n}", &result_count.to_string());
1446 span.record("task.result_summary", summary.as_str());
1447 }
1448 }
1449
1450 if let Resolved::Single(ref mut v) = resolved
1457 && let Some(pctx) = env.ctx.pipeline_ctx.as_ref()
1458 {
1459 let prov_ctx = prov_inputs.ctx();
1460 let mut visited = pctx.provenance_visited.lock().unwrap();
1462 if let Some(inner) = Arc::get_mut(v) {
1463 let _ = crate::provenance::stamp_tree_dyn(inner, &prov_ctx, &mut visited);
1464 } else {
1465 tracing::warn!(
1466 "skipping provenance stamping: shared Arc<dyn Value> for task '{}'",
1467 prov_inputs.task_name
1468 );
1469 }
1470 }
1471
1472 if let Some(data_id) = data_id {
1474 let pipeline_name = env.pipeline_name.unwrap_or("");
1475 let user_id = env.ctx.pipeline_ctx.as_ref().and_then(|p| p.user_id);
1476
1477 let node_set = match &resolved {
1479 Resolved::Single(v) => (**v)
1480 .as_any()
1481 .downcast_ref::<TaggedMeta>()
1482 .and_then(|m| m.node_set.clone()),
1483 _ => None,
1484 };
1485
1486 let _ = env
1487 .ctx
1488 .exec_status
1489 .stamp_provenance(
1490 data_id,
1491 pipeline_name,
1492 task_name.unwrap_or(""),
1493 user_id,
1494 node_set.as_deref(),
1495 )
1496 .await;
1497 }
1498
1499 emit_task_event("Completed", task, task_name, user_id, tenant_id);
1500 return Ok(resolved);
1501 }
1502 Err(e) => {
1503 let error_str = e.to_string();
1504
1505 #[cfg(feature = "telemetry")]
1507 span.record("task.error", error_str.as_str());
1508
1509 last_error = Some(e);
1510 if attempt < max_attempts {
1511 env.watcher
1512 .on_task(
1513 env.pipeline_id,
1514 task_index,
1515 task_name,
1516 env.total_tasks,
1517 TaskStatus::Retrying {
1518 attempt,
1519 error: error_str,
1520 },
1521 )
1522 .await;
1523 let retry_index = attempt - 1; if let Some(delay) = env.policy.delay(retry_index) {
1525 sleep(delay).await;
1526 }
1527 }
1528 }
1529 }
1530 }
1531
1532 let source = last_error.expect("loop ran at least once");
1533 let error_str = source.to_string();
1534
1535 #[cfg(feature = "telemetry")]
1536 span.record("task.error", error_str.as_str());
1537
1538 emit_task_event("Errored", task, task_name, user_id, tenant_id);
1540
1541 env.watcher
1542 .on_task(
1543 env.pipeline_id,
1544 task_index,
1545 task_name,
1546 env.total_tasks,
1547 TaskStatus::Failed {
1548 attempts: max_attempts,
1549 error: error_str.clone(),
1550 },
1551 )
1552 .await;
1553 env.watcher
1554 .on_task_errored(env.run_info, task_name.unwrap_or(""), &error_str)
1555 .await;
1556
1557 Err(ExecutionError::TaskFailed {
1558 task_index,
1559 attempts: max_attempts,
1560 source,
1561 })
1562}
1563
1564async fn resolve_call(call: TaskCall) -> Result<Resolved, TaskError> {
1567 match call {
1568 TaskCall::Sync(r) => r.map(Resolved::Single),
1569 TaskCall::Async(fut) => fut.await.map(Resolved::Single),
1570 TaskCall::SyncIter(r) => r.map(Resolved::Iter),
1571 TaskCall::AsyncStream(r) => r.map(Resolved::Stream),
1572 }
1573}
1574pub struct PipelineRunResult {
1576 pub run_id: Uuid,
1578 pub outputs: Vec<Arc<dyn Value>>,
1580}
1581pub struct PipelineRunHandle {
1588 pub run_id: Uuid,
1590 inner: tokio::task::JoinHandle<Result<PipelineRunResult, ExecutionError>>,
1591}
1592
1593impl PipelineRunHandle {
1594 pub async fn wait(self) -> Result<PipelineRunResult, ExecutionError> {
1596 match self.inner.await {
1597 Ok(result) => result,
1598 Err(join_err) => {
1599 if join_err.is_cancelled() {
1600 Err(ExecutionError::Cancelled)
1601 } else {
1602 Err(ExecutionError::TaskFailed {
1604 task_index: 0,
1605 attempts: 0,
1606 source: join_err.to_string().into(),
1607 })
1608 }
1609 }
1610 }
1611 }
1612
1613 pub fn abort(&self) {
1617 self.inner.abort();
1618 }
1619
1620 pub fn is_finished(&self) -> bool {
1622 self.inner.is_finished()
1623 }
1624}
1625pub fn execute_in_background(
1643 pipeline: Arc<Pipeline>,
1644 inputs: Vec<Arc<dyn Value>>,
1645 ctx: Arc<TaskContext>,
1646 watcher: Arc<dyn PipelineWatcher>,
1647) -> PipelineRunHandle {
1648 let run_id = pipeline.id;
1649 let fut = async move {
1652 let outputs = execute(&pipeline, inputs, ctx, watcher.as_ref()).await?;
1653 Ok(PipelineRunResult { run_id, outputs })
1654 };
1655 let fut: std::pin::Pin<Box<dyn std::future::Future<Output = _> + Send>> = Box::pin(fut);
1656 let inner = tokio::spawn(fut);
1657 PipelineRunHandle { run_id, inner }
1658}
1659
1660pub fn execute_blocking(
1672 pipeline: &Pipeline,
1673 inputs: Vec<Arc<dyn Value>>,
1674 ctx: Arc<TaskContext>,
1675 watcher: &dyn PipelineWatcher,
1676) -> Result<PipelineRunResult, ExecutionError> {
1677 let run_id = pipeline.id;
1678 let rt = tokio::runtime::Builder::new_current_thread()
1679 .enable_all()
1680 .build()
1681 .map_err(|e| ExecutionError::TaskFailed {
1682 task_index: 0,
1683 attempts: 0,
1684 source: e.into(),
1685 })?;
1686 let outputs = rt.block_on(execute(pipeline, inputs, ctx, watcher))?;
1687 Ok(PipelineRunResult { run_id, outputs })
1688}
1689
1690#[cfg(test)]
1691mod tests {
1692 use super::*;
1693 use std::future::Future;
1694 use std::pin::Pin;
1695
1696 use crate::cancellation::cancellation_pair;
1697 use crate::exec_status::NoopExecStatusManager;
1698 use crate::progress::ProgressToken;
1699 use crate::task::{Task, TaskError, Value};
1700 use crate::task_context::TaskContext;
1701 use crate::thread_pool::CpuPool;
1702
1703 struct StubPool;
1704 impl CpuPool for StubPool {
1705 fn spawn_raw(
1706 &self,
1707 _task: Box<dyn FnOnce() + Send + 'static>,
1708 ) -> Pin<Box<dyn Future<Output = Result<(), crate::error::CoreError>> + Send + 'static>>
1709 {
1710 Box::pin(async { Ok(()) })
1711 }
1712 }
1713
1714 async fn stub_ctx() -> Arc<TaskContext> {
1715 let db = cognee_database::connect("sqlite::memory:").await.unwrap();
1716 cognee_database::initialize(&db).await.unwrap();
1717 let (_handle, token) = cancellation_pair();
1718 Arc::new(TaskContext {
1719 thread_pool: Arc::new(StubPool),
1720 database: Arc::new(db),
1721 graph_db: Arc::new(cognee_graph::MockGraphDB::new()),
1722 vector_db: Arc::new(cognee_vector::MockVectorDB::new()),
1723 cancellation: token,
1724 progress: ProgressToken::new(),
1725 pipeline_ctx: None,
1726 exec_status: Arc::new(NoopExecStatusManager),
1727 pipeline_watcher: None,
1728 })
1729 }
1730
1731 #[test]
1732 fn pipeline_run_info_elapsed_seconds_returns_none_before_completion() {
1733 let info = PipelineRunInfo {
1734 run_id: Uuid::new_v4(),
1735 pipeline_id: Uuid::new_v4(),
1736 pipeline_name: "test".to_string(),
1737 user_id: None,
1738 tenant_id: None,
1739 dataset_id: None,
1740 data_ids: Vec::new(),
1741 status: PipelineRunStatus::Started,
1742 started_at: chrono::Utc::now(),
1743 completed_at: None,
1744 };
1745 assert_eq!(info.elapsed_seconds(), None);
1746 }
1747
1748 #[test]
1749 fn pipeline_run_info_elapsed_seconds_returns_positive_after_completion() {
1750 let now = chrono::Utc::now();
1751 let started_at = now - chrono::Duration::milliseconds(100);
1752 let info = PipelineRunInfo {
1753 run_id: Uuid::new_v4(),
1754 pipeline_id: Uuid::new_v4(),
1755 pipeline_name: "test".to_string(),
1756 user_id: None,
1757 tenant_id: None,
1758 dataset_id: None,
1759 data_ids: Vec::new(),
1760 status: PipelineRunStatus::Completed,
1761 started_at,
1762 completed_at: Some(now),
1763 };
1764 let elapsed = info
1765 .elapsed_seconds()
1766 .expect("elapsed_seconds should be Some when completed_at is set");
1767 assert!(elapsed > 0.0, "elapsed should be positive, got {elapsed}");
1768 assert!(elapsed < 1.0, "elapsed should be < 1s, got {elapsed}");
1769 }
1770
1771 #[tokio::test]
1772 async fn test_execute_retry_on_failure() {
1773 use std::sync::atomic::{AtomicU32, Ordering};
1774
1775 let counter = Arc::new(AtomicU32::new(0));
1776 let counter_clone = Arc::clone(&counter);
1777
1778 let task = Task::Sync(Arc::new(move |input, _ctx| {
1779 let prev = counter_clone.fetch_add(1, Ordering::SeqCst);
1780 if prev < 2 {
1781 return Err("not yet".into());
1783 }
1784 let val = (*input).as_any().downcast_ref::<i32>().unwrap();
1786 Ok(Arc::new(*val * 2) as Arc<dyn Value>)
1787 }));
1788
1789 let policy = RetryPolicy::Limited {
1790 max_attempts: std::num::NonZeroU32::new(3).unwrap(),
1791 delay: RetryDelay::Constant(Duration::from_millis(1)),
1792 };
1793
1794 let pipeline = Pipeline::new("retry pipeline")
1795 .with_retry(policy)
1796 .with_task(task);
1797
1798 let inputs: Vec<Arc<dyn Value>> = vec![Arc::new(21_i32)];
1799 let ctx = stub_ctx().await;
1800 let watcher = NoopWatcher;
1801
1802 let outputs = execute(&pipeline, inputs, ctx, &watcher).await.unwrap();
1803
1804 assert_eq!(outputs.len(), 1);
1805 let result = (*outputs[0]).as_any().downcast_ref::<i32>().unwrap();
1806 assert_eq!(*result, 42);
1807 assert_eq!(counter.load(Ordering::SeqCst), 3);
1808 }
1809
1810 #[tokio::test]
1811 async fn test_execute_single_task_pipeline() {
1812 let double = Task::sync_typed(|x: &i32, _ctx| -> Result<Box<i32>, TaskError> {
1813 Ok(Box::new(*x * 2))
1814 });
1815
1816 let pipeline = Pipeline::new("double pipeline").with_task(double);
1817
1818 let inputs: Vec<Arc<dyn Value>> = vec![Arc::new(5_i32)];
1819 let ctx = stub_ctx().await;
1820 let watcher = NoopWatcher;
1821
1822 let outputs = execute(&pipeline, inputs, ctx, &watcher).await.unwrap();
1823
1824 assert_eq!(outputs.len(), 1);
1825 let result = (*outputs[0]).as_any().downcast_ref::<i32>().unwrap();
1826 assert_eq!(*result, 10);
1827 }
1828
1829 #[tokio::test]
1830 async fn test_execute_chained_tasks() {
1831 let double = Task::sync_typed(|x: &i32, _ctx| Ok(Box::new(*x * 2)));
1833 let add_one = Task::sync_typed(|x: &i32, _ctx| Ok(Box::new(*x + 1)));
1834
1835 let pipeline = Pipeline::new("chained test")
1836 .with_task(double)
1837 .with_task(add_one);
1838
1839 let inputs: Vec<Arc<dyn Value>> = vec![Arc::new(3_i32)];
1840 let ctx = stub_ctx().await;
1841 let watcher = NoopWatcher;
1842
1843 let outputs = execute(&pipeline, inputs, ctx, &watcher).await.unwrap();
1844
1845 assert_eq!(outputs.len(), 1);
1846 let result = (*outputs[0]).as_any().downcast_ref::<i32>().unwrap();
1847 assert_eq!(*result, 7);
1849 }
1850
1851 #[tokio::test]
1852 async fn test_execute_iter_task_batching() {
1853 let iter_task = Task::SyncIter(Arc::new(move |_input, _ctx| {
1855 let iter = (1..=5).map(|i| Box::new(i) as Box<dyn Value>);
1856 Ok(Box::new(iter) as Box<dyn Iterator<Item = Box<dyn Value>> + Send>)
1857 }));
1858
1859 let double_task = Task::sync_typed(|x: &i32, _ctx| Ok(Box::new(*x * 2)));
1861
1862 let pipeline = Pipeline::new("iter batching test")
1863 .with_batch_size(2)
1864 .with_task(iter_task)
1865 .with_task(double_task);
1866
1867 let inputs: Vec<Arc<dyn Value>> = vec![Arc::new(0_i32)];
1868 let ctx = stub_ctx().await;
1869 let watcher = NoopWatcher;
1870
1871 let outputs = execute(&pipeline, inputs, ctx, &watcher).await.unwrap();
1872
1873 assert_eq!(outputs.len(), 5);
1875 let values: Vec<i32> = outputs
1876 .iter()
1877 .map(|v| *(**v).as_any().downcast_ref::<i32>().unwrap())
1878 .collect();
1879 assert_eq!(values, vec![2, 4, 6, 8, 10]);
1880 }
1881
1882 #[tokio::test]
1883 async fn test_cancellation_stops_pipeline() {
1884 use std::sync::atomic::{AtomicU32, Ordering};
1885
1886 let call_count = Arc::new(AtomicU32::new(0));
1887 let call_count_clone = Arc::clone(&call_count);
1888
1889 let task1 = Task::Async(Arc::new(move |input, ctx| {
1891 let cc = Arc::clone(&call_count_clone);
1892 Box::pin(async move {
1893 cc.fetch_add(1, Ordering::SeqCst);
1894 ctx.cancellation.cancelled().await; Ok(input)
1896 })
1897 }));
1898
1899 let call_count_clone2 = Arc::clone(&call_count);
1901 let task2 = Task::Sync(Arc::new(move |input, _ctx| {
1902 call_count_clone2.fetch_add(1, Ordering::SeqCst);
1903 Ok(input)
1904 }));
1905
1906 let pipeline = Pipeline::new("cancel test")
1907 .with_task(task1)
1908 .with_task(task2);
1909
1910 let db = cognee_database::connect("sqlite::memory:").await.unwrap();
1911 cognee_database::initialize(&db).await.unwrap();
1912 let (handle, token) = cancellation_pair();
1913 let ctx = Arc::new(TaskContext {
1914 thread_pool: Arc::new(StubPool),
1915 database: Arc::new(db),
1916 graph_db: Arc::new(cognee_graph::MockGraphDB::new()),
1917 vector_db: Arc::new(cognee_vector::MockVectorDB::new()),
1918 cancellation: token,
1919 progress: ProgressToken::new(),
1920 pipeline_ctx: None,
1921 exec_status: Arc::new(NoopExecStatusManager),
1922 pipeline_watcher: None,
1923 });
1924
1925 handle.cancel();
1927
1928 let inputs: Vec<Arc<dyn Value>> = vec![Arc::new(1_i32)];
1929 let result = execute(&pipeline, inputs, ctx, &NoopWatcher).await;
1930
1931 assert!(
1932 matches!(result, Err(ExecutionError::Cancelled)),
1933 "expected Cancelled error"
1934 );
1935 }
1936
1937 #[tokio::test]
1938 async fn test_sync_terminal() {
1939 let double = Task::sync_typed(|x: &i32, _ctx| -> Result<Box<i32>, TaskError> {
1940 Ok(Box::new(*x * 2))
1941 });
1942
1943 let pipeline = Pipeline::new("sync terminal").with_task(double);
1944
1945 let inputs: Vec<Arc<dyn Value>> = vec![Arc::new(5_i32)];
1946 let ctx = stub_ctx().await;
1947
1948 let outputs = execute(&pipeline, inputs, ctx, &NoopWatcher).await.unwrap();
1949
1950 assert_eq!(outputs.len(), 1);
1951 let result = (*outputs[0]).as_any().downcast_ref::<i32>().unwrap();
1952 assert_eq!(*result, 10);
1953 }
1954
1955 #[tokio::test]
1956 async fn test_async_terminal() {
1957 let triple = Task::async_fn_typed(|x: &i32, _ctx| {
1958 let val = *x;
1959 Box::pin(async move { Ok(Box::new(val * 3)) })
1960 });
1961
1962 let pipeline = Pipeline::new("async terminal").with_task(triple);
1963
1964 let inputs: Vec<Arc<dyn Value>> = vec![Arc::new(4_i32)];
1965 let ctx = stub_ctx().await;
1966
1967 let outputs = execute(&pipeline, inputs, ctx, &NoopWatcher).await.unwrap();
1968
1969 assert_eq!(outputs.len(), 1);
1970 let result = (*outputs[0]).as_any().downcast_ref::<i32>().unwrap();
1971 assert_eq!(*result, 12);
1972 }
1973
1974 #[tokio::test]
1975 async fn test_sync_iter_terminal() {
1976 use crate::task::ValueIter;
1977
1978 let iter_task = Task::SyncIter(Arc::new(|_input, _ctx| {
1979 let vec = vec![10_i32, 20, 30];
1980 Ok(Box::new(vec.into_iter().map(|i| Box::new(i) as Box<dyn Value>)) as ValueIter)
1981 }));
1982
1983 let pipeline = Pipeline::new("sync iter terminal").with_task(iter_task);
1984
1985 let inputs: Vec<Arc<dyn Value>> = vec![Arc::new(0_i32)];
1986 let ctx = stub_ctx().await;
1987
1988 let outputs = execute(&pipeline, inputs, ctx, &NoopWatcher).await.unwrap();
1989
1990 assert_eq!(outputs.len(), 3);
1991 let values: Vec<i32> = outputs
1992 .iter()
1993 .map(|v| *(**v).as_any().downcast_ref::<i32>().unwrap())
1994 .collect();
1995 assert_eq!(values, vec![10, 20, 30]);
1996 }
1997
1998 #[tokio::test]
1999 async fn test_sync_iter_then_sync() {
2000 use crate::task::ValueIter;
2001
2002 let iter_task = Task::SyncIter(Arc::new(|_input, _ctx| {
2003 let vec = vec![1_i32, 2, 3, 4];
2004 Ok(Box::new(vec.into_iter().map(|i| Box::new(i) as Box<dyn Value>)) as ValueIter)
2005 }));
2006
2007 let double_task = Task::sync_typed(|x: &i32, _ctx| Ok(Box::new(*x * 2)));
2009
2010 let pipeline = Pipeline::new("sync iter then sync")
2011 .with_batch_size(2)
2012 .with_task(iter_task)
2013 .with_task(double_task);
2014
2015 let inputs: Vec<Arc<dyn Value>> = vec![Arc::new(0_i32)];
2016 let ctx = stub_ctx().await;
2017
2018 let outputs = execute(&pipeline, inputs, ctx, &NoopWatcher).await.unwrap();
2019
2020 assert_eq!(outputs.len(), 4);
2021 let values: Vec<i32> = outputs
2022 .iter()
2023 .map(|v| *(**v).as_any().downcast_ref::<i32>().unwrap())
2024 .collect();
2025 assert_eq!(values, vec![2, 4, 6, 8]);
2026 }
2027
2028 #[tokio::test]
2029 async fn test_sync_iter_then_async() {
2030 use crate::task::ValueIter;
2031
2032 let iter_task = Task::SyncIter(Arc::new(|_input, _ctx| {
2033 let vec = vec![1_i32, 2, 3];
2034 Ok(Box::new(vec.into_iter().map(|i| Box::new(i) as Box<dyn Value>)) as ValueIter)
2035 }));
2036
2037 let add_ten = Task::async_fn_typed(|x: &i32, _ctx| {
2039 let v = *x + 10;
2040 Box::pin(async move { Ok(Box::new(v)) })
2041 });
2042
2043 let pipeline = Pipeline::new("sync iter then async")
2044 .with_batch_size(3)
2045 .with_task(iter_task)
2046 .with_task(add_ten);
2047
2048 let inputs: Vec<Arc<dyn Value>> = vec![Arc::new(0_i32)];
2049 let ctx = stub_ctx().await;
2050
2051 let outputs = execute(&pipeline, inputs, ctx, &NoopWatcher).await.unwrap();
2052
2053 assert_eq!(outputs.len(), 3);
2054 let values: Vec<i32> = outputs
2055 .iter()
2056 .map(|v| *(**v).as_any().downcast_ref::<i32>().unwrap())
2057 .collect();
2058 assert_eq!(values, vec![11, 12, 13]);
2059 }
2060
2061 #[tokio::test]
2062 async fn test_async_stream_terminal() {
2063 let stream_task = Task::AsyncStream(Arc::new(|_input, _ctx| {
2064 let items = vec![100_i32, 200, 300];
2065 Ok(
2066 Box::pin(futures::stream::iter(items).map(|i| Box::new(i) as Box<dyn Value>))
2067 as ValueStream,
2068 )
2069 }));
2070
2071 let pipeline = Pipeline::new("async stream terminal").with_task(stream_task);
2072
2073 let inputs: Vec<Arc<dyn Value>> = vec![Arc::new(0_i32)];
2074 let ctx = stub_ctx().await;
2075
2076 let outputs = execute(&pipeline, inputs, ctx, &NoopWatcher).await.unwrap();
2077
2078 assert_eq!(outputs.len(), 3);
2079 let values: Vec<i32> = outputs
2080 .iter()
2081 .map(|v| *(**v).as_any().downcast_ref::<i32>().unwrap())
2082 .collect();
2083 assert_eq!(values, vec![100, 200, 300]);
2084 }
2085
2086 #[tokio::test]
2087 async fn test_async_stream_then_sync() {
2088 let stream_task = Task::AsyncStream(Arc::new(|_input, _ctx| {
2089 let items = vec![10_i32, 20, 30, 40];
2090 Ok(
2091 Box::pin(futures::stream::iter(items).map(|i| Box::new(i) as Box<dyn Value>))
2092 as ValueStream,
2093 )
2094 }));
2095
2096 let triple = Task::sync_typed(|x: &i32, _ctx| Ok(Box::new(*x * 3)));
2098
2099 let pipeline = Pipeline::new("async stream then sync")
2100 .with_batch_size(2)
2101 .with_task(stream_task)
2102 .with_task(triple);
2103
2104 let inputs: Vec<Arc<dyn Value>> = vec![Arc::new(0_i32)];
2105 let ctx = stub_ctx().await;
2106
2107 let outputs = execute(&pipeline, inputs, ctx, &NoopWatcher).await.unwrap();
2108
2109 assert_eq!(outputs.len(), 4);
2110 let values: Vec<i32> = outputs
2111 .iter()
2112 .map(|v| *(**v).as_any().downcast_ref::<i32>().unwrap())
2113 .collect();
2114 assert_eq!(values, vec![30, 60, 90, 120]);
2115 }
2116
2117 #[tokio::test]
2118 async fn test_async_stream_then_async() {
2119 let stream_task = Task::AsyncStream(Arc::new(|_input, _ctx| {
2120 let items = vec![5_i32, 15];
2121 Ok(
2122 Box::pin(futures::stream::iter(items).map(|i| Box::new(i) as Box<dyn Value>))
2123 as ValueStream,
2124 )
2125 }));
2126
2127 let add_one = Task::async_fn_typed(|x: &i32, _ctx| {
2129 let v = *x + 1;
2130 Box::pin(async move { Ok(Box::new(v)) })
2131 });
2132
2133 let pipeline = Pipeline::new("async stream then async")
2134 .with_batch_size(10)
2135 .with_task(stream_task)
2136 .with_task(add_one);
2137
2138 let inputs: Vec<Arc<dyn Value>> = vec![Arc::new(0_i32)];
2139 let ctx = stub_ctx().await;
2140
2141 let outputs = execute(&pipeline, inputs, ctx, &NoopWatcher).await.unwrap();
2142
2143 assert_eq!(outputs.len(), 2);
2144 let values: Vec<i32> = outputs
2145 .iter()
2146 .map(|v| *(**v).as_any().downcast_ref::<i32>().unwrap())
2147 .collect();
2148 assert_eq!(values, vec![6, 16]);
2149 }
2150
2151 #[tokio::test]
2152 async fn test_sync_then_sync() {
2153 let double = Task::sync_typed(|x: &i32, _ctx| -> Result<Box<i32>, TaskError> {
2154 Ok(Box::new(*x * 2))
2155 });
2156 let add_one = Task::sync_typed(|x: &i32, _ctx| -> Result<Box<i32>, TaskError> {
2157 Ok(Box::new(*x + 1))
2158 });
2159
2160 let pipeline = Pipeline::new("sync then sync")
2161 .with_task(double)
2162 .with_task(add_one);
2163
2164 let inputs: Vec<Arc<dyn Value>> = vec![Arc::new(3_i32)];
2165 let outputs = execute(&pipeline, inputs, stub_ctx().await, &NoopWatcher)
2166 .await
2167 .unwrap();
2168
2169 assert_eq!(outputs.len(), 1);
2170 let result = (*outputs[0]).as_any().downcast_ref::<i32>().unwrap();
2171 assert_eq!(*result, 7); }
2173
2174 #[tokio::test]
2175 async fn test_sync_then_async() {
2176 let double = Task::sync_typed(|x: &i32, _ctx| -> Result<Box<i32>, TaskError> {
2177 Ok(Box::new(*x * 2))
2178 });
2179 let add_ten = Task::async_fn_typed(|x: &i32, _ctx| {
2180 let v = *x;
2181 Box::pin(async move { Ok(Box::new(v + 10)) })
2182 });
2183
2184 let pipeline = Pipeline::new("sync then async")
2185 .with_task(double)
2186 .with_task(add_ten);
2187
2188 let inputs: Vec<Arc<dyn Value>> = vec![Arc::new(5_i32)];
2189 let outputs = execute(&pipeline, inputs, stub_ctx().await, &NoopWatcher)
2190 .await
2191 .unwrap();
2192
2193 assert_eq!(outputs.len(), 1);
2194 let result = (*outputs[0]).as_any().downcast_ref::<i32>().unwrap();
2195 assert_eq!(*result, 20); }
2197
2198 #[tokio::test]
2199 async fn test_async_then_sync() {
2200 let add_hundred = Task::async_fn_typed(|x: &i32, _ctx| {
2201 let v = *x;
2202 Box::pin(async move { Ok(Box::new(v + 100)) })
2203 });
2204 let double = Task::sync_typed(|x: &i32, _ctx| -> Result<Box<i32>, TaskError> {
2205 Ok(Box::new(*x * 2))
2206 });
2207
2208 let pipeline = Pipeline::new("async then sync")
2209 .with_task(add_hundred)
2210 .with_task(double);
2211
2212 let inputs: Vec<Arc<dyn Value>> = vec![Arc::new(3_i32)];
2213 let outputs = execute(&pipeline, inputs, stub_ctx().await, &NoopWatcher)
2214 .await
2215 .unwrap();
2216
2217 assert_eq!(outputs.len(), 1);
2218 let result = (*outputs[0]).as_any().downcast_ref::<i32>().unwrap();
2219 assert_eq!(*result, 206); }
2221
2222 #[tokio::test]
2223 async fn test_async_then_async() {
2224 let triple = Task::async_fn_typed(|x: &i32, _ctx| {
2225 let v = *x;
2226 Box::pin(async move { Ok(Box::new(v * 3)) })
2227 });
2228 let add_one = Task::async_fn_typed(|x: &i32, _ctx| {
2229 let v = *x;
2230 Box::pin(async move { Ok(Box::new(v + 1)) })
2231 });
2232
2233 let pipeline = Pipeline::new("async then async")
2234 .with_task(triple)
2235 .with_task(add_one);
2236
2237 let inputs: Vec<Arc<dyn Value>> = vec![Arc::new(10_i32)];
2238 let outputs = execute(&pipeline, inputs, stub_ctx().await, &NoopWatcher)
2239 .await
2240 .unwrap();
2241
2242 assert_eq!(outputs.len(), 1);
2243 let result = (*outputs[0]).as_any().downcast_ref::<i32>().unwrap();
2244 assert_eq!(*result, 31); }
2246
2247 #[tokio::test]
2248 async fn test_sync_iter_then_sync_batch() {
2249 use crate::task::ValueIter;
2250
2251 let iter_task = Task::SyncIter(Arc::new(|_input, _ctx| {
2253 let vec = vec![1_i32, 2, 3, 4, 5];
2254 Ok(Box::new(vec.into_iter().map(|i| Box::new(i) as Box<dyn Value>)) as ValueIter)
2255 }));
2256
2257 let sum_batch = Task::SyncBatch(Arc::new(|items: &[Box<dyn Value>], _ctx| {
2259 let sum: i32 = items
2260 .iter()
2261 .map(|item| *(**item).as_any().downcast_ref::<i32>().unwrap())
2262 .sum();
2263 Ok(Arc::new(sum) as Arc<dyn Value>)
2264 }));
2265
2266 let pipeline = Pipeline::new("sync iter then sync batch")
2267 .with_batch_size(2)
2268 .with_task(iter_task)
2269 .with_task(sum_batch);
2270
2271 let inputs: Vec<Arc<dyn Value>> = vec![Arc::new(0_i32)];
2272 let ctx = stub_ctx().await;
2273
2274 let outputs = execute(&pipeline, inputs, ctx, &NoopWatcher).await.unwrap();
2275
2276 assert_eq!(outputs.len(), 3, "expected 3 batches: [1,2], [3,4], [5]");
2277 let sums: Vec<i32> = outputs
2278 .iter()
2279 .map(|v| *(**v).as_any().downcast_ref::<i32>().unwrap())
2280 .collect();
2281 assert_eq!(sums, vec![3, 7, 5]);
2282 }
2283
2284 #[tokio::test]
2285 async fn test_sync_iter_then_async_batch() {
2286 use crate::task::ValueIter;
2287
2288 let iter_task = Task::SyncIter(Arc::new(|_input, _ctx| {
2290 let vec = vec![10_i32, 20, 30];
2291 Ok(Box::new(vec.into_iter().map(|i| Box::new(i) as Box<dyn Value>)) as ValueIter)
2292 }));
2293
2294 let count_batch = Task::AsyncBatch(Arc::new(|items: &[Box<dyn Value>], _ctx| {
2296 let count = items.len() as i32;
2297 Box::pin(async move { Ok(Arc::new(count) as Arc<dyn Value>) })
2298 }));
2299
2300 let pipeline = Pipeline::new("sync iter then async batch")
2301 .with_batch_size(2)
2302 .with_task(iter_task)
2303 .with_task(count_batch);
2304
2305 let inputs: Vec<Arc<dyn Value>> = vec![Arc::new(0_i32)];
2306 let ctx = stub_ctx().await;
2307
2308 let outputs = execute(&pipeline, inputs, ctx, &NoopWatcher).await.unwrap();
2309
2310 assert_eq!(outputs.len(), 2, "expected 2 batches: [10,20], [30]");
2311 let counts: Vec<i32> = outputs
2312 .iter()
2313 .map(|v| *(**v).as_any().downcast_ref::<i32>().unwrap())
2314 .collect();
2315 assert_eq!(counts, vec![2, 1]);
2316 }
2317
2318 #[tokio::test]
2319 async fn test_async_stream_then_sync_batch() {
2320 let stream_task = Task::AsyncStream(Arc::new(|_input, _ctx| {
2322 let stream = futures::stream::iter(vec![5_i32, 10, 15, 20])
2323 .map(|i| Box::new(i) as Box<dyn Value>);
2324 Ok(Box::pin(stream) as ValueStream)
2325 }));
2326
2327 let sum_batch = Task::SyncBatch(Arc::new(|items: &[Box<dyn Value>], _ctx| {
2329 let sum: i32 = items
2330 .iter()
2331 .map(|item| *(**item).as_any().downcast_ref::<i32>().unwrap())
2332 .sum();
2333 Ok(Arc::new(sum) as Arc<dyn Value>)
2334 }));
2335
2336 let pipeline = Pipeline::new("async stream then sync batch")
2337 .with_batch_size(4)
2338 .with_task(stream_task)
2339 .with_task(sum_batch);
2340
2341 let inputs: Vec<Arc<dyn Value>> = vec![Arc::new(0_i32)];
2342 let ctx = stub_ctx().await;
2343
2344 let outputs = execute(&pipeline, inputs, ctx, &NoopWatcher).await.unwrap();
2345
2346 assert_eq!(outputs.len(), 1, "expected 1 batch of all 4 items");
2347 let sum = *(*outputs[0]).as_any().downcast_ref::<i32>().unwrap();
2348 assert_eq!(sum, 50);
2349 }
2350
2351 #[tokio::test]
2352 async fn test_async_stream_then_async_batch() {
2353 let stream_task = Task::AsyncStream(Arc::new(|_input, _ctx| {
2355 let stream =
2356 futures::stream::iter(vec![1_i32, 2, 3]).map(|i| Box::new(i) as Box<dyn Value>);
2357 Ok(Box::pin(stream) as ValueStream)
2358 }));
2359
2360 let product_batch = Task::AsyncBatch(Arc::new(|items: &[Box<dyn Value>], _ctx| {
2362 let product: i32 = items
2363 .iter()
2364 .map(|item| *(**item).as_any().downcast_ref::<i32>().unwrap())
2365 .product();
2366 Box::pin(async move { Ok(Arc::new(product) as Arc<dyn Value>) })
2367 }));
2368
2369 let pipeline = Pipeline::new("async stream then async batch")
2370 .with_batch_size(3)
2371 .with_task(stream_task)
2372 .with_task(product_batch);
2373
2374 let inputs: Vec<Arc<dyn Value>> = vec![Arc::new(0_i32)];
2375 let ctx = stub_ctx().await;
2376
2377 let outputs = execute(&pipeline, inputs, ctx, &NoopWatcher).await.unwrap();
2378
2379 assert_eq!(outputs.len(), 1, "expected 1 batch of all 3 items");
2380 let product = *(*outputs[0]).as_any().downcast_ref::<i32>().unwrap();
2381 assert_eq!(product, 6);
2382 }
2383
2384 #[tokio::test]
2385 async fn test_sync_iter_then_sync_iter_batch() {
2386 use crate::task::ValueIter;
2387
2388 let iter_task = Task::SyncIter(Arc::new(|_input, _ctx| {
2390 let vec = vec![1_i32, 2, 3, 4];
2391 Ok(Box::new(vec.into_iter().map(|i| Box::new(i) as Box<dyn Value>)) as ValueIter)
2392 }));
2393
2394 let double_batch = Task::SyncIterBatch(Arc::new(|items: &[Box<dyn Value>], _ctx| {
2396 let doubled: Vec<Box<dyn Value>> = items
2397 .iter()
2398 .map(|item| {
2399 let val = *(**item).as_any().downcast_ref::<i32>().unwrap();
2400 Box::new(val * 2) as Box<dyn Value>
2401 })
2402 .collect();
2403 Ok(Box::new(doubled.into_iter()) as ValueIter)
2404 }));
2405
2406 let pipeline = Pipeline::new("sync iter then sync iter batch")
2407 .with_batch_size(2)
2408 .with_task(iter_task)
2409 .with_task(double_batch);
2410
2411 let inputs: Vec<Arc<dyn Value>> = vec![Arc::new(0_i32)];
2412 let ctx = stub_ctx().await;
2413
2414 let outputs = execute(&pipeline, inputs, ctx, &NoopWatcher).await.unwrap();
2415
2416 assert_eq!(outputs.len(), 4);
2417 let values: Vec<i32> = outputs
2418 .iter()
2419 .map(|v| *(**v).as_any().downcast_ref::<i32>().unwrap())
2420 .collect();
2421 assert_eq!(values, vec![2, 4, 6, 8]);
2422 }
2423
2424 #[tokio::test]
2425 async fn test_sync_iter_then_async_stream_batch() {
2426 use crate::task::ValueIter;
2427
2428 let iter_task = Task::SyncIter(Arc::new(|_input, _ctx| {
2430 let vec = vec![10_i32, 20, 30];
2431 Ok(Box::new(vec.into_iter().map(|i| Box::new(i) as Box<dyn Value>)) as ValueIter)
2432 }));
2433
2434 let add_one_batch = Task::AsyncStreamBatch(Arc::new(|items: &[Box<dyn Value>], _ctx| {
2436 let results: Vec<Box<dyn Value>> = items
2437 .iter()
2438 .map(|item| {
2439 let val = *(**item).as_any().downcast_ref::<i32>().unwrap();
2440 Box::new(val + 1) as Box<dyn Value>
2441 })
2442 .collect();
2443 Ok(Box::pin(futures::stream::iter(results)) as ValueStream)
2444 }));
2445
2446 let pipeline = Pipeline::new("sync iter then async stream batch")
2447 .with_batch_size(3)
2448 .with_task(iter_task)
2449 .with_task(add_one_batch);
2450
2451 let inputs: Vec<Arc<dyn Value>> = vec![Arc::new(0_i32)];
2452 let ctx = stub_ctx().await;
2453
2454 let outputs = execute(&pipeline, inputs, ctx, &NoopWatcher).await.unwrap();
2455
2456 assert_eq!(outputs.len(), 3);
2457 let values: Vec<i32> = outputs
2458 .iter()
2459 .map(|v| *(**v).as_any().downcast_ref::<i32>().unwrap())
2460 .collect();
2461 assert_eq!(values, vec![11, 21, 31]);
2462 }
2463
2464 #[tokio::test]
2465 async fn test_async_stream_then_sync_iter_batch() {
2466 use crate::task::ValueIter;
2467
2468 let stream_task = Task::AsyncStream(Arc::new(|_input, _ctx| {
2470 let stream =
2471 futures::stream::iter(vec![5_i32, 10]).map(|i| Box::new(i) as Box<dyn Value>);
2472 Ok(Box::pin(stream) as ValueStream)
2473 }));
2474
2475 let triple_batch = Task::SyncIterBatch(Arc::new(|items: &[Box<dyn Value>], _ctx| {
2477 let tripled: Vec<Box<dyn Value>> = items
2478 .iter()
2479 .map(|item| {
2480 let val = *(**item).as_any().downcast_ref::<i32>().unwrap();
2481 Box::new(val * 3) as Box<dyn Value>
2482 })
2483 .collect();
2484 Ok(Box::new(tripled.into_iter()) as ValueIter)
2485 }));
2486
2487 let pipeline = Pipeline::new("async stream then sync iter batch")
2488 .with_batch_size(2)
2489 .with_task(stream_task)
2490 .with_task(triple_batch);
2491
2492 let inputs: Vec<Arc<dyn Value>> = vec![Arc::new(0_i32)];
2493 let ctx = stub_ctx().await;
2494
2495 let outputs = execute(&pipeline, inputs, ctx, &NoopWatcher).await.unwrap();
2496
2497 assert_eq!(outputs.len(), 2);
2498 let values: Vec<i32> = outputs
2499 .iter()
2500 .map(|v| *(**v).as_any().downcast_ref::<i32>().unwrap())
2501 .collect();
2502 assert_eq!(values, vec![15, 30]);
2503 }
2504
2505 #[tokio::test]
2506 async fn test_async_stream_then_async_stream_batch() {
2507 let stream_task = Task::AsyncStream(Arc::new(|_input, _ctx| {
2509 let stream =
2510 futures::stream::iter(vec![1_i32, 2, 3]).map(|i| Box::new(i) as Box<dyn Value>);
2511 Ok(Box::pin(stream) as ValueStream)
2512 }));
2513
2514 let negate_batch = Task::AsyncStreamBatch(Arc::new(|items: &[Box<dyn Value>], _ctx| {
2516 let results: Vec<Box<dyn Value>> = items
2517 .iter()
2518 .map(|item| {
2519 let val = *(**item).as_any().downcast_ref::<i32>().unwrap();
2520 Box::new(-val) as Box<dyn Value>
2521 })
2522 .collect();
2523 Ok(Box::pin(futures::stream::iter(results)) as ValueStream)
2524 }));
2525
2526 let pipeline = Pipeline::new("async stream then async stream batch")
2527 .with_batch_size(2)
2528 .with_task(stream_task)
2529 .with_task(negate_batch);
2530
2531 let inputs: Vec<Arc<dyn Value>> = vec![Arc::new(0_i32)];
2532 let ctx = stub_ctx().await;
2533
2534 let outputs = execute(&pipeline, inputs, ctx, &NoopWatcher).await.unwrap();
2535
2536 assert_eq!(outputs.len(), 3);
2537 let values: Vec<i32> = outputs
2538 .iter()
2539 .map(|v| *(**v).as_any().downcast_ref::<i32>().unwrap())
2540 .collect();
2541 assert_eq!(values, vec![-1, -2, -3]);
2542 }
2543
2544 #[tokio::test]
2545 async fn test_sync_batch_terminal() {
2546 use crate::task::ValueIter;
2547
2548 let iter_task = Task::SyncIter(Arc::new(|_input, _ctx| {
2550 let vec = vec![1_i32, 2, 3];
2551 Ok(Box::new(vec.into_iter().map(|i| Box::new(i) as Box<dyn Value>)) as ValueIter)
2552 }));
2553
2554 let sum_batch = Task::SyncBatch(Arc::new(|items: &[Box<dyn Value>], _ctx| {
2556 let sum: i32 = items
2557 .iter()
2558 .map(|item| *(**item).as_any().downcast_ref::<i32>().unwrap())
2559 .sum();
2560 Ok(Arc::new(sum) as Arc<dyn Value>)
2561 }));
2562
2563 let pipeline = Pipeline::new("sync batch terminal")
2564 .with_task(iter_task)
2565 .with_task(TaskInfo::new(sum_batch).with_batch_size(3));
2566
2567 let inputs: Vec<Arc<dyn Value>> = vec![Arc::new(0_i32)];
2568 let outputs = execute(&pipeline, inputs, stub_ctx().await, &NoopWatcher)
2569 .await
2570 .unwrap();
2571
2572 assert_eq!(outputs.len(), 1);
2573 let result = (*outputs[0]).as_any().downcast_ref::<i32>().unwrap();
2574 assert_eq!(*result, 6);
2575 }
2576
2577 #[tokio::test]
2578 async fn test_async_batch_terminal() {
2579 use crate::task::ValueIter;
2580
2581 let iter_task = Task::SyncIter(Arc::new(|_input, _ctx| {
2583 let vec = vec![10_i32, 20, 30, 40];
2584 Ok(Box::new(vec.into_iter().map(|i| Box::new(i) as Box<dyn Value>)) as ValueIter)
2585 }));
2586
2587 let max_batch = Task::AsyncBatch(Arc::new(|items: &[Box<dyn Value>], _ctx| {
2589 let max_val: i32 = items
2590 .iter()
2591 .map(|item| *(**item).as_any().downcast_ref::<i32>().unwrap())
2592 .max()
2593 .unwrap();
2594 Box::pin(async move { Ok(Arc::new(max_val) as Arc<dyn Value>) })
2595 }));
2596
2597 let pipeline = Pipeline::new("async batch terminal")
2598 .with_task(iter_task)
2599 .with_task(TaskInfo::new(max_batch).with_batch_size(2));
2600
2601 let inputs: Vec<Arc<dyn Value>> = vec![Arc::new(0_i32)];
2602 let outputs = execute(&pipeline, inputs, stub_ctx().await, &NoopWatcher)
2603 .await
2604 .unwrap();
2605
2606 assert_eq!(outputs.len(), 2);
2607 let values: Vec<i32> = outputs
2608 .iter()
2609 .map(|v| *(**v).as_any().downcast_ref::<i32>().unwrap())
2610 .collect();
2611 assert_eq!(values, vec![20, 40]);
2612 }
2613
2614 #[tokio::test]
2615 async fn test_sync_iter_batch_terminal() {
2616 use crate::task::ValueIter;
2617
2618 let iter_task = Task::SyncIter(Arc::new(|_input, _ctx| {
2620 let vec = vec![1_i32, 2, 3];
2621 Ok(Box::new(vec.into_iter().map(|i| Box::new(i) as Box<dyn Value>)) as ValueIter)
2622 }));
2623
2624 let double_batch = Task::SyncIterBatch(Arc::new(|items: &[Box<dyn Value>], _ctx| {
2626 let doubled: Vec<Box<dyn Value>> = items
2627 .iter()
2628 .map(|item| {
2629 let val = *(**item).as_any().downcast_ref::<i32>().unwrap();
2630 Box::new(val * 2) as Box<dyn Value>
2631 })
2632 .collect();
2633 Ok(Box::new(doubled.into_iter()) as ValueIter)
2634 }));
2635
2636 let pipeline = Pipeline::new("sync iter batch terminal")
2637 .with_task(iter_task)
2638 .with_task(TaskInfo::new(double_batch).with_batch_size(3));
2639
2640 let inputs: Vec<Arc<dyn Value>> = vec![Arc::new(0_i32)];
2641 let outputs = execute(&pipeline, inputs, stub_ctx().await, &NoopWatcher)
2642 .await
2643 .unwrap();
2644
2645 assert_eq!(outputs.len(), 3);
2646 let values: Vec<i32> = outputs
2647 .iter()
2648 .map(|v| *(**v).as_any().downcast_ref::<i32>().unwrap())
2649 .collect();
2650 assert_eq!(values, vec![2, 4, 6]);
2651 }
2652
2653 #[tokio::test]
2654 async fn test_async_stream_batch_terminal() {
2655 use crate::task::ValueIter;
2656
2657 let iter_task = Task::SyncIter(Arc::new(|_input, _ctx| {
2659 let vec = vec![5_i32, 10];
2660 Ok(Box::new(vec.into_iter().map(|i| Box::new(i) as Box<dyn Value>)) as ValueIter)
2661 }));
2662
2663 let negate_batch = Task::AsyncStreamBatch(Arc::new(|items: &[Box<dyn Value>], _ctx| {
2665 let negated: Vec<i32> = items
2666 .iter()
2667 .map(|item| {
2668 let val = *(**item).as_any().downcast_ref::<i32>().unwrap();
2669 -val
2670 })
2671 .collect();
2672 Ok(
2673 Box::pin(futures::stream::iter(negated).map(|i| Box::new(i) as Box<dyn Value>))
2674 as ValueStream,
2675 )
2676 }));
2677
2678 let pipeline = Pipeline::new("async stream batch terminal")
2679 .with_task(iter_task)
2680 .with_task(TaskInfo::new(negate_batch).with_batch_size(2));
2681
2682 let inputs: Vec<Arc<dyn Value>> = vec![Arc::new(0_i32)];
2683 let outputs = execute(&pipeline, inputs, stub_ctx().await, &NoopWatcher)
2684 .await
2685 .unwrap();
2686
2687 assert_eq!(outputs.len(), 2);
2688 let values: Vec<i32> = outputs
2689 .iter()
2690 .map(|v| *(**v).as_any().downcast_ref::<i32>().unwrap())
2691 .collect();
2692 assert_eq!(values, vec![-5, -10]);
2693 }
2694
2695 #[tokio::test]
2696 async fn test_sync_then_sync_iter_then_sync_batch() {
2697 use crate::task::ValueIter;
2698
2699 let double = Task::sync_typed(|x: &i32, _ctx| -> Result<Box<i32>, TaskError> {
2701 Ok(Box::new(*x * 2))
2702 });
2703
2704 let expand = Task::SyncIter(Arc::new(|input, _ctx| {
2706 let val = *(*input).as_any().downcast_ref::<i32>().unwrap();
2707 let vec: Vec<i32> = vec![val, val + 1, val + 2];
2708 Ok(Box::new(vec.into_iter().map(|i| Box::new(i) as Box<dyn Value>)) as ValueIter)
2709 }));
2710
2711 let sum_batch = Task::SyncBatch(Arc::new(|items: &[Box<dyn Value>], _ctx| {
2713 let sum: i32 = items
2714 .iter()
2715 .map(|item| *(**item).as_any().downcast_ref::<i32>().unwrap())
2716 .sum();
2717 Ok(Arc::new(sum) as Arc<dyn Value>)
2718 }));
2719
2720 let pipeline = Pipeline::new("sync -> sync_iter -> sync_batch")
2721 .with_batch_size(2)
2722 .with_task(double)
2723 .with_task(expand)
2724 .with_task(sum_batch);
2725
2726 let inputs: Vec<Arc<dyn Value>> = vec![Arc::new(5_i32)];
2727 let outputs = execute(&pipeline, inputs, stub_ctx().await, &NoopWatcher)
2728 .await
2729 .unwrap();
2730
2731 assert_eq!(outputs.len(), 2);
2735 let values: Vec<i32> = outputs
2736 .iter()
2737 .map(|v| *(**v).as_any().downcast_ref::<i32>().unwrap())
2738 .collect();
2739 assert_eq!(values, vec![21, 12]);
2740 }
2741
2742 #[tokio::test]
2743 async fn test_sync_iter_then_sync_batch_then_sync() {
2744 use crate::task::ValueIter;
2745
2746 let iter_task = Task::SyncIter(Arc::new(|_input, _ctx| {
2748 let iter = (1..=4).map(|i| Box::new(i) as Box<dyn Value>);
2749 Ok(Box::new(iter) as ValueIter)
2750 }));
2751
2752 let sum_batch = Task::SyncBatch(Arc::new(|items: &[Box<dyn Value>], _ctx| {
2754 let sum: i32 = items
2755 .iter()
2756 .map(|item| *(**item).as_any().downcast_ref::<i32>().unwrap())
2757 .sum();
2758 Ok(Arc::new(sum) as Arc<dyn Value>)
2759 }));
2760
2761 let double = Task::sync_typed(|x: &i32, _ctx| -> Result<Box<i32>, TaskError> {
2763 Ok(Box::new(*x * 2))
2764 });
2765
2766 let pipeline = Pipeline::new("sync_iter -> sync_batch -> sync")
2767 .with_batch_size(2)
2768 .with_task(iter_task)
2769 .with_task(sum_batch)
2770 .with_task(double);
2771
2772 let inputs: Vec<Arc<dyn Value>> = vec![Arc::new(0_i32)];
2773 let outputs = execute(&pipeline, inputs, stub_ctx().await, &NoopWatcher)
2774 .await
2775 .unwrap();
2776
2777 assert_eq!(outputs.len(), 2);
2781 let values: Vec<i32> = outputs
2782 .iter()
2783 .map(|v| *(**v).as_any().downcast_ref::<i32>().unwrap())
2784 .collect();
2785 assert_eq!(values, vec![6, 14]);
2786 }
2787
2788 #[tokio::test]
2789 async fn test_sync_iter_then_sync_batch_then_sync_iter() {
2790 use crate::task::ValueIter;
2791
2792 let iter_task = Task::SyncIter(Arc::new(|_input, _ctx| {
2794 let iter = (1..=3).map(|i| Box::new(i) as Box<dyn Value>);
2795 Ok(Box::new(iter) as ValueIter)
2796 }));
2797
2798 let sum_batch = Task::SyncBatch(Arc::new(|items: &[Box<dyn Value>], _ctx| {
2800 let sum: i32 = items
2801 .iter()
2802 .map(|item| *(**item).as_any().downcast_ref::<i32>().unwrap())
2803 .sum();
2804 Ok(Arc::new(sum) as Arc<dyn Value>)
2805 }));
2806
2807 let re_expand = Task::SyncIter(Arc::new(|input, _ctx| {
2809 let val = *(*input).as_any().downcast_ref::<i32>().unwrap();
2810 let iter = (0..2).map(move |i| Box::new(val + i) as Box<dyn Value>);
2811 Ok(Box::new(iter) as ValueIter)
2812 }));
2813
2814 let pipeline = Pipeline::new("sync_iter -> sync_batch -> sync_iter")
2815 .with_batch_size(3)
2816 .with_task(iter_task)
2817 .with_task(sum_batch)
2818 .with_task(re_expand);
2819
2820 let inputs: Vec<Arc<dyn Value>> = vec![Arc::new(0_i32)];
2821 let outputs = execute(&pipeline, inputs, stub_ctx().await, &NoopWatcher)
2822 .await
2823 .unwrap();
2824
2825 assert_eq!(outputs.len(), 2);
2829 let values: Vec<i32> = outputs
2830 .iter()
2831 .map(|v| *(**v).as_any().downcast_ref::<i32>().unwrap())
2832 .collect();
2833 assert_eq!(values, vec![6, 7]);
2834 }
2835
2836 #[tokio::test]
2837 async fn test_pipeline_progress_with_weights() {
2838 use crate::progress::ProgressToken;
2839 use crate::task::TaskInfo;
2840
2841 let progress = ProgressToken::new();
2842 let (_handle, token) = cancellation_pair();
2843 let db = cognee_database::connect("sqlite::memory:").await.unwrap();
2844 cognee_database::initialize(&db).await.unwrap();
2845 let ctx = Arc::new(TaskContext {
2846 thread_pool: Arc::new(StubPool),
2847 database: Arc::new(db),
2848 graph_db: Arc::new(cognee_graph::MockGraphDB::new()),
2849 vector_db: Arc::new(cognee_vector::MockVectorDB::new()),
2850 cancellation: token,
2851 progress: progress.clone(),
2852 pipeline_ctx: None,
2853 exec_status: Arc::new(NoopExecStatusManager),
2854 pipeline_watcher: None,
2855 });
2856
2857 let task1 = TaskInfo::new(Task::sync_typed(|x: &i32, ctx| {
2859 ctx.progress.set(0.5);
2860 Ok(Box::new(*x))
2861 }))
2862 .with_weight(1);
2863
2864 let task2 =
2865 TaskInfo::new(Task::sync_typed(|x: &i32, _ctx| Ok(Box::new(*x)))).with_weight(3);
2866
2867 let pipeline = Pipeline::new("progress test")
2868 .with_task(task1)
2869 .with_task(task2);
2870
2871 let inputs: Vec<Arc<dyn Value>> = vec![Arc::new(42_i32)];
2872 let _ = execute(&pipeline, inputs, ctx, &NoopWatcher).await.unwrap();
2873
2874 assert!((progress.root_fraction() - 1.0).abs() < 0.01);
2876 }
2877
2878 #[tokio::test]
2879 async fn test_pipeline_builder_typed_chain() {
2880 let t1 = TypedTask::sync(|s: &String, _| Ok(Box::new(s.len())));
2882 let t2 = TypedTask::sync(|n: &usize, _| Ok(Box::new(format!("len={n}"))));
2883
2884 let pipeline = PipelineBuilder::new_with_task("typed chain", t1)
2885 .add_task(t2)
2886 .build();
2887
2888 let inputs: Vec<Arc<dyn Value>> = vec![Arc::new("hello".to_string())];
2889 let outputs = execute(&pipeline, inputs, stub_ctx().await, &NoopWatcher)
2890 .await
2891 .unwrap();
2892
2893 assert_eq!(outputs.len(), 1);
2894 let s = (*outputs[0]).as_any().downcast_ref::<String>().unwrap();
2895 assert_eq!(s, "len=5");
2896 }
2897
2898 #[tokio::test]
2899 async fn test_pipeline_builder_config_forwarded() {
2900 let t1 = TypedTask::sync(|x: &i32, _| Ok(Box::new(*x * 2)));
2901 let pipeline = PipelineBuilder::new_with_task("cfg test", t1)
2902 .with_name("my pipeline")
2903 .with_batch_size(8)
2904 .with_concurrency(2)
2905 .build();
2906
2907 assert_eq!(pipeline.name.as_deref(), Some("my pipeline"));
2908 assert_eq!(pipeline.batch_size, 8);
2909 assert_eq!(pipeline.concurrency, 2);
2910 }
2911
2912 #[test]
2913 fn test_typed_task_into_task_info() {
2914 let typed: TypedTask<i32, i32> = TypedTask::sync(|x: &i32, _| Ok(Box::new(*x)));
2915 let info: TaskInfo = typed.into();
2916 assert!(info.name.is_none());
2918 assert!(info.batch_size.is_none());
2919 assert_eq!(info.weight, 1);
2920 }
2921
2922 #[tokio::test]
2923 async fn test_typed_task_into_untyped_pipeline() {
2924 let typed: TypedTask<i32, i32> = TypedTask::sync(|x: &i32, _| Ok(Box::new(*x + 10)));
2926 let pipeline = Pipeline::new("escape hatch").with_task(typed);
2927
2928 let inputs: Vec<Arc<dyn Value>> = vec![Arc::new(5_i32)];
2929 let outputs = execute(&pipeline, inputs, stub_ctx().await, &NoopWatcher)
2930 .await
2931 .unwrap();
2932
2933 let v = (*outputs[0]).as_any().downcast_ref::<i32>().unwrap();
2934 assert_eq!(*v, 15);
2935 }
2936}