use async_std::{
fs::File,
io::{self, ReadExt, WriteExt},
net::{TcpStream, ToSocketAddrs},
path::Path,
stream::{Stream, StreamExt},
};
#[cfg(unix)]
use async_std::os::unix::net::UnixStream;
use super::{
IoResult, DEFAULT_CHUNK_SIZE, END_OF_STREAM, INSTREAM, PING, PONG, RELOAD, RELOADING, SHUTDOWN,
VERSION,
};
async fn send_command<RW: ReadExt + WriteExt + Unpin>(
mut stream: RW,
command: &[u8],
expected_response_length: Option<usize>,
) -> IoResult {
stream.write_all(command).await?;
stream.flush().await?;
let mut response = match expected_response_length {
Some(len) => Vec::with_capacity(len),
None => Vec::new(),
};
stream.read_to_end(&mut response).await?;
Ok(response)
}
async fn scan<R: ReadExt + Unpin, RW: ReadExt + WriteExt + Unpin>(
mut input: R,
chunk_size: Option<usize>,
mut stream: RW,
) -> IoResult {
stream.write_all(INSTREAM).await?;
let chunk_size = chunk_size
.unwrap_or(DEFAULT_CHUNK_SIZE)
.min(u32::MAX as usize);
let mut buffer = vec![0; chunk_size];
loop {
let len = input.read(&mut buffer[..]).await?;
if len != 0 {
stream.write_all(&(len as u32).to_be_bytes()).await?;
stream.write_all(&buffer[..len]).await?;
} else {
stream.write_all(END_OF_STREAM).await?;
stream.flush().await?;
break;
}
}
let mut response = Vec::new();
stream.read_to_end(&mut response).await?;
Ok(response)
}
async fn _scan_stream<
S: Stream<Item = Result<bytes::Bytes, std::io::Error>>,
RW: ReadExt + WriteExt + Unpin,
>(
input_stream: S,
chunk_size: Option<usize>,
mut output_stream: RW,
) -> IoResult {
output_stream.write_all(INSTREAM).await?;
let chunk_size = chunk_size
.unwrap_or(DEFAULT_CHUNK_SIZE)
.min(u32::MAX as usize);
let mut input_stream = std::pin::pin!(input_stream);
while let Some(bytes) = input_stream.next().await {
let bytes = bytes?;
let bytes = bytes.as_ref();
for chunk in bytes.chunks(chunk_size) {
let len = chunk.len();
output_stream.write_all(&(len as u32).to_be_bytes()).await?;
output_stream.write_all(chunk).await?;
}
}
output_stream.write_all(END_OF_STREAM).await?;
output_stream.flush().await?;
let mut response = Vec::new();
output_stream.read_to_end(&mut response).await?;
Ok(response)
}
#[derive(Copy, Clone)]
pub struct Tcp<A: ToSocketAddrs> {
pub host_address: A,
}
#[derive(Copy, Clone)]
#[cfg(unix)]
pub struct Socket<P: AsRef<Path>> {
pub socket_path: P,
}
pub trait TransportProtocol {
type Stream: ReadExt + WriteExt + Unpin;
fn connect(&self) -> impl std::future::Future<Output = io::Result<Self::Stream>>;
}
impl<A: ToSocketAddrs> TransportProtocol for Tcp<A> {
type Stream = TcpStream;
fn connect(&self) -> impl std::future::Future<Output = io::Result<Self::Stream>> {
TcpStream::connect(&self.host_address)
}
}
#[cfg(unix)]
impl<P: AsRef<Path>> TransportProtocol for Socket<P> {
type Stream = UnixStream;
fn connect(&self) -> impl std::future::Future<Output = io::Result<Self::Stream>> {
UnixStream::connect(&self.socket_path)
}
}
impl<T: TransportProtocol> TransportProtocol for &T {
type Stream = T::Stream;
fn connect(&self) -> impl std::future::Future<Output = io::Result<Self::Stream>> {
TransportProtocol::connect(*self)
}
}
#[cfg(test)]
mod tests {
use super::*;
trait _AssertSendSync: Send + Sync {}
impl _AssertSendSync for Tcp<&str> {}
#[cfg(unix)]
impl _AssertSendSync for Socket<&str> {}
}
pub async fn ping<T: TransportProtocol>(connection: T) -> IoResult {
let stream = connection.connect().await?;
send_command(stream, PING, Some(PONG.len())).await
}
pub async fn reload<T: TransportProtocol>(connection: T) -> IoResult {
let stream = connection.connect().await?;
send_command(stream, RELOAD, Some(RELOADING.len())).await
}
pub async fn get_version<T: TransportProtocol>(connection: T) -> IoResult {
let stream = connection.connect().await?;
send_command(stream, VERSION, None).await
}
pub async fn scan_file<P: AsRef<Path>, T: TransportProtocol>(
file_path: P,
connection: T,
chunk_size: Option<usize>,
) -> IoResult {
let file = File::open(file_path).await?;
let stream = connection.connect().await?;
scan(file, chunk_size, stream).await
}
pub async fn scan_buffer<T: TransportProtocol>(
buffer: &[u8],
connection: T,
chunk_size: Option<usize>,
) -> IoResult {
let stream = connection.connect().await?;
scan(buffer, chunk_size, stream).await
}
pub async fn scan_stream<
S: Stream<Item = Result<bytes::Bytes, io::Error>>,
T: TransportProtocol,
>(
input_stream: S,
connection: T,
chunk_size: Option<usize>,
) -> IoResult {
let output_stream = connection.connect().await?;
_scan_stream(input_stream, chunk_size, output_stream).await
}
pub async fn shutdown<T: TransportProtocol>(connection: T) -> IoResult {
let stream = connection.connect().await?;
send_command(stream, SHUTDOWN, None).await
}