1mod error;
67pub use error::Error;
68
69use pin_project::pin_project;
70use std::{
71 error::Error as StdError,
72 io::{Read, Write},
73 net::{TcpStream, ToSocketAddrs},
74 path::Path,
75 pin::{pin, Pin},
76 task::{Context, Poll},
77};
78use tokio_stream::Stream;
79
80#[cfg(unix)]
81use std::os::unix::net::UnixStream;
82
83const START: &[u8; 10] = b"zINSTREAM\0";
84const FINISH: &[u8; 4] = &[0, 0, 0, 0];
85const CHUNK_SIZE: usize = 4096;
86
87#[pin_project]
89pub struct ScannedStream<'a, St: ?Sized, RW: Read + Write> {
90 #[pin]
91 input: &'a mut St,
92 inner: RW,
93 started: bool,
94 finished: bool,
95}
96
97macro_rules! write_clamav {
98 ($stream:expr, $bytes:expr) => {
99 if let Err(err) = write_stream($stream, $bytes) {
100 return Poll::Ready(Some(Err(err)));
101 }
102 };
103}
104
105macro_rules! read_clamav {
106 ($stream:expr) => {
107 if let Err(err) = read_stream_response($stream) {
108 return Poll::Ready(Some(Err(err)));
109 }
110 };
111}
112
113impl<'a, St, RW, E> Stream for ScannedStream<'a, St, RW>
114where
115 St: Stream<Item = Result<bytes::Bytes, E>> + Unpin + ?Sized,
116 RW: Read + Write,
117 E: StdError + Send + Sync + 'static,
118{
119 type Item = Result<bytes::Bytes, Error>;
120
121 fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
122 let me = self.project();
123 match me.input.poll_next(cx) {
124 Poll::Pending => Poll::Pending,
125 Poll::Ready(Some(Ok(bytes))) => {
126 if !*me.started {
127 *me.started = true;
128 write_clamav!(me.inner, START);
129 }
130
131 for chunk in bytes.as_ref().chunks(CHUNK_SIZE) {
132 let len = chunk.len() as u32;
133 write_clamav!(me.inner, &len.to_be_bytes());
134 write_clamav!(me.inner, chunk);
135 }
136
137 Poll::Ready(Some(Ok(bytes)))
138 }
139 Poll::Ready(Some(Err(err))) => Poll::Ready(Some(Err(Error::Stream(Box::new(err))))),
140 Poll::Ready(None) => {
141 if *me.finished {
142 return Poll::Ready(None);
143 }
144
145 *me.finished = true;
146 write_clamav!(me.inner, FINISH);
147 read_clamav!(me.inner);
148
149 Poll::Ready(None)
150 }
151 }
152 }
153}
154
155impl<'a, St, RW, E> ScannedStream<'a, St, RW>
156where
157 St: Stream<Item = Result<bytes::Bytes, E>> + Unpin + ?Sized,
158 RW: Read + Write,
159 E: StdError,
160{
161 pub fn new(input: &'a mut St, inner: RW) -> Self {
163 Self {
164 input,
165 inner,
166 started: false,
167 finished: false,
168 }
169 }
170
171 pub fn tcp(
173 input: &'a mut St,
174 addr: impl ToSocketAddrs,
175 ) -> Result<ScannedStream<'a, St, TcpStream>, Error> {
176 let inner = TcpStream::connect(addr)?;
177 Ok(ScannedStream::new(input, inner))
178 }
179
180 #[cfg(unix)]
182 pub fn socket(
183 input: &'a mut St,
184 path: impl AsRef<Path>,
185 ) -> Result<ScannedStream<'a, St, UnixStream>, Error> {
186 let inner = UnixStream::connect(path)?;
187 Ok(ScannedStream::new(input, inner))
188 }
189}
190
191fn write_stream(stream: &mut impl Write, buf: &[u8]) -> Result<(), Error> {
192 stream.write_all(buf)?;
193 Ok(())
194}
195
196fn read_stream_response(stream: &mut impl Read) -> Result<(), Error> {
197 let mut body: Vec<u8> = vec![];
198 stream.read_to_end(&mut body)?;
199
200 let res = std::str::from_utf8(&body)?;
201
202 if res.contains("OK") && !res.contains("FOUND") {
203 Ok(())
204 } else {
205 Err(Error::Scan(res.to_string()))
206 }
207}
208
209#[cfg(test)]
210mod tests {
211 use super::*;
212 use bytes::Bytes;
213 use std::io::{self, Cursor};
214 use tokio_stream::StreamExt;
215
216 #[tokio::test]
217 async fn it_returns_original_inputs_when_success() {
218 let mut input = tokio_stream::iter(stream_from_str("Hello World"));
219 let mut inner = MockStream::new("OK");
220
221 let stream = ScannedStream::new(&mut input, &mut inner);
222 let result = consume(stream).await;
223 assert!(result.is_ok());
224 assert_eq!(result.unwrap(), "Hello World");
225
226 assert_eq!(inner.written.len(), 4);
227 assert_eq!(inner.written.first().unwrap(), "zINSTREAM\0");
228 assert_eq!(
229 inner.written.get(1).unwrap(),
230 &String::from_utf8(("Hello World".len() as u32).to_be_bytes().to_vec()).unwrap(),
231 );
232 assert_eq!(inner.written.get(2).unwrap(), "Hello World");
233 assert_eq!(
234 inner.written.get(3).unwrap(),
235 &String::from_utf8(vec![0, 0, 0, 0]).unwrap(),
236 );
237 }
238
239 #[tokio::test]
240 async fn it_returns_an_error_when_found_any_virus() {
241 let mut input = tokio_stream::iter(stream_from_str("Hello World"));
242 let mut inner = MockStream::new("FOUND test virus");
243
244 let stream = ScannedStream::new(&mut input, &mut inner);
245 let result = consume(stream).await;
246 assert!(result.is_err());
247 assert_eq!(result.unwrap_err().to_string(), "FOUND test virus");
248 }
249
250 struct MockStream {
251 written: Vec<String>,
252 output: Cursor<Vec<u8>>,
253 }
254
255 impl MockStream {
256 fn new(value: &str) -> Self {
257 Self {
258 written: vec![],
259 output: Cursor::new(value.as_bytes().to_vec()),
260 }
261 }
262 }
263
264 impl Read for MockStream {
265 fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
266 self.output.read(buf)
267 }
268 }
269
270 impl Write for MockStream {
271 fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
272 self.written.push(String::from_utf8(buf.to_vec()).unwrap());
273 Ok(buf.len())
274 }
275
276 fn flush(&mut self) -> io::Result<()> {
277 Ok(())
278 }
279 }
280
281 fn stream_from_str(value: &'static str) -> impl Iterator<Item = Result<Bytes, Error>> {
282 [Ok(Bytes::from(value))].into_iter()
283 }
284
285 async fn consume<S>(mut stream: S) -> Result<String, Error>
286 where
287 S: Stream<Item = Result<Bytes, Error>> + Unpin,
288 {
289 let mut bytes: Vec<u8> = vec![];
290
291 while let Some(chunk) = stream.next().await {
292 let chunk = chunk?;
293 bytes.append(&mut chunk.into());
294 }
295
296 let res = std::str::from_utf8(&bytes)?;
297 Ok(res.to_string())
298 }
299}