csv_stream/
stream.rs

1use pin_project::pin_project;
2use serde::Serialize;
3
4use crate::{Result, Writer};
5
6/// A Streamable CSV creator
7///
8/// # Example
9///
10/// ```
11/// use std::error::Error;
12/// use csv_stream::WriterBuilder;
13/// use serde::Serialize;
14/// use futures::StreamExt;
15///
16/// # #[tokio::main]
17/// # async fn main() { example().await.unwrap(); }
18/// async fn example() -> Result<(), Box<dyn Error>> {
19///     #[derive(Serialize)]
20///     struct Row { foo: usize, bar: usize }
21///     let rows = [
22///         Row{ foo: 1, bar: 2 },
23///         Row{ foo: 3, bar: 4 },
24///     ];
25///     // a Stream over rows
26///     let stream = futures::stream::iter(rows);
27///
28///     let mut csv_stream = WriterBuilder::default().build_stream(stream);
29///
30///     let mut buf = vec![];
31///     while let Some(row) = csv_stream.next().await {
32///         let row = row.unwrap();
33///         buf.extend_from_slice(&row);
34///     }
35///
36///     let data = String::from_utf8(buf)?;
37///     assert_eq!(data, "foo,bar\n1,2\n3,4\n");
38///     Ok(())
39/// }
40/// ```
41#[pin_project]
42pub struct Stream<S> {
43    #[pin]
44    stream: S,
45
46    writer: Writer,
47}
48impl<S> Stream<S> {
49    pub fn new(stream: S, writer: Writer) -> Self {
50        Self { stream, writer }
51    }
52}
53
54impl<S: futures::Stream> futures::Stream for Stream<S>
55where
56    S::Item: Serialize,
57{
58    type Item = Result<Vec<u8>>;
59
60    fn poll_next(
61        self: std::pin::Pin<&mut Self>,
62        cx: &mut std::task::Context<'_>,
63    ) -> std::task::Poll<Option<Self::Item>> {
64        let p = self.project();
65        let s = match p.stream.poll_next(cx) {
66            std::task::Poll::Pending => return std::task::Poll::Pending,
67            std::task::Poll::Ready(None) => return std::task::Poll::Ready(None),
68            std::task::Poll::Ready(Some(s)) => s,
69        };
70
71        let mut buf = vec![];
72        p.writer.serialize(&mut buf, s)?;
73        std::task::Poll::Ready(Some(Ok(buf)))
74    }
75}
76
77#[cfg(test)]
78mod tests {
79    use crate::{Terminator, WriterBuilder};
80    use serde::Serialize;
81
82    use super::Stream;
83    use futures::StreamExt;
84
85    #[derive(Serialize)]
86    struct Row<'a> {
87        city: &'a str,
88        country: &'a str,
89        // Serde allows us to name our headers exactly,
90        // even if they don't match our struct field names.
91        #[serde(rename = "popcount")]
92        population: u64,
93    }
94
95    const ROWS: [Row<'static>; 2] = [
96        Row {
97            city: "Boston",
98            country: "United States",
99            population: 4628910,
100        },
101        Row {
102            city: "Concord",
103            country: "United States",
104            population: 42695,
105        },
106    ];
107
108    #[tokio::test]
109    async fn serialize() {
110
111        let writer = WriterBuilder::default().build();
112
113        let row_stream = futures::stream::iter(ROWS);
114        let csv_stream = Stream::new(row_stream, writer);
115
116        let buf = csv_stream
117            .map(Result::unwrap)
118            .map(futures::stream::iter)
119            .flatten()
120            .collect()
121            .await;
122
123        let buf = String::from_utf8(buf).unwrap();
124
125        assert_eq!(
126            buf,
127            r#"city,country,popcount
128Boston,United States,4628910
129Concord,United States,42695
130"#
131        )
132    }
133
134    #[tokio::test]
135    async fn config() {
136        let writer = WriterBuilder::default()
137            .has_headers(false)
138            .delimiter(b';')
139            .terminator(Terminator::CRLF)
140            .build();
141
142        let row_stream = futures::stream::iter(ROWS);
143        let csv_stream = Stream::new(row_stream, writer);
144
145        let buf = csv_stream
146            .map(Result::unwrap)
147            .map(futures::stream::iter)
148            .flatten()
149            .collect()
150            .await;
151
152        let buf = String::from_utf8(buf).unwrap();
153
154        assert_eq!(
155            buf,
156            r#"Boston;United States;4628910
157Concord;United States;42695
158"#.replace("\n", "\r\n")
159        )
160    }
161}