1use std::io::{BufReader, BufRead};
3
4use polars::frame::row::Row;
5use polars::prelude::{AnyValue, DataFrame, Field, Schema};
6
7use crate::format_specs::FortValue;
8use crate::{format_specs::FortFormat, serde_common::{DResult, DError}};
9
10
11pub fn read_to_dataframe<R: std::io::Read, S: AsRef<str>>(f: BufReader<R>, fmt: &FortFormat, colnames: &[S]) -> DResult<DataFrame> {
30 if fmt.non_pos_len() < colnames.len() {
31 return Err(DError::FormatSpecTooShort)
32 }
33
34 let col_iter = fmt.iter_non_pos_fields().zip(colnames.iter())
39 .filter_map(|(f, n)| {
40 if let Some(dt) = f.polars_dtype() {
41 Some(Field::new(n.as_ref(), dt))
42 } else {
43 None
44 }
45 });
46 let schema = Schema::from_iter(col_iter);
47
48 let mut rows = vec![];
49 for (line_num, line) in f.lines().enumerate() {
50 let line = line.map_err(|e| DError::TableReadError(e, line_num + 1))?;
51 let values: Vec<FortValue> = crate::de::from_str(&line, fmt)?;
52 let this_row: Vec<AnyValue> = values.into_iter().map(|v| v.into()).collect();
53 if this_row.len() != colnames.len() {
54 return Err(DError::TableLineEndedEarly { line_num: line_num + 1, ncol: colnames.len() })
55 }
56 rows.push(Row::new(this_row));
57 }
58
59
60 Ok(DataFrame::from_rows_and_schema(&rows, &schema).unwrap())
61}
62
63#[cfg(test)]
64mod tests {
65 use super::*;
66 use polars::prelude::*;
67 use stringreader::StringReader;
68
69 #[test]
70 fn test_to_dataframe() -> DResult<()> {
71 let table = StringReader::new("Alpha T 1234 9.5\nBeta F -678 -1.5");
72 let table = BufReader::new(table);
73 let ff = FortFormat::parse("(a5,1x,l1,1x,i4,1x,f4.1)")?;
74 let df = read_to_dataframe(table, &ff, &["Name", "Flag", "ID", "Score"])?;
75
76 let ex_schema = Schema::from_iter([
78 Field::new("Name", DataType::Utf8),
79 Field::new("Flag", DataType::Boolean),
80 Field::new("ID", DataType::Int64),
81 Field::new("Score", DataType::Float64),
82 ]);
83
84 let ex_rows = vec![
85 Row::new(vec![AnyValue::Utf8Owned("Alpha".into()), AnyValue::Boolean(true), AnyValue::Int64(1234), AnyValue::Float64(9.5)]),
86 Row::new(vec![AnyValue::Utf8Owned("Beta".into()), AnyValue::Boolean(false), AnyValue::Int64(-678), AnyValue::Float64(-1.5)]),
87 ];
88
89 let expected = DataFrame::from_rows_and_schema(&ex_rows, &ex_schema).unwrap();
90 assert_eq!(df.column("Name").unwrap(), expected.column("Name").unwrap());
91 assert_eq!(df.column("Flag").unwrap(), expected.column("Flag").unwrap());
92 assert_eq!(df.column("ID").unwrap(), expected.column("ID").unwrap());
93 assert_eq!(df.column("Score").unwrap(), expected.column("Score").unwrap());
94 Ok(())
95 }
96
97 #[test]
98 fn test_line_short() -> DResult<()> {
99 let table = StringReader::new("Alpha T 1234\nBeta F -678 -1.5");
100 let table = BufReader::new(table);
101 let ff = FortFormat::parse("(a5,1x,l1,1x,i4,1x,f4.1)")?;
102 let err = read_to_dataframe(table, &ff, &["Name", "Flag", "ID", "Score"]).unwrap_err();
103
104 if let DError::TableLineEndedEarly { line_num, ncol } = err {
105 assert_eq!(line_num, 1);
106 assert_eq!(ncol, 4);
107 } else {
108 assert!(false, "Wrong error type");
109 }
110 Ok(())
111 }
112}