rustframes/dataframe/
io.rs

1use super::{DataFrame, Series};
2use crate::dataframe::core::SeriesType;
3use csv::{ReaderBuilder, WriterBuilder};
4use std::collections::HashMap;
5use std::fs::File;
6use std::io::{BufReader, BufWriter};
7
8#[derive(Debug)]
9pub struct BoolParseError;
10
11impl DataFrame {
12    /// Read CSV with automatic type inference
13    pub fn from_csv(path: &str) -> Result<Self, Box<dyn std::error::Error>> {
14        Self::from_csv_with_options(path, CsvReadOptions::default())
15    }
16
17    /// Read CSV with custom options
18    pub fn from_csv_with_options(
19        path: &str,
20        options: CsvReadOptions,
21    ) -> Result<Self, Box<dyn std::error::Error>> {
22        let file = File::open(path)?;
23        let mut rdr = ReaderBuilder::new()
24            .delimiter(options.delimiter)
25            .has_headers(options.has_headers)
26            .from_reader(BufReader::new(file));
27
28        let headers = if options.has_headers {
29            rdr.headers()?.clone()
30        } else {
31            // Generate default column names
32            csv::StringRecord::from(
33                (0..rdr.headers()?.len())
34                    .map(|i| format!("column_{}", i))
35                    .collect::<Vec<_>>(),
36            )
37        };
38
39        // First pass: collect all data as strings and infer types
40        let mut raw_data: Vec<Vec<String>> = vec![Vec::new(); headers.len()];
41        for result in rdr.records() {
42            let record = result?;
43            for (i, field) in record.iter().enumerate() {
44                if i < raw_data.len() {
45                    raw_data[i].push(field.to_string());
46                }
47            }
48        }
49
50        // Infer column types
51        let mut column_types = Vec::new();
52        for col_data in &raw_data {
53            column_types.push(Self::infer_column_type(col_data));
54        }
55
56        // Convert to appropriate Series types
57        let mut series_data = Vec::new();
58        for (i, col_data) in raw_data.into_iter().enumerate() {
59            let series = match column_types[i] {
60                SeriesType::Int64 => {
61                    let parsed: Result<Vec<i64>, _> =
62                        col_data.iter().map(|s| s.trim().parse::<i64>()).collect();
63                    match parsed {
64                        Ok(values) => Series::Int64(values),
65                        Err(_) => Series::Utf8(col_data), // Fallback to string
66                    }
67                }
68                SeriesType::Float64 => {
69                    let parsed: Result<Vec<f64>, _> =
70                        col_data.iter().map(|s| s.trim().parse::<f64>()).collect();
71                    match parsed {
72                        Ok(values) => Series::Float64(values),
73                        Err(_) => Series::Utf8(col_data), // Fallback to string
74                    }
75                }
76                SeriesType::Bool => {
77                    let parsed: Result<Vec<bool>, _> = col_data
78                        .iter()
79                        .map(|s| Self::parse_bool(s.trim()))
80                        .collect();
81                    match parsed {
82                        Ok(values) => Series::Bool(values),
83                        Err(_) => Series::Utf8(col_data), // Fallback to string
84                    }
85                }
86                SeriesType::Utf8 => Series::Utf8(col_data),
87            };
88            series_data.push(series);
89        }
90
91        let column_names: Vec<String> = headers.iter().map(|h| h.to_string()).collect();
92        Ok(DataFrame::new(
93            column_names.into_iter().zip(series_data).collect(),
94        ))
95    }
96
97    /// Write DataFrame to CSV
98    pub fn to_csv(&self, path: &str) -> Result<(), Box<dyn std::error::Error>> {
99        self.to_csv_with_options(path, CsvWriteOptions::default())
100    }
101
102    /// Write DataFrame to CSV with custom options
103    pub fn to_csv_with_options(
104        &self,
105        path: &str,
106        options: CsvWriteOptions,
107    ) -> Result<(), Box<dyn std::error::Error>> {
108        let file = File::create(path)?;
109        let mut wtr = WriterBuilder::new()
110            .delimiter(options.delimiter)
111            .from_writer(BufWriter::new(file));
112
113        // Write headers
114        if options.write_headers {
115            wtr.write_record(&self.columns)?;
116        }
117
118        // Write data rows
119        for row_idx in 0..self.len() {
120            let mut record = Vec::new();
121            for series in &self.data {
122                let value = match series {
123                    Series::Int64(v) => v[row_idx].to_string(),
124                    Series::Float64(v) => {
125                        if options.float_precision > 0 {
126                            format!("{:.prec$}", v[row_idx], prec = options.float_precision)
127                        } else {
128                            v[row_idx].to_string()
129                        }
130                    }
131                    Series::Bool(v) => v[row_idx].to_string(),
132                    Series::Utf8(v) => v[row_idx].clone(),
133                };
134                record.push(value);
135            }
136            wtr.write_record(&record)?;
137        }
138
139        wtr.flush()?;
140        Ok(())
141    }
142
143    /// Read from JSON Lines format
144    pub fn from_jsonl(path: &str) -> Result<Self, Box<dyn std::error::Error>> {
145        use std::fs;
146        let content = fs::read_to_string(path)?;
147
148        let mut all_columns: std::collections::HashSet<String> = std::collections::HashSet::new();
149        let mut records: Vec<HashMap<String, serde_json::Value>> = Vec::new();
150
151        // Parse each line and collect all possible column names
152        for line in content.lines() {
153            if line.trim().is_empty() {
154                continue;
155            }
156
157            let record: HashMap<String, serde_json::Value> = serde_json::from_str(line)?;
158            for key in record.keys() {
159                all_columns.insert(key.clone());
160            }
161            records.push(record);
162        }
163
164        let columns: Vec<String> = all_columns.into_iter().collect();
165        let mut column_data: HashMap<String, Vec<String>> = HashMap::new();
166
167        // Initialize column data
168        for col in &columns {
169            column_data.insert(col.clone(), Vec::new());
170        }
171
172        // Fill data, handling missing values
173        for record in records {
174            for col in &columns {
175                let value = match record.get(col) {
176                    Some(serde_json::Value::String(s)) => s.clone(),
177                    Some(serde_json::Value::Number(n)) => n.to_string(),
178                    Some(serde_json::Value::Bool(b)) => b.to_string(),
179                    Some(serde_json::Value::Null) => "".to_string(),
180                    Some(_) => "".to_string(), // Arrays, objects -> empty string
181                    None => "".to_string(),    // Missing field
182                };
183                column_data.get_mut(col).unwrap().push(value);
184            }
185        }
186
187        // Convert to Series with type inference
188        let mut series_data = Vec::new();
189        let mut final_columns = Vec::new();
190
191        for col in columns {
192            let col_values = column_data.remove(&col).unwrap();
193            let col_type = Self::infer_column_type(&col_values);
194
195            let series = match col_type {
196                SeriesType::Int64 => {
197                    let parsed: Vec<i64> = col_values
198                        .iter()
199                        .map(|s| s.trim().parse::<i64>().unwrap_or(0))
200                        .collect();
201                    Series::Int64(parsed)
202                }
203                SeriesType::Float64 => {
204                    let parsed: Vec<f64> = col_values
205                        .iter()
206                        .map(|s| s.trim().parse::<f64>().unwrap_or(0.0))
207                        .collect();
208                    Series::Float64(parsed)
209                }
210                SeriesType::Bool => {
211                    let parsed: Vec<bool> = col_values
212                        .iter()
213                        .map(|s| Self::parse_bool(s.trim()).unwrap_or(false))
214                        .collect();
215                    Series::Bool(parsed)
216                }
217                SeriesType::Utf8 => Series::Utf8(col_values),
218            };
219
220            final_columns.push(col);
221            series_data.push(series);
222        }
223
224        Ok(DataFrame::new(
225            final_columns.into_iter().zip(series_data).collect(),
226        ))
227    }
228
229    /// Write DataFrame to JSON Lines format
230    pub fn to_jsonl(&self, path: &str) -> Result<(), Box<dyn std::error::Error>> {
231        use std::fs::File;
232        use std::io::Write;
233
234        let mut file = File::create(path)?;
235
236        for row_idx in 0..self.len() {
237            let mut record = serde_json::Map::new();
238
239            for (col_idx, col_name) in self.columns.iter().enumerate() {
240                let value = match &self.data[col_idx] {
241                    Series::Int64(v) => {
242                        serde_json::Value::Number(serde_json::Number::from(v[row_idx]))
243                    }
244                    Series::Float64(v) => {
245                        if let Some(n) = serde_json::Number::from_f64(v[row_idx]) {
246                            serde_json::Value::Number(n)
247                        } else {
248                            serde_json::Value::Null
249                        }
250                    }
251                    Series::Bool(v) => serde_json::Value::Bool(v[row_idx]),
252                    Series::Utf8(v) => serde_json::Value::String(v[row_idx].clone()),
253                };
254                record.insert(col_name.clone(), value);
255            }
256
257            let line = serde_json::to_string(&record)?;
258            writeln!(file, "{}", line)?;
259        }
260
261        Ok(())
262    }
263
264    /// Write DataFrame to regular JSON format (array of objects)
265    pub fn to_json(&self, path: &str) -> Result<(), Box<dyn std::error::Error>> {
266        use std::fs::File;
267        use std::io::Write;
268
269        let mut records = Vec::new();
270
271        for row_idx in 0..self.len() {
272            let mut record = serde_json::Map::new();
273
274            for (col_idx, col_name) in self.columns.iter().enumerate() {
275                let value = match &self.data[col_idx] {
276                    Series::Int64(v) => {
277                        serde_json::Value::Number(serde_json::Number::from(v[row_idx]))
278                    }
279                    Series::Float64(v) => {
280                        if let Some(n) = serde_json::Number::from_f64(v[row_idx]) {
281                            serde_json::Value::Number(n)
282                        } else {
283                            serde_json::Value::Null
284                        }
285                    }
286                    Series::Bool(v) => serde_json::Value::Bool(v[row_idx]),
287                    Series::Utf8(v) => serde_json::Value::String(v[row_idx].clone()),
288                };
289                record.insert(col_name.clone(), value);
290            }
291
292            records.push(serde_json::Value::Object(record));
293        }
294
295        let json_array = serde_json::Value::Array(records);
296        let mut file = File::create(path)?;
297        writeln!(file, "{}", serde_json::to_string_pretty(&json_array)?)?;
298
299        Ok(())
300    }
301
302    /// Infer the type of a column from string data
303    pub fn infer_column_type(data: &[String]) -> SeriesType {
304        if data.is_empty() {
305            return SeriesType::Utf8;
306        }
307
308        let mut int_count = 0;
309        let mut float_count = 0;
310        let mut bool_count = 0;
311        let total = data.len();
312
313        for value in data {
314            let trimmed = value.trim();
315            if trimmed.is_empty() {
316                continue;
317            }
318
319            if trimmed.parse::<i64>().is_ok() {
320                int_count += 1;
321            } else if trimmed.parse::<f64>().is_ok() {
322                float_count += 1;
323            } else if Self::parse_bool(trimmed).is_ok() {
324                bool_count += 1;
325            }
326        }
327
328        let threshold = (total as f64 * 0.8).ceil() as usize; // 80% threshold
329
330        if bool_count >= threshold {
331            SeriesType::Bool
332        } else if int_count >= threshold {
333            SeriesType::Int64
334        } else if (int_count + float_count) >= threshold {
335            SeriesType::Float64
336        } else {
337            SeriesType::Utf8
338        }
339    }
340
341    /// Parse boolean from string
342    pub fn parse_bool(s: &str) -> Result<bool, BoolParseError> {
343        match s.to_lowercase().as_str() {
344            "true" | "t" | "yes" | "y" | "1" => Ok(true),
345            "false" | "f" | "no" | "n" | "0" => Ok(false),
346            _ => Err(BoolParseError),
347        }
348    }
349}
350
351#[derive(Debug, Clone)]
352pub struct CsvReadOptions {
353    pub delimiter: u8,
354    pub has_headers: bool,
355}
356
357impl Default for CsvReadOptions {
358    fn default() -> Self {
359        CsvReadOptions {
360            delimiter: b',',
361            has_headers: true,
362        }
363    }
364}
365
366#[derive(Debug, Clone)]
367pub struct CsvWriteOptions {
368    pub delimiter: u8,
369    pub write_headers: bool,
370    pub float_precision: usize,
371}
372
373impl Default for CsvWriteOptions {
374    fn default() -> Self {
375        CsvWriteOptions {
376            delimiter: b',',
377            write_headers: true,
378            float_precision: 0, // 0 means no special formatting
379        }
380    }
381}