rustframes/dataframe/
io.rs1use super::{DataFrame, Series};
2use crate::dataframe::core::SeriesType;
3use csv::{ReaderBuilder, WriterBuilder};
4use std::collections::HashMap;
5use std::fs::File;
6use std::io::{BufReader, BufWriter};
7
8#[derive(Debug)]
9pub struct BoolParseError;
10
11impl DataFrame {
12 pub fn from_csv(path: &str) -> Result<Self, Box<dyn std::error::Error>> {
14 Self::from_csv_with_options(path, CsvReadOptions::default())
15 }
16
17 pub fn from_csv_with_options(
19 path: &str,
20 options: CsvReadOptions,
21 ) -> Result<Self, Box<dyn std::error::Error>> {
22 let file = File::open(path)?;
23 let mut rdr = ReaderBuilder::new()
24 .delimiter(options.delimiter)
25 .has_headers(options.has_headers)
26 .from_reader(BufReader::new(file));
27
28 let headers = if options.has_headers {
29 rdr.headers()?.clone()
30 } else {
31 csv::StringRecord::from(
33 (0..rdr.headers()?.len())
34 .map(|i| format!("column_{}", i))
35 .collect::<Vec<_>>(),
36 )
37 };
38
39 let mut raw_data: Vec<Vec<String>> = vec![Vec::new(); headers.len()];
41 for result in rdr.records() {
42 let record = result?;
43 for (i, field) in record.iter().enumerate() {
44 if i < raw_data.len() {
45 raw_data[i].push(field.to_string());
46 }
47 }
48 }
49
50 let mut column_types = Vec::new();
52 for col_data in &raw_data {
53 column_types.push(Self::infer_column_type(col_data));
54 }
55
56 let mut series_data = Vec::new();
58 for (i, col_data) in raw_data.into_iter().enumerate() {
59 let series = match column_types[i] {
60 SeriesType::Int64 => {
61 let parsed: Result<Vec<i64>, _> =
62 col_data.iter().map(|s| s.trim().parse::<i64>()).collect();
63 match parsed {
64 Ok(values) => Series::Int64(values),
65 Err(_) => Series::Utf8(col_data), }
67 }
68 SeriesType::Float64 => {
69 let parsed: Result<Vec<f64>, _> =
70 col_data.iter().map(|s| s.trim().parse::<f64>()).collect();
71 match parsed {
72 Ok(values) => Series::Float64(values),
73 Err(_) => Series::Utf8(col_data), }
75 }
76 SeriesType::Bool => {
77 let parsed: Result<Vec<bool>, _> = col_data
78 .iter()
79 .map(|s| Self::parse_bool(s.trim()))
80 .collect();
81 match parsed {
82 Ok(values) => Series::Bool(values),
83 Err(_) => Series::Utf8(col_data), }
85 }
86 SeriesType::Utf8 => Series::Utf8(col_data),
87 };
88 series_data.push(series);
89 }
90
91 let column_names: Vec<String> = headers.iter().map(|h| h.to_string()).collect();
92 Ok(DataFrame::new(
93 column_names.into_iter().zip(series_data).collect(),
94 ))
95 }
96
97 pub fn to_csv(&self, path: &str) -> Result<(), Box<dyn std::error::Error>> {
99 self.to_csv_with_options(path, CsvWriteOptions::default())
100 }
101
102 pub fn to_csv_with_options(
104 &self,
105 path: &str,
106 options: CsvWriteOptions,
107 ) -> Result<(), Box<dyn std::error::Error>> {
108 let file = File::create(path)?;
109 let mut wtr = WriterBuilder::new()
110 .delimiter(options.delimiter)
111 .from_writer(BufWriter::new(file));
112
113 if options.write_headers {
115 wtr.write_record(&self.columns)?;
116 }
117
118 for row_idx in 0..self.len() {
120 let mut record = Vec::new();
121 for series in &self.data {
122 let value = match series {
123 Series::Int64(v) => v[row_idx].to_string(),
124 Series::Float64(v) => {
125 if options.float_precision > 0 {
126 format!("{:.prec$}", v[row_idx], prec = options.float_precision)
127 } else {
128 v[row_idx].to_string()
129 }
130 }
131 Series::Bool(v) => v[row_idx].to_string(),
132 Series::Utf8(v) => v[row_idx].clone(),
133 };
134 record.push(value);
135 }
136 wtr.write_record(&record)?;
137 }
138
139 wtr.flush()?;
140 Ok(())
141 }
142
143 pub fn from_jsonl(path: &str) -> Result<Self, Box<dyn std::error::Error>> {
145 use std::fs;
146 let content = fs::read_to_string(path)?;
147
148 let mut all_columns: std::collections::HashSet<String> = std::collections::HashSet::new();
149 let mut records: Vec<HashMap<String, serde_json::Value>> = Vec::new();
150
151 for line in content.lines() {
153 if line.trim().is_empty() {
154 continue;
155 }
156
157 let record: HashMap<String, serde_json::Value> = serde_json::from_str(line)?;
158 for key in record.keys() {
159 all_columns.insert(key.clone());
160 }
161 records.push(record);
162 }
163
164 let columns: Vec<String> = all_columns.into_iter().collect();
165 let mut column_data: HashMap<String, Vec<String>> = HashMap::new();
166
167 for col in &columns {
169 column_data.insert(col.clone(), Vec::new());
170 }
171
172 for record in records {
174 for col in &columns {
175 let value = match record.get(col) {
176 Some(serde_json::Value::String(s)) => s.clone(),
177 Some(serde_json::Value::Number(n)) => n.to_string(),
178 Some(serde_json::Value::Bool(b)) => b.to_string(),
179 Some(serde_json::Value::Null) => "".to_string(),
180 Some(_) => "".to_string(), None => "".to_string(), };
183 column_data.get_mut(col).unwrap().push(value);
184 }
185 }
186
187 let mut series_data = Vec::new();
189 let mut final_columns = Vec::new();
190
191 for col in columns {
192 let col_values = column_data.remove(&col).unwrap();
193 let col_type = Self::infer_column_type(&col_values);
194
195 let series = match col_type {
196 SeriesType::Int64 => {
197 let parsed: Vec<i64> = col_values
198 .iter()
199 .map(|s| s.trim().parse::<i64>().unwrap_or(0))
200 .collect();
201 Series::Int64(parsed)
202 }
203 SeriesType::Float64 => {
204 let parsed: Vec<f64> = col_values
205 .iter()
206 .map(|s| s.trim().parse::<f64>().unwrap_or(0.0))
207 .collect();
208 Series::Float64(parsed)
209 }
210 SeriesType::Bool => {
211 let parsed: Vec<bool> = col_values
212 .iter()
213 .map(|s| Self::parse_bool(s.trim()).unwrap_or(false))
214 .collect();
215 Series::Bool(parsed)
216 }
217 SeriesType::Utf8 => Series::Utf8(col_values),
218 };
219
220 final_columns.push(col);
221 series_data.push(series);
222 }
223
224 Ok(DataFrame::new(
225 final_columns.into_iter().zip(series_data).collect(),
226 ))
227 }
228
229 pub fn to_jsonl(&self, path: &str) -> Result<(), Box<dyn std::error::Error>> {
231 use std::fs::File;
232 use std::io::Write;
233
234 let mut file = File::create(path)?;
235
236 for row_idx in 0..self.len() {
237 let mut record = serde_json::Map::new();
238
239 for (col_idx, col_name) in self.columns.iter().enumerate() {
240 let value = match &self.data[col_idx] {
241 Series::Int64(v) => {
242 serde_json::Value::Number(serde_json::Number::from(v[row_idx]))
243 }
244 Series::Float64(v) => {
245 if let Some(n) = serde_json::Number::from_f64(v[row_idx]) {
246 serde_json::Value::Number(n)
247 } else {
248 serde_json::Value::Null
249 }
250 }
251 Series::Bool(v) => serde_json::Value::Bool(v[row_idx]),
252 Series::Utf8(v) => serde_json::Value::String(v[row_idx].clone()),
253 };
254 record.insert(col_name.clone(), value);
255 }
256
257 let line = serde_json::to_string(&record)?;
258 writeln!(file, "{}", line)?;
259 }
260
261 Ok(())
262 }
263
264 pub fn to_json(&self, path: &str) -> Result<(), Box<dyn std::error::Error>> {
266 use std::fs::File;
267 use std::io::Write;
268
269 let mut records = Vec::new();
270
271 for row_idx in 0..self.len() {
272 let mut record = serde_json::Map::new();
273
274 for (col_idx, col_name) in self.columns.iter().enumerate() {
275 let value = match &self.data[col_idx] {
276 Series::Int64(v) => {
277 serde_json::Value::Number(serde_json::Number::from(v[row_idx]))
278 }
279 Series::Float64(v) => {
280 if let Some(n) = serde_json::Number::from_f64(v[row_idx]) {
281 serde_json::Value::Number(n)
282 } else {
283 serde_json::Value::Null
284 }
285 }
286 Series::Bool(v) => serde_json::Value::Bool(v[row_idx]),
287 Series::Utf8(v) => serde_json::Value::String(v[row_idx].clone()),
288 };
289 record.insert(col_name.clone(), value);
290 }
291
292 records.push(serde_json::Value::Object(record));
293 }
294
295 let json_array = serde_json::Value::Array(records);
296 let mut file = File::create(path)?;
297 writeln!(file, "{}", serde_json::to_string_pretty(&json_array)?)?;
298
299 Ok(())
300 }
301
302 pub fn infer_column_type(data: &[String]) -> SeriesType {
304 if data.is_empty() {
305 return SeriesType::Utf8;
306 }
307
308 let mut int_count = 0;
309 let mut float_count = 0;
310 let mut bool_count = 0;
311 let total = data.len();
312
313 for value in data {
314 let trimmed = value.trim();
315 if trimmed.is_empty() {
316 continue;
317 }
318
319 if trimmed.parse::<i64>().is_ok() {
320 int_count += 1;
321 } else if trimmed.parse::<f64>().is_ok() {
322 float_count += 1;
323 } else if Self::parse_bool(trimmed).is_ok() {
324 bool_count += 1;
325 }
326 }
327
328 let threshold = (total as f64 * 0.8).ceil() as usize; if bool_count >= threshold {
331 SeriesType::Bool
332 } else if int_count >= threshold {
333 SeriesType::Int64
334 } else if (int_count + float_count) >= threshold {
335 SeriesType::Float64
336 } else {
337 SeriesType::Utf8
338 }
339 }
340
341 pub fn parse_bool(s: &str) -> Result<bool, BoolParseError> {
343 match s.to_lowercase().as_str() {
344 "true" | "t" | "yes" | "y" | "1" => Ok(true),
345 "false" | "f" | "no" | "n" | "0" => Ok(false),
346 _ => Err(BoolParseError),
347 }
348 }
349}
350
351#[derive(Debug, Clone)]
352pub struct CsvReadOptions {
353 pub delimiter: u8,
354 pub has_headers: bool,
355}
356
357impl Default for CsvReadOptions {
358 fn default() -> Self {
359 CsvReadOptions {
360 delimiter: b',',
361 has_headers: true,
362 }
363 }
364}
365
366#[derive(Debug, Clone)]
367pub struct CsvWriteOptions {
368 pub delimiter: u8,
369 pub write_headers: bool,
370 pub float_precision: usize,
371}
372
373impl Default for CsvWriteOptions {
374 fn default() -> Self {
375 CsvWriteOptions {
376 delimiter: b',',
377 write_headers: true,
378 float_precision: 0, }
380 }
381}