Skip to main content

cognee_models/
data_input.rs

1use serde::{Deserialize, Serialize};
2use std::future::Future;
3// Local-filesystem streaming is unavailable on wasm32 (no OS filesystem); the
4// FilePath arm below is cfg'd out there, so these imports are too.
5#[cfg(not(target_arch = "wasm32"))]
6use tokio::fs::File;
7#[cfg(not(target_arch = "wasm32"))]
8use tokio::io::AsyncReadExt;
9
10#[derive(Debug, Clone, Serialize, Deserialize)]
11pub enum DataInput {
12    /// Raw text content
13    Text(String),
14
15    /// Local file path
16    FilePath(String),
17
18    /// HTTP/HTTPS URL
19    Url(String),
20
21    /// S3 path (s3://bucket/key) — TODO stub
22    S3Path(String),
23
24    /// In-memory binary data with a filename for MIME detection
25    Binary { data: Vec<u8>, name: String },
26
27    /// DataItem wrapper — wraps any other input with a custom label and optional metadata
28    DataItem {
29        data: Box<DataInput>,
30        label: String,
31        external_metadata: Option<String>,
32    },
33}
34
35impl DataInput {
36    /// Process the input data by chunks, calling the provided callback for each chunk.
37    /// This allows efficient streaming processing without loading entire files into memory.
38    ///
39    /// # Arguments
40    /// * `callback` - An async callback function that receives each chunk of data
41    pub async fn process_by_chunks<F, Fut, E>(&self, mut callback: F) -> Result<(), E>
42    where
43        F: FnMut(&[u8]) -> Fut,
44        Fut: Future<Output = Result<(), E>>,
45        E: From<std::io::Error>,
46    {
47        const BUFFER_SIZE: usize = 8192; // 8KB buffer
48
49        match self {
50            Self::Text(text) => {
51                callback(text.as_bytes()).await?;
52            }
53            Self::FilePath(path) => {
54                #[cfg(not(target_arch = "wasm32"))]
55                {
56                    let clean_path = path.strip_prefix("file://").unwrap_or(path);
57
58                    let mut file = File::open(clean_path).await.map_err(E::from)?;
59                    let mut buffer = vec![0u8; BUFFER_SIZE];
60
61                    loop {
62                        let bytes_read = file.read(&mut buffer).await.map_err(E::from)?;
63                        if bytes_read == 0 {
64                            break;
65                        }
66                        callback(&buffer[..bytes_read]).await?;
67                    }
68                }
69                // wasm32 has no local filesystem; callers must resolve a FilePath to
70                // Text/Binary before streaming (mirrors the Url/S3Path arms).
71                #[cfg(target_arch = "wasm32")]
72                {
73                    let _ = path;
74                    return Err(E::from(std::io::Error::new(
75                        std::io::ErrorKind::Unsupported,
76                        "Local file paths are not supported on wasm32; resolve inputs to Text or Binary before streaming.",
77                    )));
78                }
79            }
80            Self::Url(_url) => {
81                return Err(E::from(std::io::Error::new(
82                    std::io::ErrorKind::Unsupported,
83                    "URL inputs must be resolved before streaming. Use cognee_ingestion::resolve_url_input() or AddPipeline::add().",
84                )));
85            }
86            // TODO(COG-4456): implement S3 path ingestion — fetch bytes from S3 using
87            // aws-sdk-s3 or object_store, then route through the same MIME-based dispatch
88            // used for URL inputs (text → UTF-8, image/audio/pdf → Binary).
89            Self::S3Path(_s3_path) => {
90                return Err(E::from(std::io::Error::new(
91                    std::io::ErrorKind::Unsupported,
92                    "S3 processing not yet supported",
93                )));
94            }
95            Self::Binary { data, .. } => {
96                // Process binary data in chunks
97                for chunk in data.chunks(BUFFER_SIZE) {
98                    callback(chunk).await?;
99                }
100            }
101            Self::DataItem { data, .. } => {
102                // Box::pin breaks the infinite layout cycle caused by recursive async delegation
103                Box::pin(data.process_by_chunks(callback)).await?;
104            }
105        }
106
107        Ok(())
108    }
109
110    /// Classify a string into the appropriate DataInput variant
111    pub fn from_string(s: String) -> Self {
112        if s.starts_with("http://") || s.starts_with("https://") {
113            Self::Url(s)
114        } else if s.starts_with("s3://") {
115            Self::S3Path(s)
116        } else if s.starts_with('/') || s.starts_with("file://") || s.contains(":\\") {
117            Self::FilePath(s)
118        } else {
119            Self::Text(s)
120        }
121    }
122
123    /// Get the type of this input as a string
124    pub fn classify(&self) -> &str {
125        match self {
126            Self::Text(_) => "text",
127            Self::FilePath(_) => "file",
128            Self::Url(_) => "url",
129            Self::S3Path(_) => "s3",
130            Self::Binary { .. } => "binary",
131            Self::DataItem { data, .. } => data.classify(),
132        }
133    }
134
135    /// Get the inner string value (not applicable for Binary/DataItem)
136    pub fn as_str(&self) -> &str {
137        match self {
138            Self::Text(s) | Self::FilePath(s) | Self::Url(s) | Self::S3Path(s) => s,
139            Self::Binary { name, .. } => name,
140            Self::DataItem { data, .. } => data.as_str(),
141        }
142    }
143}
144
145#[cfg(test)]
146#[allow(
147    clippy::unwrap_used,
148    clippy::expect_used,
149    reason = "test code — panics are acceptable failures"
150)]
151mod tests {
152    use super::*;
153
154    #[test]
155    fn test_classify_text() {
156        let input = DataInput::from_string("Hello, world!".to_string());
157        assert!(matches!(input, DataInput::Text(_)));
158        assert_eq!(input.classify(), "text");
159    }
160
161    #[test]
162    fn test_classify_url() {
163        let input = DataInput::from_string("https://example.com".to_string());
164        assert!(matches!(input, DataInput::Url(_)));
165        assert_eq!(input.classify(), "url");
166    }
167
168    #[test]
169    fn test_classify_file_path() {
170        let input = DataInput::from_string("/path/to/file.txt".to_string());
171        assert!(matches!(input, DataInput::FilePath(_)));
172        assert_eq!(input.classify(), "file");
173    }
174
175    #[test]
176    fn test_classify_windows_path() {
177        for input in [
178            "C:\\path\\to\\file.txt".to_string(),
179            "file://C:/path/to/file.txt".to_string(),
180            "/path/to/file.txt".to_string(),
181        ] {
182            let data_input = DataInput::from_string(input);
183            assert!(matches!(data_input, DataInput::FilePath(_)));
184            assert_eq!(data_input.classify(), "file");
185        }
186    }
187
188    #[test]
189    fn test_classify_s3_path() {
190        let input = DataInput::from_string("s3://my-bucket/key/file.txt".to_string());
191        assert!(matches!(input, DataInput::S3Path(_)));
192        assert_eq!(input.classify(), "s3");
193    }
194
195    #[test]
196    fn test_binary_classify() {
197        let input = DataInput::Binary {
198            data: vec![0u8; 10],
199            name: "test.png".to_string(),
200        };
201        assert_eq!(input.classify(), "binary");
202        assert_eq!(input.as_str(), "test.png");
203    }
204
205    #[test]
206    fn test_data_item_delegates_classify() {
207        let inner = DataInput::Text("hello".to_string());
208        let item = DataInput::DataItem {
209            data: Box::new(inner),
210            label: "my label".to_string(),
211            external_metadata: None,
212        };
213        assert_eq!(item.classify(), "text");
214    }
215
216    // tokio is a non-wasm-only dependency here (see Cargo.toml), so this async
217    // test is gated off wasm to keep `cargo test --target wasm32` compiling. The
218    // sibling sync #[test]s above stay compiled on wasm as a lightweight API
219    // drift check; this one runs on native, where process_by_chunks is async.
220    #[cfg(not(target_arch = "wasm32"))]
221    #[tokio::test]
222    async fn test_url_process_by_chunks_error_message() {
223        let input = DataInput::Url("https://example.com".to_string());
224        let err = input
225            .process_by_chunks(|_| async { Ok::<(), std::io::Error>(()) })
226            .await
227            .unwrap_err();
228
229        assert_eq!(err.kind(), std::io::ErrorKind::Unsupported);
230        assert_eq!(
231            err.to_string(),
232            "URL inputs must be resolved before streaming. Use cognee_ingestion::resolve_url_input() or AddPipeline::add()."
233        );
234    }
235}