#![allow(clippy::expect_used, reason = "invariants are upheld by construction")]
use std::any::Any;
use std::sync::Arc;
use crate::rate_limiter::RateLimiter;
use futures::future::BoxFuture;
use futures::stream::{BoxStream, Stream, StreamExt};
use crate::task_context::TaskContext;
pub trait Value: Any + Send + Sync + 'static {
fn as_any(&self) -> &dyn Any;
fn as_any_mut(&mut self) -> &mut dyn Any;
fn into_any(self: Box<Self>) -> Box<dyn Any>;
}
impl<T: Any + Send + Sync + 'static> Value for T {
fn as_any(&self) -> &dyn Any {
self
}
fn as_any_mut(&mut self) -> &mut dyn Any {
self
}
fn into_any(self: Box<Self>) -> Box<dyn Any> {
self
}
}
pub fn downcast_value<T: Any>(value: Box<dyn Value>) -> Result<Box<T>, Box<dyn Value>> {
if value.as_any().is::<T>() {
Ok(value
.into_any()
.downcast::<T>()
.expect("downcast can't fail after is::<T>() check"))
} else {
Err(value)
}
}
pub struct Tagged<T: Value> {
inner: T,
metadata: std::collections::HashMap<String, String>,
}
impl<T: Value> Tagged<T> {
pub fn new(inner: T) -> Self {
Self {
inner,
metadata: std::collections::HashMap::new(),
}
}
pub fn with_meta(mut self, key: impl Into<String>, value: impl Into<String>) -> Self {
self.metadata.insert(key.into(), value.into());
self
}
pub fn inner(&self) -> &T {
&self.inner
}
pub fn into_inner(self) -> T {
self.inner
}
pub fn meta(&self, key: &str) -> Option<&str> {
self.metadata.get(key).map(|s| s.as_str())
}
pub fn metadata(&self) -> &std::collections::HashMap<String, String> {
&self.metadata
}
}
pub fn extract_node_set(value: &dyn Value) -> Option<&str> {
value
.as_any()
.downcast_ref::<TaggedMeta>()
.and_then(|m| m.node_set.as_deref())
}
pub struct TaggedMeta {
pub value: Arc<dyn Value>,
pub node_set: Option<String>,
}
impl TaggedMeta {
pub fn new(value: Arc<dyn Value>) -> Self {
Self {
value,
node_set: None,
}
}
pub fn with_node_set(mut self, node_set: impl Into<String>) -> Self {
self.node_set = Some(node_set.into());
self
}
}
pub type TaskError = Box<dyn std::error::Error + Send + Sync + 'static>;
pub type ValueIter = Box<dyn Iterator<Item = Box<dyn Value>> + Send + 'static>;
pub type ValueStream = BoxStream<'static, Box<dyn Value>>;
pub type SyncFn = Arc<
dyn Fn(Arc<dyn Value>, Arc<TaskContext>) -> Result<Arc<dyn Value>, TaskError> + Send + Sync,
>;
pub type AsyncFn = Arc<
dyn Fn(
Arc<dyn Value>,
Arc<TaskContext>,
) -> BoxFuture<'static, Result<Arc<dyn Value>, TaskError>>
+ Send
+ Sync,
>;
pub type SyncIterFn =
Arc<dyn Fn(Arc<dyn Value>, Arc<TaskContext>) -> Result<ValueIter, TaskError> + Send + Sync>;
pub type AsyncStreamFn =
Arc<dyn Fn(Arc<dyn Value>, Arc<TaskContext>) -> Result<ValueStream, TaskError> + Send + Sync>;
pub type SyncBatchFn = Arc<
dyn for<'a> Fn(&'a [Box<dyn Value>], Arc<TaskContext>) -> Result<Arc<dyn Value>, TaskError>
+ Send
+ Sync,
>;
pub type AsyncBatchFn = Arc<
dyn for<'a> Fn(
&'a [Box<dyn Value>],
Arc<TaskContext>,
) -> BoxFuture<'static, Result<Arc<dyn Value>, TaskError>>
+ Send
+ Sync,
>;
pub type SyncIterBatchFn = Arc<
dyn for<'a> Fn(&'a [Box<dyn Value>], Arc<TaskContext>) -> Result<ValueIter, TaskError>
+ Send
+ Sync,
>;
pub type AsyncStreamBatchFn = Arc<
dyn for<'a> Fn(&'a [Box<dyn Value>], Arc<TaskContext>) -> Result<ValueStream, TaskError>
+ Send
+ Sync,
>;
pub enum Task {
Sync(SyncFn),
Async(AsyncFn),
SyncIter(SyncIterFn),
AsyncStream(AsyncStreamFn),
SyncBatch(SyncBatchFn),
AsyncBatch(AsyncBatchFn),
SyncIterBatch(SyncIterBatchFn),
AsyncStreamBatch(AsyncStreamBatchFn),
}
impl Task {
pub fn is_batch(&self) -> bool {
matches!(
self,
Task::SyncBatch(_)
| Task::AsyncBatch(_)
| Task::SyncIterBatch(_)
| Task::AsyncStreamBatch(_)
)
}
pub fn python_task_type(&self) -> &'static str {
match self {
Task::Sync(_) | Task::SyncBatch(_) => "Function",
Task::Async(_) | Task::AsyncBatch(_) => "Coroutine",
Task::SyncIter(_) | Task::SyncIterBatch(_) => "Generator",
Task::AsyncStream(_) | Task::AsyncStreamBatch(_) => "Async Generator",
}
}
}
impl Task {
pub fn sync<F>(f: F) -> Self
where
F: Fn(Arc<dyn Value>, Arc<TaskContext>) -> Result<Arc<dyn Value>, TaskError>
+ Send
+ Sync
+ 'static,
{
Task::Sync(Arc::new(f))
}
pub fn async_fn<F>(f: F) -> Self
where
F: Fn(
Arc<dyn Value>,
Arc<TaskContext>,
) -> BoxFuture<'static, Result<Arc<dyn Value>, TaskError>>
+ Send
+ Sync
+ 'static,
{
Task::Async(Arc::new(f))
}
pub fn sync_iter<F>(f: F) -> Self
where
F: Fn(Arc<dyn Value>, Arc<TaskContext>) -> Result<ValueIter, TaskError>
+ Send
+ Sync
+ 'static,
{
Task::SyncIter(Arc::new(f))
}
pub fn async_stream<F>(f: F) -> Self
where
F: Fn(Arc<dyn Value>, Arc<TaskContext>) -> Result<ValueStream, TaskError>
+ Send
+ Sync
+ 'static,
{
Task::AsyncStream(Arc::new(f))
}
pub fn sync_batch<F>(f: F) -> Self
where
F: for<'a> Fn(&'a [Box<dyn Value>], Arc<TaskContext>) -> Result<Arc<dyn Value>, TaskError>
+ Send
+ Sync
+ 'static,
{
Task::SyncBatch(Arc::new(f))
}
pub fn async_batch<F>(f: F) -> Self
where
F: for<'a> Fn(
&'a [Box<dyn Value>],
Arc<TaskContext>,
) -> BoxFuture<'static, Result<Arc<dyn Value>, TaskError>>
+ Send
+ Sync
+ 'static,
{
Task::AsyncBatch(Arc::new(f))
}
pub fn sync_iter_batch<F>(f: F) -> Self
where
F: for<'a> Fn(&'a [Box<dyn Value>], Arc<TaskContext>) -> Result<ValueIter, TaskError>
+ Send
+ Sync
+ 'static,
{
Task::SyncIterBatch(Arc::new(f))
}
pub fn async_stream_batch<F>(f: F) -> Self
where
F: for<'a> Fn(&'a [Box<dyn Value>], Arc<TaskContext>) -> Result<ValueStream, TaskError>
+ Send
+ Sync
+ 'static,
{
Task::AsyncStreamBatch(Arc::new(f))
}
pub fn sync_typed<I, O, F>(f: F) -> Self
where
I: Value,
O: Value,
F: Fn(&I, Arc<TaskContext>) -> Result<Box<O>, TaskError> + Send + Sync + 'static,
{
Task::Sync(Arc::new(move |input: Arc<dyn Value>, ctx| {
let typed = Self::borrow_input::<I>(&input);
f(typed, ctx).map(|v| Arc::new(*v) as Arc<dyn Value>)
}))
}
pub fn async_fn_typed<I, O, F>(f: F) -> Self
where
I: Value,
O: Value,
F: Fn(&I, Arc<TaskContext>) -> BoxFuture<'static, Result<Box<O>, TaskError>>
+ Send
+ Sync
+ 'static,
{
Task::Async(Arc::new(move |input: Arc<dyn Value>, ctx| {
let typed = Self::borrow_input::<I>(&input);
let fut = f(typed, ctx);
Box::pin(async move { fut.await.map(|v| Arc::new(*v) as Arc<dyn Value>) })
}))
}
pub fn sync_iter_typed<I, O, F, Iter>(f: F) -> Self
where
I: Value,
O: Value,
F: Fn(&I, Arc<TaskContext>) -> Result<Iter, TaskError> + Send + Sync + 'static,
Iter: Iterator<Item = Box<O>> + Send + 'static,
{
Task::SyncIter(Arc::new(move |input: Arc<dyn Value>, ctx| {
let typed = Self::borrow_input::<I>(&input);
f(typed, ctx).map(|iter| Box::new(iter.map(|v| v as Box<dyn Value>)) as ValueIter)
}))
}
pub fn async_stream_typed<I, O, F, S>(f: F) -> Self
where
I: Value,
O: Value,
F: Fn(&I, Arc<TaskContext>) -> Result<S, TaskError> + Send + Sync + 'static,
S: Stream<Item = Box<O>> + Send + 'static,
{
Task::AsyncStream(Arc::new(move |input: Arc<dyn Value>, ctx| {
let typed = Self::borrow_input::<I>(&input);
f(typed, ctx).map(|s| Box::pin(s.map(|v| v as Box<dyn Value>)) as ValueStream)
}))
}
pub fn sync_batch_typed<I, O, F>(f: F) -> Self
where
I: Value,
O: Value,
F: for<'a> Fn(&'a [&'a I], Arc<TaskContext>) -> Result<Box<O>, TaskError>
+ Send
+ Sync
+ 'static,
{
Task::SyncBatch(Arc::new(move |items: &[Box<dyn Value>], ctx| {
let typed: Vec<&I> = items.iter().map(|v| Self::borrow_item::<I>(v)).collect();
f(&typed, ctx).map(|v| Arc::new(*v) as Arc<dyn Value>)
}))
}
pub fn async_batch_typed<I, O, F>(f: F) -> Self
where
I: Value,
O: Value,
F: for<'a> Fn(
&'a [&'a I],
Arc<TaskContext>,
) -> BoxFuture<'static, Result<Box<O>, TaskError>>
+ Send
+ Sync
+ 'static,
{
Task::AsyncBatch(Arc::new(move |items: &[Box<dyn Value>], ctx| {
let typed: Vec<&I> = items.iter().map(|v| Self::borrow_item::<I>(v)).collect();
let fut = f(&typed, ctx);
Box::pin(async move { fut.await.map(|v| Arc::new(*v) as Arc<dyn Value>) })
}))
}
pub fn sync_iter_batch_typed<I, O, F, Iter>(f: F) -> Self
where
I: Value,
O: Value,
F: for<'a> Fn(&'a [&'a I], Arc<TaskContext>) -> Result<Iter, TaskError>
+ Send
+ Sync
+ 'static,
Iter: Iterator<Item = Box<O>> + Send + 'static,
{
Task::SyncIterBatch(Arc::new(move |items: &[Box<dyn Value>], ctx| {
let typed: Vec<&I> = items.iter().map(|v| Self::borrow_item::<I>(v)).collect();
f(&typed, ctx).map(|iter| Box::new(iter.map(|v| v as Box<dyn Value>)) as ValueIter)
}))
}
pub fn async_stream_batch_typed<I, O, F, S>(f: F) -> Self
where
I: Value,
O: Value,
F: for<'a> Fn(&'a [&'a I], Arc<TaskContext>) -> Result<S, TaskError>
+ Send
+ Sync
+ 'static,
S: Stream<Item = Box<O>> + Send + 'static,
{
Task::AsyncStreamBatch(Arc::new(move |items: &[Box<dyn Value>], ctx| {
let typed: Vec<&I> = items.iter().map(|v| Self::borrow_item::<I>(v)).collect();
f(&typed, ctx).map(|s| Box::pin(s.map(|v| v as Box<dyn Value>)) as ValueStream)
}))
}
fn borrow_input<I: Value>(input: &Arc<dyn Value>) -> &I {
let type_name = std::any::type_name::<I>();
(**input)
.as_any()
.downcast_ref::<I>()
.unwrap_or_else(|| panic!("Task input type mismatch: expected {type_name}"))
}
fn borrow_item<I: Value>(item: &dyn Value) -> &I {
let type_name = std::any::type_name::<I>();
item.as_any()
.downcast_ref::<I>()
.unwrap_or_else(|| panic!("Batch item type mismatch: expected {type_name}"))
}
pub fn call(&self, input: Arc<dyn Value>, ctx: Arc<TaskContext>) -> TaskCall {
match self {
Task::Sync(f) => TaskCall::Sync(f(input, ctx)),
Task::Async(f) => TaskCall::Async(f(input, ctx)),
Task::SyncIter(f) => TaskCall::SyncIter(f(input, ctx)),
Task::AsyncStream(f) => TaskCall::AsyncStream(f(input, ctx)),
Task::SyncBatch(_)
| Task::AsyncBatch(_)
| Task::SyncIterBatch(_)
| Task::AsyncStreamBatch(_) => {
panic!("call() used on a batch task variant — use call_batch() instead")
}
}
}
pub fn parallel(tasks: Vec<Task>) -> Self {
let tasks = Arc::new(tasks);
Task::Async(Arc::new(move |input, ctx| {
let tasks = Arc::clone(&tasks);
Box::pin(async move {
if tasks.is_empty() {
return Ok(input);
}
let futs: Vec<_> = tasks
.iter()
.map(|t| {
let call = t.call(Arc::clone(&input), Arc::clone(&ctx));
async move {
match call {
TaskCall::Sync(result) => result,
TaskCall::Async(fut) => fut.await,
TaskCall::SyncIter(_) | TaskCall::AsyncStream(_) => {
Err("iter/stream tasks are not supported inside Task::parallel"
.into())
}
}
}
})
.collect();
let results = futures::future::join_all(futs).await;
let mut last_ok: Option<Arc<dyn Value>> = None;
for r in results {
match r {
Err(e) => return Err(e),
Ok(v) => last_ok = Some(v),
}
}
Ok(last_ok.expect("non-empty tasks guaranteed above"))
})
}))
}
pub fn call_batch(&self, items: &[Box<dyn Value>], ctx: Arc<TaskContext>) -> TaskCall {
match self {
Task::SyncBatch(f) => TaskCall::Sync(f(items, ctx)),
Task::AsyncBatch(f) => TaskCall::Async(f(items, ctx)),
Task::SyncIterBatch(f) => TaskCall::SyncIter(f(items, ctx)),
Task::AsyncStreamBatch(f) => TaskCall::AsyncStream(f(items, ctx)),
Task::Sync(_) | Task::Async(_) | Task::SyncIter(_) | Task::AsyncStream(_) => {
panic!("call_batch() used on a single-value task variant — use call() instead")
}
}
}
}
pub struct TaskInfo {
pub task: Task,
pub name: Option<String>,
pub batch_size: Option<usize>,
pub summary_template: Option<String>,
pub weight: u32,
pub enriches: bool,
pub rate_limiter: Option<Arc<dyn RateLimiter>>,
}
impl TaskInfo {
pub fn new(task: Task) -> Self {
Self {
task,
name: None,
batch_size: None,
summary_template: None,
weight: 1,
enriches: false,
rate_limiter: None,
}
}
pub fn with_name(mut self, name: impl Into<String>) -> Self {
self.name = Some(name.into());
self
}
pub fn with_batch_size(mut self, size: usize) -> Self {
assert!(size > 0, "batch_size must be > 0");
self.batch_size = Some(size);
self
}
pub fn with_summary(mut self, template: impl Into<String>) -> Self {
self.summary_template = Some(template.into());
self
}
pub fn with_weight(mut self, weight: u32) -> Self {
self.weight = weight;
self
}
pub fn with_enriches(mut self) -> Self {
self.enriches = true;
self
}
pub fn with_rate_limiter(mut self, rl: Arc<dyn RateLimiter>) -> Self {
self.rate_limiter = Some(rl);
self
}
pub fn parallel(infos: Vec<TaskInfo>) -> Self {
let names: Vec<String> = infos
.iter()
.enumerate()
.map(|(i, ti)| ti.name.clone().unwrap_or_else(|| format!("task_{i}")))
.collect();
let tasks: Vec<Task> = infos.into_iter().map(|ti| ti.task).collect();
TaskInfo {
task: Task::parallel(tasks),
name: Some(format!("parallel([{}])", names.join(", "))),
batch_size: None,
summary_template: None,
weight: 1,
enriches: false,
rate_limiter: None,
}
}
}
impl From<Task> for TaskInfo {
fn from(task: Task) -> Self {
TaskInfo::new(task)
}
}
type TypedSyncFn<I, O> = dyn Fn(&I, Arc<TaskContext>) -> Result<Box<O>, TaskError> + Send + Sync;
type TypedAsyncFn<I, O> =
dyn Fn(&I, Arc<TaskContext>) -> BoxFuture<'static, Result<Box<O>, TaskError>> + Send + Sync;
type TypedSyncIterFn<I, O> = dyn Fn(&I, Arc<TaskContext>) -> Result<Box<dyn Iterator<Item = Box<O>> + Send + 'static>, TaskError>
+ Send
+ Sync;
type TypedAsyncStreamFn<I, O> =
dyn Fn(&I, Arc<TaskContext>) -> Result<BoxStream<'static, Box<O>>, TaskError> + Send + Sync;
type TypedSyncBatchFn<I, O> =
dyn for<'a> Fn(&'a [&'a I], Arc<TaskContext>) -> Result<Box<O>, TaskError> + Send + Sync;
type TypedAsyncBatchFn<I, O> = dyn for<'a> Fn(&'a [&'a I], Arc<TaskContext>) -> BoxFuture<'static, Result<Box<O>, TaskError>>
+ Send
+ Sync;
type TypedSyncIterBatchFn<I, O> = dyn for<'a> Fn(
&'a [&'a I],
Arc<TaskContext>,
) -> Result<Box<dyn Iterator<Item = Box<O>> + Send + 'static>, TaskError>
+ Send
+ Sync;
type TypedAsyncStreamBatchFn<I, O> = dyn for<'a> Fn(&'a [&'a I], Arc<TaskContext>) -> Result<BoxStream<'static, Box<O>>, TaskError>
+ Send
+ Sync;
pub enum TypedTask<I: Value, O: Value> {
Sync(Arc<TypedSyncFn<I, O>>),
Async(Arc<TypedAsyncFn<I, O>>),
SyncIter(Arc<TypedSyncIterFn<I, O>>),
AsyncStream(Arc<TypedAsyncStreamFn<I, O>>),
SyncBatch(Arc<TypedSyncBatchFn<I, O>>),
AsyncBatch(Arc<TypedAsyncBatchFn<I, O>>),
SyncIterBatch(Arc<TypedSyncIterBatchFn<I, O>>),
AsyncStreamBatch(Arc<TypedAsyncStreamBatchFn<I, O>>),
}
impl<I: Value, O: Value> TypedTask<I, O> {
pub fn sync<F>(f: F) -> Self
where
F: Fn(&I, Arc<TaskContext>) -> Result<Box<O>, TaskError> + Send + Sync + 'static,
{
TypedTask::Sync(Arc::new(f))
}
pub fn async_fn<F>(f: F) -> Self
where
F: Fn(&I, Arc<TaskContext>) -> BoxFuture<'static, Result<Box<O>, TaskError>>
+ Send
+ Sync
+ 'static,
{
TypedTask::Async(Arc::new(f))
}
pub fn sync_iter<F, Iter>(f: F) -> Self
where
F: Fn(&I, Arc<TaskContext>) -> Result<Iter, TaskError> + Send + Sync + 'static,
Iter: Iterator<Item = Box<O>> + Send + 'static,
{
TypedTask::SyncIter(Arc::new(move |i, ctx| {
f(i, ctx)
.map(|iter| Box::new(iter) as Box<dyn Iterator<Item = Box<O>> + Send + 'static>)
}))
}
pub fn async_stream<F, S>(f: F) -> Self
where
F: Fn(&I, Arc<TaskContext>) -> Result<S, TaskError> + Send + Sync + 'static,
S: Stream<Item = Box<O>> + Send + 'static,
{
TypedTask::AsyncStream(Arc::new(move |i, ctx| {
f(i, ctx).map(|s| Box::pin(s) as BoxStream<'static, Box<O>>)
}))
}
pub fn sync_batch<F>(f: F) -> Self
where
F: for<'a> Fn(&'a [&'a I], Arc<TaskContext>) -> Result<Box<O>, TaskError>
+ Send
+ Sync
+ 'static,
{
TypedTask::SyncBatch(Arc::new(f))
}
pub fn async_batch<F>(f: F) -> Self
where
F: for<'a> Fn(
&'a [&'a I],
Arc<TaskContext>,
) -> BoxFuture<'static, Result<Box<O>, TaskError>>
+ Send
+ Sync
+ 'static,
{
TypedTask::AsyncBatch(Arc::new(f))
}
pub fn sync_iter_batch<F, Iter>(f: F) -> Self
where
F: for<'a> Fn(&'a [&'a I], Arc<TaskContext>) -> Result<Iter, TaskError>
+ Send
+ Sync
+ 'static,
Iter: Iterator<Item = Box<O>> + Send + 'static,
{
TypedTask::SyncIterBatch(Arc::new(move |items, ctx| {
f(items, ctx)
.map(|iter| Box::new(iter) as Box<dyn Iterator<Item = Box<O>> + Send + 'static>)
}))
}
pub fn async_stream_batch<F, S>(f: F) -> Self
where
F: for<'a> Fn(&'a [&'a I], Arc<TaskContext>) -> Result<S, TaskError>
+ Send
+ Sync
+ 'static,
S: Stream<Item = Box<O>> + Send + 'static,
{
TypedTask::AsyncStreamBatch(Arc::new(move |items, ctx| {
f(items, ctx).map(|s| Box::pin(s) as BoxStream<'static, Box<O>>)
}))
}
}
impl<I: Value, O: Value> From<TypedTask<I, O>> for Task {
fn from(typed: TypedTask<I, O>) -> Self {
match typed {
TypedTask::Sync(f) => Task::sync_typed(move |i: &I, ctx| f(i, ctx)),
TypedTask::Async(f) => Task::async_fn_typed(move |i: &I, ctx| f(i, ctx)),
TypedTask::SyncIter(f) => Task::sync_iter_typed(move |i: &I, ctx| f(i, ctx)),
TypedTask::AsyncStream(f) => Task::async_stream_typed(move |i: &I, ctx| f(i, ctx)),
TypedTask::SyncBatch(f) => {
Task::sync_batch_typed(move |items: &[&I], ctx| f(items, ctx))
}
TypedTask::AsyncBatch(f) => {
Task::async_batch_typed(move |items: &[&I], ctx| f(items, ctx))
}
TypedTask::SyncIterBatch(f) => {
Task::sync_iter_batch_typed(move |items: &[&I], ctx| f(items, ctx))
}
TypedTask::AsyncStreamBatch(f) => {
Task::async_stream_batch_typed(move |items: &[&I], ctx| f(items, ctx))
}
}
}
}
impl<I: Value, O: Value> From<TypedTask<I, O>> for TaskInfo {
fn from(t: TypedTask<I, O>) -> TaskInfo {
TaskInfo::new(Task::from(t))
}
}
pub enum TaskCall {
Sync(Result<Arc<dyn Value>, TaskError>),
Async(BoxFuture<'static, Result<Arc<dyn Value>, TaskError>>),
SyncIter(Result<ValueIter, TaskError>),
AsyncStream(Result<ValueStream, TaskError>),
}
#[cfg(test)]
#[allow(
clippy::unwrap_used,
clippy::expect_used,
reason = "test code — panics are acceptable failures"
)]
mod tests {
use super::*;
use std::future::Future;
use std::pin::Pin;
use crate::cancellation::cancellation_pair;
use crate::exec_status::NoopExecStatusManager;
use crate::progress::ProgressToken;
use crate::task_context::TaskContext;
use crate::thread_pool::CpuPool;
struct StubPool;
impl CpuPool for StubPool {
fn spawn_raw(
&self,
_task: Box<dyn FnOnce() + Send + 'static>,
) -> Pin<Box<dyn Future<Output = Result<(), crate::error::CoreError>> + Send + 'static>>
{
Box::pin(async { Ok(()) })
}
}
async fn stub_ctx() -> Arc<TaskContext> {
let db = cognee_database::connect("sqlite::memory:").await.unwrap();
cognee_database::initialize(&db).await.unwrap();
let (_handle, token) = cancellation_pair();
Arc::new(TaskContext {
thread_pool: Arc::new(StubPool),
database: Arc::new(db),
graph_db: Arc::new(cognee_graph::MockGraphDB::new()),
vector_db: Arc::new(cognee_vector::MockVectorDB::new()),
cancellation: token,
progress: ProgressToken::new(),
pipeline_ctx: None,
exec_status: Arc::new(NoopExecStatusManager),
pipeline_watcher: None,
})
}
#[tokio::test]
async fn parallel_runs_sync_tasks_concurrently() {
let double = Task::sync_typed(|x: &i32, _ctx| Ok(Box::new(*x * 2)));
let triple = Task::sync_typed(|x: &i32, _ctx| Ok(Box::new(*x * 3)));
let par = Task::parallel(vec![double, triple]);
let input: Arc<dyn Value> = Arc::new(5_i32);
let ctx = stub_ctx().await;
let call = par.call(input, ctx);
let result = match call {
TaskCall::Async(fut) => fut.await.unwrap(),
_ => panic!("parallel should produce Async variant"),
};
assert_eq!(*(*result).as_any().downcast_ref::<i32>().unwrap(), 15);
}
#[tokio::test]
async fn parallel_runs_async_tasks() {
let add_ten = Task::async_fn_typed(|x: &i32, _ctx| {
let v = *x + 10;
Box::pin(async move { Ok(Box::new(v)) })
});
let add_twenty = Task::async_fn_typed(|x: &i32, _ctx| {
let v = *x + 20;
Box::pin(async move { Ok(Box::new(v)) })
});
let par = Task::parallel(vec![add_ten, add_twenty]);
let input: Arc<dyn Value> = Arc::new(100_i32);
let ctx = stub_ctx().await;
let result = match par.call(input, ctx) {
TaskCall::Async(fut) => fut.await.unwrap(),
_ => panic!("expected Async"),
};
assert_eq!(*(*result).as_any().downcast_ref::<i32>().unwrap(), 120);
}
#[tokio::test]
async fn parallel_propagates_first_error() {
let ok_task = Task::sync_typed(|x: &i32, _ctx| Ok(Box::new(*x)));
let err_task = Task::Sync(Arc::new(|_input, _ctx| Err("boom".into())));
let par = Task::parallel(vec![ok_task, err_task]);
let input: Arc<dyn Value> = Arc::new(42_i32);
let ctx = stub_ctx().await;
let result = match par.call(input, ctx) {
TaskCall::Async(fut) => fut.await,
_ => panic!("expected Async"),
};
let err = result.err().expect("should be an error");
assert!(err.to_string().contains("boom"));
}
#[tokio::test]
async fn parallel_empty_returns_input() {
let par = Task::parallel(vec![]);
let input: Arc<dyn Value> = Arc::new(99_i32);
let ctx = stub_ctx().await;
let result = match par.call(Arc::clone(&input), ctx) {
TaskCall::Async(fut) => fut.await.unwrap(),
_ => panic!("expected Async"),
};
assert_eq!(*(*result).as_any().downcast_ref::<i32>().unwrap(), 99);
}
#[tokio::test]
async fn test_typed_task_panics_on_type_mismatch() {
use std::panic::{AssertUnwindSafe, catch_unwind};
let task = Task::sync_typed(|_x: &String, _ctx| Ok(Box::new("ok".to_string())));
let input: Arc<dyn Value> = Arc::new(42_i32); let ctx = stub_ctx().await;
let result = catch_unwind(AssertUnwindSafe(|| task.call(input, ctx)));
let err = match result {
Err(e) => e,
Ok(_) => panic!("should have panicked on type mismatch"),
};
let msg = err
.downcast_ref::<String>()
.map(|s| s.as_str())
.or_else(|| err.downcast_ref::<&str>().copied())
.expect("panic payload should be a string");
assert!(
msg.contains("type mismatch"),
"expected 'type mismatch' in panic message, got: {msg}"
);
}
#[test]
fn test_taskinfo_weight_default() {
let info = TaskInfo::new(Task::sync_typed(|_: &i32, _| Ok(Box::new(0_i32))));
assert_eq!(info.weight, 1);
}
#[test]
fn test_taskinfo_with_weight() {
let info = TaskInfo::new(Task::sync_typed(|_: &i32, _| Ok(Box::new(0_i32)))).with_weight(5);
assert_eq!(info.weight, 5);
}
#[test]
fn task_info_parallel_generates_name() {
let t1 =
TaskInfo::new(Task::sync_typed(|_: &i32, _| Ok(Box::new(0_i32)))).with_name("classify");
let t2 =
TaskInfo::new(Task::sync_typed(|_: &i32, _| Ok(Box::new(0_i32)))).with_name("embed");
let t3 = TaskInfo::new(Task::sync_typed(|_: &i32, _| Ok(Box::new(0_i32))));
let par = TaskInfo::parallel(vec![t1, t2, t3]);
assert_eq!(
par.name.as_deref(),
Some("parallel([classify, embed, task_2])")
);
}
mod python_task_type {
use super::*;
use futures::stream;
#[test]
fn sync_variant_maps_to_function() {
let t = Task::sync_typed(|_: &i32, _| Ok(Box::new(0_i32)));
assert_eq!(t.python_task_type(), "Function");
}
#[test]
fn sync_batch_variant_maps_to_function() {
let t = Task::sync_batch_typed(|_: &[&i32], _| Ok(Box::new(0_i32)));
assert_eq!(t.python_task_type(), "Function");
}
#[test]
fn async_variant_maps_to_coroutine() {
let t = Task::async_fn_typed(|_: &i32, _| Box::pin(async move { Ok(Box::new(0_i32)) }));
assert_eq!(t.python_task_type(), "Coroutine");
}
#[test]
fn async_batch_variant_maps_to_coroutine() {
let t = Task::async_batch_typed(|_: &[&i32], _| {
Box::pin(async move { Ok(Box::new(0_i32)) })
});
assert_eq!(t.python_task_type(), "Coroutine");
}
#[test]
fn sync_iter_variant_maps_to_generator() {
let t = Task::sync_iter_typed(|_: &i32, _| Ok(std::iter::empty::<Box<i32>>()));
assert_eq!(t.python_task_type(), "Generator");
}
#[test]
fn sync_iter_batch_variant_maps_to_generator() {
let t = Task::sync_iter_batch_typed(|_: &[&i32], _| Ok(std::iter::empty::<Box<i32>>()));
assert_eq!(t.python_task_type(), "Generator");
}
#[test]
fn async_stream_variant_maps_to_async_generator() {
let t = Task::async_stream_typed(|_: &i32, _| Ok(stream::empty::<Box<i32>>()));
assert_eq!(t.python_task_type(), "Async Generator");
}
#[test]
fn async_stream_batch_variant_maps_to_async_generator() {
let t = Task::async_stream_batch_typed(|_: &[&i32], _| Ok(stream::empty::<Box<i32>>()));
assert_eq!(t.python_task_type(), "Async Generator");
}
#[test]
fn covers_all_eight_variants_with_four_distinct_labels() {
let labels: std::collections::HashSet<&'static str> = [
Task::sync_typed(|_: &i32, _| Ok(Box::new(0_i32))).python_task_type(),
Task::sync_batch_typed(|_: &[&i32], _| Ok(Box::new(0_i32))).python_task_type(),
Task::async_fn_typed(|_: &i32, _| Box::pin(async move { Ok(Box::new(0_i32)) }))
.python_task_type(),
Task::async_batch_typed(|_: &[&i32], _| {
Box::pin(async move { Ok(Box::new(0_i32)) })
})
.python_task_type(),
Task::sync_iter_typed(|_: &i32, _| Ok(std::iter::empty::<Box<i32>>()))
.python_task_type(),
Task::sync_iter_batch_typed(|_: &[&i32], _| Ok(std::iter::empty::<Box<i32>>()))
.python_task_type(),
Task::async_stream_typed(|_: &i32, _| Ok(stream::empty::<Box<i32>>()))
.python_task_type(),
Task::async_stream_batch_typed(|_: &[&i32], _| Ok(stream::empty::<Box<i32>>()))
.python_task_type(),
]
.into_iter()
.collect();
assert_eq!(
labels.len(),
4,
"expected exactly 4 distinct Python task-type labels, got {labels:?}"
);
assert!(labels.contains("Function"));
assert!(labels.contains("Coroutine"));
assert!(labels.contains("Generator"));
assert!(labels.contains("Async Generator"));
}
}
}