cognee_models/
data_input.rs1use serde::{Deserialize, Serialize};
2use std::future::Future;
3#[cfg(not(target_arch = "wasm32"))]
6use tokio::fs::File;
7#[cfg(not(target_arch = "wasm32"))]
8use tokio::io::AsyncReadExt;
9
10#[derive(Debug, Clone, Serialize, Deserialize)]
11pub enum DataInput {
12 Text(String),
14
15 FilePath(String),
17
18 Url(String),
20
21 S3Path(String),
23
24 Binary { data: Vec<u8>, name: String },
26
27 DataItem {
29 data: Box<DataInput>,
30 label: String,
31 external_metadata: Option<String>,
32 },
33}
34
35impl DataInput {
36 pub async fn process_by_chunks<F, Fut, E>(&self, mut callback: F) -> Result<(), E>
42 where
43 F: FnMut(&[u8]) -> Fut,
44 Fut: Future<Output = Result<(), E>>,
45 E: From<std::io::Error>,
46 {
47 const BUFFER_SIZE: usize = 8192; match self {
50 Self::Text(text) => {
51 callback(text.as_bytes()).await?;
52 }
53 Self::FilePath(path) => {
54 #[cfg(not(target_arch = "wasm32"))]
55 {
56 let clean_path = path.strip_prefix("file://").unwrap_or(path);
57
58 let mut file = File::open(clean_path).await.map_err(E::from)?;
59 let mut buffer = vec![0u8; BUFFER_SIZE];
60
61 loop {
62 let bytes_read = file.read(&mut buffer).await.map_err(E::from)?;
63 if bytes_read == 0 {
64 break;
65 }
66 callback(&buffer[..bytes_read]).await?;
67 }
68 }
69 #[cfg(target_arch = "wasm32")]
72 {
73 let _ = path;
74 return Err(E::from(std::io::Error::new(
75 std::io::ErrorKind::Unsupported,
76 "Local file paths are not supported on wasm32; resolve inputs to Text or Binary before streaming.",
77 )));
78 }
79 }
80 Self::Url(_url) => {
81 return Err(E::from(std::io::Error::new(
82 std::io::ErrorKind::Unsupported,
83 "URL inputs must be resolved before streaming. Use cognee_ingestion::resolve_url_input() or AddPipeline::add().",
84 )));
85 }
86 Self::S3Path(_s3_path) => {
90 return Err(E::from(std::io::Error::new(
91 std::io::ErrorKind::Unsupported,
92 "S3 processing not yet supported",
93 )));
94 }
95 Self::Binary { data, .. } => {
96 for chunk in data.chunks(BUFFER_SIZE) {
98 callback(chunk).await?;
99 }
100 }
101 Self::DataItem { data, .. } => {
102 Box::pin(data.process_by_chunks(callback)).await?;
104 }
105 }
106
107 Ok(())
108 }
109
110 pub fn from_string(s: String) -> Self {
112 if s.starts_with("http://") || s.starts_with("https://") {
113 Self::Url(s)
114 } else if s.starts_with("s3://") {
115 Self::S3Path(s)
116 } else if s.starts_with('/') || s.starts_with("file://") || s.contains(":\\") {
117 Self::FilePath(s)
118 } else {
119 Self::Text(s)
120 }
121 }
122
123 pub fn classify(&self) -> &str {
125 match self {
126 Self::Text(_) => "text",
127 Self::FilePath(_) => "file",
128 Self::Url(_) => "url",
129 Self::S3Path(_) => "s3",
130 Self::Binary { .. } => "binary",
131 Self::DataItem { data, .. } => data.classify(),
132 }
133 }
134
135 pub fn as_str(&self) -> &str {
137 match self {
138 Self::Text(s) | Self::FilePath(s) | Self::Url(s) | Self::S3Path(s) => s,
139 Self::Binary { name, .. } => name,
140 Self::DataItem { data, .. } => data.as_str(),
141 }
142 }
143}
144
145#[cfg(test)]
146#[allow(
147 clippy::unwrap_used,
148 clippy::expect_used,
149 reason = "test code — panics are acceptable failures"
150)]
151mod tests {
152 use super::*;
153
154 #[test]
155 fn test_classify_text() {
156 let input = DataInput::from_string("Hello, world!".to_string());
157 assert!(matches!(input, DataInput::Text(_)));
158 assert_eq!(input.classify(), "text");
159 }
160
161 #[test]
162 fn test_classify_url() {
163 let input = DataInput::from_string("https://example.com".to_string());
164 assert!(matches!(input, DataInput::Url(_)));
165 assert_eq!(input.classify(), "url");
166 }
167
168 #[test]
169 fn test_classify_file_path() {
170 let input = DataInput::from_string("/path/to/file.txt".to_string());
171 assert!(matches!(input, DataInput::FilePath(_)));
172 assert_eq!(input.classify(), "file");
173 }
174
175 #[test]
176 fn test_classify_windows_path() {
177 for input in [
178 "C:\\path\\to\\file.txt".to_string(),
179 "file://C:/path/to/file.txt".to_string(),
180 "/path/to/file.txt".to_string(),
181 ] {
182 let data_input = DataInput::from_string(input);
183 assert!(matches!(data_input, DataInput::FilePath(_)));
184 assert_eq!(data_input.classify(), "file");
185 }
186 }
187
188 #[test]
189 fn test_classify_s3_path() {
190 let input = DataInput::from_string("s3://my-bucket/key/file.txt".to_string());
191 assert!(matches!(input, DataInput::S3Path(_)));
192 assert_eq!(input.classify(), "s3");
193 }
194
195 #[test]
196 fn test_binary_classify() {
197 let input = DataInput::Binary {
198 data: vec![0u8; 10],
199 name: "test.png".to_string(),
200 };
201 assert_eq!(input.classify(), "binary");
202 assert_eq!(input.as_str(), "test.png");
203 }
204
205 #[test]
206 fn test_data_item_delegates_classify() {
207 let inner = DataInput::Text("hello".to_string());
208 let item = DataInput::DataItem {
209 data: Box::new(inner),
210 label: "my label".to_string(),
211 external_metadata: None,
212 };
213 assert_eq!(item.classify(), "text");
214 }
215
216 #[cfg(not(target_arch = "wasm32"))]
221 #[tokio::test]
222 async fn test_url_process_by_chunks_error_message() {
223 let input = DataInput::Url("https://example.com".to_string());
224 let err = input
225 .process_by_chunks(|_| async { Ok::<(), std::io::Error>(()) })
226 .await
227 .unwrap_err();
228
229 assert_eq!(err.kind(), std::io::ErrorKind::Unsupported);
230 assert_eq!(
231 err.to_string(),
232 "URL inputs must be resolved before streaming. Use cognee_ingestion::resolve_url_input() or AddPipeline::add()."
233 );
234 }
235}