clamav_stream/
lib.rs

1//! A [`ScannedStream`] sends the inner stream to [clamav](https://www.clamav.net/) to scan its
2//! contents while passes it through to the stream consumer.
3//!
4//! If a virus is detected by the clamav, it returns Err as stream chunk otherwise it just passes the
5//! inner stream through to the consumer.
6//!
7//! ## When the byte stream is clean
8//!
9//! There are no deferences between consuming [`ScannedStream`] and its inner stream.
10//! ```rust,no_run
11//! use clamav_stream::ScannedStream;
12//!
13//! use bytes::Bytes;
14//! use std::net::TcpStream;
15//! use tokio::fs::File;
16//! use tokio_stream::StreamExt;
17//! use tokio_util::io::ReaderStream;
18//!
19//! #[tokio::main]
20//! async fn main() {
21//!     let file = File::open("tests/clean.txt").await.unwrap();
22//!     let mut input = ReaderStream::new(file);
23//!
24//!     let addr = "localhost:3310"; // tcp address to clamav server.
25//!     let mut stream = ScannedStream::<_, TcpStream>::tcp(&mut input, addr).unwrap();
26//!
27//!     // The result of consuming ScannedStream is equal to consuming the input stream.
28//!     assert_eq!(stream.next().await, Some(Ok(Bytes::from("file contents 1st"))));
29//!     assert_eq!(stream.next().await, Some(Ok(Bytes::from("file contents 2nd"))));
30//!     // ... continue until all contents are consumed ...
31//!     assert_eq!(stream.next().await, Some(Ok(Bytes::from("file contents last"))));
32//!     assert_eq!(stream.next().await, None);
33//! }
34//! ```
35//!
36//! ## When the byte stream is infected
37//!
38//! An Err is returned after all contents are consumed.
39//! ```rust,no_run
40//! use clamav_stream::{Error, ScannedStream};
41//!
42//! use bytes::Bytes;
43//! use std::net::TcpStream;
44//! use tokio::fs::File;
45//! use tokio_stream::StreamExt;
46//! use tokio_util::io::ReaderStream;
47//!
48//! #[tokio::main]
49//! async fn main() {
50//!     let file = File::open("tests/eicar.txt").await.unwrap();
51//!     let mut input = ReaderStream::new(file);
52//!
53//!     let addr = "localhost:3310"; // tcp address to clamav server.
54//!     let mut stream = ScannedStream::<_, TcpStream>::tcp(&mut input, addr).unwrap();
55//!
56//!     // An Err is returned after all contents are consumed.
57//!     assert_eq!(stream.next().await, Some(Ok(Bytes::from("file contents 1st"))));
58//!     assert_eq!(stream.next().await, Some(Ok(Bytes::from("file contents 2nd"))));
59//!     // ... continue until all contents are consumed ...
60//!     assert_eq!(stream.next().await, Some(Ok(Bytes::from("file contents last"))));
61//!     assert_eq!(stream.next().await, Some(Err(Error::Scan("message from clamav".into()))));
62//!     assert_eq!(stream.next().await, None);
63//! }
64//! ```
65
66mod error;
67pub use error::Error;
68
69use pin_project::pin_project;
70use std::{
71    error::Error as StdError,
72    io::{Read, Write},
73    net::{TcpStream, ToSocketAddrs},
74    path::Path,
75    pin::{pin, Pin},
76    task::{Context, Poll},
77};
78use tokio_stream::Stream;
79
80#[cfg(unix)]
81use std::os::unix::net::UnixStream;
82
83const START: &[u8; 10] = b"zINSTREAM\0";
84const FINISH: &[u8; 4] = &[0, 0, 0, 0];
85const CHUNK_SIZE: usize = 4096;
86
87/// A wrapper stream holding byte stream. This sends the inner stream to [clamav](https://www.clamav.net/) to scan it while passes it through to the consumer.
88#[pin_project]
89pub struct ScannedStream<'a, St: ?Sized, RW: Read + Write> {
90    #[pin]
91    input: &'a mut St,
92    inner: RW,
93    started: bool,
94    finished: bool,
95}
96
97macro_rules! write_clamav {
98    ($stream:expr, $bytes:expr) => {
99        if let Err(err) = write_stream($stream, $bytes) {
100            return Poll::Ready(Some(Err(err)));
101        }
102    };
103}
104
105macro_rules! read_clamav {
106    ($stream:expr) => {
107        if let Err(err) = read_stream_response($stream) {
108            return Poll::Ready(Some(Err(err)));
109        }
110    };
111}
112
113impl<'a, St, RW, E> Stream for ScannedStream<'a, St, RW>
114where
115    St: Stream<Item = Result<bytes::Bytes, E>> + Unpin + ?Sized,
116    RW: Read + Write,
117    E: StdError + Send + Sync + 'static,
118{
119    type Item = Result<bytes::Bytes, Error>;
120
121    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
122        let me = self.project();
123        match me.input.poll_next(cx) {
124            Poll::Pending => Poll::Pending,
125            Poll::Ready(Some(Ok(bytes))) => {
126                if !*me.started {
127                    *me.started = true;
128                    write_clamav!(me.inner, START);
129                }
130
131                for chunk in bytes.as_ref().chunks(CHUNK_SIZE) {
132                    let len = chunk.len() as u32;
133                    write_clamav!(me.inner, &len.to_be_bytes());
134                    write_clamav!(me.inner, chunk);
135                }
136
137                Poll::Ready(Some(Ok(bytes)))
138            }
139            Poll::Ready(Some(Err(err))) => Poll::Ready(Some(Err(Error::Stream(Box::new(err))))),
140            Poll::Ready(None) => {
141                if *me.finished {
142                    return Poll::Ready(None);
143                }
144
145                *me.finished = true;
146                write_clamav!(me.inner, FINISH);
147                read_clamav!(me.inner);
148
149                Poll::Ready(None)
150            }
151        }
152    }
153}
154
155impl<'a, St, RW, E> ScannedStream<'a, St, RW>
156where
157    St: Stream<Item = Result<bytes::Bytes, E>> + Unpin + ?Sized,
158    RW: Read + Write,
159    E: StdError,
160{
161    /// Create a new [`ScannedStream`]
162    pub fn new(input: &'a mut St, inner: RW) -> Self {
163        Self {
164            input,
165            inner,
166            started: false,
167            finished: false,
168        }
169    }
170
171    /// Create a new [`ScannedStream`] connecting to clamav server with tcp socket.
172    pub fn tcp(
173        input: &'a mut St,
174        addr: impl ToSocketAddrs,
175    ) -> Result<ScannedStream<'a, St, TcpStream>, Error> {
176        let inner = TcpStream::connect(addr)?;
177        Ok(ScannedStream::new(input, inner))
178    }
179
180    /// Create a new [`ScannedStream`] connecting to clamav server with unix socket.
181    #[cfg(unix)]
182    pub fn socket(
183        input: &'a mut St,
184        path: impl AsRef<Path>,
185    ) -> Result<ScannedStream<'a, St, UnixStream>, Error> {
186        let inner = UnixStream::connect(path)?;
187        Ok(ScannedStream::new(input, inner))
188    }
189}
190
191fn write_stream(stream: &mut impl Write, buf: &[u8]) -> Result<(), Error> {
192    stream.write_all(buf)?;
193    Ok(())
194}
195
196fn read_stream_response(stream: &mut impl Read) -> Result<(), Error> {
197    let mut body: Vec<u8> = vec![];
198    stream.read_to_end(&mut body)?;
199
200    let res = std::str::from_utf8(&body)?;
201
202    if res.contains("OK") && !res.contains("FOUND") {
203        Ok(())
204    } else {
205        Err(Error::Scan(res.to_string()))
206    }
207}
208
209#[cfg(test)]
210mod tests {
211    use super::*;
212    use bytes::Bytes;
213    use std::io::{self, Cursor};
214    use tokio_stream::StreamExt;
215
216    #[tokio::test]
217    async fn it_returns_original_inputs_when_success() {
218        let mut input = tokio_stream::iter(stream_from_str("Hello World"));
219        let mut inner = MockStream::new("OK");
220
221        let stream = ScannedStream::new(&mut input, &mut inner);
222        let result = consume(stream).await;
223        assert!(result.is_ok());
224        assert_eq!(result.unwrap(), "Hello World");
225
226        assert_eq!(inner.written.len(), 4);
227        assert_eq!(inner.written.first().unwrap(), "zINSTREAM\0");
228        assert_eq!(
229            inner.written.get(1).unwrap(),
230            &String::from_utf8(("Hello World".len() as u32).to_be_bytes().to_vec()).unwrap(),
231        );
232        assert_eq!(inner.written.get(2).unwrap(), "Hello World");
233        assert_eq!(
234            inner.written.get(3).unwrap(),
235            &String::from_utf8(vec![0, 0, 0, 0]).unwrap(),
236        );
237    }
238
239    #[tokio::test]
240    async fn it_returns_an_error_when_found_any_virus() {
241        let mut input = tokio_stream::iter(stream_from_str("Hello World"));
242        let mut inner = MockStream::new("FOUND test virus");
243
244        let stream = ScannedStream::new(&mut input, &mut inner);
245        let result = consume(stream).await;
246        assert!(result.is_err());
247        assert_eq!(result.unwrap_err().to_string(), "FOUND test virus");
248    }
249
250    struct MockStream {
251        written: Vec<String>,
252        output: Cursor<Vec<u8>>,
253    }
254
255    impl MockStream {
256        fn new(value: &str) -> Self {
257            Self {
258                written: vec![],
259                output: Cursor::new(value.as_bytes().to_vec()),
260            }
261        }
262    }
263
264    impl Read for MockStream {
265        fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
266            self.output.read(buf)
267        }
268    }
269
270    impl Write for MockStream {
271        fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
272            self.written.push(String::from_utf8(buf.to_vec()).unwrap());
273            Ok(buf.len())
274        }
275
276        fn flush(&mut self) -> io::Result<()> {
277            Ok(())
278        }
279    }
280
281    fn stream_from_str(value: &'static str) -> impl Iterator<Item = Result<Bytes, Error>> {
282        [Ok(Bytes::from(value))].into_iter()
283    }
284
285    async fn consume<S>(mut stream: S) -> Result<String, Error>
286    where
287        S: Stream<Item = Result<Bytes, Error>> + Unpin,
288    {
289        let mut bytes: Vec<u8> = vec![];
290
291        while let Some(chunk) = stream.next().await {
292            let chunk = chunk?;
293            bytes.append(&mut chunk.into());
294        }
295
296        let res = std::str::from_utf8(&bytes)?;
297        Ok(res.to_string())
298    }
299}