use std::io::{Error, ErrorKind, IoSlice, Result};
use std::pin::Pin;
use std::task::{Context, Poll};
use std::time::Duration;
use async_io::Timer;
use async_net::TcpStream;
use asyncs::select;
use bytes::buf::BufMut;
use futures::io::BufReader;
use futures::prelude::*;
use futures_lite::AsyncReadExt;
#[cfg(feature = "tls")]
pub use futures_rustls::client::TlsStream;
use ignore_result::Ignore;
use tracing::{debug, trace};
use crate::deadline::Deadline;
use crate::endpoint::{EndpointRef, IterableEndpoints};
#[cfg(feature = "tls")]
use crate::tls::TlsClient;
#[derive(Debug)]
pub enum Connection {
Raw(TcpStream),
#[cfg(feature = "tls")]
Tls(Box<TlsStream<TcpStream>>),
}
pub trait AsyncReadToBuf: AsyncReadExt {
async fn read_to_buf(&mut self, buf: &mut impl BufMut) -> Result<usize>
where
Self: Unpin, {
let chunk = buf.chunk_mut();
let read_to =
unsafe { std::mem::transmute::<&mut [std::mem::MaybeUninit<u8>], &mut [u8]>(chunk.as_uninit_slice_mut()) };
let n = self.read(read_to).await?;
if n != 0 {
unsafe {
buf.advance_mut(n);
}
}
Ok(n)
}
}
impl<T> AsyncReadToBuf for T where T: AsyncReadExt {}
impl AsyncRead for Connection {
fn poll_read(self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &mut [u8]) -> Poll<Result<usize>> {
match self.get_mut() {
Self::Raw(stream) => Pin::new(stream).poll_read(cx, buf),
#[cfg(feature = "tls")]
Self::Tls(stream) => Pin::new(stream).poll_read(cx, buf),
}
}
}
impl AsyncWrite for Connection {
fn poll_write(self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &[u8]) -> Poll<Result<usize>> {
match self.get_mut() {
Self::Raw(stream) => Pin::new(stream).poll_write(cx, buf),
#[cfg(feature = "tls")]
Self::Tls(stream) => Pin::new(stream).poll_write(cx, buf),
}
}
fn poll_write_vectored(self: Pin<&mut Self>, cx: &mut Context<'_>, bufs: &[IoSlice<'_>]) -> Poll<Result<usize>> {
match self.get_mut() {
Self::Raw(stream) => Pin::new(stream).poll_write_vectored(cx, bufs),
#[cfg(feature = "tls")]
Self::Tls(stream) => Pin::new(stream).poll_write_vectored(cx, bufs),
}
}
fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<()>> {
match self.get_mut() {
Self::Raw(stream) => Pin::new(stream).poll_flush(cx),
#[cfg(feature = "tls")]
Self::Tls(stream) => Pin::new(stream).poll_flush(cx),
}
}
fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<()>> {
match self.get_mut() {
Self::Raw(stream) => Pin::new(stream).poll_close(cx),
#[cfg(feature = "tls")]
Self::Tls(stream) => Pin::new(stream).poll_close(cx),
}
}
}
pub struct ConnReader<'a> {
conn: &'a mut Connection,
}
impl AsyncRead for ConnReader<'_> {
fn poll_read(self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &mut [u8]) -> Poll<Result<usize>> {
Pin::new(&mut self.get_mut().conn).poll_read(cx, buf)
}
}
pub struct ConnWriter<'a> {
conn: &'a mut Connection,
}
impl AsyncWrite for ConnWriter<'_> {
fn poll_write(self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &[u8]) -> Poll<Result<usize>> {
Pin::new(&mut self.get_mut().conn).poll_write(cx, buf)
}
fn poll_write_vectored(self: Pin<&mut Self>, cx: &mut Context<'_>, bufs: &[IoSlice<'_>]) -> Poll<Result<usize>> {
Pin::new(&mut self.get_mut().conn).poll_write_vectored(cx, bufs)
}
fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<()>> {
Pin::new(&mut self.get_mut().conn).poll_flush(cx)
}
fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<()>> {
Pin::new(&mut self.get_mut().conn).poll_close(cx)
}
}
impl Connection {
pub fn new_raw(stream: TcpStream) -> Self {
Self::Raw(stream)
}
pub fn split(&mut self) -> (ConnReader<'_>, ConnWriter<'_>) {
let reader = ConnReader { conn: self };
let writer = ConnWriter { conn: unsafe { std::ptr::read(&reader.conn) } };
(reader, writer)
}
#[cfg(feature = "tls")]
pub fn new_tls(stream: TlsStream<TcpStream>) -> Self {
Self::Tls(stream.into())
}
pub async fn command(mut self, cmd: &str) -> Result<String> {
self.write_all(cmd.as_bytes()).await?;
self.flush().await?;
let mut line = String::new();
let mut reader = BufReader::new(self);
reader.read_line(&mut line).await?;
reader.close().await.ignore();
Ok(line)
}
pub async fn command_isro(self) -> Result<bool> {
let r = self.command("isro").await?;
if r == "rw" {
Ok(true)
} else {
Ok(false)
}
}
}
#[derive(Clone)]
pub struct Connector {
#[cfg(feature = "tls")]
tls: Option<TlsClient>,
timeout: Duration,
}
impl Connector {
pub fn new() -> Self {
Self {
#[cfg(feature = "tls")]
tls: None,
timeout: Duration::from_secs(10),
}
}
#[cfg(feature = "tls")]
pub fn with_tls(client: TlsClient) -> Self {
Self { tls: Some(client), timeout: Duration::from_secs(10) }
}
pub fn timeout(&self) -> Duration {
self.timeout
}
pub fn set_timeout(&mut self, timeout: Duration) {
self.timeout = timeout;
}
async fn connect_endpoint(&self, endpoint: EndpointRef<'_>) -> Result<Connection> {
if endpoint.tls {
#[cfg(feature = "tls")]
return match self.tls.as_ref() {
None => return Err(Error::new(ErrorKind::Unsupported, "tls not configured")),
Some(client) => client.connect(endpoint.host, endpoint.port).await.map(Connection::new_tls),
};
#[cfg(not(feature = "tls"))]
return Err(Error::new(ErrorKind::Unsupported, "tls not supported"));
}
TcpStream::connect((endpoint.host, endpoint.port)).await.map(Connection::new_raw)
}
pub async fn connect(&self, endpoint: EndpointRef<'_>, deadline: &mut Deadline) -> Result<Connection> {
select! {
biased;
r = self.connect_endpoint(endpoint) => r,
_ = unsafe { Pin::new_unchecked(deadline) } => Err(Error::new(ErrorKind::TimedOut, "deadline exceed")),
_ = Timer::after(self.timeout) => Err(Error::new(ErrorKind::TimedOut, format!("connection timeout{:?} exceed", self.timeout))),
}
}
pub async fn seek_for_writable(self, endpoints: &mut IterableEndpoints) -> Option<EndpointRef<'_>> {
let n = endpoints.len();
let max_timeout = Duration::from_secs(60);
let mut i = 0;
let mut timeout = Duration::from_millis(100);
let mut deadline = Deadline::never();
while let Some(endpoint) = endpoints.peek() {
i += 1;
match self.connect(endpoint, &mut deadline).await {
Ok(conn) => match conn.command_isro().await {
Ok(true) => {
return Some(unsafe { std::mem::transmute::<EndpointRef<'_>, EndpointRef<'_>>(endpoint) })
},
Ok(false) => trace!("succeeds to contact readonly {}", endpoint),
Err(err) => trace!(%err, r#"fails to complete "isro" to {}"#, endpoint),
},
Err(err) => trace!(%err, "fails to contact {}", endpoint),
}
endpoints.step();
if i % n == 0 {
debug!(
sleep = timeout.as_millis(),
"fails to contact writable server from endpoints {:?}",
endpoints.endpoints()
);
Timer::after(timeout).await;
timeout = max_timeout.min(timeout * 2);
} else {
Timer::after(Duration::from_millis(5)).await;
}
}
None
}
}
#[cfg(test)]
mod tests {
use std::io::ErrorKind;
use super::Connector;
use crate::deadline::Deadline;
use crate::endpoint::EndpointRef;
#[asyncs::test]
async fn raw() {
let connector = Connector::new();
let endpoint = EndpointRef::new("host1", 2181, true);
let err = connector.connect(endpoint, &mut Deadline::never()).await.unwrap_err();
assert_eq!(err.kind(), ErrorKind::Unsupported);
}
}