Skip to main content

cognee_core/
task_context.rs

1// expect() guards a documented invariant: pipeline() must only be called from
2// tasks running inside a pipeline executor.
3#![allow(
4    clippy::expect_used,
5    reason = "caller is documented to only invoke this within a pipeline executor"
6)]
7
8use std::collections::HashSet;
9use std::sync::{Arc, Mutex};
10
11use cognee_database::DatabaseConnection;
12use cognee_graph::GraphDBTrait;
13use cognee_vector::VectorDB;
14use uuid::Uuid;
15
16use crate::{
17    cancellation::{CancellationHandle, CancellationToken, cancellation_pair},
18    error::CoreError,
19    exec_status::{ExecStatusManager, NoopExecStatusManager},
20    pipeline::PipelineWatcher,
21    progress::ProgressToken,
22    task::Value,
23    thread_pool::CpuPool,
24};
25/// Identity of the running pipeline and the data item being processed.
26///
27/// Tasks that need attribution metadata (user, dataset, current data item)
28/// read this from [`TaskContext::pipeline_ctx`].
29#[derive(Clone)]
30pub struct PipelineContext {
31    /// Unique ID of this pipeline run (matches [`Pipeline::id`]).
32    pub pipeline_id: Uuid,
33    /// Human-readable pipeline name.
34    pub pipeline_name: String,
35    /// Owner / tenant executing the pipeline.
36    pub user_id: Option<Uuid>,
37    /// Tenant the pipeline run belongs to. `None` for single-user
38    /// deployments — telemetry emitters substitute the literal
39    /// `"Single User Tenant"` to match Python's behaviour.
40    pub tenant_id: Option<Uuid>,
41    /// Dataset being processed.
42    pub dataset_id: Option<Uuid>,
43    /// The data item currently being processed.
44    /// Set per-item by the executor before calling a task.
45    pub current_data: Option<Arc<dyn Value>>,
46    /// Random per-invocation run id. Set by [`crate::pipeline::execute`] when
47    /// it creates `PipelineRunInfo`. Used by tasks (via
48    /// [`TaskContext::publish_payload_field`]) to attribute payload events.
49    /// `None` when the task is not running inside `execute()`.
50    pub run_id: Option<Uuid>,
51    /// Email of the user running the pipeline, if known. Used by the
52    /// provenance-stamping algorithm to populate
53    /// `DataPoint.source_user`. Mirrors Python's `user.email`.
54    /// Resolution priority is captured by [`PipelineContext::user_label`].
55    pub user_email: Option<String>,
56    /// DataPoints already stamped during this pipeline run, keyed on
57    /// their UUID. Shared across all tasks via the per-run
58    /// `PipelineContext` so a DataPoint that survives multiple tasks
59    /// is stamped exactly once — with the **first** task's name.
60    /// Mirrors Python's `PipelineContext._provenance_visited`.
61    pub provenance_visited: Arc<Mutex<HashSet<Uuid>>>,
62}
63
64impl PipelineContext {
65    /// Resolved label used as `DataPoint.source_user` by the
66    /// provenance-stamping algorithm.
67    ///
68    /// Priority order (matches Python's `user.email or str(user.id)`,
69    /// locked decision 4):
70    ///
71    /// 1. `user_email` if set.
72    /// 2. Else `user_id.to_string()` if set.
73    /// 3. Else `None` (the DP keeps its own value, or stays unstamped).
74    pub fn user_label(&self) -> Option<String> {
75        self.user_email
76            .clone()
77            .or_else(|| self.user_id.map(|id| id.to_string()))
78    }
79}
80/// Runtime dependencies and control tokens for a single pipeline task.
81///
82/// Build via [`TaskContextBuilder`].
83pub struct TaskContext {
84    /// CPU-bound work executor (wraps a Rayon pool by default).
85    pub thread_pool: Arc<dyn CpuPool>,
86    /// Relational / metadata database connection.
87    pub database: Arc<DatabaseConnection>,
88    /// Graph database.
89    pub graph_db: Arc<dyn GraphDBTrait>,
90    /// Vector database.
91    pub vector_db: Arc<dyn VectorDB>,
92    /// Token the task checks to detect cancellation requests.
93    pub cancellation: CancellationToken,
94    /// Token the task uses to report progress.
95    pub progress: ProgressToken,
96    /// Pipeline run identity and current data item context.
97    pub pipeline_ctx: Option<PipelineContext>,
98    /// Per-item incremental status tracker (deduplication / resume).
99    pub exec_status: Arc<dyn ExecStatusManager>,
100    /// Optional pipeline watcher injected by the registry.
101    ///
102    /// When set, the pipeline executor routes lifecycle events here in addition
103    /// to (or instead of) any watcher passed directly to `execute()`. Set by
104    /// `PipelineRunRegistry` so library functions can publish events without
105    /// knowing about the registry.
106    pub pipeline_watcher: Option<Arc<dyn PipelineWatcher>>,
107}
108
109impl TaskContext {
110    /// Convenience accessor for the pipeline context.
111    ///
112    /// Panics if the context was not set — only call this from tasks that are
113    /// known to run inside a pipeline executor.
114    pub fn pipeline(&self) -> &PipelineContext {
115        self.pipeline_ctx
116            .as_ref()
117            .expect("PipelineContext not set — task is not running inside a pipeline executor")
118    }
119
120    /// Create a new `Arc<TaskContext>` with a different progress token.
121    /// All other fields are shallow-cloned.
122    pub fn with_progress(self: &Arc<Self>, progress: ProgressToken) -> Arc<Self> {
123        Arc::new(TaskContext {
124            thread_pool: Arc::clone(&self.thread_pool),
125            database: Arc::clone(&self.database),
126            graph_db: Arc::clone(&self.graph_db),
127            vector_db: Arc::clone(&self.vector_db),
128            cancellation: self.cancellation.clone(),
129            progress,
130            pipeline_ctx: self.pipeline_ctx.clone(),
131            exec_status: Arc::clone(&self.exec_status),
132            pipeline_watcher: self.pipeline_watcher.clone(),
133        })
134    }
135
136    /// Create a new `Arc<TaskContext>` with `current_data` set on the pipeline
137    /// context. All `Arc` fields are shallow-cloned (cheap reference bumps).
138    ///
139    /// Returns the original `Arc` unchanged if no `pipeline_ctx` is present.
140    pub fn with_current_data(self: &Arc<Self>, data: Arc<dyn Value>) -> Arc<Self> {
141        let mut pipeline_ctx = match &self.pipeline_ctx {
142            Some(ctx) => ctx.clone(),
143            None => return Arc::clone(self),
144        };
145        pipeline_ctx.current_data = Some(data);
146        Arc::new(TaskContext {
147            thread_pool: Arc::clone(&self.thread_pool),
148            database: Arc::clone(&self.database),
149            graph_db: Arc::clone(&self.graph_db),
150            vector_db: Arc::clone(&self.vector_db),
151            cancellation: self.cancellation.clone(),
152            progress: self.progress.clone(),
153            pipeline_ctx: Some(pipeline_ctx),
154            exec_status: Arc::clone(&self.exec_status),
155            pipeline_watcher: self.pipeline_watcher.clone(),
156        })
157    }
158
159    /// Create a new `Arc<TaskContext>` with `user_email` set on the pipeline
160    /// context. All `Arc` fields are shallow-cloned.
161    ///
162    /// Returns the original `Arc` unchanged if no `pipeline_ctx` is present.
163    pub fn with_user_email(self: &Arc<Self>, email: String) -> Arc<Self> {
164        let mut pipeline_ctx = match &self.pipeline_ctx {
165            Some(ctx) => ctx.clone(),
166            None => return Arc::clone(self),
167        };
168        pipeline_ctx.user_email = Some(email);
169        Arc::new(TaskContext {
170            thread_pool: Arc::clone(&self.thread_pool),
171            database: Arc::clone(&self.database),
172            graph_db: Arc::clone(&self.graph_db),
173            vector_db: Arc::clone(&self.vector_db),
174            cancellation: self.cancellation.clone(),
175            progress: self.progress.clone(),
176            pipeline_ctx: Some(pipeline_ctx),
177            exec_status: Arc::clone(&self.exec_status),
178            pipeline_watcher: self.pipeline_watcher.clone(),
179        })
180    }
181
182    /// Create a new `Arc<TaskContext>` with `run_id` set on the pipeline
183    /// context. All other fields are shallow-cloned.
184    ///
185    /// Returns the original `Arc` unchanged if no `pipeline_ctx` is present.
186    pub fn with_run_id(self: &Arc<Self>, run_id: Uuid) -> Arc<Self> {
187        let mut pipeline_ctx = match &self.pipeline_ctx {
188            Some(ctx) => ctx.clone(),
189            None => return Arc::clone(self),
190        };
191        pipeline_ctx.run_id = Some(run_id);
192        Arc::new(TaskContext {
193            thread_pool: Arc::clone(&self.thread_pool),
194            database: Arc::clone(&self.database),
195            graph_db: Arc::clone(&self.graph_db),
196            vector_db: Arc::clone(&self.vector_db),
197            cancellation: self.cancellation.clone(),
198            progress: self.progress.clone(),
199            pipeline_ctx: Some(pipeline_ctx),
200            exec_status: Arc::clone(&self.exec_status),
201            pipeline_watcher: self.pipeline_watcher.clone(),
202        })
203    }
204
205    /// Publish a run-scoped payload field. Tasks running inside
206    /// [`crate::pipeline::execute`] call this to attach metadata that downstream
207    /// observers read via the registry's payload accumulator.
208    ///
209    /// Silently no-ops if no `pipeline_watcher` is attached or if
210    /// `pipeline_ctx.run_id` was never set (i.e. the task is not running
211    /// inside `execute()`).
212    pub async fn publish_payload_field(&self, key: &str, value: serde_json::Value) {
213        let Some(w) = self.pipeline_watcher.as_ref() else {
214            return;
215        };
216        let Some(pctx) = self.pipeline_ctx.as_ref() else {
217            return;
218        };
219        let Some(run_id) = pctx.run_id else {
220            return;
221        };
222        w.on_payload_field(run_id, key, value).await;
223    }
224}
225/// Fluent builder for [`TaskContext`].
226///
227/// ```rust,ignore
228/// let (handle, ctx) = TaskContextBuilder::new()
229///     .thread_pool(Arc::new(RayonThreadPool::with_default_threads()?))
230///     .database(db)
231///     .graph_db(graph)
232///     .vector_db(vectors)
233///     .progress(ProgressToken::new())
234///     .build()?;
235/// ```
236#[derive(Default)]
237pub struct TaskContextBuilder {
238    thread_pool: Option<Arc<dyn CpuPool>>,
239    database: Option<Arc<DatabaseConnection>>,
240    graph_db: Option<Arc<dyn GraphDBTrait>>,
241    vector_db: Option<Arc<dyn VectorDB>>,
242    /// If set, the cancellation pair is created from an external handle.
243    cancellation: Option<(CancellationHandle, CancellationToken)>,
244    progress: Option<ProgressToken>,
245    pipeline_ctx: Option<PipelineContext>,
246    exec_status: Option<Arc<dyn ExecStatusManager>>,
247    pipeline_watcher: Option<Arc<dyn PipelineWatcher>>,
248}
249
250impl TaskContextBuilder {
251    pub fn new() -> Self {
252        Self::default()
253    }
254
255    pub fn thread_pool(mut self, pool: Arc<dyn CpuPool>) -> Self {
256        self.thread_pool = Some(pool);
257        self
258    }
259
260    pub fn database(mut self, db: Arc<DatabaseConnection>) -> Self {
261        self.database = Some(db);
262        self
263    }
264
265    pub fn graph_db(mut self, graph: Arc<dyn GraphDBTrait>) -> Self {
266        self.graph_db = Some(graph);
267        self
268    }
269
270    pub fn vector_db(mut self, vectors: Arc<dyn VectorDB>) -> Self {
271        self.vector_db = Some(vectors);
272        self
273    }
274
275    /// Set a pre-built progress token. Defaults to a fresh root token.
276    pub fn progress(mut self, token: ProgressToken) -> Self {
277        self.progress = Some(token);
278        self
279    }
280
281    /// Set pipeline run identity context.
282    pub fn pipeline_context(mut self, ctx: PipelineContext) -> Self {
283        self.pipeline_ctx = Some(ctx);
284        self
285    }
286
287    /// Set the per-item status manager for incremental deduplication.
288    /// Defaults to [`NoopExecStatusManager`] if not set.
289    pub fn exec_status(mut self, mgr: Arc<dyn ExecStatusManager>) -> Self {
290        self.exec_status = Some(mgr);
291        self
292    }
293
294    /// Inject a pipeline watcher into the context.
295    ///
296    /// When set, the registry's `ScopedRunWatcher` is stored here so that
297    /// library functions can publish lifecycle events without needing to know
298    /// about the registry. Defaults to `None` (no watcher).
299    pub fn pipeline_watcher(mut self, w: Arc<dyn PipelineWatcher>) -> Self {
300        self.pipeline_watcher = Some(w);
301        self
302    }
303
304    /// Build the context. Returns `(CancellationHandle, TaskContext)` so the
305    /// caller keeps the handle while the task receives the token.
306    pub fn build(self) -> Result<(CancellationHandle, TaskContext), CoreError> {
307        let thread_pool = self.thread_pool.ok_or(CoreError::MissingContextField {
308            field: "thread_pool",
309        })?;
310        let database = self
311            .database
312            .ok_or(CoreError::MissingContextField { field: "database" })?;
313        let graph_db = self
314            .graph_db
315            .ok_or(CoreError::MissingContextField { field: "graph_db" })?;
316        let vector_db = self
317            .vector_db
318            .ok_or(CoreError::MissingContextField { field: "vector_db" })?;
319
320        let (handle, token) = self.cancellation.unwrap_or_else(cancellation_pair);
321
322        let ctx = TaskContext {
323            thread_pool,
324            database,
325            graph_db,
326            vector_db,
327            cancellation: token,
328            progress: self.progress.unwrap_or_default(),
329            pipeline_ctx: self.pipeline_ctx,
330            exec_status: self
331                .exec_status
332                .unwrap_or_else(|| Arc::new(NoopExecStatusManager)),
333            pipeline_watcher: self.pipeline_watcher,
334        };
335
336        Ok((handle, ctx))
337    }
338}
339
340#[cfg(test)]
341mod tests {
342    use super::*;
343
344    #[test]
345    fn user_label_prefers_email() {
346        let ctx = PipelineContext {
347            pipeline_id: Uuid::new_v4(),
348            pipeline_name: "test".into(),
349            user_id: Some(Uuid::new_v4()),
350            tenant_id: None,
351            dataset_id: None,
352            current_data: None,
353            run_id: None,
354            user_email: Some("alice@example.com".into()),
355            provenance_visited: Arc::new(Mutex::new(HashSet::new())),
356        };
357        assert_eq!(ctx.user_label().as_deref(), Some("alice@example.com"));
358    }
359
360    #[test]
361    fn user_label_falls_back_to_user_id() {
362        let uid = Uuid::new_v4();
363        let ctx = PipelineContext {
364            pipeline_id: Uuid::new_v4(),
365            pipeline_name: "test".into(),
366            user_id: Some(uid),
367            tenant_id: None,
368            dataset_id: None,
369            current_data: None,
370            run_id: None,
371            user_email: None,
372            provenance_visited: Arc::new(Mutex::new(HashSet::new())),
373        };
374        assert_eq!(ctx.user_label(), Some(uid.to_string()));
375    }
376
377    #[test]
378    fn user_label_is_none_when_neither_set() {
379        let ctx = PipelineContext {
380            pipeline_id: Uuid::new_v4(),
381            pipeline_name: "test".into(),
382            user_id: None,
383            tenant_id: None,
384            dataset_id: None,
385            current_data: None,
386            run_id: None,
387            user_email: None,
388            provenance_visited: Arc::new(Mutex::new(HashSet::new())),
389        };
390        assert!(ctx.user_label().is_none());
391    }
392}