ndarray_csv/
lib.rs

1//! Easily read and write homogeneous CSV data to and from 2D ndarrays.
2//!
3//! ```rust
4//! extern crate csv;
5//! extern crate ndarray;
6//! extern crate ndarray_csv;
7//!
8//! use csv::{ReaderBuilder, WriterBuilder};
9//! use ndarray::{array, Array2};
10//! use ndarray_csv::{Array2Reader, Array2Writer};
11//! use std::error::Error;
12//! use std::fs::File;
13//!
14//! fn main() -> Result<(), Box<dyn Error>> {
15//!     // Our 2x3 test array
16//!     let array = array![[1, 2, 3], [4, 5, 6]];
17//!
18//!     // Write the array into the file.
19//!     {
20//!         let file = File::create("test.csv")?;
21//!         let mut writer = WriterBuilder::new().has_headers(false).from_writer(file);
22//!         writer.serialize_array2(&array)?;
23//!     }
24//!
25//!     // Read an array back from the file
26//!     let file = File::open("test.csv")?;
27//!     let mut reader = ReaderBuilder::new().has_headers(false).from_reader(file);
28//!     let array_read: Array2<u64> = reader.deserialize_array2((2, 3))?;
29//!
30//!     // Ensure that we got the original array back
31//!     assert_eq!(array_read, array);
32//!     Ok(())
33//! }
34//! ```
35//!
36//! This project uses [cargo-make](https://sagiegurari.github.io/cargo-make/) for builds; to build,
37//! run `cargo make all`.
38//!
39//! To prevent denial-of-service attacks, do not read in untrusted CSV streams of unbounded length;
40//! this can be implemented with `std::io::Read::take`.
41extern crate csv;
42extern crate either;
43#[cfg(test)]
44#[macro_use]
45extern crate matches;
46#[cfg_attr(test, macro_use(array))]
47extern crate ndarray;
48extern crate serde;
49
50use csv::{Reader, Writer};
51use either::Either;
52use ndarray::iter::Iter;
53use ndarray::{Array1, Array2, Dim};
54use serde::de::DeserializeOwned;
55use serde::{Serialize, Serializer};
56use std::cell::Cell;
57use std::error::Error;
58use std::fmt::{Display, Formatter};
59use std::io::{Read, Write};
60use std::iter::once;
61
62/// An extension trait; this is implemented by `&mut csv::Reader`
63pub trait Array2Reader {
64    /// Read CSV data into a new ndarray with the given shape
65    fn deserialize_array2<A: DeserializeOwned>(
66        self,
67        shape: (usize, usize),
68    ) -> Result<Array2<A>, ReadError>;
69
70    fn deserialize_array2_dynamic<A: DeserializeOwned>(self) -> Result<Array2<A>, ReadError>;
71}
72
73#[derive(Debug)]
74pub enum ReadError {
75    Csv(csv::Error),
76    NRows {
77        expected: usize,
78        actual: usize,
79    },
80    NColumns {
81        at_row_index: usize,
82        expected: usize,
83        actual: usize,
84    },
85}
86
87impl Display for ReadError {
88    fn fmt(&self, f: &mut Formatter) -> Result<(), std::fmt::Error> {
89        match self {
90            ReadError::Csv(csv_error) => csv_error.fmt(f),
91            ReadError::NRows { expected, actual } => {
92                write!(f, "Expected {} rows but got {} rows", expected, actual)
93            }
94            ReadError::NColumns {
95                at_row_index,
96                expected,
97                actual,
98            } => write!(
99                f,
100                "On row {}, expected {} columns but got {} columns",
101                at_row_index, expected, actual
102            ),
103        }
104    }
105}
106
107impl Error for ReadError {}
108
109impl<'a, R: Read> Array2Reader for &'a mut Reader<R> {
110    fn deserialize_array2<A: DeserializeOwned>(
111        self,
112        shape: (usize, usize),
113    ) -> Result<Array2<A>, ReadError> {
114        let (n_rows, n_columns) = shape;
115
116        let rows = self.deserialize::<Vec<A>>();
117        let values = rows.enumerate().flat_map(|(row_index, row)| match row {
118            Err(e) => Either::Left(once(Err(ReadError::Csv(e)))),
119            Ok(row_vec) => Either::Right(if row_vec.len() == n_columns {
120                Either::Right(row_vec.into_iter().map(Ok))
121            } else {
122                Either::Left(once(Err(ReadError::NColumns {
123                    at_row_index: row_index,
124                    expected: n_columns,
125                    actual: row_vec.len(),
126                })))
127            }),
128        });
129        let array1_result: Result<Array1<A>, _> = values.collect();
130        array1_result.and_then(|array1| {
131            let array1_len = array1.len();
132            #[allow(deprecated)]
133            array1.into_shape(shape).map_err(|_| ReadError::NRows {
134                expected: n_rows,
135                actual: array1_len / n_columns,
136            })
137        })
138    }
139
140    fn deserialize_array2_dynamic<A: DeserializeOwned>(self) -> Result<Array2<A>, ReadError> {
141        let mut row_count = 0;
142        let mut last_columns = None;
143
144        let rows = self.deserialize::<Vec<A>>();
145        let values = rows.enumerate().flat_map(|(row_index, row)| {
146            row_count += 1;
147            match row {
148                Err(e) => Either::Left(once(Err(ReadError::Csv(e)))),
149                Ok(row_vec) => {
150                    if let Some(last_columns) = last_columns {
151                        if last_columns != row_vec.len() {
152                            return Either::Right(Either::Left(once(Err(ReadError::NColumns {
153                                at_row_index: row_index,
154                                expected: last_columns,
155                                actual: row_vec.len(),
156                            }))));
157                        }
158                    };
159                    last_columns = Some(row_vec.len());
160                    Either::Right(Either::Right(row_vec.into_iter().map(Ok)))
161                }
162            }
163        });
164        let array1_result: Result<Array1<A>, _> = values.collect();
165        array1_result.map(|array1| {
166            #[allow(deprecated)]
167            array1
168                .into_shape((row_count, last_columns.unwrap_or(0)))
169                .unwrap()
170        })
171    }
172}
173
174/// An extension trait; this is implemented by `&mut csv::Writer`
175pub trait Array2Writer {
176    /// Write this ndarray into CSV format
177    fn serialize_array2<A: Serialize>(self, array: &Array2<A>) -> Result<(), csv::Error>;
178}
179
180impl<'a, W: Write> Array2Writer for &'a mut Writer<W> {
181    fn serialize_array2<A: Serialize>(self, array: &Array2<A>) -> Result<(), csv::Error> {
182        /// This wraps the iterator for a row so that we can implement Serialize.
183        ///
184        /// Serialize is not implemented for iterators: https://github.com/serde-rs/serde/issues/571
185        ///
186        /// This solution from Hyeonu wraps the iterator:
187        /// https://users.rust-lang.org/t/how-to-serialize-an-iterator-to-json/59272/3
188        struct Row1DIter<'b, B>(Cell<Option<Iter<'b, B, Dim<[usize; 1]>>>>);
189
190        impl<'b, B> Serialize for Row1DIter<'b, B>
191        where
192            B: Serialize,
193        {
194            fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
195            where
196                S: Serializer,
197            {
198                serializer.collect_seq(self.0.take().unwrap())
199            }
200        }
201
202        for row in array.outer_iter() {
203            self.serialize(Row1DIter(Cell::new(Some(row.iter()))))?;
204        }
205        self.flush()?;
206        Ok(())
207    }
208}
209
210#[cfg(test)]
211mod tests {
212    use super::ReadError::*;
213    use super::*;
214    use csv::{Reader, ReaderBuilder, WriterBuilder};
215    use std::io::Cursor;
216
217    fn in_memory_reader(content: &'static str) -> Reader<impl Read> {
218        ReaderBuilder::new()
219            .has_headers(false)
220            .from_reader(Cursor::new(content))
221    }
222
223    fn test_reader() -> Reader<impl Read> {
224        in_memory_reader("1,2,3\n4,5,6\n")
225    }
226
227    #[test]
228    fn test_read_float() {
229        let actual: Array2<f64> = test_reader().deserialize_array2((2, 3)).unwrap();
230        let expected = array![[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]];
231        assert_eq!(actual, expected);
232    }
233
234    #[test]
235    fn test_read_integer() {
236        let actual: Array2<u64> = test_reader().deserialize_array2((2, 3)).unwrap();
237        let expected = array![[1, 2, 3], [4, 5, 6]];
238        assert_eq!(actual, expected);
239    }
240
241    #[test]
242    fn test_read_dynamic() {
243        let actual: Array2<u64> = test_reader().deserialize_array2_dynamic().unwrap();
244        let expected = array![[1, 2, 3], [4, 5, 6]];
245        assert_eq!(actual, expected);
246    }
247
248    #[test]
249    fn test_read_csv_error() {
250        in_memory_reader("1,2,3\n4,x,6\n")
251            .deserialize_array2::<i8>((2, 3))
252            .unwrap_err();
253    }
254
255    #[test]
256    fn test_read_too_few_rows() {
257        assert_matches! {
258            test_reader().deserialize_array2::<i8>((3, 3)).unwrap_err(),
259            NRows { expected: 3, actual: 2 }
260        }
261    }
262
263    #[test]
264    fn test_read_too_many_rows() {
265        assert_matches! {
266            test_reader().deserialize_array2::<i8>((1, 3)).unwrap_err(),
267            NRows { expected: 1, actual: 2 }
268        }
269    }
270
271    #[test]
272    fn test_read_too_few_columns() {
273        assert_matches! {
274            test_reader().deserialize_array2::<i8>((2, 4)).unwrap_err(),
275            NColumns { at_row_index: 0, expected: 4, actual: 3 }
276        }
277    }
278
279    #[test]
280    fn test_read_too_many_columns() {
281        assert_matches! {
282            test_reader().deserialize_array2::<i8>((2, 2)).unwrap_err(),
283            NColumns { at_row_index: 0, expected: 2, actual: 3 }
284        }
285    }
286
287    #[test]
288    fn test_write_ok() {
289        let mut writer = WriterBuilder::new().has_headers(false).from_writer(vec![]);
290
291        assert_matches! {
292            writer.serialize_array2(&array![[1, 2, 3], [4, 5, 6]]),
293            Ok(())
294        }
295        assert_eq!(
296            writer.into_inner().expect("flush failed"),
297            b"1,2,3\n4,5,6\n"
298        );
299    }
300
301    #[test]
302    fn test_write_transposed() {
303        let mut writer = WriterBuilder::new().has_headers(false).from_writer(vec![]);
304
305        assert_matches! {
306            writer.serialize_array2(&array![[1, 4], [2, 5], [3, 6]].t().to_owned()),
307            Ok(())
308        }
309
310        assert_eq!(
311            writer.into_inner().expect("flush failed"),
312            b"1,2,3\n4,5,6\n"
313        );
314    }
315
316    #[test]
317    fn test_write_err() {
318        let destination: &mut [u8] = &mut [0; 8];
319        let mut writer = WriterBuilder::new()
320            .has_headers(false)
321            .from_writer(Cursor::new(destination));
322
323        // The destination is too short
324        assert_matches! {
325            writer.serialize_array2(&array![[1, 2, 3], [4, 5, 6]]),
326            Err(_)
327        }
328    }
329}