1extern crate csv;
42extern crate either;
43#[cfg(test)]
44#[macro_use]
45extern crate matches;
46#[cfg_attr(test, macro_use(array))]
47extern crate ndarray;
48extern crate serde;
49
50use csv::{Reader, Writer};
51use either::Either;
52use ndarray::iter::Iter;
53use ndarray::{Array1, Array2, Dim};
54use serde::de::DeserializeOwned;
55use serde::{Serialize, Serializer};
56use std::cell::Cell;
57use std::error::Error;
58use std::fmt::{Display, Formatter};
59use std::io::{Read, Write};
60use std::iter::once;
61
62pub trait Array2Reader {
64 fn deserialize_array2<A: DeserializeOwned>(
66 self,
67 shape: (usize, usize),
68 ) -> Result<Array2<A>, ReadError>;
69
70 fn deserialize_array2_dynamic<A: DeserializeOwned>(self) -> Result<Array2<A>, ReadError>;
71}
72
73#[derive(Debug)]
74pub enum ReadError {
75 Csv(csv::Error),
76 NRows {
77 expected: usize,
78 actual: usize,
79 },
80 NColumns {
81 at_row_index: usize,
82 expected: usize,
83 actual: usize,
84 },
85}
86
87impl Display for ReadError {
88 fn fmt(&self, f: &mut Formatter) -> Result<(), std::fmt::Error> {
89 match self {
90 ReadError::Csv(csv_error) => csv_error.fmt(f),
91 ReadError::NRows { expected, actual } => {
92 write!(f, "Expected {} rows but got {} rows", expected, actual)
93 }
94 ReadError::NColumns {
95 at_row_index,
96 expected,
97 actual,
98 } => write!(
99 f,
100 "On row {}, expected {} columns but got {} columns",
101 at_row_index, expected, actual
102 ),
103 }
104 }
105}
106
107impl Error for ReadError {}
108
109impl<'a, R: Read> Array2Reader for &'a mut Reader<R> {
110 fn deserialize_array2<A: DeserializeOwned>(
111 self,
112 shape: (usize, usize),
113 ) -> Result<Array2<A>, ReadError> {
114 let (n_rows, n_columns) = shape;
115
116 let rows = self.deserialize::<Vec<A>>();
117 let values = rows.enumerate().flat_map(|(row_index, row)| match row {
118 Err(e) => Either::Left(once(Err(ReadError::Csv(e)))),
119 Ok(row_vec) => Either::Right(if row_vec.len() == n_columns {
120 Either::Right(row_vec.into_iter().map(Ok))
121 } else {
122 Either::Left(once(Err(ReadError::NColumns {
123 at_row_index: row_index,
124 expected: n_columns,
125 actual: row_vec.len(),
126 })))
127 }),
128 });
129 let array1_result: Result<Array1<A>, _> = values.collect();
130 array1_result.and_then(|array1| {
131 let array1_len = array1.len();
132 #[allow(deprecated)]
133 array1.into_shape(shape).map_err(|_| ReadError::NRows {
134 expected: n_rows,
135 actual: array1_len / n_columns,
136 })
137 })
138 }
139
140 fn deserialize_array2_dynamic<A: DeserializeOwned>(self) -> Result<Array2<A>, ReadError> {
141 let mut row_count = 0;
142 let mut last_columns = None;
143
144 let rows = self.deserialize::<Vec<A>>();
145 let values = rows.enumerate().flat_map(|(row_index, row)| {
146 row_count += 1;
147 match row {
148 Err(e) => Either::Left(once(Err(ReadError::Csv(e)))),
149 Ok(row_vec) => {
150 if let Some(last_columns) = last_columns {
151 if last_columns != row_vec.len() {
152 return Either::Right(Either::Left(once(Err(ReadError::NColumns {
153 at_row_index: row_index,
154 expected: last_columns,
155 actual: row_vec.len(),
156 }))));
157 }
158 };
159 last_columns = Some(row_vec.len());
160 Either::Right(Either::Right(row_vec.into_iter().map(Ok)))
161 }
162 }
163 });
164 let array1_result: Result<Array1<A>, _> = values.collect();
165 array1_result.map(|array1| {
166 #[allow(deprecated)]
167 array1
168 .into_shape((row_count, last_columns.unwrap_or(0)))
169 .unwrap()
170 })
171 }
172}
173
174pub trait Array2Writer {
176 fn serialize_array2<A: Serialize>(self, array: &Array2<A>) -> Result<(), csv::Error>;
178}
179
180impl<'a, W: Write> Array2Writer for &'a mut Writer<W> {
181 fn serialize_array2<A: Serialize>(self, array: &Array2<A>) -> Result<(), csv::Error> {
182 struct Row1DIter<'b, B>(Cell<Option<Iter<'b, B, Dim<[usize; 1]>>>>);
189
190 impl<'b, B> Serialize for Row1DIter<'b, B>
191 where
192 B: Serialize,
193 {
194 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
195 where
196 S: Serializer,
197 {
198 serializer.collect_seq(self.0.take().unwrap())
199 }
200 }
201
202 for row in array.outer_iter() {
203 self.serialize(Row1DIter(Cell::new(Some(row.iter()))))?;
204 }
205 self.flush()?;
206 Ok(())
207 }
208}
209
210#[cfg(test)]
211mod tests {
212 use super::ReadError::*;
213 use super::*;
214 use csv::{Reader, ReaderBuilder, WriterBuilder};
215 use std::io::Cursor;
216
217 fn in_memory_reader(content: &'static str) -> Reader<impl Read> {
218 ReaderBuilder::new()
219 .has_headers(false)
220 .from_reader(Cursor::new(content))
221 }
222
223 fn test_reader() -> Reader<impl Read> {
224 in_memory_reader("1,2,3\n4,5,6\n")
225 }
226
227 #[test]
228 fn test_read_float() {
229 let actual: Array2<f64> = test_reader().deserialize_array2((2, 3)).unwrap();
230 let expected = array![[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]];
231 assert_eq!(actual, expected);
232 }
233
234 #[test]
235 fn test_read_integer() {
236 let actual: Array2<u64> = test_reader().deserialize_array2((2, 3)).unwrap();
237 let expected = array![[1, 2, 3], [4, 5, 6]];
238 assert_eq!(actual, expected);
239 }
240
241 #[test]
242 fn test_read_dynamic() {
243 let actual: Array2<u64> = test_reader().deserialize_array2_dynamic().unwrap();
244 let expected = array![[1, 2, 3], [4, 5, 6]];
245 assert_eq!(actual, expected);
246 }
247
248 #[test]
249 fn test_read_csv_error() {
250 in_memory_reader("1,2,3\n4,x,6\n")
251 .deserialize_array2::<i8>((2, 3))
252 .unwrap_err();
253 }
254
255 #[test]
256 fn test_read_too_few_rows() {
257 assert_matches! {
258 test_reader().deserialize_array2::<i8>((3, 3)).unwrap_err(),
259 NRows { expected: 3, actual: 2 }
260 }
261 }
262
263 #[test]
264 fn test_read_too_many_rows() {
265 assert_matches! {
266 test_reader().deserialize_array2::<i8>((1, 3)).unwrap_err(),
267 NRows { expected: 1, actual: 2 }
268 }
269 }
270
271 #[test]
272 fn test_read_too_few_columns() {
273 assert_matches! {
274 test_reader().deserialize_array2::<i8>((2, 4)).unwrap_err(),
275 NColumns { at_row_index: 0, expected: 4, actual: 3 }
276 }
277 }
278
279 #[test]
280 fn test_read_too_many_columns() {
281 assert_matches! {
282 test_reader().deserialize_array2::<i8>((2, 2)).unwrap_err(),
283 NColumns { at_row_index: 0, expected: 2, actual: 3 }
284 }
285 }
286
287 #[test]
288 fn test_write_ok() {
289 let mut writer = WriterBuilder::new().has_headers(false).from_writer(vec![]);
290
291 assert_matches! {
292 writer.serialize_array2(&array![[1, 2, 3], [4, 5, 6]]),
293 Ok(())
294 }
295 assert_eq!(
296 writer.into_inner().expect("flush failed"),
297 b"1,2,3\n4,5,6\n"
298 );
299 }
300
301 #[test]
302 fn test_write_transposed() {
303 let mut writer = WriterBuilder::new().has_headers(false).from_writer(vec![]);
304
305 assert_matches! {
306 writer.serialize_array2(&array![[1, 4], [2, 5], [3, 6]].t().to_owned()),
307 Ok(())
308 }
309
310 assert_eq!(
311 writer.into_inner().expect("flush failed"),
312 b"1,2,3\n4,5,6\n"
313 );
314 }
315
316 #[test]
317 fn test_write_err() {
318 let destination: &mut [u8] = &mut [0; 8];
319 let mut writer = WriterBuilder::new()
320 .has_headers(false)
321 .from_writer(Cursor::new(destination));
322
323 assert_matches! {
325 writer.serialize_array2(&array![[1, 2, 3], [4, 5, 6]]),
326 Err(_)
327 }
328 }
329}