clamav_client/tokio.rs
1use std::path::Path;
2use tokio::{
3 fs::File,
4 io::{self, AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt},
5 net::{TcpStream, ToSocketAddrs},
6};
7
8#[cfg(unix)]
9use tokio::net::UnixStream;
10
11#[cfg(feature = "tokio-stream")]
12use tokio_stream::{Stream, StreamExt};
13
14use super::{
15 IoResult, DEFAULT_CHUNK_SIZE, END_OF_STREAM, INSTREAM, PING, PONG, RELOAD, RELOADING, SHUTDOWN,
16 VERSION,
17};
18
19async fn send_command<RW: AsyncRead + AsyncWrite + Unpin>(
20 mut stream: RW,
21 command: &[u8],
22 expected_response_length: Option<usize>,
23) -> IoResult {
24 stream.write_all(command).await?;
25 stream.flush().await?;
26
27 let mut response = match expected_response_length {
28 Some(len) => Vec::with_capacity(len),
29 None => Vec::new(),
30 };
31
32 stream.read_to_end(&mut response).await?;
33 Ok(response)
34}
35
36async fn scan<R: AsyncRead + Unpin, RW: AsyncRead + AsyncWrite + Unpin>(
37 mut input: R,
38 chunk_size: Option<usize>,
39 mut stream: RW,
40) -> IoResult {
41 stream.write_all(INSTREAM).await?;
42
43 let chunk_size = chunk_size
44 .unwrap_or(DEFAULT_CHUNK_SIZE)
45 .min(u32::MAX as usize);
46
47 let mut buffer = vec![0; chunk_size];
48
49 loop {
50 let len = input.read(&mut buffer[..]).await?;
51 if len != 0 {
52 stream.write_all(&(len as u32).to_be_bytes()).await?;
53 stream.write_all(&buffer[..len]).await?;
54 } else {
55 stream.write_all(END_OF_STREAM).await?;
56 stream.flush().await?;
57 break;
58 }
59 }
60
61 let mut response = Vec::new();
62 stream.read_to_end(&mut response).await?;
63 Ok(response)
64}
65
66#[cfg(feature = "tokio-stream")]
67async fn _scan_stream<
68 S: Stream<Item = Result<bytes::Bytes, std::io::Error>>,
69 RW: AsyncRead + AsyncWrite + Unpin,
70>(
71 input_stream: S,
72 chunk_size: Option<usize>,
73 mut output_stream: RW,
74) -> IoResult {
75 output_stream.write_all(INSTREAM).await?;
76
77 let chunk_size = chunk_size
78 .unwrap_or(DEFAULT_CHUNK_SIZE)
79 .min(u32::MAX as usize);
80
81 let mut input_stream = std::pin::pin!(input_stream);
82
83 while let Some(bytes) = input_stream.next().await {
84 let bytes = bytes?;
85 let bytes = bytes.as_ref();
86 for chunk in bytes.chunks(chunk_size) {
87 let len = chunk.len();
88 output_stream.write_all(&(len as u32).to_be_bytes()).await?;
89 output_stream.write_all(chunk).await?;
90 }
91 }
92
93 output_stream.write_all(END_OF_STREAM).await?;
94 output_stream.flush().await?;
95
96 let mut response = Vec::new();
97 output_stream.read_to_end(&mut response).await?;
98 Ok(response)
99}
100
101/// Use a TCP connection to communicate with a ClamAV server
102#[derive(Copy, Clone)]
103pub struct Tcp<A: ToSocketAddrs> {
104 /// The address (host and port) of the ClamAV server
105 pub host_address: A,
106}
107
108/// Use a Unix socket connection to communicate with a ClamAV server
109#[derive(Copy, Clone)]
110#[cfg(unix)]
111pub struct Socket<P: AsRef<Path>> {
112 /// The socket file path of the ClamAV server
113 pub socket_path: P,
114}
115
116/// The communication protocol to use
117pub trait TransportProtocol {
118 /// Bidirectional stream
119 type Stream: AsyncRead + AsyncWrite + Unpin;
120
121 /// Converts the protocol instance into the corresponding stream
122 fn connect(&self) -> impl std::future::Future<Output = io::Result<Self::Stream>>;
123}
124
125impl<A: ToSocketAddrs> TransportProtocol for Tcp<A> {
126 type Stream = TcpStream;
127
128 fn connect(&self) -> impl std::future::Future<Output = io::Result<Self::Stream>> {
129 TcpStream::connect(&self.host_address)
130 }
131}
132
133#[cfg(unix)]
134impl<P: AsRef<Path>> TransportProtocol for Socket<P> {
135 type Stream = UnixStream;
136
137 fn connect(&self) -> impl std::future::Future<Output = io::Result<Self::Stream>> {
138 UnixStream::connect(&self.socket_path)
139 }
140}
141
142impl<T: TransportProtocol> TransportProtocol for &T {
143 type Stream = T::Stream;
144
145 fn connect(&self) -> impl std::future::Future<Output = io::Result<Self::Stream>> {
146 TransportProtocol::connect(*self)
147 }
148}
149
150#[cfg(test)]
151mod tests {
152 use super::*;
153
154 // Compile-time assertions
155 trait _AssertSendSync: Send + Sync {}
156 impl _AssertSendSync for Tcp<&str> {}
157 #[cfg(unix)]
158 impl _AssertSendSync for Socket<&str> {}
159}
160
161/// Sends a ping request to ClamAV
162///
163/// This function establishes a connection to a ClamAV server and sends the PING
164/// command to it. If the server is available, it responds with [`PONG`].
165///
166/// # Arguments
167///
168/// * `connection`: The connection type to use - either TCP or a Unix socket connection
169///
170/// # Returns
171///
172/// An [`IoResult`] containing the server's response as a vector of bytes
173///
174/// # Example
175///
176/// ```
177/// # #[tokio::main(flavor = "current_thread")]
178/// # async fn main() {
179/// let clamd_tcp = clamav_client::tokio::Tcp{ host_address: "localhost:3310" };
180/// let clamd_available = match clamav_client::tokio::ping(clamd_tcp).await {
181/// Ok(ping_response) => ping_response == clamav_client::PONG,
182/// Err(_) => false,
183/// };
184/// # assert!(clamd_available);
185/// # }
186/// ```
187///
188pub async fn ping<T: TransportProtocol>(connection: T) -> IoResult {
189 let stream = connection.connect().await?;
190 send_command(stream, PING, Some(PONG.len())).await
191}
192
193/// Reloads the virus databases
194///
195/// This function establishes a connection to a ClamAV server and sends the
196/// RELOAD command to it. If the server is available, it responds with
197/// [`RELOADING`].
198///
199/// # Arguments
200///
201/// * `connection`: The connection type to use - either TCP or a Unix socket connection
202///
203/// # Returns
204///
205/// An [`IoResult`] containing the server's response as a vector of bytes
206///
207/// # Example
208///
209/// ```
210/// # #[tokio::main(flavor = "current_thread")]
211/// # async fn main() {
212/// let clamd_tcp = clamav_client::tokio::Tcp{ host_address: "localhost:3310" };
213/// let response = clamav_client::tokio::reload(clamd_tcp).await.unwrap();
214/// # assert!(response == clamav_client::RELOADING);
215/// # }
216/// ```
217///
218pub async fn reload<T: TransportProtocol>(connection: T) -> IoResult {
219 let stream = connection.connect().await?;
220 send_command(stream, RELOAD, Some(RELOADING.len())).await
221}
222
223/// Gets the version number from ClamAV
224///
225/// This function establishes a connection to a ClamAV server and sends the
226/// VERSION command to it. If the server is available, it responds with its
227/// version number.
228///
229/// # Arguments
230///
231/// * `connection`: The connection type to use - either TCP or a Unix socket connection
232///
233/// # Returns
234///
235/// An [`IoResult`] containing the server's response as a vector of bytes
236///
237/// # Example
238///
239/// ```
240/// # #[tokio::main(flavor = "current_thread")]
241/// # async fn main() {
242/// let clamd_tcp = clamav_client::tokio::Tcp{ host_address: "localhost:3310" };
243/// let version = clamav_client::tokio::get_version(clamd_tcp).await.unwrap();
244/// # assert!(version.starts_with(b"ClamAV"));
245/// # }
246/// ```
247///
248pub async fn get_version<T: TransportProtocol>(connection: T) -> IoResult {
249 let stream = connection.connect().await?;
250 send_command(stream, VERSION, None).await
251}
252
253/// Scans a file for viruses
254///
255/// This function reads data from a file located at the specified `file_path`
256/// and streams it to a ClamAV server for scanning.
257///
258/// # Arguments
259///
260/// * `file_path`: The path to the file to be scanned
261/// * `connection`: The connection type to use - either TCP or a Unix socket connection
262/// * `chunk_size`: An optional chunk size for reading data. If [`None`], a default chunk size is used
263///
264/// # Returns
265///
266/// An [`IoResult`] containing the server's response as a vector of bytes
267///
268pub async fn scan_file<P: AsRef<Path>, T: TransportProtocol>(
269 file_path: P,
270 connection: T,
271 chunk_size: Option<usize>,
272) -> IoResult {
273 let file = File::open(file_path).await?;
274 let stream = connection.connect().await?;
275 scan(file, chunk_size, stream).await
276}
277
278/// Scans a data buffer for viruses
279///
280/// This function streams the provided `buffer` data to a ClamAV server
281///
282/// # Arguments
283///
284/// * `buffer`: The data to be scanned
285/// * `connection`: The connection type to use - either TCP or a Unix socket connection
286/// * `chunk_size`: An optional chunk size for reading data. If [`None`], a default chunk size is used
287///
288/// # Returns
289///
290/// An [`IoResult`] containing the server's response as a vector of bytes
291///
292pub async fn scan_buffer<T: TransportProtocol>(
293 buffer: &[u8],
294 connection: T,
295 chunk_size: Option<usize>,
296) -> IoResult {
297 let stream = connection.connect().await?;
298 scan(buffer, chunk_size, stream).await
299}
300
301/// Scans a stream for viruses
302///
303/// This function sends the provided stream to a ClamAV server for scanning.
304///
305/// # Arguments
306///
307/// * `input_stream`: The stream to be scanned
308/// * `connection`: The connection type to use - either TCP or a Unix socket connection
309/// * `chunk_size`: An optional chunk size for reading data. If [`None`], a default chunk size is used
310///
311/// # Returns
312///
313/// An [`IoResult`] containing the server's response as a vector of bytes
314///
315#[cfg(feature = "tokio-stream")]
316pub async fn scan_stream<
317 S: Stream<Item = Result<bytes::Bytes, io::Error>>,
318 T: TransportProtocol,
319>(
320 input_stream: S,
321 connection: T,
322 chunk_size: Option<usize>,
323) -> IoResult {
324 let output_stream = connection.connect().await?;
325 _scan_stream(input_stream, chunk_size, output_stream).await
326}
327
328/// Shuts down a ClamAV server
329///
330/// This function establishes a connection to a ClamAV server and sends the
331/// SHUTDOWN command to it. If the server is available, it will perform a clean
332/// exit and shut itself down. The response will be empty.
333///
334/// # Arguments
335///
336/// * `connection`: The connection type to use - either TCP or a Unix socket connection
337///
338/// # Returns
339///
340/// An [`IoResult`] containing the server's response
341///
342pub async fn shutdown<T: TransportProtocol>(connection: T) -> IoResult {
343 let stream = connection.connect().await?;
344 send_command(stream, SHUTDOWN, None).await
345}