1use std::fs::File;
4use std::io::BufReader;
5
6use kyu_common::{KyuError, KyuResult};
7use kyu_types::{LogicalType, TypedValue};
8
9use crate::{DataReader, parse_field};
10
11pub struct CsvReader {
13 reader: csv::Reader<BufReader<File>>,
14 schema: Vec<LogicalType>,
15}
16
17impl CsvReader {
18 pub fn open(path: &str, schema: &[LogicalType]) -> KyuResult<Self> {
20 let file =
21 File::open(path).map_err(|e| KyuError::Copy(format!("cannot open '{path}': {e}")))?;
22 let reader = csv::ReaderBuilder::new()
23 .has_headers(true)
24 .from_reader(BufReader::new(file));
25 Ok(Self {
26 reader,
27 schema: schema.to_vec(),
28 })
29 }
30}
31
32impl DataReader for CsvReader {
33 fn schema(&self) -> &[LogicalType] {
34 &self.schema
35 }
36}
37
38impl Iterator for CsvReader {
39 type Item = KyuResult<Vec<TypedValue>>;
40
41 fn next(&mut self) -> Option<Self::Item> {
42 let record = match self.reader.records().next()? {
43 Ok(r) => r,
44 Err(e) => return Some(Err(KyuError::Copy(format!("CSV parse error: {e}")))),
45 };
46
47 let mut values = Vec::with_capacity(self.schema.len());
48 for (i, ty) in self.schema.iter().enumerate() {
49 let field = record.get(i).unwrap_or("");
50 match parse_field(field, ty) {
51 Ok(v) => values.push(v),
52 Err(e) => return Some(Err(e)),
53 }
54 }
55 Some(Ok(values))
56 }
57}
58
59#[cfg(test)]
60mod tests {
61 use super::*;
62 use std::io::Write;
63
64 fn write_csv(dir: &std::path::Path, name: &str, content: &str) -> String {
65 let path = dir.join(name);
66 let mut f = File::create(&path).unwrap();
67 f.write_all(content.as_bytes()).unwrap();
68 path.to_str().unwrap().to_string()
69 }
70
71 #[test]
72 fn read_csv_basic() {
73 let dir = std::env::temp_dir().join("kyu_csv_reader_test");
74 let _ = std::fs::create_dir_all(&dir);
75 let path = write_csv(&dir, "basic.csv", "id,name\n1,Alice\n2,Bob\n");
76
77 let schema = vec![LogicalType::Int64, LogicalType::String];
78 let reader = CsvReader::open(&path, &schema).unwrap();
79 let rows: Vec<_> = reader.collect::<Result<Vec<_>, _>>().unwrap();
80
81 assert_eq!(rows.len(), 2);
82 assert_eq!(rows[0][0], TypedValue::Int64(1));
83 assert_eq!(
84 rows[0][1],
85 TypedValue::String(smol_str::SmolStr::new("Alice"))
86 );
87 assert_eq!(rows[1][0], TypedValue::Int64(2));
88
89 let _ = std::fs::remove_dir_all(&dir);
90 }
91
92 #[test]
93 fn read_csv_with_nulls() {
94 let dir = std::env::temp_dir().join("kyu_csv_null_test");
95 let _ = std::fs::create_dir_all(&dir);
96 let path = write_csv(&dir, "nulls.csv", "id,value\n1,\n2,42\n");
97
98 let schema = vec![LogicalType::Int64, LogicalType::Int64];
99 let reader = CsvReader::open(&path, &schema).unwrap();
100 let rows: Vec<_> = reader.collect::<Result<Vec<_>, _>>().unwrap();
101
102 assert_eq!(rows[0][1], TypedValue::Null);
103 assert_eq!(rows[1][1], TypedValue::Int64(42));
104
105 let _ = std::fs::remove_dir_all(&dir);
106 }
107}