clamav_client/smol.rs
1use smol::{
2 fs::File,
3 io::{self, AsyncReadExt, AsyncWriteExt},
4 net::{AsyncToSocketAddrs, TcpStream},
5 stream::{Stream, StreamExt},
6};
7use std::path::Path;
8
9#[cfg(unix)]
10use smol::net::unix::UnixStream;
11
12use super::{
13 IoResult, DEFAULT_CHUNK_SIZE, END_OF_STREAM, INSTREAM, PING, PONG, RELOAD, RELOADING, SHUTDOWN,
14 VERSION,
15};
16
17async fn send_command<RW: AsyncReadExt + AsyncWriteExt + Unpin>(
18 mut stream: RW,
19 command: &[u8],
20 expected_response_length: Option<usize>,
21) -> IoResult {
22 stream.write_all(command).await?;
23 stream.flush().await?;
24
25 let mut response = match expected_response_length {
26 Some(len) => Vec::with_capacity(len),
27 None => Vec::new(),
28 };
29
30 stream.read_to_end(&mut response).await?;
31 Ok(response)
32}
33
34async fn scan<R: AsyncReadExt + Unpin, RW: AsyncReadExt + AsyncWriteExt + Unpin>(
35 mut input: R,
36 chunk_size: Option<usize>,
37 mut stream: RW,
38) -> IoResult {
39 stream.write_all(INSTREAM).await?;
40
41 let chunk_size = chunk_size
42 .unwrap_or(DEFAULT_CHUNK_SIZE)
43 .min(u32::MAX as usize);
44
45 let mut buffer = vec![0; chunk_size];
46
47 loop {
48 let len = input.read(&mut buffer[..]).await?;
49 if len != 0 {
50 stream.write_all(&(len as u32).to_be_bytes()).await?;
51 stream.write_all(&buffer[..len]).await?;
52 } else {
53 stream.write_all(END_OF_STREAM).await?;
54 stream.flush().await?;
55 break;
56 }
57 }
58
59 let mut response = Vec::new();
60 stream.read_to_end(&mut response).await?;
61 Ok(response)
62}
63
64async fn _scan_stream<
65 S: Stream<Item = Result<bytes::Bytes, std::io::Error>>,
66 RW: AsyncReadExt + AsyncWriteExt + Unpin,
67>(
68 input_stream: S,
69 chunk_size: Option<usize>,
70 mut output_stream: RW,
71) -> IoResult {
72 output_stream.write_all(INSTREAM).await?;
73
74 let chunk_size = chunk_size
75 .unwrap_or(DEFAULT_CHUNK_SIZE)
76 .min(u32::MAX as usize);
77
78 let mut input_stream = std::pin::pin!(input_stream);
79
80 while let Some(bytes) = input_stream.next().await {
81 let bytes = bytes?;
82 let bytes = bytes.as_ref();
83 for chunk in bytes.chunks(chunk_size) {
84 let len = chunk.len();
85 output_stream.write_all(&(len as u32).to_be_bytes()).await?;
86 output_stream.write_all(chunk).await?;
87 }
88 }
89
90 output_stream.write_all(END_OF_STREAM).await?;
91 output_stream.flush().await?;
92
93 let mut response = Vec::new();
94 output_stream.read_to_end(&mut response).await?;
95 Ok(response)
96}
97
98/// Use a TCP connection to communicate with a ClamAV server
99#[derive(Copy, Clone)]
100pub struct Tcp<A: AsyncToSocketAddrs> {
101 /// The address (host and port) of the ClamAV server
102 pub host_address: A,
103}
104
105/// Use a Unix socket connection to communicate with a ClamAV server
106#[derive(Copy, Clone)]
107#[cfg(unix)]
108pub struct Socket<P: AsRef<Path>> {
109 /// The socket file path of the ClamAV server
110 pub socket_path: P,
111}
112
113/// The communication protocol to use
114pub trait TransportProtocol {
115 /// Bidirectional stream
116 type Stream: AsyncReadExt + AsyncWriteExt + Unpin;
117
118 /// Converts the protocol instance into the corresponding stream
119 fn connect(&self) -> impl std::future::Future<Output = io::Result<Self::Stream>>;
120}
121
122impl<A: AsyncToSocketAddrs> TransportProtocol for Tcp<A> {
123 type Stream = TcpStream;
124
125 fn connect(&self) -> impl std::future::Future<Output = io::Result<Self::Stream>> {
126 TcpStream::connect(&self.host_address)
127 }
128}
129
130#[cfg(unix)]
131impl<P: AsRef<Path>> TransportProtocol for Socket<P> {
132 type Stream = UnixStream;
133
134 fn connect(&self) -> impl std::future::Future<Output = io::Result<Self::Stream>> {
135 UnixStream::connect(&self.socket_path)
136 }
137}
138
139impl<T: TransportProtocol> TransportProtocol for &T {
140 type Stream = T::Stream;
141
142 fn connect(&self) -> impl std::future::Future<Output = io::Result<Self::Stream>> {
143 TransportProtocol::connect(*self)
144 }
145}
146
147#[cfg(test)]
148mod tests {
149 use super::*;
150
151 // Compile-time assertions
152 trait _AssertSendSync: Send + Sync {}
153 impl _AssertSendSync for Tcp<&str> {}
154 #[cfg(unix)]
155 impl _AssertSendSync for Socket<&str> {}
156}
157
158/// Sends a ping request to ClamAV
159///
160/// This function establishes a connection to a ClamAV server and sends the PING
161/// command to it. If the server is available, it responds with [`PONG`].
162///
163/// # Arguments
164///
165/// * `connection`: The connection type to use - either TCP or a Unix socket connection
166///
167/// # Returns
168///
169/// An [`IoResult`] containing the server's response as a vector of bytes
170///
171/// # Example
172///
173/// ```
174/// # smol::block_on(async {
175/// let clamd_tcp = clamav_client::smol::Tcp{ host_address: "localhost:3310" };
176/// let clamd_available = match clamav_client::smol::ping(clamd_tcp).await {
177/// Ok(ping_response) => ping_response == clamav_client::PONG,
178/// Err(_) => false,
179/// };
180/// # assert!(clamd_available);
181/// # })
182/// ```
183///
184pub async fn ping<T: TransportProtocol>(connection: T) -> IoResult {
185 let stream = connection.connect().await?;
186 send_command(stream, PING, Some(PONG.len())).await
187}
188
189/// Reloads the virus databases
190///
191/// This function establishes a connection to a ClamAV server and sends the
192/// RELOAD command to it. If the server is available, it responds with
193/// [`RELOADING`].
194///
195/// # Arguments
196///
197/// * `connection`: The connection type to use - either TCP or a Unix socket connection
198///
199/// # Returns
200///
201/// An [`IoResult`] containing the server's response as a vector of bytes
202///
203/// # Example
204///
205/// ```
206/// # smol::block_on(async {
207/// let clamd_tcp = clamav_client::smol::Tcp{ host_address: "localhost:3310" };
208/// let response = clamav_client::smol::reload(clamd_tcp).await.unwrap();
209/// # assert!(response == clamav_client::RELOADING);
210/// # })
211/// ```
212///
213pub async fn reload<T: TransportProtocol>(connection: T) -> IoResult {
214 let stream = connection.connect().await?;
215 send_command(stream, RELOAD, Some(RELOADING.len())).await
216}
217
218/// Gets the version number from ClamAV
219///
220/// This function establishes a connection to a ClamAV server and sends the
221/// VERSION command to it. If the server is available, it responds with its
222/// version number.
223///
224/// # Arguments
225///
226/// * `connection`: The connection type to use - either TCP or a Unix socket connection
227///
228/// # Returns
229///
230/// An [`IoResult`] containing the server's response as a vector of bytes
231///
232/// # Example
233///
234/// ```
235/// # smol::block_on(async {
236/// let clamd_tcp = clamav_client::smol::Tcp{ host_address: "localhost:3310" };
237/// let version = clamav_client::smol::get_version(clamd_tcp).await.unwrap();
238/// # assert!(version.starts_with(b"ClamAV"));
239/// # })
240/// ```
241///
242pub async fn get_version<T: TransportProtocol>(connection: T) -> IoResult {
243 let stream = connection.connect().await?;
244 send_command(stream, VERSION, None).await
245}
246
247/// Scans a file for viruses
248///
249/// This function reads data from a file located at the specified `file_path`
250/// and streams it to a ClamAV server for scanning.
251///
252/// # Arguments
253///
254/// * `file_path`: The path to the file to be scanned
255/// * `connection`: The connection type to use - either TCP or a Unix socket connection
256/// * `chunk_size`: An optional chunk size for reading data. If [`None`], a default chunk size is used
257///
258/// # Returns
259///
260/// An [`IoResult`] containing the server's response as a vector of bytes
261///
262pub async fn scan_file<P: AsRef<Path>, T: TransportProtocol>(
263 file_path: P,
264 connection: T,
265 chunk_size: Option<usize>,
266) -> IoResult {
267 let file = File::open(file_path).await?;
268 let stream = connection.connect().await?;
269 scan(file, chunk_size, stream).await
270}
271
272/// Scans a data buffer for viruses
273///
274/// This function streams the provided `buffer` data to a ClamAV server
275///
276/// # Arguments
277///
278/// * `buffer`: The data to be scanned
279/// * `connection`: The connection type to use - either TCP or a Unix socket connection
280/// * `chunk_size`: An optional chunk size for reading data. If [`None`], a default chunk size is used
281///
282/// # Returns
283///
284/// An [`IoResult`] containing the server's response as a vector of bytes
285///
286pub async fn scan_buffer<T: TransportProtocol>(
287 buffer: &[u8],
288 connection: T,
289 chunk_size: Option<usize>,
290) -> IoResult {
291 let stream = connection.connect().await?;
292 scan(buffer, chunk_size, stream).await
293}
294
295/// Scans a stream for viruses
296///
297/// This function sends the provided stream to a ClamAV server for scanning.
298///
299/// # Arguments
300///
301/// * `input_stream`: The stream to be scanned
302/// * `connection`: The connection type to use - either TCP or a Unix socket connection
303/// * `chunk_size`: An optional chunk size for reading data. If [`None`], a default chunk size is used
304///
305/// # Returns
306///
307/// An [`IoResult`] containing the server's response as a vector of bytes
308///
309pub async fn scan_stream<
310 S: Stream<Item = Result<bytes::Bytes, io::Error>>,
311 T: TransportProtocol,
312>(
313 input_stream: S,
314 connection: T,
315 chunk_size: Option<usize>,
316) -> IoResult {
317 let output_stream = connection.connect().await?;
318 _scan_stream(input_stream, chunk_size, output_stream).await
319}
320
321/// Shuts down a ClamAV server
322///
323/// This function establishes a connection to a ClamAV server and sends the
324/// SHUTDOWN command to it. If the server is available, it will perform a clean
325/// exit and shut itself down. The response will be empty.
326///
327/// # Arguments
328///
329/// * `connection`: The connection type to use - either TCP or a Unix socket connection
330///
331/// # Returns
332///
333/// An [`IoResult`] containing the server's response
334///
335pub async fn shutdown<T: TransportProtocol>(connection: T) -> IoResult {
336 let stream = connection.connect().await?;
337 send_command(stream, SHUTDOWN, None).await
338}