1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
//! Easily read and write homogeneous CSV data to and from 2D ndarrays.
//!
//! ```rust
//! extern crate csv;
//! extern crate ndarray;
//! extern crate ndarray_csv;
//!
//! use csv::{ReaderBuilder, WriterBuilder};
//! use ndarray::{Array, Array2};
//! use ndarray_csv::{Array2Reader, Array2Writer};
//! use std::error::Error;
//! use std::fs::File;
//!
//! fn main() -> Result<(), Box<dyn Error>> {
//!     // Our 2x3 test array
//!     let array = Array::from(vec![1, 2, 3, 4, 5, 6]).into_shape((2, 3)).unwrap();
//!
//!     // Write the array into the file.
//!     {
//!         let file = File::create("test.csv")?;
//!         let mut writer = WriterBuilder::new().has_headers(false).from_writer(file);
//!         writer.serialize_array2(&array)?;
//!     }
//!
//!     // Read an array back from the file
//!     let file = File::open("test.csv")?;
//!     let mut reader = ReaderBuilder::new().has_headers(false).from_reader(file);
//!     let array_read: Array2<u64> = reader.deserialize_array2((2, 3))?;
//!
//!     // Ensure that we got the original array back
//!     assert_eq!(array_read, array);
//!     Ok(())
//! }
//! ```
//!
//! This project uses [cargo-make](https://sagiegurari.github.io/cargo-make/) for builds; to build,
//! run `cargo make all`.
//!
//! To prevent denial-of-service attacks, do not read in untrusted CSV streams of unbounded length;
//! this can be implemented with `std::io::Read::take`.
extern crate csv;
extern crate either;
#[cfg(test)]
#[macro_use]
extern crate matches;
#[cfg_attr(test, macro_use(array))]
extern crate ndarray;
extern crate serde;

use csv::{Reader, Writer};
use either::Either;
use ndarray::iter::Iter;
use ndarray::{Array1, Array2, Dim};
use serde::de::DeserializeOwned;
use serde::{Serialize, Serializer};
use std::cell::Cell;
use std::error::Error;
use std::fmt::{Display, Formatter};
use std::io::{Read, Write};
use std::iter::once;

/// An extension trait; this is implemented by `&mut csv::Reader`
pub trait Array2Reader {
    /// Read CSV data into a new ndarray with the given shape
    fn deserialize_array2<A: DeserializeOwned>(
        self,
        shape: (usize, usize),
    ) -> Result<Array2<A>, ReadError>;

    fn deserialize_array2_dynamic<A: DeserializeOwned>(self) -> Result<Array2<A>, ReadError>;
}

#[derive(Debug)]
pub enum ReadError {
    Csv(csv::Error),
    NRows {
        expected: usize,
        actual: usize,
    },
    NColumns {
        at_row_index: usize,
        expected: usize,
        actual: usize,
    },
}

impl Display for ReadError {
    fn fmt(&self, f: &mut Formatter) -> Result<(), std::fmt::Error> {
        match self {
            ReadError::Csv(csv_error) => csv_error.fmt(f),
            ReadError::NRows { expected, actual } => {
                write!(f, "Expected {} rows but got {} rows", expected, actual)
            }
            ReadError::NColumns {
                at_row_index,
                expected,
                actual,
            } => write!(
                f,
                "On row {}, expected {} columns but got {} columns",
                at_row_index, expected, actual
            ),
        }
    }
}

impl Error for ReadError {}

impl<'a, R: Read> Array2Reader for &'a mut Reader<R> {
    fn deserialize_array2<A: DeserializeOwned>(
        self,
        shape: (usize, usize),
    ) -> Result<Array2<A>, ReadError> {
        let (n_rows, n_columns) = shape;

        let rows = self.deserialize::<Vec<A>>();
        let values = rows.enumerate().flat_map(|(row_index, row)| match row {
            Err(e) => Either::Left(once(Err(ReadError::Csv(e)))),
            Ok(row_vec) => Either::Right(if row_vec.len() == n_columns {
                Either::Right(row_vec.into_iter().map(Ok))
            } else {
                Either::Left(once(Err(ReadError::NColumns {
                    at_row_index: row_index,
                    expected: n_columns,
                    actual: row_vec.len(),
                })))
            }),
        });
        let array1_result: Result<Array1<A>, _> = values.collect();
        array1_result.and_then(|array1| {
            let array1_len = array1.len();
            array1.into_shape(shape).map_err(|_| ReadError::NRows {
                expected: n_rows,
                actual: array1_len / n_columns,
            })
        })
    }

    fn deserialize_array2_dynamic<A: DeserializeOwned>(self) -> Result<Array2<A>, ReadError> {
        let mut row_count = 0;
        let mut last_columns = None;

        let rows = self.deserialize::<Vec<A>>();
        let values = rows.enumerate().flat_map(|(row_index, row)| {
            row_count += 1;
            match row {
                Err(e) => Either::Left(once(Err(ReadError::Csv(e)))),
                Ok(row_vec) => {
                    if let Some(last_columns) = last_columns {
                        if last_columns != row_vec.len() {
                            return Either::Right(Either::Left(once(Err(ReadError::NColumns {
                                at_row_index: row_index,
                                expected: last_columns,
                                actual: row_vec.len(),
                            }))));
                        }
                    };
                    last_columns = Some(row_vec.len());
                    Either::Right(Either::Right(row_vec.into_iter().map(Ok)))
                }
            }
        });
        let array1_result: Result<Array1<A>, _> = values.collect();
        array1_result.map(|array1| {
            array1
                .into_shape((row_count, last_columns.unwrap_or(0)))
                .unwrap()
        })
    }
}

/// An extension trait; this is implemented by `&mut csv::Writer`
pub trait Array2Writer {
    /// Write this ndarray into CSV format
    fn serialize_array2<A: Serialize>(self, array: &Array2<A>) -> Result<(), csv::Error>;
}

impl<'a, W: Write> Array2Writer for &'a mut Writer<W> {
    fn serialize_array2<A: Serialize>(self, array: &Array2<A>) -> Result<(), csv::Error> {
        /// This wraps the iterator for a row so that we can implement Serialize.
        ///
        /// Serialize is not implemented for iterators: https://github.com/serde-rs/serde/issues/571
        ///
        /// This solution from Hyeonu wraps the iterator:
        /// https://users.rust-lang.org/t/how-to-serialize-an-iterator-to-json/59272/3
        struct Row1DIter<'b, B>(Cell<Option<Iter<'b, B, Dim<[usize; 1]>>>>);

        impl<'b, B> Serialize for Row1DIter<'b, B>
        where
            B: Serialize,
        {
            fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
            where
                S: Serializer,
            {
                serializer.collect_seq(self.0.take().unwrap())
            }
        }

        for row in array.outer_iter() {
            self.serialize(Row1DIter(Cell::new(Some(row.iter()))))?;
        }
        self.flush()?;
        Ok(())
    }
}

#[cfg(test)]
mod tests {
    use super::ReadError::*;
    use super::*;
    use csv::{Reader, ReaderBuilder, WriterBuilder};
    use std::io::Cursor;

    fn in_memory_reader(content: &'static str) -> Reader<impl Read> {
        ReaderBuilder::new()
            .has_headers(false)
            .from_reader(Cursor::new(content))
    }

    fn test_reader() -> Reader<impl Read> {
        in_memory_reader("1,2,3\n4,5,6\n")
    }

    #[test]
    fn test_read_float() {
        let actual: Array2<f64> = test_reader().deserialize_array2((2, 3)).unwrap();
        let expected = array![[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]];
        assert_eq!(actual, expected);
    }

    #[test]
    fn test_read_integer() {
        let actual: Array2<u64> = test_reader().deserialize_array2((2, 3)).unwrap();
        let expected = array![[1, 2, 3], [4, 5, 6]];
        assert_eq!(actual, expected);
    }

    #[test]
    fn test_read_dynamic() {
        let actual: Array2<u64> = test_reader().deserialize_array2_dynamic().unwrap();
        let expected = array![[1, 2, 3], [4, 5, 6]];
        assert_eq!(actual, expected);
    }

    #[test]
    fn test_read_csv_error() {
        in_memory_reader("1,2,3\n4,x,6\n")
            .deserialize_array2::<i8>((2, 3))
            .unwrap_err();
    }

    #[test]
    fn test_read_too_few_rows() {
        assert_matches! {
            test_reader().deserialize_array2::<i8>((3, 3)).unwrap_err(),
            NRows { expected: 3, actual: 2 }
        }
    }

    #[test]
    fn test_read_too_many_rows() {
        assert_matches! {
            test_reader().deserialize_array2::<i8>((1, 3)).unwrap_err(),
            NRows { expected: 1, actual: 2 }
        }
    }

    #[test]
    fn test_read_too_few_columns() {
        assert_matches! {
            test_reader().deserialize_array2::<i8>((2, 4)).unwrap_err(),
            NColumns { at_row_index: 0, expected: 4, actual: 3 }
        }
    }

    #[test]
    fn test_read_too_many_columns() {
        assert_matches! {
            test_reader().deserialize_array2::<i8>((2, 2)).unwrap_err(),
            NColumns { at_row_index: 0, expected: 2, actual: 3 }
        }
    }

    #[test]
    fn test_write_ok() {
        let mut writer = WriterBuilder::new().has_headers(false).from_writer(vec![]);

        assert_matches! {
            writer.serialize_array2(&array![[1, 2, 3], [4, 5, 6]]),
            Ok(())
        }
        assert_eq!(
            writer.into_inner().expect("flush failed"),
            b"1,2,3\n4,5,6\n"
        );
    }

    #[test]
    fn test_write_transposed() {
        let mut writer = WriterBuilder::new().has_headers(false).from_writer(vec![]);

        assert_matches! {
            writer.serialize_array2(&array![[1, 4], [2, 5], [3, 6]].t().to_owned()),
            Ok(())
        }

        assert_eq!(
            writer.into_inner().expect("flush failed"),
            b"1,2,3\n4,5,6\n"
        );
    }

    #[test]
    fn test_write_err() {
        let destination: &mut [u8] = &mut [0; 8];
        let mut writer = WriterBuilder::new()
            .has_headers(false)
            .from_writer(Cursor::new(destination));

        // The destination is too short
        assert_matches! {
            writer.serialize_array2(&array![[1, 2, 3], [4, 5, 6]]),
            Err(_)
        }
    }
}