Skip to main content

cognee_chunking/
cognify_pipeline.rs

1//! Extract text chunks pipeline.
2//!
3//! Orchestrates the initial stages of the cognify process:
4//! classify documents → chunk text.
5
6use std::sync::Arc;
7
8use cognee_models::{Data, Document, DocumentChunk, classify_documents};
9use cognee_storage::StorageTrait;
10use tracing::{debug, info, info_span, instrument};
11
12use crate::error::ChunkingError;
13use crate::text_chunker::chunk_text;
14use crate::token_counter::{TokenCounter, WordCounter};
15
16/// The extract text chunks pipeline.
17///
18/// This pipeline handles the first two stages of cognify:
19/// 1. Document classification (text/* only)
20/// 2. Text chunking
21pub struct ExtractTextChunksPipeline {
22    storage: Arc<dyn StorageTrait>,
23}
24
25impl ExtractTextChunksPipeline {
26    pub fn new(storage: Arc<dyn StorageTrait>) -> Self {
27        Self { storage }
28    }
29
30    /// Extract text chunks from a set of Data items.
31    ///
32    /// Implements:
33    /// 1. Document classification (text/* only)
34    /// 2. Text chunking
35    ///
36    /// Returns the generated chunks.
37    pub async fn extract_chunks(
38        &self,
39        data_items: Vec<Data>,
40        max_chunk_size: usize,
41    ) -> Result<Vec<DocumentChunk>, ChunkingError> {
42        self.extract_chunks_with_counter(data_items, max_chunk_size, &WordCounter)
43            .await
44    }
45
46    /// Extract text chunks with a custom token counter.
47    #[instrument(name = "chunking.extract_chunks", skip(self, data_items, counter), fields(max_chunk_size, data_count = data_items.len()))]
48    pub async fn extract_chunks_with_counter<C: TokenCounter>(
49        &self,
50        data_items: Vec<Data>,
51        max_chunk_size: usize,
52        counter: &C,
53    ) -> Result<Vec<DocumentChunk>, ChunkingError> {
54        if max_chunk_size == 0 {
55            return Err(ChunkingError::InvalidChunkSize(0));
56        }
57
58        let documents: Vec<Document> = classify_documents(&data_items);
59        info!(doc_count = documents.len(), "documents classified");
60
61        let mut all_chunks = Vec::new();
62        for document in &documents {
63            let _doc_span = info_span!(
64                "chunking.process_document",
65                document_id = %document.base.id,
66                mime_type = %document.mime_type,
67            )
68            .entered();
69
70            let content_bytes = self
71                .storage
72                .retrieve(&document.raw_data_location)
73                .await
74                .map_err(|e| ChunkingError::StorageError(e.to_string()))?;
75
76            let content = String::from_utf8(content_bytes)
77                .map_err(|e| ChunkingError::InvalidUtf8(e.to_string()))?;
78
79            let chunks = chunk_text(document.base.id, &content, max_chunk_size, counter);
80            debug!(chunk_count = chunks.len(), document_id = %document.base.id, "document chunked");
81            all_chunks.extend(chunks);
82        }
83
84        info!(total_chunks = all_chunks.len(), "chunking complete");
85        Ok(all_chunks)
86    }
87}
88
89#[cfg(test)]
90#[allow(
91    clippy::unwrap_used,
92    clippy::expect_used,
93    reason = "test code — panics are acceptable failures"
94)]
95mod tests {
96    use super::*;
97    use cognee_storage::MockStorage;
98    use uuid::Uuid;
99
100    #[tokio::test]
101    async fn extract_chunks_empty_data() {
102        let storage = Arc::new(MockStorage::new());
103        let pipeline = ExtractTextChunksPipeline::new(storage);
104        let chunks = pipeline.extract_chunks(vec![], 100).await.unwrap();
105        assert!(chunks.is_empty());
106    }
107
108    #[tokio::test]
109    async fn extract_chunks_invalid_chunk_size() {
110        let storage = Arc::new(MockStorage::new());
111        let pipeline = ExtractTextChunksPipeline::new(storage);
112        let result = pipeline.extract_chunks(vec![], 0).await;
113        assert!(result.is_err());
114    }
115
116    #[tokio::test]
117    async fn extract_chunks_text_data() {
118        let storage = Arc::new(MockStorage::new());
119
120        // Store some content
121        let location = storage
122            .store(b"Hello world. This is a test.", "test.txt")
123            .await
124            .unwrap();
125
126        let data = Data::builder(
127            Uuid::new_v4(),
128            "test.txt",
129            location,
130            "text://test",
131            "txt",
132            "text/plain",
133            "hash123",
134            Uuid::new_v4(),
135        )
136        .build();
137
138        let pipeline = ExtractTextChunksPipeline::new(storage);
139        let chunks = pipeline.extract_chunks(vec![data], 100).await.unwrap();
140
141        assert!(!chunks.is_empty());
142        // All chunks should have text content
143        for chunk in &chunks {
144            assert!(!chunk.text.is_empty());
145        }
146    }
147
148    #[tokio::test]
149    async fn extract_chunks_skips_unknown_extension() {
150        let storage = Arc::new(MockStorage::new());
151
152        let data = Data::builder(
153            Uuid::new_v4(),
154            "data.xyz",
155            "/storage/data.xyz",
156            "file://data.xyz",
157            "xyz",
158            "application/octet-stream",
159            "hash456",
160            Uuid::new_v4(),
161        )
162        .build();
163
164        let pipeline = ExtractTextChunksPipeline::new(storage);
165        let chunks = pipeline.extract_chunks(vec![data], 100).await.unwrap();
166        assert!(chunks.is_empty());
167    }
168}