1#![allow(
4 clippy::expect_used,
5 reason = "caller is documented to only invoke this within a pipeline executor"
6)]
7
8use std::collections::HashSet;
9use std::sync::{Arc, Mutex};
10
11use cognee_database::DatabaseConnection;
12use cognee_graph::GraphDBTrait;
13use cognee_vector::VectorDB;
14use uuid::Uuid;
15
16use crate::{
17 cancellation::{CancellationHandle, CancellationToken, cancellation_pair},
18 error::CoreError,
19 exec_status::{ExecStatusManager, NoopExecStatusManager},
20 pipeline::PipelineWatcher,
21 progress::ProgressToken,
22 task::Value,
23 thread_pool::CpuPool,
24};
25#[derive(Clone)]
30pub struct PipelineContext {
31 pub pipeline_id: Uuid,
33 pub pipeline_name: String,
35 pub user_id: Option<Uuid>,
37 pub tenant_id: Option<Uuid>,
41 pub dataset_id: Option<Uuid>,
43 pub current_data: Option<Arc<dyn Value>>,
46 pub run_id: Option<Uuid>,
51 pub user_email: Option<String>,
56 pub provenance_visited: Arc<Mutex<HashSet<Uuid>>>,
62}
63
64impl PipelineContext {
65 pub fn user_label(&self) -> Option<String> {
75 self.user_email
76 .clone()
77 .or_else(|| self.user_id.map(|id| id.to_string()))
78 }
79}
80pub struct TaskContext {
84 pub thread_pool: Arc<dyn CpuPool>,
86 pub database: Arc<DatabaseConnection>,
88 pub graph_db: Arc<dyn GraphDBTrait>,
90 pub vector_db: Arc<dyn VectorDB>,
92 pub cancellation: CancellationToken,
94 pub progress: ProgressToken,
96 pub pipeline_ctx: Option<PipelineContext>,
98 pub exec_status: Arc<dyn ExecStatusManager>,
100 pub pipeline_watcher: Option<Arc<dyn PipelineWatcher>>,
107}
108
109impl TaskContext {
110 pub fn pipeline(&self) -> &PipelineContext {
115 self.pipeline_ctx
116 .as_ref()
117 .expect("PipelineContext not set — task is not running inside a pipeline executor")
118 }
119
120 pub fn with_progress(self: &Arc<Self>, progress: ProgressToken) -> Arc<Self> {
123 Arc::new(TaskContext {
124 thread_pool: Arc::clone(&self.thread_pool),
125 database: Arc::clone(&self.database),
126 graph_db: Arc::clone(&self.graph_db),
127 vector_db: Arc::clone(&self.vector_db),
128 cancellation: self.cancellation.clone(),
129 progress,
130 pipeline_ctx: self.pipeline_ctx.clone(),
131 exec_status: Arc::clone(&self.exec_status),
132 pipeline_watcher: self.pipeline_watcher.clone(),
133 })
134 }
135
136 pub fn with_current_data(self: &Arc<Self>, data: Arc<dyn Value>) -> Arc<Self> {
141 let mut pipeline_ctx = match &self.pipeline_ctx {
142 Some(ctx) => ctx.clone(),
143 None => return Arc::clone(self),
144 };
145 pipeline_ctx.current_data = Some(data);
146 Arc::new(TaskContext {
147 thread_pool: Arc::clone(&self.thread_pool),
148 database: Arc::clone(&self.database),
149 graph_db: Arc::clone(&self.graph_db),
150 vector_db: Arc::clone(&self.vector_db),
151 cancellation: self.cancellation.clone(),
152 progress: self.progress.clone(),
153 pipeline_ctx: Some(pipeline_ctx),
154 exec_status: Arc::clone(&self.exec_status),
155 pipeline_watcher: self.pipeline_watcher.clone(),
156 })
157 }
158
159 pub fn with_user_email(self: &Arc<Self>, email: String) -> Arc<Self> {
164 let mut pipeline_ctx = match &self.pipeline_ctx {
165 Some(ctx) => ctx.clone(),
166 None => return Arc::clone(self),
167 };
168 pipeline_ctx.user_email = Some(email);
169 Arc::new(TaskContext {
170 thread_pool: Arc::clone(&self.thread_pool),
171 database: Arc::clone(&self.database),
172 graph_db: Arc::clone(&self.graph_db),
173 vector_db: Arc::clone(&self.vector_db),
174 cancellation: self.cancellation.clone(),
175 progress: self.progress.clone(),
176 pipeline_ctx: Some(pipeline_ctx),
177 exec_status: Arc::clone(&self.exec_status),
178 pipeline_watcher: self.pipeline_watcher.clone(),
179 })
180 }
181
182 pub fn with_run_id(self: &Arc<Self>, run_id: Uuid) -> Arc<Self> {
187 let mut pipeline_ctx = match &self.pipeline_ctx {
188 Some(ctx) => ctx.clone(),
189 None => return Arc::clone(self),
190 };
191 pipeline_ctx.run_id = Some(run_id);
192 Arc::new(TaskContext {
193 thread_pool: Arc::clone(&self.thread_pool),
194 database: Arc::clone(&self.database),
195 graph_db: Arc::clone(&self.graph_db),
196 vector_db: Arc::clone(&self.vector_db),
197 cancellation: self.cancellation.clone(),
198 progress: self.progress.clone(),
199 pipeline_ctx: Some(pipeline_ctx),
200 exec_status: Arc::clone(&self.exec_status),
201 pipeline_watcher: self.pipeline_watcher.clone(),
202 })
203 }
204
205 pub async fn publish_payload_field(&self, key: &str, value: serde_json::Value) {
213 let Some(w) = self.pipeline_watcher.as_ref() else {
214 return;
215 };
216 let Some(pctx) = self.pipeline_ctx.as_ref() else {
217 return;
218 };
219 let Some(run_id) = pctx.run_id else {
220 return;
221 };
222 w.on_payload_field(run_id, key, value).await;
223 }
224}
225#[derive(Default)]
237pub struct TaskContextBuilder {
238 thread_pool: Option<Arc<dyn CpuPool>>,
239 database: Option<Arc<DatabaseConnection>>,
240 graph_db: Option<Arc<dyn GraphDBTrait>>,
241 vector_db: Option<Arc<dyn VectorDB>>,
242 cancellation: Option<(CancellationHandle, CancellationToken)>,
244 progress: Option<ProgressToken>,
245 pipeline_ctx: Option<PipelineContext>,
246 exec_status: Option<Arc<dyn ExecStatusManager>>,
247 pipeline_watcher: Option<Arc<dyn PipelineWatcher>>,
248}
249
250impl TaskContextBuilder {
251 pub fn new() -> Self {
252 Self::default()
253 }
254
255 pub fn thread_pool(mut self, pool: Arc<dyn CpuPool>) -> Self {
256 self.thread_pool = Some(pool);
257 self
258 }
259
260 pub fn database(mut self, db: Arc<DatabaseConnection>) -> Self {
261 self.database = Some(db);
262 self
263 }
264
265 pub fn graph_db(mut self, graph: Arc<dyn GraphDBTrait>) -> Self {
266 self.graph_db = Some(graph);
267 self
268 }
269
270 pub fn vector_db(mut self, vectors: Arc<dyn VectorDB>) -> Self {
271 self.vector_db = Some(vectors);
272 self
273 }
274
275 pub fn progress(mut self, token: ProgressToken) -> Self {
277 self.progress = Some(token);
278 self
279 }
280
281 pub fn pipeline_context(mut self, ctx: PipelineContext) -> Self {
283 self.pipeline_ctx = Some(ctx);
284 self
285 }
286
287 pub fn exec_status(mut self, mgr: Arc<dyn ExecStatusManager>) -> Self {
290 self.exec_status = Some(mgr);
291 self
292 }
293
294 pub fn pipeline_watcher(mut self, w: Arc<dyn PipelineWatcher>) -> Self {
300 self.pipeline_watcher = Some(w);
301 self
302 }
303
304 pub fn build(self) -> Result<(CancellationHandle, TaskContext), CoreError> {
307 let thread_pool = self.thread_pool.ok_or(CoreError::MissingContextField {
308 field: "thread_pool",
309 })?;
310 let database = self
311 .database
312 .ok_or(CoreError::MissingContextField { field: "database" })?;
313 let graph_db = self
314 .graph_db
315 .ok_or(CoreError::MissingContextField { field: "graph_db" })?;
316 let vector_db = self
317 .vector_db
318 .ok_or(CoreError::MissingContextField { field: "vector_db" })?;
319
320 let (handle, token) = self.cancellation.unwrap_or_else(cancellation_pair);
321
322 let ctx = TaskContext {
323 thread_pool,
324 database,
325 graph_db,
326 vector_db,
327 cancellation: token,
328 progress: self.progress.unwrap_or_default(),
329 pipeline_ctx: self.pipeline_ctx,
330 exec_status: self
331 .exec_status
332 .unwrap_or_else(|| Arc::new(NoopExecStatusManager)),
333 pipeline_watcher: self.pipeline_watcher,
334 };
335
336 Ok((handle, ctx))
337 }
338}
339
340#[cfg(test)]
341mod tests {
342 use super::*;
343
344 #[test]
345 fn user_label_prefers_email() {
346 let ctx = PipelineContext {
347 pipeline_id: Uuid::new_v4(),
348 pipeline_name: "test".into(),
349 user_id: Some(Uuid::new_v4()),
350 tenant_id: None,
351 dataset_id: None,
352 current_data: None,
353 run_id: None,
354 user_email: Some("alice@example.com".into()),
355 provenance_visited: Arc::new(Mutex::new(HashSet::new())),
356 };
357 assert_eq!(ctx.user_label().as_deref(), Some("alice@example.com"));
358 }
359
360 #[test]
361 fn user_label_falls_back_to_user_id() {
362 let uid = Uuid::new_v4();
363 let ctx = PipelineContext {
364 pipeline_id: Uuid::new_v4(),
365 pipeline_name: "test".into(),
366 user_id: Some(uid),
367 tenant_id: None,
368 dataset_id: None,
369 current_data: None,
370 run_id: None,
371 user_email: None,
372 provenance_visited: Arc::new(Mutex::new(HashSet::new())),
373 };
374 assert_eq!(ctx.user_label(), Some(uid.to_string()));
375 }
376
377 #[test]
378 fn user_label_is_none_when_neither_set() {
379 let ctx = PipelineContext {
380 pipeline_id: Uuid::new_v4(),
381 pipeline_name: "test".into(),
382 user_id: None,
383 tenant_id: None,
384 dataset_id: None,
385 current_data: None,
386 run_id: None,
387 user_email: None,
388 provenance_visited: Arc::new(Mutex::new(HashSet::new())),
389 };
390 assert!(ctx.user_label().is_none());
391 }
392}