use async_trait::async_trait;
use std::path::Path;
use tokio::{
fs::File,
io::{self, AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt},
net::{TcpStream, ToSocketAddrs},
};
#[cfg(unix)]
use tokio::net::UnixStream;
#[cfg(feature = "tokio-stream")]
use tokio_stream::{Stream, StreamExt};
use super::{IoResult, DEFAULT_CHUNK_SIZE, END_OF_STREAM, INSTREAM, PING, PONG, SHUTDOWN, VERSION};
async fn send_command<RW: AsyncRead + AsyncWrite + Unpin>(
mut stream: RW,
command: &[u8],
expected_response_length: Option<usize>,
) -> IoResult {
stream.write_all(command).await?;
stream.flush().await?;
let mut response = match expected_response_length {
Some(len) => Vec::with_capacity(len),
None => Vec::new(),
};
stream.read_to_end(&mut response).await?;
Ok(response)
}
async fn scan<R: AsyncRead + Unpin, RW: AsyncRead + AsyncWrite + Unpin>(
mut input: R,
chunk_size: Option<usize>,
mut stream: RW,
) -> IoResult {
stream.write_all(INSTREAM).await?;
let chunk_size = chunk_size
.unwrap_or(DEFAULT_CHUNK_SIZE)
.min(u32::MAX as usize);
let mut buffer = vec![0; chunk_size];
loop {
let len = input.read(&mut buffer[..]).await?;
if len != 0 {
stream.write_all(&(len as u32).to_be_bytes()).await?;
stream.write_all(&buffer[..len]).await?;
} else {
stream.write_all(END_OF_STREAM).await?;
stream.flush().await?;
break;
}
}
let mut response = Vec::new();
stream.read_to_end(&mut response).await?;
Ok(response)
}
#[cfg(feature = "tokio-stream")]
async fn _scan_stream<
S: Stream<Item = Result<bytes::Bytes, std::io::Error>>,
RW: AsyncRead + AsyncWrite + Unpin,
>(
input_stream: S,
chunk_size: Option<usize>,
mut output_stream: RW,
) -> IoResult {
output_stream.write_all(INSTREAM).await?;
let chunk_size = chunk_size
.unwrap_or(DEFAULT_CHUNK_SIZE)
.min(u32::MAX as usize);
let mut input_stream = std::pin::pin!(input_stream);
while let Some(bytes) = input_stream.next().await {
let bytes = bytes?;
let bytes = bytes.as_ref();
for chunk in bytes.chunks(chunk_size) {
let len = chunk.len();
output_stream.write_all(&(len as u32).to_be_bytes()).await?;
output_stream.write_all(chunk).await?;
}
}
output_stream.write_all(END_OF_STREAM).await?;
output_stream.flush().await?;
let mut response = Vec::new();
output_stream.read_to_end(&mut response).await?;
Ok(response)
}
#[deprecated(since = "0.5.0", note = "Use `ping` instead")]
#[cfg(unix)]
pub async fn ping_socket<P: AsRef<Path>>(socket_path: P) -> IoResult {
ping(Socket { socket_path }).await
}
#[deprecated(since = "0.5.0", note = "Use `scan_file` instead")]
#[cfg(unix)]
pub async fn scan_file_socket<P: AsRef<Path>>(
file_path: P,
socket_path: P,
chunk_size: Option<usize>,
) -> IoResult {
scan_file(file_path, Socket { socket_path }, chunk_size).await
}
#[deprecated(since = "0.5.0", note = "Use `scan_buffer` instead")]
#[cfg(unix)]
pub async fn scan_buffer_socket<P: AsRef<Path>>(
buffer: &[u8],
socket_path: P,
chunk_size: Option<usize>,
) -> IoResult {
scan_buffer(buffer, Socket { socket_path }, chunk_size).await
}
#[deprecated(since = "0.5.0", note = "Use `scan_stream` instead")]
#[cfg(all(unix, feature = "tokio-stream"))]
pub async fn scan_stream_socket<
S: Stream<Item = Result<bytes::Bytes, io::Error>>,
P: AsRef<Path>,
>(
input_stream: S,
socket_path: P,
chunk_size: Option<usize>,
) -> IoResult {
scan_stream(input_stream, Socket { socket_path }, chunk_size).await
}
#[deprecated(since = "0.5.0", note = "Use `ping` instead")]
pub async fn ping_tcp<A: ToSocketAddrs>(host_address: A) -> IoResult {
ping(Tcp { host_address }).await
}
#[deprecated(since = "0.5.0", note = "Use `scan_file` instead")]
pub async fn scan_file_tcp<P: AsRef<Path>, A: ToSocketAddrs>(
file_path: P,
host_address: A,
chunk_size: Option<usize>,
) -> IoResult {
scan_file(file_path, Tcp { host_address }, chunk_size).await
}
#[deprecated(since = "0.5.0", note = "Use `scan_buffer` instead")]
pub async fn scan_buffer_tcp<A: ToSocketAddrs>(
buffer: &[u8],
host_address: A,
chunk_size: Option<usize>,
) -> IoResult {
scan_buffer(buffer, Tcp { host_address }, chunk_size).await
}
#[deprecated(since = "0.5.0", note = "Use `scan_stream` instead")]
#[cfg(feature = "tokio-stream")]
pub async fn scan_stream_tcp<
S: Stream<Item = Result<bytes::Bytes, io::Error>>,
A: ToSocketAddrs,
>(
input_stream: S,
host_address: A,
chunk_size: Option<usize>,
) -> IoResult {
scan_stream(input_stream, Tcp { host_address }, chunk_size).await
}
#[derive(Copy, Clone)]
pub struct Tcp<A: ToSocketAddrs> {
pub host_address: A,
}
#[derive(Copy, Clone)]
#[cfg(unix)]
pub struct Socket<P: AsRef<Path>> {
pub socket_path: P,
}
#[async_trait(?Send)]
pub trait TransportProtocol {
type Stream: AsyncRead + AsyncWrite + Unpin;
async fn connect(&self) -> io::Result<Self::Stream>;
}
#[async_trait(?Send)]
impl<A: ToSocketAddrs> TransportProtocol for Tcp<A> {
type Stream = TcpStream;
async fn connect(&self) -> io::Result<Self::Stream> {
TcpStream::connect(&self.host_address).await
}
}
#[async_trait(?Send)]
#[cfg(unix)]
impl<P: AsRef<Path>> TransportProtocol for Socket<P> {
type Stream = UnixStream;
async fn connect(&self) -> io::Result<Self::Stream> {
UnixStream::connect(&self.socket_path).await
}
}
pub async fn ping<T: TransportProtocol>(connection: T) -> IoResult {
let stream = connection.connect().await?;
send_command(stream, PING, Some(PONG.len())).await
}
pub async fn get_version<T: TransportProtocol>(connection: T) -> IoResult {
let stream = connection.connect().await?;
send_command(stream, VERSION, None).await
}
pub async fn scan_file<P: AsRef<Path>, T: TransportProtocol>(
file_path: P,
connection: T,
chunk_size: Option<usize>,
) -> IoResult {
let file = File::open(file_path).await?;
let stream = connection.connect().await?;
scan(file, chunk_size, stream).await
}
pub async fn scan_buffer<T: TransportProtocol>(
buffer: &[u8],
connection: T,
chunk_size: Option<usize>,
) -> IoResult {
let stream = connection.connect().await?;
scan(buffer, chunk_size, stream).await
}
#[cfg(feature = "tokio-stream")]
pub async fn scan_stream<
S: Stream<Item = Result<bytes::Bytes, io::Error>>,
T: TransportProtocol,
>(
input_stream: S,
connection: T,
chunk_size: Option<usize>,
) -> IoResult {
let output_stream = connection.connect().await?;
_scan_stream(input_stream, chunk_size, output_stream).await
}
pub async fn shutdown<T: TransportProtocol>(connection: T) -> IoResult {
let stream = connection.connect().await?;
send_command(stream, SHUTDOWN, None).await
}