Skip to main content

cognee_cognify/
dataset_resolver.rs

1//! Dataset resolution trait and `cognify_datasets` entry point.
2//!
3//! Mirrors Python's `resolve_authorized_user_datasets` + per-dataset loop
4//! in `cognee/modules/pipelines/operations/pipeline.py`.
5//!
6//! The [`DatasetResolver`] trait abstracts how dataset names are turned into
7//! concrete [`Dataset`] and [`Data`] objects so the cognify pipeline stays
8//! independent of any specific database backend.
9
10use std::sync::Arc;
11
12use async_trait::async_trait;
13use chrono::Utc;
14use cognee_core::CpuPool;
15use cognee_database::ops::pipeline_runs::{create_pipeline_run, get_latest_pipeline_status};
16use cognee_database::{DatabaseConnection, PipelineRun, PipelineRunRepository, PipelineRunStatus};
17use cognee_embedding::engine::EmbeddingEngine;
18use cognee_graph::GraphDBTrait;
19use cognee_llm::Llm;
20use cognee_models::{Data, Dataset};
21use cognee_ontology::OntologyResolver;
22use cognee_storage::StorageTrait;
23use cognee_vector::VectorDB;
24use tracing::info;
25use uuid::Uuid;
26
27use crate::config::CognifyConfig;
28use crate::error::CognifyError;
29use crate::pipeline::CognifyResult;
30use crate::tasks::cognify;
31
32/// Pipeline name used for cognify pipeline run records (matches Python convention).
33const COGNIFY_PIPELINE_NAME: &str = "cognify_pipeline";
34
35// ---------------------------------------------------------------------------
36// DatasetRef — identify a dataset by name or UUID
37// ---------------------------------------------------------------------------
38
39/// Reference to a dataset, either by name or by UUID.
40///
41/// Mirrors Python's `Union[str, list[str], list[UUID]]` parameter on `cognify()`.
42#[derive(Debug, Clone)]
43pub enum DatasetRef {
44    /// Identify a dataset by its human-readable name.
45    ByName(String),
46    /// Identify a dataset by its UUID.
47    ById(Uuid),
48}
49
50// ---------------------------------------------------------------------------
51// Trait
52// ---------------------------------------------------------------------------
53
54/// Resolve dataset names (or all datasets) to concrete [`Dataset`] and
55/// [`Data`] objects.
56///
57/// Implementations are expected to enforce authorization (the `permission`
58/// parameter mirrors Python's `get_authorized_existing_datasets`).
59#[async_trait]
60pub trait DatasetResolver: Send + Sync {
61    /// Resolve dataset names to [`Dataset`] objects for a given user.
62    ///
63    /// * If `datasets` is empty, implementations should return **all** datasets
64    ///   the user has access to (matching Python behaviour when `datasets=None`).
65    /// * `permission` is a hint for access control (e.g. `"read"`, `"write"`).
66    async fn resolve_datasets(
67        &self,
68        datasets: &[String],
69        user_id: Uuid,
70        permission: &str,
71    ) -> Result<Vec<Dataset>, CognifyError>;
72
73    /// Return all [`Data`] items attached to the given dataset.
74    async fn get_dataset_data(&self, dataset_id: Uuid) -> Result<Vec<Data>, CognifyError>;
75
76    /// Resolve a single dataset by its UUID.
77    ///
78    /// Default implementation returns `None` (not supported). Implementors
79    /// backed by a real database should override.
80    async fn resolve_dataset_by_id(
81        &self,
82        _id: Uuid,
83        _user_id: Uuid,
84        _permission: &str,
85    ) -> Result<Option<Dataset>, CognifyError> {
86        Ok(None)
87    }
88}
89
90// ---------------------------------------------------------------------------
91// cognify_datasets
92// ---------------------------------------------------------------------------
93
94/// High-level entry point: resolve dataset names, then run [`cognify`] for
95/// each dataset.
96///
97/// This mirrors the Python `cognify(datasets, user, ...)` API which:
98/// 1. Resolves dataset names to `Dataset` objects via the database.
99/// 2. For each dataset, fetches its `Data` items.
100/// 3. Runs the full cognify pipeline per dataset.
101///
102/// Empty datasets (no data items) are silently skipped.
103#[allow(clippy::too_many_arguments)]
104pub async fn cognify_datasets(
105    dataset_names: Vec<String>,
106    user_id: Uuid,
107    tenant_id: Option<Uuid>,
108    resolver: Arc<dyn DatasetResolver>,
109    llm: Arc<dyn Llm>,
110    storage: Arc<dyn StorageTrait>,
111    graph_db: Arc<dyn GraphDBTrait>,
112    vector_db: Arc<dyn VectorDB>,
113    embedding_engine: Arc<dyn EmbeddingEngine>,
114    database: Arc<DatabaseConnection>,
115    pipeline_run_repo: Arc<dyn PipelineRunRepository>,
116    thread_pool: Arc<dyn CpuPool>,
117    ontology_resolver: Arc<dyn OntologyResolver>,
118    config: &CognifyConfig,
119) -> Result<Vec<CognifyResult>, CognifyError> {
120    let datasets = resolver
121        .resolve_datasets(&dataset_names, user_id, "read")
122        .await?;
123
124    info!(
125        dataset_count = datasets.len(),
126        "Resolved {} dataset(s) for cognify",
127        datasets.len()
128    );
129
130    let mut results = Vec::new();
131
132    for dataset in &datasets {
133        // --- Pipeline cache check ---
134        if config.use_pipeline_cache {
135            let status =
136                get_latest_pipeline_status(&database, COGNIFY_PIPELINE_NAME, dataset.id).await?;
137            if matches!(status, Some(PipelineRunStatus::Completed)) {
138                info!(
139                    dataset_name = %dataset.name,
140                    dataset_id = %dataset.id,
141                    "Skipping already-processed dataset (pipeline cache hit)"
142                );
143                continue;
144            }
145        }
146
147        let data_items = resolver.get_dataset_data(dataset.id).await?;
148
149        if data_items.is_empty() {
150            info!(
151                dataset_name = %dataset.name,
152                dataset_id = %dataset.id,
153                "Skipping empty dataset"
154            );
155            continue;
156        }
157
158        info!(
159            dataset_name = %dataset.name,
160            dataset_id = %dataset.id,
161            data_items = data_items.len(),
162            "Running cognify for dataset"
163        );
164
165        let result = cognify(
166            data_items,
167            dataset.id,
168            Some(user_id),
169            None,
170            tenant_id,
171            Arc::clone(&llm),
172            Arc::clone(&storage),
173            Arc::clone(&graph_db),
174            Arc::clone(&vector_db),
175            Arc::clone(&embedding_engine),
176            Arc::clone(&database),
177            Arc::clone(&pipeline_run_repo),
178            Arc::clone(&thread_pool),
179            Arc::clone(&ontology_resolver),
180            config,
181        )
182        .await?;
183
184        // --- Record successful pipeline run ---
185        let pipeline_run_id = Uuid::new_v4();
186        let run = PipelineRun {
187            id: Uuid::new_v4(),
188            created_at: Utc::now(),
189            status: PipelineRunStatus::Completed,
190            pipeline_run_id,
191            pipeline_name: COGNIFY_PIPELINE_NAME.to_string(),
192            pipeline_id: pipeline_run_id,
193            dataset_id: Some(dataset.id),
194            run_info: None,
195        };
196        create_pipeline_run(&database, run).await?;
197
198        results.push(result);
199    }
200
201    info!(
202        "cognify_datasets complete: {} dataset(s) processed",
203        results.len()
204    );
205    Ok(results)
206}
207
208/// Like [`cognify_datasets`], but accepts [`DatasetRef`] values (by name or
209/// by UUID).
210///
211/// UUID-based refs are resolved via [`DatasetResolver::resolve_dataset_by_id`].
212/// Name-based refs are collected and resolved via [`DatasetResolver::resolve_datasets`].
213#[allow(clippy::too_many_arguments)]
214pub async fn cognify_dataset_refs(
215    refs: Vec<DatasetRef>,
216    user_id: Uuid,
217    tenant_id: Option<Uuid>,
218    resolver: Arc<dyn DatasetResolver>,
219    llm: Arc<dyn Llm>,
220    storage: Arc<dyn StorageTrait>,
221    graph_db: Arc<dyn GraphDBTrait>,
222    vector_db: Arc<dyn VectorDB>,
223    embedding_engine: Arc<dyn EmbeddingEngine>,
224    database: Arc<DatabaseConnection>,
225    pipeline_run_repo: Arc<dyn PipelineRunRepository>,
226    thread_pool: Arc<dyn CpuPool>,
227    ontology_resolver: Arc<dyn OntologyResolver>,
228    config: &CognifyConfig,
229) -> Result<Vec<CognifyResult>, CognifyError> {
230    // Split refs into name-based and id-based.
231    let mut names = Vec::new();
232    let mut id_datasets = Vec::new();
233
234    for r in refs {
235        match r {
236            DatasetRef::ByName(n) => names.push(n),
237            DatasetRef::ById(id) => {
238                let ds = resolver
239                    .resolve_dataset_by_id(id, user_id, "read")
240                    .await?
241                    .ok_or_else(|| {
242                        CognifyError::DatasetResolutionError(format!(
243                            "Dataset with id {id} not found"
244                        ))
245                    })?;
246                id_datasets.push(ds);
247            }
248        }
249    }
250
251    // Resolve name-based refs.
252    let name_datasets = resolver.resolve_datasets(&names, user_id, "read").await?;
253
254    // Merge both sets and delegate to the core loop via cognify_datasets.
255    // To avoid duplicating the per-dataset loop, we just call cognify_datasets
256    // with a fake name list (empty) and handle both sets directly.
257    let mut all_datasets = name_datasets;
258    all_datasets.extend(id_datasets);
259
260    info!(
261        dataset_count = all_datasets.len(),
262        "Resolved {} dataset(s) for cognify (via refs)",
263        all_datasets.len()
264    );
265
266    let mut results = Vec::new();
267    for dataset in &all_datasets {
268        if config.use_pipeline_cache {
269            let status =
270                get_latest_pipeline_status(&database, COGNIFY_PIPELINE_NAME, dataset.id).await?;
271            if matches!(status, Some(PipelineRunStatus::Completed)) {
272                info!(
273                    dataset_name = %dataset.name,
274                    dataset_id = %dataset.id,
275                    "Skipping already-processed dataset (pipeline cache hit)"
276                );
277                continue;
278            }
279        }
280
281        let data_items = resolver.get_dataset_data(dataset.id).await?;
282        if data_items.is_empty() {
283            info!(
284                dataset_name = %dataset.name,
285                dataset_id = %dataset.id,
286                "Skipping empty dataset"
287            );
288            continue;
289        }
290
291        info!(
292            dataset_name = %dataset.name,
293            dataset_id = %dataset.id,
294            data_items = data_items.len(),
295            "Running cognify for dataset"
296        );
297
298        let result = cognify(
299            data_items,
300            dataset.id,
301            Some(user_id),
302            None,
303            tenant_id,
304            Arc::clone(&llm),
305            Arc::clone(&storage),
306            Arc::clone(&graph_db),
307            Arc::clone(&vector_db),
308            Arc::clone(&embedding_engine),
309            Arc::clone(&database),
310            Arc::clone(&pipeline_run_repo),
311            Arc::clone(&thread_pool),
312            Arc::clone(&ontology_resolver),
313            config,
314        )
315        .await?;
316
317        let pipeline_run_id = Uuid::new_v4();
318        let run = PipelineRun {
319            id: Uuid::new_v4(),
320            created_at: Utc::now(),
321            status: PipelineRunStatus::Completed,
322            pipeline_run_id,
323            pipeline_name: COGNIFY_PIPELINE_NAME.to_string(),
324            pipeline_id: pipeline_run_id,
325            dataset_id: Some(dataset.id),
326            run_info: None,
327        };
328        create_pipeline_run(&database, run).await?;
329
330        results.push(result);
331    }
332
333    info!(
334        "cognify_dataset_refs complete: {} dataset(s) processed",
335        results.len()
336    );
337    Ok(results)
338}
339
340#[cfg(test)]
341#[allow(
342    clippy::unwrap_used,
343    clippy::expect_used,
344    reason = "test code — panics are acceptable failures"
345)]
346mod tests {
347    use super::*;
348
349    /// A trivial in-memory resolver for testing.
350    struct MockResolver {
351        datasets: Vec<Dataset>,
352        data: std::collections::HashMap<Uuid, Vec<Data>>,
353    }
354
355    #[async_trait]
356    impl DatasetResolver for MockResolver {
357        async fn resolve_datasets(
358            &self,
359            names: &[String],
360            _user_id: Uuid,
361            _permission: &str,
362        ) -> Result<Vec<Dataset>, CognifyError> {
363            if names.is_empty() {
364                return Ok(self.datasets.clone());
365            }
366            Ok(self
367                .datasets
368                .iter()
369                .filter(|ds| names.contains(&ds.name))
370                .cloned()
371                .collect())
372        }
373
374        async fn get_dataset_data(&self, dataset_id: Uuid) -> Result<Vec<Data>, CognifyError> {
375            Ok(self.data.get(&dataset_id).cloned().unwrap_or_default())
376        }
377    }
378
379    #[test]
380    fn test_mock_resolver_filters_by_name() {
381        let owner = Uuid::new_v4();
382        let ds1 = Dataset::new("alpha".to_string(), owner, None, Uuid::new_v4());
383        let ds2 = Dataset::new("beta".to_string(), owner, None, Uuid::new_v4());
384        let resolver = MockResolver {
385            datasets: vec![ds1.clone(), ds2],
386            data: std::collections::HashMap::new(),
387        };
388
389        let rt = tokio::runtime::Builder::new_current_thread()
390            .enable_all()
391            .build()
392            .unwrap();
393        let result = rt.block_on(resolver.resolve_datasets(&["alpha".to_string()], owner, "read"));
394        let datasets = result.unwrap();
395        assert_eq!(datasets.len(), 1);
396        assert_eq!(datasets[0].name, "alpha");
397    }
398
399    #[test]
400    fn test_mock_resolver_returns_all_when_empty() {
401        let owner = Uuid::new_v4();
402        let ds1 = Dataset::new("alpha".to_string(), owner, None, Uuid::new_v4());
403        let ds2 = Dataset::new("beta".to_string(), owner, None, Uuid::new_v4());
404        let resolver = MockResolver {
405            datasets: vec![ds1, ds2],
406            data: std::collections::HashMap::new(),
407        };
408
409        let rt = tokio::runtime::Builder::new_current_thread()
410            .enable_all()
411            .build()
412            .unwrap();
413        let result = rt.block_on(resolver.resolve_datasets(&[], owner, "read"));
414        let datasets = result.unwrap();
415        assert_eq!(datasets.len(), 2);
416    }
417
418    #[test]
419    fn test_mock_resolver_get_data_empty_dataset() {
420        let resolver = MockResolver {
421            datasets: vec![],
422            data: std::collections::HashMap::new(),
423        };
424
425        let rt = tokio::runtime::Builder::new_current_thread()
426            .enable_all()
427            .build()
428            .unwrap();
429        let result = rt.block_on(resolver.get_dataset_data(Uuid::new_v4()));
430        assert!(result.unwrap().is_empty());
431    }
432
433    #[test]
434    fn test_mock_resolver_get_data_with_items() {
435        let dataset_id = Uuid::new_v4();
436        let owner_id = Uuid::new_v4();
437        let data_item = Data::builder(
438            Uuid::new_v4(),
439            "test.txt",
440            "/storage/test.txt",
441            "file://test.txt",
442            "txt",
443            "text/plain",
444            "hash123",
445            owner_id,
446        )
447        .build();
448
449        let mut data_map = std::collections::HashMap::new();
450        data_map.insert(dataset_id, vec![data_item]);
451
452        let resolver = MockResolver {
453            datasets: vec![Dataset::new(
454                "ds".to_string(),
455                owner_id,
456                None,
457                Uuid::new_v4(),
458            )],
459            data: data_map,
460        };
461
462        let rt = tokio::runtime::Builder::new_current_thread()
463            .enable_all()
464            .build()
465            .unwrap();
466        let result = rt.block_on(resolver.get_dataset_data(dataset_id));
467        let items = result.unwrap();
468        assert_eq!(items.len(), 1);
469        assert_eq!(items[0].name, "test.txt");
470    }
471}