clamav_client/
tokio.rs

1use std::path::Path;
2use tokio::{
3    fs::File,
4    io::{self, AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt},
5    net::{TcpStream, ToSocketAddrs},
6};
7
8#[cfg(unix)]
9use tokio::net::UnixStream;
10
11#[cfg(feature = "tokio-stream")]
12use tokio_stream::{Stream, StreamExt};
13
14use super::{
15    IoResult, DEFAULT_CHUNK_SIZE, END_OF_STREAM, INSTREAM, PING, PONG, RELOAD, RELOADING, SHUTDOWN,
16    VERSION,
17};
18
19async fn send_command<RW: AsyncRead + AsyncWrite + Unpin>(
20    mut stream: RW,
21    command: &[u8],
22    expected_response_length: Option<usize>,
23) -> IoResult {
24    stream.write_all(command).await?;
25    stream.flush().await?;
26
27    let mut response = match expected_response_length {
28        Some(len) => Vec::with_capacity(len),
29        None => Vec::new(),
30    };
31
32    stream.read_to_end(&mut response).await?;
33    Ok(response)
34}
35
36async fn scan<R: AsyncRead + Unpin, RW: AsyncRead + AsyncWrite + Unpin>(
37    mut input: R,
38    chunk_size: Option<usize>,
39    mut stream: RW,
40) -> IoResult {
41    stream.write_all(INSTREAM).await?;
42
43    let chunk_size = chunk_size
44        .unwrap_or(DEFAULT_CHUNK_SIZE)
45        .min(u32::MAX as usize);
46
47    let mut buffer = vec![0; chunk_size];
48
49    loop {
50        let len = input.read(&mut buffer[..]).await?;
51        if len != 0 {
52            stream.write_all(&(len as u32).to_be_bytes()).await?;
53            stream.write_all(&buffer[..len]).await?;
54        } else {
55            stream.write_all(END_OF_STREAM).await?;
56            stream.flush().await?;
57            break;
58        }
59    }
60
61    let mut response = Vec::new();
62    stream.read_to_end(&mut response).await?;
63    Ok(response)
64}
65
66#[cfg(feature = "tokio-stream")]
67async fn _scan_stream<
68    S: Stream<Item = Result<bytes::Bytes, std::io::Error>>,
69    RW: AsyncRead + AsyncWrite + Unpin,
70>(
71    input_stream: S,
72    chunk_size: Option<usize>,
73    mut output_stream: RW,
74) -> IoResult {
75    output_stream.write_all(INSTREAM).await?;
76
77    let chunk_size = chunk_size
78        .unwrap_or(DEFAULT_CHUNK_SIZE)
79        .min(u32::MAX as usize);
80
81    let mut input_stream = std::pin::pin!(input_stream);
82
83    while let Some(bytes) = input_stream.next().await {
84        let bytes = bytes?;
85        let bytes = bytes.as_ref();
86        for chunk in bytes.chunks(chunk_size) {
87            let len = chunk.len();
88            output_stream.write_all(&(len as u32).to_be_bytes()).await?;
89            output_stream.write_all(chunk).await?;
90        }
91    }
92
93    output_stream.write_all(END_OF_STREAM).await?;
94    output_stream.flush().await?;
95
96    let mut response = Vec::new();
97    output_stream.read_to_end(&mut response).await?;
98    Ok(response)
99}
100
101/// Use a TCP connection to communicate with a ClamAV server
102#[derive(Copy, Clone)]
103pub struct Tcp<A: ToSocketAddrs> {
104    /// The address (host and port) of the ClamAV server
105    pub host_address: A,
106}
107
108/// Use a Unix socket connection to communicate with a ClamAV server
109#[derive(Copy, Clone)]
110#[cfg(unix)]
111pub struct Socket<P: AsRef<Path>> {
112    /// The socket file path of the ClamAV server
113    pub socket_path: P,
114}
115
116/// The communication protocol to use
117pub trait TransportProtocol {
118    /// Bidirectional stream
119    type Stream: AsyncRead + AsyncWrite + Unpin;
120
121    /// Converts the protocol instance into the corresponding stream
122    fn connect(&self) -> impl std::future::Future<Output = io::Result<Self::Stream>>;
123}
124
125impl<A: ToSocketAddrs> TransportProtocol for Tcp<A> {
126    type Stream = TcpStream;
127
128    fn connect(&self) -> impl std::future::Future<Output = io::Result<Self::Stream>> {
129        TcpStream::connect(&self.host_address)
130    }
131}
132
133#[cfg(unix)]
134impl<P: AsRef<Path>> TransportProtocol for Socket<P> {
135    type Stream = UnixStream;
136
137    fn connect(&self) -> impl std::future::Future<Output = io::Result<Self::Stream>> {
138        UnixStream::connect(&self.socket_path)
139    }
140}
141
142impl<T: TransportProtocol> TransportProtocol for &T {
143    type Stream = T::Stream;
144
145    fn connect(&self) -> impl std::future::Future<Output = io::Result<Self::Stream>> {
146        TransportProtocol::connect(*self)
147    }
148}
149
150#[cfg(test)]
151mod tests {
152    use super::*;
153
154    // Compile-time assertions
155    trait _AssertSendSync: Send + Sync {}
156    impl _AssertSendSync for Tcp<&str> {}
157    #[cfg(unix)]
158    impl _AssertSendSync for Socket<&str> {}
159}
160
161/// Sends a ping request to ClamAV
162///
163/// This function establishes a connection to a ClamAV server and sends the PING
164/// command to it. If the server is available, it responds with [`PONG`].
165///
166/// # Arguments
167///
168/// * `connection`: The connection type to use - either TCP or a Unix socket connection
169///
170/// # Returns
171///
172/// An [`IoResult`] containing the server's response as a vector of bytes
173///
174/// # Example
175///
176/// ```
177/// # #[tokio::main(flavor = "current_thread")]
178/// # async fn main() {
179/// let clamd_tcp = clamav_client::tokio::Tcp{ host_address: "localhost:3310" };
180/// let clamd_available = match clamav_client::tokio::ping(clamd_tcp).await {
181///     Ok(ping_response) => ping_response == clamav_client::PONG,
182///     Err(_) => false,
183/// };
184/// # assert!(clamd_available);
185/// # }
186/// ```
187///
188pub async fn ping<T: TransportProtocol>(connection: T) -> IoResult {
189    let stream = connection.connect().await?;
190    send_command(stream, PING, Some(PONG.len())).await
191}
192
193/// Reloads the virus databases
194///
195/// This function establishes a connection to a ClamAV server and sends the
196/// RELOAD command to it. If the server is available, it responds with
197/// [`RELOADING`].
198///
199/// # Arguments
200///
201/// * `connection`: The connection type to use - either TCP or a Unix socket connection
202///
203/// # Returns
204///
205/// An [`IoResult`] containing the server's response as a vector of bytes
206///
207/// # Example
208///
209/// ```
210/// # #[tokio::main(flavor = "current_thread")]
211/// # async fn main() {
212/// let clamd_tcp = clamav_client::tokio::Tcp{ host_address: "localhost:3310" };
213/// let response = clamav_client::tokio::reload(clamd_tcp).await.unwrap();
214/// # assert!(response == clamav_client::RELOADING);
215/// # }
216/// ```
217///
218pub async fn reload<T: TransportProtocol>(connection: T) -> IoResult {
219    let stream = connection.connect().await?;
220    send_command(stream, RELOAD, Some(RELOADING.len())).await
221}
222
223/// Gets the version number from ClamAV
224///
225/// This function establishes a connection to a ClamAV server and sends the
226/// VERSION command to it. If the server is available, it responds with its
227/// version number.
228///
229/// # Arguments
230///
231/// * `connection`: The connection type to use - either TCP or a Unix socket connection
232///
233/// # Returns
234///
235/// An [`IoResult`] containing the server's response as a vector of bytes
236///
237/// # Example
238///
239/// ```
240/// # #[tokio::main(flavor = "current_thread")]
241/// # async fn main() {
242/// let clamd_tcp = clamav_client::tokio::Tcp{ host_address: "localhost:3310" };
243/// let version = clamav_client::tokio::get_version(clamd_tcp).await.unwrap();
244/// # assert!(version.starts_with(b"ClamAV"));
245/// # }
246/// ```
247///
248pub async fn get_version<T: TransportProtocol>(connection: T) -> IoResult {
249    let stream = connection.connect().await?;
250    send_command(stream, VERSION, None).await
251}
252
253/// Scans a file for viruses
254///
255/// This function reads data from a file located at the specified `file_path`
256/// and streams it to a ClamAV server for scanning.
257///
258/// # Arguments
259///
260/// * `file_path`: The path to the file to be scanned
261/// * `connection`: The connection type to use - either TCP or a Unix socket connection
262/// * `chunk_size`: An optional chunk size for reading data. If [`None`], a default chunk size is used
263///
264/// # Returns
265///
266/// An [`IoResult`] containing the server's response as a vector of bytes
267///
268pub async fn scan_file<P: AsRef<Path>, T: TransportProtocol>(
269    file_path: P,
270    connection: T,
271    chunk_size: Option<usize>,
272) -> IoResult {
273    let file = File::open(file_path).await?;
274    let stream = connection.connect().await?;
275    scan(file, chunk_size, stream).await
276}
277
278/// Scans a data buffer for viruses
279///
280/// This function streams the provided `buffer` data to a ClamAV server
281///
282/// # Arguments
283///
284/// * `buffer`: The data to be scanned
285/// * `connection`: The connection type to use - either TCP or a Unix socket connection
286/// * `chunk_size`: An optional chunk size for reading data. If [`None`], a default chunk size is used
287///
288/// # Returns
289///
290/// An [`IoResult`] containing the server's response as a vector of bytes
291///
292pub async fn scan_buffer<T: TransportProtocol>(
293    buffer: &[u8],
294    connection: T,
295    chunk_size: Option<usize>,
296) -> IoResult {
297    let stream = connection.connect().await?;
298    scan(buffer, chunk_size, stream).await
299}
300
301/// Scans a stream for viruses
302///
303/// This function sends the provided stream to a ClamAV server for scanning.
304///
305/// # Arguments
306///
307/// * `input_stream`: The stream to be scanned
308/// * `connection`: The connection type to use - either TCP or a Unix socket connection
309/// * `chunk_size`: An optional chunk size for reading data. If [`None`], a default chunk size is used
310///
311/// # Returns
312///
313/// An [`IoResult`] containing the server's response as a vector of bytes
314///
315#[cfg(feature = "tokio-stream")]
316pub async fn scan_stream<
317    S: Stream<Item = Result<bytes::Bytes, io::Error>>,
318    T: TransportProtocol,
319>(
320    input_stream: S,
321    connection: T,
322    chunk_size: Option<usize>,
323) -> IoResult {
324    let output_stream = connection.connect().await?;
325    _scan_stream(input_stream, chunk_size, output_stream).await
326}
327
328/// Shuts down a ClamAV server
329///
330/// This function establishes a connection to a ClamAV server and sends the
331/// SHUTDOWN command to it. If the server is available, it will perform a clean
332/// exit and shut itself down. The response will be empty.
333///
334/// # Arguments
335///
336/// * `connection`: The connection type to use - either TCP or a Unix socket connection
337///
338/// # Returns
339///
340/// An [`IoResult`] containing the server's response
341///
342pub async fn shutdown<T: TransportProtocol>(connection: T) -> IoResult {
343    let stream = connection.connect().await?;
344    send_command(stream, SHUTDOWN, None).await
345}