ai_workbench_lib/modules/file_splitter/
splitter.rs1use super::types::*;
2use super::splitter_types::{TextSplitter, CsvSplitter, JsonSplitter, CodeSplitter};
3use anyhow::Result;
4use std::path::Path;
5
6pub struct FileSplitter {
8 config: SplitConfig,
9}
10
11impl FileSplitter {
12 pub fn new() -> Self {
14 Self {
15 config: SplitConfig::default(),
16 }
17 }
18
19 pub fn with_config(config: SplitConfig) -> Self {
21 Self { config }
22 }
23
24 pub fn split_file(&self, file_path: &Path, data: &[u8]) -> Result<Vec<FileChunk>> {
26 let file_type = FileType::from_extension(file_path);
27 self.split_with_type(data, file_type)
28 }
29
30 pub fn split_with_type(&self, data: &[u8], file_type: FileType) -> Result<Vec<FileChunk>> {
32 if data.is_empty() {
33 return Ok(vec![]);
34 }
35
36 let strategy: Box<dyn SplitStrategy> = match file_type {
37 FileType::Csv => {
38 let csv_splitter = CsvSplitter::new();
39 if !csv_splitter.validate_format(data)? {
41 return self.split_with_type(data, FileType::Text);
43 }
44 Box::new(csv_splitter)
45 },
46 FileType::Tsv => {
47 let tsv_splitter = CsvSplitter::new_tsv();
48 if !tsv_splitter.validate_format(data)? {
50 return self.split_with_type(data, FileType::Text);
52 }
53 Box::new(tsv_splitter)
54 },
55 FileType::Json => {
56 let json_splitter = JsonSplitter::new();
57 match json_splitter.split(data, &self.config) {
59 Ok(chunks) => return Ok(chunks),
60 Err(_) => return self.split_with_type(data, FileType::Text),
61 }
62 },
63 file_type if file_type.is_source_code() => {
65 Box::new(CodeSplitter::new(file_type))
66 },
67 _ => Box::new(TextSplitter),
69 };
70
71 strategy.split(data, &self.config)
72 }
73
74}
75
76impl Default for FileSplitter {
77 fn default() -> Self {
78 Self::new()
79 }
80}
81
82#[cfg(test)]
83mod tests {
84 use super::*;
85 use std::path::Path;
86
87 #[test]
88 fn test_file_type_detection_by_extension() {
89 let splitter = FileSplitter::new();
90
91 assert_eq!(splitter.split_file(Path::new("test.txt"), b"hello world").unwrap().len(), 1);
92 assert_eq!(splitter.split_file(Path::new("data.csv"), b"a,b\n1,2").unwrap()[0].metadata.file_type, FileType::Csv);
93 assert_eq!(splitter.split_file(Path::new("data.tsv"), b"a\tb\n1\t2").unwrap()[0].metadata.file_type, FileType::Tsv);
94 assert_eq!(splitter.split_file(Path::new("logs.json"), br#"{"message": "test"}"#).unwrap()[0].metadata.file_type, FileType::Json);
95 }
96
97 #[test]
98 fn test_json_file_splitting() {
99 let splitter = FileSplitter::new();
100
101 let json_lines = br#"{"timestamp": "2023-01-01", "level": "INFO", "message": "Application started"}
103{"timestamp": "2023-01-01", "level": "ERROR", "message": "Database connection failed"}
104{"timestamp": "2023-01-01", "level": "INFO", "message": "Retrying connection"}"#;
105
106 let chunks = splitter.split_file(Path::new("app.log.json"), json_lines).unwrap();
107 assert!(!chunks.is_empty());
108 assert_eq!(chunks[0].metadata.file_type, FileType::Json);
109
110 let json_object = br#"{"users": [{"name": "John", "age": 25}, {"name": "Jane", "age": 30}]}"#;
112 let chunks = splitter.split_file(Path::new("users.json"), json_object).unwrap();
113 assert!(!chunks.is_empty());
114 assert_eq!(chunks[0].metadata.file_type, FileType::Json);
115 }
116
117 #[test]
118 fn test_invalid_csv_fallback() {
119 let splitter = FileSplitter::new();
120
121 let invalid_csv = "a,b,c\nthis is not csv\njust random text";
123 let chunks = splitter.split_with_type(invalid_csv.as_bytes(), FileType::Csv).unwrap();
124
125 assert_eq!(chunks[0].metadata.file_type, FileType::Text);
127 }
128
129 #[test]
130 fn test_binary_file_rejection() {
131 let splitter = FileSplitter::new();
132
133 let binary_data = &[0xFF, 0xFE, 0x00, 0x01, 0x02, 0x03];
135 let result = splitter.split_with_type(binary_data, FileType::Binary);
136
137 assert!(result.is_err());
138 }
139
140 #[test]
141 fn test_empty_file() {
142 let splitter = FileSplitter::new();
143 let chunks = splitter.split_file(Path::new("empty.txt"), b"").unwrap();
144 assert!(chunks.is_empty());
145 }
146
147}
148
149#[cfg(test)]
150mod integration_tests {
151 use super::*;
152 use std::path::Path;
153
154 #[test]
155 fn test_large_text_file_splitting() {
156 let large_text = (0..10000)
158 .map(|i| format!("This is line number {} with some additional content to make it longer", i))
159 .collect::<Vec<_>>()
160 .join("\n");
161
162 let config = SplitConfig {
163 chunk_size_mb: 0.1, ..Default::default()
165 };
166
167 let splitter = FileSplitter::with_config(config);
168 let chunks = splitter.split_file(Path::new("large.txt"), large_text.as_bytes()).unwrap();
169
170 assert!(chunks.len() > 1);
171
172 for chunk in &chunks {
174 assert!(chunk.data.len() > 0);
175 assert!(chunk.data.len() <= 120_000); }
177
178 let reconstructed = chunks.iter()
180 .map(|chunk| String::from_utf8_lossy(&chunk.data))
181 .collect::<Vec<_>>()
182 .join("\n");
183
184 let original_lines: std::collections::HashSet<&str> = large_text.lines().collect();
186 let reconstructed_lines: std::collections::HashSet<&str> = reconstructed.lines().collect();
187
188 let preserved_ratio = original_lines.intersection(&reconstructed_lines).count() as f64 / original_lines.len() as f64;
190 assert!(preserved_ratio > 0.95);
191 }
192
193 #[test]
194 fn test_large_csv_file_splitting() {
195 let mut csv_data = "id,name,email,department,salary\n".to_string();
197 for i in 0..5000 {
198 csv_data.push_str(&format!(
199 "{},Employee {},employee{}@company.com,Dept{},{}\n",
200 i, i, i, i % 10, 50000 + (i % 1000) * 100
201 ));
202 }
203
204 let config = SplitConfig {
205 chunk_size_mb: 0.05, ..Default::default()
207 };
208
209 let splitter = FileSplitter::with_config(config);
210 let chunks = splitter.split_file(Path::new("employees.csv"), csv_data.as_bytes()).unwrap();
211
212 assert!(chunks.len() > 1);
213
214 for chunk in &chunks {
216 assert!(chunk.metadata.has_headers);
217 let text = String::from_utf8_lossy(&chunk.data);
218 let lines: Vec<&str> = text.lines().collect();
219 assert!(lines.len() >= 1);
220 assert_eq!(lines[0], "id,name,email,department,salary");
221
222 for line in &lines[1..] {
224 let fields: Vec<&str> = line.split(',').collect();
225 assert_eq!(fields.len(), 5);
226 }
227 }
228 }
229
230 #[test]
231 fn test_mixed_file_types() {
232 let splitter = FileSplitter::new();
233
234 let test_cases = vec![
235 ("data.txt", "Simple text file\nwith multiple lines\nof content", FileType::Text),
236 ("data.csv", "a,b,c\n1,2,3\n4,5,6", FileType::Csv),
237 ("data.tsv", "a\tb\tc\n1\t2\t3\n4\t5\t6", FileType::Tsv),
238 ("data.json", r#"{"name": "test", "value": 123}"#, FileType::Json),
239 ("data.xml", "<root><item>value</item></root>", FileType::Xml),
240 ];
241
242 for (filename, content, expected_type) in test_cases {
243 let chunks = splitter.split_file(Path::new(filename), content.as_bytes()).unwrap();
244 assert_eq!(chunks.len(), 1); assert_eq!(chunks[0].metadata.file_type, expected_type);
246 assert_eq!(chunks[0].chunk_id, 0);
247 }
248 }
249
250 #[test]
251 fn test_edge_cases() {
252 let splitter = FileSplitter::new();
253
254 let long_line = "a".repeat(1_000_000);
256 let chunks = splitter.split_file(Path::new("long.txt"), long_line.as_bytes()).unwrap();
257 assert_eq!(chunks.len(), 1); let header_only = "name,age,city";
261 let chunks = splitter.split_file(Path::new("headers.csv"), header_only.as_bytes()).unwrap();
262 assert_eq!(chunks.len(), 1);
263 assert!(chunks[0].metadata.has_headers);
264
265 let with_empty_lines = "line1\n\n\nline2\n\nline3";
267 let chunks = splitter.split_file(Path::new("gaps.txt"), with_empty_lines.as_bytes()).unwrap();
268 assert_eq!(chunks.len(), 1);
269 assert_eq!(chunks[0].metadata.unit_count, Some(6)); }
271
272 #[test]
273 fn test_performance_characteristics() {
274 let splitter = FileSplitter::new();
275
276 let large_content = (0..50000)
278 .map(|i| format!("Line {} with some content that makes it reasonably long", i))
279 .collect::<Vec<_>>()
280 .join("\n");
281
282 let start = std::time::Instant::now();
283 let chunks = splitter.split_file(Path::new("performance.txt"), large_content.as_bytes()).unwrap();
284 let duration = start.elapsed();
285
286 assert!(duration.as_secs() < 5);
288 assert!(chunks.len() > 1);
289
290 let total_chunk_size: usize = chunks.iter().map(|c| c.data.len()).sum();
292 let original_size = large_content.len();
293
294 let size_ratio = total_chunk_size as f64 / original_size as f64;
296 assert!(size_ratio > 0.95 && size_ratio < 1.05);
297 }
298}