#![allow(
clippy::unwrap_used,
clippy::expect_used,
reason = "test code — panics are acceptable failures"
)]
use std::future::Future;
use std::num::NonZeroU32;
use std::pin::Pin;
use std::sync::Arc;
use std::sync::atomic::{AtomicUsize, Ordering};
use std::time::Duration;
use async_trait::async_trait;
use cognee_core::cancellation::cancellation_pair;
use cognee_core::rate_limiter::RateLimiter;
use cognee_core::{
CpuPool, NoopExecStatusManager, NoopWatcher, Pipeline, ProgressToken, RetryDelay, RetryPolicy,
Task, TaskContext, TaskError, TaskInfo, Value, execute,
};
struct StubPool;
impl CpuPool for StubPool {
fn spawn_raw(
&self,
_task: Box<dyn FnOnce() + Send + 'static>,
) -> Pin<Box<dyn Future<Output = Result<(), cognee_core::CoreError>> + Send + 'static>> {
Box::pin(async { Ok(()) })
}
}
async fn stub_ctx() -> Arc<TaskContext> {
let db = cognee_database::connect("sqlite::memory:")
.await
.expect("in-memory SQLite always connects");
cognee_database::initialize(&db)
.await
.expect("in-memory schema init never fails");
let (_handle, token) = cancellation_pair();
Arc::new(TaskContext {
thread_pool: Arc::new(StubPool),
database: Arc::new(db),
graph_db: Arc::new(cognee_graph::MockGraphDB::new()),
vector_db: Arc::new(cognee_vector::MockVectorDB::new()),
cancellation: token,
progress: ProgressToken::new(),
pipeline_ctx: None,
exec_status: Arc::new(NoopExecStatusManager),
pipeline_watcher: None,
})
}
struct CountingLimiter {
count: Arc<AtomicUsize>,
}
impl CountingLimiter {
fn new() -> (Self, Arc<AtomicUsize>) {
let count = Arc::new(AtomicUsize::new(0));
(
Self {
count: count.clone(),
},
count,
)
}
}
#[async_trait]
impl RateLimiter for CountingLimiter {
async fn acquire(&self) {
self.count.fetch_add(1, Ordering::SeqCst);
}
}
#[tokio::test]
async fn pipeline_limiter_called_once_per_item() {
let (limiter, count) = CountingLimiter::new();
let passthrough =
Task::sync_typed(|x: &i32, _ctx| -> Result<Box<i32>, TaskError> { Ok(Box::new(*x)) });
let pipeline = Pipeline::new("count test")
.with_task(passthrough)
.with_rate_limiter(Arc::new(limiter));
let inputs: Vec<Arc<dyn Value>> = (1..=6_i32).map(|i| Arc::new(i) as Arc<dyn Value>).collect();
let ctx = stub_ctx().await;
execute(&pipeline, inputs, ctx, &NoopWatcher)
.await
.expect("pipeline must not fail");
assert_eq!(
count.load(Ordering::SeqCst),
6,
"pipeline limiter must be acquired once per item"
);
}
#[tokio::test]
async fn per_task_limiter_overrides_pipeline_limiter() {
let (pipeline_limiter, pipeline_count) = CountingLimiter::new();
let (task_limiter, task_count) = CountingLimiter::new();
let task1_raw =
Task::sync_typed(|x: &i32, _ctx| -> Result<Box<i32>, TaskError> { Ok(Box::new(*x * 10)) });
let task1 = TaskInfo::new(task1_raw).with_rate_limiter(Arc::new(task_limiter));
let task2 =
Task::sync_typed(|x: &i32, _ctx| -> Result<Box<i32>, TaskError> { Ok(Box::new(*x + 1)) });
let pipeline = Pipeline::new("override test")
.with_task(task1)
.with_task(task2)
.with_rate_limiter(Arc::new(pipeline_limiter));
let inputs: Vec<Arc<dyn Value>> = (1..=4_i32).map(|i| Arc::new(i) as Arc<dyn Value>).collect();
let ctx = stub_ctx().await;
execute(&pipeline, inputs, ctx, &NoopWatcher)
.await
.expect("pipeline must not fail");
assert_eq!(
task_count.load(Ordering::SeqCst),
4,
"per-task limiter must be acquired once per item through that task"
);
assert_eq!(
pipeline_count.load(Ordering::SeqCst),
4,
"pipeline limiter must be called for tasks without a per-task limiter"
);
}
#[tokio::test]
async fn limiter_acquired_per_retry_attempt() {
let (limiter, count) = CountingLimiter::new();
let attempt_count = Arc::new(AtomicUsize::new(0));
let attempt_count_clone = attempt_count.clone();
let flaky = Task::Sync(Arc::new(move |input: Arc<dyn Value>, _ctx| {
let attempt = attempt_count_clone.fetch_add(1, Ordering::SeqCst) + 1;
if attempt < 3 {
Err(Box::new(std::io::Error::other("transient failure")) as TaskError)
} else {
Ok(input)
}
}));
let pipeline = Pipeline::new("retry count test")
.with_task(flaky)
.with_retry(RetryPolicy::Limited {
max_attempts: NonZeroU32::new(3).expect("3 is nonzero"),
delay: RetryDelay::Constant(Duration::ZERO),
})
.with_rate_limiter(Arc::new(limiter));
let inputs: Vec<Arc<dyn Value>> = vec![Arc::new(42_i32)];
let ctx = stub_ctx().await;
execute(&pipeline, inputs, ctx, &NoopWatcher)
.await
.expect("pipeline must succeed on 3rd attempt");
assert_eq!(
count.load(Ordering::SeqCst),
3,
"limiter must be acquired once per retry attempt (3 total: 2 failures + 1 success)"
);
}
#[tokio::test]
async fn no_limiter_runs_without_throttle() {
let passthrough =
Task::sync_typed(|x: &i32, _ctx| -> Result<Box<i32>, TaskError> { Ok(Box::new(*x)) });
let pipeline = Pipeline::new("no limiter").with_task(passthrough);
let inputs: Vec<Arc<dyn Value>> = (1..=5_i32).map(|i| Arc::new(i) as Arc<dyn Value>).collect();
let ctx = stub_ctx().await;
let outputs = execute(&pipeline, inputs, ctx, &NoopWatcher)
.await
.expect("pipeline without limiter must not fail");
assert_eq!(outputs.len(), 5);
}