cognee_models/
data_input.rs1use serde::{Deserialize, Serialize};
2use std::future::Future;
3use tokio::fs::File;
4use tokio::io::AsyncReadExt;
5
6#[derive(Debug, Clone, Serialize, Deserialize)]
7pub enum DataInput {
8 Text(String),
10
11 FilePath(String),
13
14 Url(String),
16
17 S3Path(String),
19
20 Binary { data: Vec<u8>, name: String },
22
23 DataItem {
25 data: Box<DataInput>,
26 label: String,
27 external_metadata: Option<String>,
28 },
29}
30
31impl DataInput {
32 pub async fn process_by_chunks<F, Fut, E>(&self, mut callback: F) -> Result<(), E>
38 where
39 F: FnMut(&[u8]) -> Fut,
40 Fut: Future<Output = Result<(), E>>,
41 E: From<std::io::Error>,
42 {
43 const BUFFER_SIZE: usize = 8192; match self {
46 Self::Text(text) => {
47 callback(text.as_bytes()).await?;
48 }
49 Self::FilePath(path) => {
50 let clean_path = path.strip_prefix("file://").unwrap_or(path);
51
52 let mut file = File::open(clean_path).await.map_err(E::from)?;
53 let mut buffer = vec![0u8; BUFFER_SIZE];
54
55 loop {
56 let bytes_read = file.read(&mut buffer).await.map_err(E::from)?;
57 if bytes_read == 0 {
58 break;
59 }
60 callback(&buffer[..bytes_read]).await?;
61 }
62 }
63 Self::Url(_url) => {
64 return Err(E::from(std::io::Error::new(
65 std::io::ErrorKind::Unsupported,
66 "URL inputs must be resolved before streaming. Use cognee_ingestion::resolve_url_input() or AddPipeline::add().",
67 )));
68 }
69 Self::S3Path(_s3_path) => {
73 return Err(E::from(std::io::Error::new(
74 std::io::ErrorKind::Unsupported,
75 "S3 processing not yet supported",
76 )));
77 }
78 Self::Binary { data, .. } => {
79 for chunk in data.chunks(BUFFER_SIZE) {
81 callback(chunk).await?;
82 }
83 }
84 Self::DataItem { data, .. } => {
85 Box::pin(data.process_by_chunks(callback)).await?;
87 }
88 }
89
90 Ok(())
91 }
92
93 pub fn from_string(s: String) -> Self {
95 if s.starts_with("http://") || s.starts_with("https://") {
96 Self::Url(s)
97 } else if s.starts_with("s3://") {
98 Self::S3Path(s)
99 } else if s.starts_with('/') || s.starts_with("file://") || s.contains(":\\") {
100 Self::FilePath(s)
101 } else {
102 Self::Text(s)
103 }
104 }
105
106 pub fn classify(&self) -> &str {
108 match self {
109 Self::Text(_) => "text",
110 Self::FilePath(_) => "file",
111 Self::Url(_) => "url",
112 Self::S3Path(_) => "s3",
113 Self::Binary { .. } => "binary",
114 Self::DataItem { data, .. } => data.classify(),
115 }
116 }
117
118 pub fn as_str(&self) -> &str {
120 match self {
121 Self::Text(s) | Self::FilePath(s) | Self::Url(s) | Self::S3Path(s) => s,
122 Self::Binary { name, .. } => name,
123 Self::DataItem { data, .. } => data.as_str(),
124 }
125 }
126}
127
128#[cfg(test)]
129#[allow(
130 clippy::unwrap_used,
131 clippy::expect_used,
132 reason = "test code — panics are acceptable failures"
133)]
134mod tests {
135 use super::*;
136
137 #[test]
138 fn test_classify_text() {
139 let input = DataInput::from_string("Hello, world!".to_string());
140 assert!(matches!(input, DataInput::Text(_)));
141 assert_eq!(input.classify(), "text");
142 }
143
144 #[test]
145 fn test_classify_url() {
146 let input = DataInput::from_string("https://example.com".to_string());
147 assert!(matches!(input, DataInput::Url(_)));
148 assert_eq!(input.classify(), "url");
149 }
150
151 #[test]
152 fn test_classify_file_path() {
153 let input = DataInput::from_string("/path/to/file.txt".to_string());
154 assert!(matches!(input, DataInput::FilePath(_)));
155 assert_eq!(input.classify(), "file");
156 }
157
158 #[test]
159 fn test_classify_windows_path() {
160 for input in [
161 "C:\\path\\to\\file.txt".to_string(),
162 "file://C:/path/to/file.txt".to_string(),
163 "/path/to/file.txt".to_string(),
164 ] {
165 let data_input = DataInput::from_string(input);
166 assert!(matches!(data_input, DataInput::FilePath(_)));
167 assert_eq!(data_input.classify(), "file");
168 }
169 }
170
171 #[test]
172 fn test_classify_s3_path() {
173 let input = DataInput::from_string("s3://my-bucket/key/file.txt".to_string());
174 assert!(matches!(input, DataInput::S3Path(_)));
175 assert_eq!(input.classify(), "s3");
176 }
177
178 #[test]
179 fn test_binary_classify() {
180 let input = DataInput::Binary {
181 data: vec![0u8; 10],
182 name: "test.png".to_string(),
183 };
184 assert_eq!(input.classify(), "binary");
185 assert_eq!(input.as_str(), "test.png");
186 }
187
188 #[test]
189 fn test_data_item_delegates_classify() {
190 let inner = DataInput::Text("hello".to_string());
191 let item = DataInput::DataItem {
192 data: Box::new(inner),
193 label: "my label".to_string(),
194 external_metadata: None,
195 };
196 assert_eq!(item.classify(), "text");
197 }
198
199 #[tokio::test]
200 async fn test_url_process_by_chunks_error_message() {
201 let input = DataInput::Url("https://example.com".to_string());
202 let err = input
203 .process_by_chunks(|_| async { Ok::<(), std::io::Error>(()) })
204 .await
205 .unwrap_err();
206
207 assert_eq!(err.kind(), std::io::ErrorKind::Unsupported);
208 assert_eq!(
209 err.to_string(),
210 "URL inputs must be resolved before streaming. Use cognee_ingestion::resolve_url_input() or AddPipeline::add()."
211 );
212 }
213}