1use bytes::{Buf, BufMut, Bytes, BytesMut};
12use futures::SinkExt;
13use futures::StreamExt;
14use std::io::Cursor;
15use std::path::Path;
16use std::path::PathBuf;
17use std::pin::Pin;
18use tokio::fs::File;
19use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite};
20use tokio::net::{TcpStream, ToSocketAddrs, UnixStream};
21use tokio_util::codec::Decoder;
22use tokio_util::codec::Encoder;
23use tokio_util::codec::Framed;
24use tracing::trace;
25
26use crate::error::Result;
27
28mod error;
29
30pub use error::ClamdError;
31
32pub const DEFAULT_CHUNK_SIZE: usize = 8192;
34
35enum ClamdRequestMessage {
36 Ping,
37 Version,
38 Reload,
39 Shutdown,
40 Stats,
41 StartStream,
42 StreamChunk(Bytes),
43 EndStream,
44}
45
46struct ClamdZeroDelimitedCodec {
47 next_index: usize,
48}
49
50impl ClamdZeroDelimitedCodec {
51 fn new() -> Self {
52 Self { next_index: 0 }
53 }
54}
55
56impl Encoder<ClamdRequestMessage> for ClamdZeroDelimitedCodec {
57 type Error = ClamdError;
58
59 fn encode(&mut self, item: ClamdRequestMessage, dst: &mut BytesMut) -> Result<()> {
60 match item {
61 ClamdRequestMessage::Ping => {
62 dst.reserve(6);
63 dst.put(&b"zPING"[..]);
64 dst.put_u8(0);
65 Ok(())
66 }
67 ClamdRequestMessage::Version => {
68 dst.reserve(9);
69 dst.put(&b"zVERSION"[..]);
70 dst.put_u8(0);
71 Ok(())
72 }
73 ClamdRequestMessage::Reload => {
74 dst.reserve(8);
75 dst.put(&b"zRELOAD"[..]);
76 dst.put_u8(0);
77 Ok(())
78 }
79 ClamdRequestMessage::Stats => {
80 dst.reserve(7);
81 dst.put(&b"zSTATS"[..]);
82 dst.put_u8(0);
83 Ok(())
84 }
85 ClamdRequestMessage::Shutdown => {
86 dst.reserve(10);
87 dst.put(&b"zSHUTDOWN"[..]);
88 dst.put_u8(0);
89 Ok(())
90 }
91 ClamdRequestMessage::StartStream => {
92 dst.reserve(10);
93 dst.put(&b"zINSTREAM"[..]);
94 dst.put_u8(0);
95 Ok(())
96 }
97 ClamdRequestMessage::StreamChunk(bytes) => {
98 dst.reserve(4);
99 dst.put_u32(bytes.len().try_into().map_err(ClamdError::ChunkSizeError)?);
100 dst.extend_from_slice(&bytes);
101 Ok(())
102 }
103
104 ClamdRequestMessage::EndStream => {
105 dst.reserve(4);
106 dst.put_u32(0);
107 Ok(())
108 }
109 }
110 }
111}
112
113impl Decoder for ClamdZeroDelimitedCodec {
114 type Item = String;
115
116 type Error = ClamdError;
117
118 fn decode(&mut self, src: &mut BytesMut) -> Result<Option<Self::Item>> {
119 if let Some(rel_split_pos) = src[self.next_index..].iter().position(|&x| x == 0u8) {
120 let split_pos = rel_split_pos + self.next_index;
121 let chunk = src.split_to(split_pos).freeze();
122 src.advance(1);
123 self.next_index = 0;
124 let s = String::from_utf8(chunk.into()).map_err(ClamdError::DecodingUtf8Error)?;
125 Ok(Some(s))
126 } else {
127 self.next_index = src.len();
128 Ok(None)
129 }
130 }
131}
132
133enum SocketType<T: ToSocketAddrs + ToOwned<Owned = T>> {
134 Tcp(T),
135 #[cfg(target_family = "unix")]
136 Unix(PathBuf),
137}
138
139#[derive(Clone, Copy, Debug)]
140enum ConnectionType {
141 Oneshot,
142 KeepAlive,
143}
144
145enum SocketWrapper {
146 Tcp(TcpStream),
147 Unix(UnixStream),
148}
149
150impl AsyncRead for SocketWrapper {
151 fn poll_read(
152 mut self: std::pin::Pin<&mut Self>,
153 cx: &mut std::task::Context<'_>,
154 buf: &mut tokio::io::ReadBuf<'_>,
155 ) -> std::task::Poll<std::io::Result<()>> {
156 match &mut *self {
157 SocketWrapper::Tcp(tcp) => Pin::new(tcp).poll_read(cx, buf),
158 SocketWrapper::Unix(unix) => Pin::new(unix).poll_read(cx, buf),
159 }
160 }
161}
162
163impl AsyncWrite for SocketWrapper {
164 fn poll_write(
165 mut self: std::pin::Pin<&mut Self>,
166 cx: &mut std::task::Context<'_>,
167 buf: &[u8],
168 ) -> std::task::Poll<std::result::Result<usize, std::io::Error>> {
169 match &mut *self {
170 SocketWrapper::Tcp(tcp) => Pin::new(tcp).poll_write(cx, buf),
171 SocketWrapper::Unix(unix) => Pin::new(unix).poll_write(cx, buf),
172 }
173 }
174
175 fn poll_flush(
176 mut self: std::pin::Pin<&mut Self>,
177 cx: &mut std::task::Context<'_>,
178 ) -> std::task::Poll<std::result::Result<(), std::io::Error>> {
179 match &mut *self {
180 SocketWrapper::Tcp(tcp) => Pin::new(tcp).poll_flush(cx),
181 SocketWrapper::Unix(unix) => Pin::new(unix).poll_flush(cx),
182 }
183 }
184
185 fn poll_shutdown(
186 mut self: std::pin::Pin<&mut Self>,
187 cx: &mut std::task::Context<'_>,
188 ) -> std::task::Poll<std::result::Result<(), std::io::Error>> {
189 match &mut *self {
190 SocketWrapper::Tcp(tcp) => Pin::new(tcp).poll_shutdown(cx),
191 SocketWrapper::Unix(unix) => Pin::new(unix).poll_shutdown(cx),
192 }
193 }
194}
195
196enum SocketTypeBuilder<'a, T: ToSocketAddrs + Clone, B: ToOwned<Owned = T> + ?Sized> {
197 Tcp(&'a B),
198 #[cfg(target_family = "unix")]
199 Unix(&'a Path),
200}
201
202pub struct ClamdClientBuilder<'a, T: ToSocketAddrs + Clone, B: ToOwned<Owned = T> + ?Sized> {
215 socket_type: SocketTypeBuilder<'a, T, B>,
216 connection_type: ConnectionType,
217 chunk_size: usize,
218}
219
220impl<'a, T, B> ClamdClientBuilder<'a, T, B>
221where
222 T: ToSocketAddrs + Clone,
223 B: ToOwned<Owned = T> + ?Sized,
224{
225 pub fn unix_socket<P: AsRef<Path> + ?Sized>(path: &'a P) -> Self {
241 Self {
242 socket_type: SocketTypeBuilder::Unix(path.as_ref()),
243 connection_type: ConnectionType::Oneshot,
244 chunk_size: DEFAULT_CHUNK_SIZE,
245 }
246 }
247 pub fn tcp_socket(addr: &'a B) -> Self {
249 Self {
250 socket_type: SocketTypeBuilder::Tcp(addr),
251 connection_type: ConnectionType::Oneshot,
252 chunk_size: DEFAULT_CHUNK_SIZE,
253 }
254 }
255
256 pub fn chunk_size(&'a mut self, chunk_size: usize) -> &'a mut Self {
258 self.chunk_size = chunk_size;
259 self
260 }
261
262 pub fn build(&'a self) -> ClamdClient<T> {
264 ClamdClient {
265 socket_type: match self.socket_type {
266 SocketTypeBuilder::Tcp(t) => SocketType::Tcp(t.to_owned()),
267 SocketTypeBuilder::Unix(u) => SocketType::Unix(u.to_owned()),
268 },
269 connection_type: self.connection_type,
270 chunk_size: self.chunk_size,
271 }
272 }
273}
274
275pub struct ClamdClient<T: ToSocketAddrs + ToOwned<Owned = T>> {
294 socket_type: SocketType<T>,
296 connection_type: ConnectionType,
297 chunk_size: usize,
298}
299
300impl<T: ToSocketAddrs + ToOwned<Owned = T>> ClamdClient<T> {
301 async fn connect(&mut self) -> Result<Framed<SocketWrapper, ClamdZeroDelimitedCodec>> {
302 let codec = ClamdZeroDelimitedCodec::new();
303 match &self.connection_type {
304 ConnectionType::Oneshot => match &self.socket_type {
305 SocketType::Tcp(address) => Ok(Framed::new(
306 SocketWrapper::Tcp(
307 TcpStream::connect(address)
308 .await
309 .map_err(ClamdError::ConnectError)?,
310 ),
311 codec,
312 )),
313 SocketType::Unix(path) => Ok(Framed::new(
314 SocketWrapper::Unix(
315 UnixStream::connect(path)
316 .await
317 .map_err(ClamdError::ConnectError)?,
318 ),
319 codec,
320 )),
321 },
322 ConnectionType::KeepAlive => todo!(),
323 }
324 }
325
326 pub async fn ping(&mut self) -> Result<()> {
329 let mut sock = self.connect().await?;
330 sock.send(ClamdRequestMessage::Ping).await?;
331 trace!("Sent ping to clamd");
332 if let Some(s) = sock.next().await.transpose()? {
333 if s == "PONG" {
334 trace!("Received pong from clamd");
335 Ok(())
336 } else {
337 Err(ClamdError::InvalidResponse(s))
338 }
339 } else {
340 Err(ClamdError::NoResponse)
341 }
342 }
343
344 pub async fn version(&mut self) -> Result<String> {
346 let mut sock = self.connect().await?;
347 sock.send(ClamdRequestMessage::Version).await?;
348 trace!("Sent version request to clamd");
349
350 if let Some(s) = sock.next().await.transpose()? {
351 trace!("Received version from clamd");
352 Ok(s)
353 } else {
354 Err(ClamdError::NoResponse)
355 }
356 }
357
358 pub async fn reload(&mut self) -> Result<()> {
360 let mut sock = self.connect().await?;
361 sock.send(ClamdRequestMessage::Reload).await?;
362 trace!("Sent reload request to clamd");
363 if let Some(s) = sock.next().await.transpose()? {
364 if s == "RELOADING" {
365 trace!("Clamd started reload");
366 drop(sock);
368 self.ping().await?;
370 trace!("Clamd finished reload");
371 Ok(())
372 } else {
373 Err(ClamdError::InvalidResponse(s))
374 }
375 } else {
376 Err(ClamdError::NoResponse)
377 }
378 }
379
380 pub async fn stats(&mut self) -> Result<String> {
382 let mut sock = self.connect().await?;
383 sock.send(ClamdRequestMessage::Stats).await?;
384 trace!("Sent stats request to clamd");
385
386 if let Some(s) = sock.next().await.transpose()? {
387 if s.ends_with("END") {
388 trace!("Got stats from clamd");
389 Ok(s)
390 } else {
391 Err(ClamdError::IncompleteResponse(s))
392 }
393 } else {
394 Err(ClamdError::NoResponse)
395 }
396 }
397
398 pub async fn shutdown(mut self) -> Result<()> {
400 let mut sock = self.connect().await?;
401 trace!("Sent shutdown request to clamd");
402 sock.send(ClamdRequestMessage::Shutdown).await?;
403 Ok(())
404 }
405
406 pub async fn scan_reader<R: AsyncRead + AsyncReadExt + Unpin>(
419 &mut self,
420 mut to_scan: R,
421 ) -> Result<()> {
422 let mut sock = self.connect().await?;
423 let mut buf = BytesMut::with_capacity(self.chunk_size);
424
425 sock.send(ClamdRequestMessage::StartStream).await?;
426 trace!("Starting bytes stream to clamd");
427
428 while to_scan.read_buf(&mut buf).await? != 0 {
429 trace!("Sending {} bytes to clamd", buf.len());
430 sock.send(ClamdRequestMessage::StreamChunk(buf.split().freeze()))
431 .await?;
432 }
433 trace!("Hit EOF, closing stream to clamd");
434 sock.send(ClamdRequestMessage::EndStream).await?;
435 if let Some(s) = sock.next().await.transpose()? {
436 let msg = s
437 .split_once(':')
438 .map(|(_, msg)| msg.trim())
439 .ok_or_else(|| ClamdError::IncompleteResponse(s.clone()))?;
440
441 if msg == "OK" {
442 Ok(())
443 } else {
444 Err(ClamdError::ScanError(msg.to_owned()))
445 }
446 } else {
447 Err(ClamdError::NoResponse)
448 }
449 }
450
451 pub async fn scan_bytes(&mut self, to_scan: &[u8]) -> Result<()> {
454 let cursor = Cursor::new(to_scan);
455 self.scan_reader(cursor).await
456 }
457
458 pub async fn scan_file(&mut self, path_to_scan: impl AsRef<Path>) -> Result<()> {
462 let reader = File::open(path_to_scan).await?;
463 self.scan_reader(reader).await
464 }
465}
466
467#[cfg(test)]
468mod tests {
469
470 use super::*;
471 use tracing_test::traced_test;
472
473 const TCP_ADDRESS: &str = "127.0.0.1:3310";
474 const UNIX_SOCKET_PATH: &str = "/run/clamav/clamd.sock";
475
476 #[tokio::test]
477 #[traced_test]
478 async fn tcp_common_operations() -> eyre::Result<()> {
479 let mut clamd_client = ClamdClientBuilder::tcp_socket(TCP_ADDRESS).build();
480 clamd_client.ping().await?;
481 let version = clamd_client.version().await?;
482 assert!(!version.is_empty());
483 let stats = clamd_client.stats().await?;
484 assert!(!stats.is_empty());
485 Ok(())
486 }
487
488 #[tokio::test]
489 #[traced_test]
490 async fn tcp_random_bytes() -> eyre::Result<()> {
491 const NUM_BYTES: usize = 1024 * 1024;
492
493 let random_bytes: Vec<u8> = (0..NUM_BYTES).map(|_| rand::random::<u8>()).collect();
494
495 let mut clamd_client = ClamdClientBuilder::tcp_socket(TCP_ADDRESS).build();
496 clamd_client.scan_bytes(&random_bytes).await?;
497 Ok(())
498 }
499
500 #[tokio::test]
501 #[traced_test]
502 async fn tcp_eicar() -> eyre::Result<()> {
503 let eicar_bytes = reqwest::get("https://secure.eicar.org/eicarcom2.zip")
504 .await?
505 .bytes()
506 .await?;
507
508 let mut clamd_client = ClamdClientBuilder::tcp_socket(TCP_ADDRESS).build();
509 let err = clamd_client.scan_bytes(&eicar_bytes).await.unwrap_err();
510 if let ClamdError::ScanError(s) = err {
511 assert_eq!(s, "Win.Test.EICAR_HDB-1 FOUND");
512 } else {
513 panic!("Scan error expected");
514 }
515 Ok(())
516 }
517
518 #[tokio::test]
519 #[traced_test]
520 async fn tcp_reload() -> eyre::Result<()> {
521 let mut clamd_client = ClamdClientBuilder::tcp_socket(TCP_ADDRESS).build();
522 clamd_client.reload().await?;
523 Ok(())
524 }
525
526 #[tokio::test]
527 #[traced_test]
528 async fn unix_socket_common_operations() -> eyre::Result<()> {
529 let mut clamd_client =
530 ClamdClientBuilder::<String, str>::unix_socket(UNIX_SOCKET_PATH).build();
531 clamd_client.ping().await?;
532 let version = clamd_client.version().await?;
533 assert!(!version.is_empty());
534 let stats = clamd_client.stats().await?;
535 assert!(!stats.is_empty());
536 Ok(())
537 }
538
539 #[tokio::test]
540 #[traced_test]
541 async fn unix_socket_random_bytes() -> eyre::Result<()> {
542 const NUM_BYTES: usize = 1024 * 1024;
543
544 let random_bytes: Vec<u8> = (0..NUM_BYTES).map(|_| rand::random::<u8>()).collect();
545 let mut clamd_client =
546 ClamdClientBuilder::<String, str>::unix_socket(UNIX_SOCKET_PATH).build();
547
548 clamd_client.scan_bytes(&random_bytes).await?;
549 Ok(())
550 }
551
552 #[tokio::test]
553 #[traced_test]
554 async fn unix_socket_eicar() -> eyre::Result<()> {
555 let eicar_bytes = reqwest::get("https://secure.eicar.org/eicarcom2.zip")
556 .await?
557 .bytes()
558 .await?;
559 let mut clamd_client =
560 ClamdClientBuilder::<String, str>::unix_socket(UNIX_SOCKET_PATH).build();
561
562 let err = clamd_client.scan_bytes(&eicar_bytes).await.unwrap_err();
563 if let ClamdError::ScanError(s) = err {
564 assert_eq!(s, "Win.Test.EICAR_HDB-1 FOUND");
565 } else {
566 panic!("Scan error expected");
567 }
568 Ok(())
569 }
570
571 #[tokio::test]
572 #[traced_test]
573 async fn unix_socket_reload() -> eyre::Result<()> {
574 let mut clamd_client =
575 ClamdClientBuilder::<String, str>::unix_socket(UNIX_SOCKET_PATH).build();
576
577 clamd_client.reload().await?;
578 Ok(())
579 }
580}