clamav_client/
async_std.rs

1use async_std::{
2    fs::File,
3    io::{self, ReadExt, WriteExt},
4    net::{TcpStream, ToSocketAddrs},
5    path::Path,
6    stream::{Stream, StreamExt},
7};
8
9#[cfg(unix)]
10use async_std::os::unix::net::UnixStream;
11
12use super::{
13    IoResult, DEFAULT_CHUNK_SIZE, END_OF_STREAM, INSTREAM, PING, PONG, RELOAD, RELOADING, SHUTDOWN,
14    VERSION,
15};
16
17async fn send_command<RW: ReadExt + WriteExt + Unpin>(
18    mut stream: RW,
19    command: &[u8],
20    expected_response_length: Option<usize>,
21) -> IoResult {
22    stream.write_all(command).await?;
23    stream.flush().await?;
24
25    let mut response = match expected_response_length {
26        Some(len) => Vec::with_capacity(len),
27        None => Vec::new(),
28    };
29
30    stream.read_to_end(&mut response).await?;
31    Ok(response)
32}
33
34async fn scan<R: ReadExt + Unpin, RW: ReadExt + WriteExt + Unpin>(
35    mut input: R,
36    chunk_size: Option<usize>,
37    mut stream: RW,
38) -> IoResult {
39    stream.write_all(INSTREAM).await?;
40
41    let chunk_size = chunk_size
42        .unwrap_or(DEFAULT_CHUNK_SIZE)
43        .min(u32::MAX as usize);
44
45    let mut buffer = vec![0; chunk_size];
46
47    loop {
48        let len = input.read(&mut buffer[..]).await?;
49        if len != 0 {
50            stream.write_all(&(len as u32).to_be_bytes()).await?;
51            stream.write_all(&buffer[..len]).await?;
52        } else {
53            stream.write_all(END_OF_STREAM).await?;
54            stream.flush().await?;
55            break;
56        }
57    }
58
59    let mut response = Vec::new();
60    stream.read_to_end(&mut response).await?;
61    Ok(response)
62}
63
64async fn _scan_stream<
65    S: Stream<Item = Result<bytes::Bytes, std::io::Error>>,
66    RW: ReadExt + WriteExt + Unpin,
67>(
68    input_stream: S,
69    chunk_size: Option<usize>,
70    mut output_stream: RW,
71) -> IoResult {
72    output_stream.write_all(INSTREAM).await?;
73
74    let chunk_size = chunk_size
75        .unwrap_or(DEFAULT_CHUNK_SIZE)
76        .min(u32::MAX as usize);
77
78    let mut input_stream = std::pin::pin!(input_stream);
79
80    while let Some(bytes) = input_stream.next().await {
81        let bytes = bytes?;
82        let bytes = bytes.as_ref();
83        for chunk in bytes.chunks(chunk_size) {
84            let len = chunk.len();
85            output_stream.write_all(&(len as u32).to_be_bytes()).await?;
86            output_stream.write_all(chunk).await?;
87        }
88    }
89
90    output_stream.write_all(END_OF_STREAM).await?;
91    output_stream.flush().await?;
92
93    let mut response = Vec::new();
94    output_stream.read_to_end(&mut response).await?;
95    Ok(response)
96}
97
98/// Use a TCP connection to communicate with a ClamAV server
99#[derive(Copy, Clone)]
100pub struct Tcp<A: ToSocketAddrs> {
101    /// The address (host and port) of the ClamAV server
102    pub host_address: A,
103}
104
105/// Use a Unix socket connection to communicate with a ClamAV server
106#[derive(Copy, Clone)]
107#[cfg(unix)]
108pub struct Socket<P: AsRef<Path>> {
109    /// The socket file path of the ClamAV server
110    pub socket_path: P,
111}
112
113/// The communication protocol to use
114pub trait TransportProtocol {
115    /// Bidirectional stream
116    type Stream: ReadExt + WriteExt + Unpin;
117
118    /// Converts the protocol instance into the corresponding stream
119    fn connect(&self) -> impl std::future::Future<Output = io::Result<Self::Stream>>;
120}
121
122impl<A: ToSocketAddrs> TransportProtocol for Tcp<A> {
123    type Stream = TcpStream;
124
125    fn connect(&self) -> impl std::future::Future<Output = io::Result<Self::Stream>> {
126        TcpStream::connect(&self.host_address)
127    }
128}
129
130#[cfg(unix)]
131impl<P: AsRef<Path>> TransportProtocol for Socket<P> {
132    type Stream = UnixStream;
133
134    fn connect(&self) -> impl std::future::Future<Output = io::Result<Self::Stream>> {
135        UnixStream::connect(&self.socket_path)
136    }
137}
138
139impl<T: TransportProtocol> TransportProtocol for &T {
140    type Stream = T::Stream;
141
142    fn connect(&self) -> impl std::future::Future<Output = io::Result<Self::Stream>> {
143        TransportProtocol::connect(*self)
144    }
145}
146
147#[cfg(test)]
148mod tests {
149    use super::*;
150
151    // Compile-time assertions
152    trait _AssertSendSync: Send + Sync {}
153    impl _AssertSendSync for Tcp<&str> {}
154    #[cfg(unix)]
155    impl _AssertSendSync for Socket<&str> {}
156}
157
158/// Sends a ping request to ClamAV
159///
160/// This function establishes a connection to a ClamAV server and sends the PING
161/// command to it. If the server is available, it responds with [`PONG`].
162///
163/// # Arguments
164///
165/// * `connection`: The connection type to use - either TCP or a Unix socket connection
166///
167/// # Returns
168///
169/// An [`IoResult`] containing the server's response as a vector of bytes
170///
171/// # Example
172///
173/// ```
174/// # #[async_std::main]
175/// # async fn main() {
176/// let clamd_tcp = clamav_client::async_std::Tcp{ host_address: "localhost:3310" };
177/// let clamd_available = match clamav_client::async_std::ping(clamd_tcp).await {
178///     Ok(ping_response) => ping_response == clamav_client::PONG,
179///     Err(_) => false,
180/// };
181/// # assert!(clamd_available);
182/// # }
183/// ```
184///
185pub async fn ping<T: TransportProtocol>(connection: T) -> IoResult {
186    let stream = connection.connect().await?;
187    send_command(stream, PING, Some(PONG.len())).await
188}
189
190/// Reloads the virus databases
191///
192/// This function establishes a connection to a ClamAV server and sends the
193/// RELOAD command to it. If the server is available, it responds with
194/// [`RELOADING`].
195///
196/// # Arguments
197///
198/// * `connection`: The connection type to use - either TCP or a Unix socket connection
199///
200/// # Returns
201///
202/// An [`IoResult`] containing the server's response as a vector of bytes
203///
204/// # Example
205///
206/// ```
207/// # #[async_std::main]
208/// # async fn main() {
209/// let clamd_tcp = clamav_client::async_std::Tcp{ host_address: "localhost:3310" };
210/// let response = clamav_client::async_std::reload(clamd_tcp).await.unwrap();
211/// # assert!(response == clamav_client::RELOADING);
212/// # }
213/// ```
214///
215pub async fn reload<T: TransportProtocol>(connection: T) -> IoResult {
216    let stream = connection.connect().await?;
217    send_command(stream, RELOAD, Some(RELOADING.len())).await
218}
219
220/// Gets the version number from ClamAV
221///
222/// This function establishes a connection to a ClamAV server and sends the
223/// VERSION command to it. If the server is available, it responds with its
224/// version number.
225///
226/// # Arguments
227///
228/// * `connection`: The connection type to use - either TCP or a Unix socket connection
229///
230/// # Returns
231///
232/// An [`IoResult`] containing the server's response as a vector of bytes
233///
234/// # Example
235///
236/// ```
237/// # #[async_std::main]
238/// # async fn main() {
239/// let clamd_tcp = clamav_client::async_std::Tcp{ host_address: "localhost:3310" };
240/// let version = clamav_client::async_std::get_version(clamd_tcp).await.unwrap();
241/// # assert!(version.starts_with(b"ClamAV"));
242/// # }
243/// ```
244///
245pub async fn get_version<T: TransportProtocol>(connection: T) -> IoResult {
246    let stream = connection.connect().await?;
247    send_command(stream, VERSION, None).await
248}
249
250/// Scans a file for viruses
251///
252/// This function reads data from a file located at the specified `file_path`
253/// and streams it to a ClamAV server for scanning.
254///
255/// # Arguments
256///
257/// * `file_path`: The path to the file to be scanned
258/// * `connection`: The connection type to use - either TCP or a Unix socket connection
259/// * `chunk_size`: An optional chunk size for reading data. If [`None`], a default chunk size is used
260///
261/// # Returns
262///
263/// An [`IoResult`] containing the server's response as a vector of bytes
264///
265pub async fn scan_file<P: AsRef<Path>, T: TransportProtocol>(
266    file_path: P,
267    connection: T,
268    chunk_size: Option<usize>,
269) -> IoResult {
270    let file = File::open(file_path).await?;
271    let stream = connection.connect().await?;
272    scan(file, chunk_size, stream).await
273}
274
275/// Scans a data buffer for viruses
276///
277/// This function streams the provided `buffer` data to a ClamAV server
278///
279/// # Arguments
280///
281/// * `buffer`: The data to be scanned
282/// * `connection`: The connection type to use - either TCP or a Unix socket connection
283/// * `chunk_size`: An optional chunk size for reading data. If [`None`], a default chunk size is used
284///
285/// # Returns
286///
287/// An [`IoResult`] containing the server's response as a vector of bytes
288///
289pub async fn scan_buffer<T: TransportProtocol>(
290    buffer: &[u8],
291    connection: T,
292    chunk_size: Option<usize>,
293) -> IoResult {
294    let stream = connection.connect().await?;
295    scan(buffer, chunk_size, stream).await
296}
297
298/// Scans a stream for viruses
299///
300/// This function sends the provided stream to a ClamAV server for scanning.
301///
302/// # Arguments
303///
304/// * `input_stream`: The stream to be scanned
305/// * `connection`: The connection type to use - either TCP or a Unix socket connection
306/// * `chunk_size`: An optional chunk size for reading data. If [`None`], a default chunk size is used
307///
308/// # Returns
309///
310/// An [`IoResult`] containing the server's response as a vector of bytes
311///
312pub async fn scan_stream<
313    S: Stream<Item = Result<bytes::Bytes, io::Error>>,
314    T: TransportProtocol,
315>(
316    input_stream: S,
317    connection: T,
318    chunk_size: Option<usize>,
319) -> IoResult {
320    let output_stream = connection.connect().await?;
321    _scan_stream(input_stream, chunk_size, output_stream).await
322}
323
324/// Shuts down a ClamAV server
325///
326/// This function establishes a connection to a ClamAV server and sends the
327/// SHUTDOWN command to it. If the server is available, it will perform a clean
328/// exit and shut itself down. The response will be empty.
329///
330/// # Arguments
331///
332/// * `connection`: The connection type to use - either TCP or a Unix socket connection
333///
334/// # Returns
335///
336/// An [`IoResult`] containing the server's response
337///
338pub async fn shutdown<T: TransportProtocol>(connection: T) -> IoResult {
339    let stream = connection.connect().await?;
340    send_command(stream, SHUTDOWN, None).await
341}