Skip to main content

cognee_models/
data_input.rs

1use serde::{Deserialize, Serialize};
2use std::future::Future;
3use tokio::fs::File;
4use tokio::io::AsyncReadExt;
5
6#[derive(Debug, Clone, Serialize, Deserialize)]
7pub enum DataInput {
8    /// Raw text content
9    Text(String),
10
11    /// Local file path
12    FilePath(String),
13
14    /// HTTP/HTTPS URL
15    Url(String),
16
17    /// S3 path (s3://bucket/key) — TODO stub
18    S3Path(String),
19
20    /// In-memory binary data with a filename for MIME detection
21    Binary { data: Vec<u8>, name: String },
22
23    /// DataItem wrapper — wraps any other input with a custom label and optional metadata
24    DataItem {
25        data: Box<DataInput>,
26        label: String,
27        external_metadata: Option<String>,
28    },
29}
30
31impl DataInput {
32    /// Process the input data by chunks, calling the provided callback for each chunk.
33    /// This allows efficient streaming processing without loading entire files into memory.
34    ///
35    /// # Arguments
36    /// * `callback` - An async callback function that receives each chunk of data
37    pub async fn process_by_chunks<F, Fut, E>(&self, mut callback: F) -> Result<(), E>
38    where
39        F: FnMut(&[u8]) -> Fut,
40        Fut: Future<Output = Result<(), E>>,
41        E: From<std::io::Error>,
42    {
43        const BUFFER_SIZE: usize = 8192; // 8KB buffer
44
45        match self {
46            Self::Text(text) => {
47                callback(text.as_bytes()).await?;
48            }
49            Self::FilePath(path) => {
50                let clean_path = path.strip_prefix("file://").unwrap_or(path);
51
52                let mut file = File::open(clean_path).await.map_err(E::from)?;
53                let mut buffer = vec![0u8; BUFFER_SIZE];
54
55                loop {
56                    let bytes_read = file.read(&mut buffer).await.map_err(E::from)?;
57                    if bytes_read == 0 {
58                        break;
59                    }
60                    callback(&buffer[..bytes_read]).await?;
61                }
62            }
63            Self::Url(_url) => {
64                return Err(E::from(std::io::Error::new(
65                    std::io::ErrorKind::Unsupported,
66                    "URL inputs must be resolved before streaming. Use cognee_ingestion::resolve_url_input() or AddPipeline::add().",
67                )));
68            }
69            // TODO(COG-4456): implement S3 path ingestion — fetch bytes from S3 using
70            // aws-sdk-s3 or object_store, then route through the same MIME-based dispatch
71            // used for URL inputs (text → UTF-8, image/audio/pdf → Binary).
72            Self::S3Path(_s3_path) => {
73                return Err(E::from(std::io::Error::new(
74                    std::io::ErrorKind::Unsupported,
75                    "S3 processing not yet supported",
76                )));
77            }
78            Self::Binary { data, .. } => {
79                // Process binary data in chunks
80                for chunk in data.chunks(BUFFER_SIZE) {
81                    callback(chunk).await?;
82                }
83            }
84            Self::DataItem { data, .. } => {
85                // Box::pin breaks the infinite layout cycle caused by recursive async delegation
86                Box::pin(data.process_by_chunks(callback)).await?;
87            }
88        }
89
90        Ok(())
91    }
92
93    /// Classify a string into the appropriate DataInput variant
94    pub fn from_string(s: String) -> Self {
95        if s.starts_with("http://") || s.starts_with("https://") {
96            Self::Url(s)
97        } else if s.starts_with("s3://") {
98            Self::S3Path(s)
99        } else if s.starts_with('/') || s.starts_with("file://") || s.contains(":\\") {
100            Self::FilePath(s)
101        } else {
102            Self::Text(s)
103        }
104    }
105
106    /// Get the type of this input as a string
107    pub fn classify(&self) -> &str {
108        match self {
109            Self::Text(_) => "text",
110            Self::FilePath(_) => "file",
111            Self::Url(_) => "url",
112            Self::S3Path(_) => "s3",
113            Self::Binary { .. } => "binary",
114            Self::DataItem { data, .. } => data.classify(),
115        }
116    }
117
118    /// Get the inner string value (not applicable for Binary/DataItem)
119    pub fn as_str(&self) -> &str {
120        match self {
121            Self::Text(s) | Self::FilePath(s) | Self::Url(s) | Self::S3Path(s) => s,
122            Self::Binary { name, .. } => name,
123            Self::DataItem { data, .. } => data.as_str(),
124        }
125    }
126}
127
128#[cfg(test)]
129#[allow(
130    clippy::unwrap_used,
131    clippy::expect_used,
132    reason = "test code — panics are acceptable failures"
133)]
134mod tests {
135    use super::*;
136
137    #[test]
138    fn test_classify_text() {
139        let input = DataInput::from_string("Hello, world!".to_string());
140        assert!(matches!(input, DataInput::Text(_)));
141        assert_eq!(input.classify(), "text");
142    }
143
144    #[test]
145    fn test_classify_url() {
146        let input = DataInput::from_string("https://example.com".to_string());
147        assert!(matches!(input, DataInput::Url(_)));
148        assert_eq!(input.classify(), "url");
149    }
150
151    #[test]
152    fn test_classify_file_path() {
153        let input = DataInput::from_string("/path/to/file.txt".to_string());
154        assert!(matches!(input, DataInput::FilePath(_)));
155        assert_eq!(input.classify(), "file");
156    }
157
158    #[test]
159    fn test_classify_windows_path() {
160        for input in [
161            "C:\\path\\to\\file.txt".to_string(),
162            "file://C:/path/to/file.txt".to_string(),
163            "/path/to/file.txt".to_string(),
164        ] {
165            let data_input = DataInput::from_string(input);
166            assert!(matches!(data_input, DataInput::FilePath(_)));
167            assert_eq!(data_input.classify(), "file");
168        }
169    }
170
171    #[test]
172    fn test_classify_s3_path() {
173        let input = DataInput::from_string("s3://my-bucket/key/file.txt".to_string());
174        assert!(matches!(input, DataInput::S3Path(_)));
175        assert_eq!(input.classify(), "s3");
176    }
177
178    #[test]
179    fn test_binary_classify() {
180        let input = DataInput::Binary {
181            data: vec![0u8; 10],
182            name: "test.png".to_string(),
183        };
184        assert_eq!(input.classify(), "binary");
185        assert_eq!(input.as_str(), "test.png");
186    }
187
188    #[test]
189    fn test_data_item_delegates_classify() {
190        let inner = DataInput::Text("hello".to_string());
191        let item = DataInput::DataItem {
192            data: Box::new(inner),
193            label: "my label".to_string(),
194            external_metadata: None,
195        };
196        assert_eq!(item.classify(), "text");
197    }
198
199    #[tokio::test]
200    async fn test_url_process_by_chunks_error_message() {
201        let input = DataInput::Url("https://example.com".to_string());
202        let err = input
203            .process_by_chunks(|_| async { Ok::<(), std::io::Error>(()) })
204            .await
205            .unwrap_err();
206
207        assert_eq!(err.kind(), std::io::ErrorKind::Unsupported);
208        assert_eq!(
209            err.to_string(),
210            "URL inputs must be resolved before streaming. Use cognee_ingestion::resolve_url_input() or AddPipeline::add()."
211        );
212    }
213}