#![warn(missing_docs)]
use crate::base::Message;
use crate::net::client::protocol::{
AsyncConnect, AsyncDgramRecv, AsyncDgramRecvEx, AsyncDgramSend,
AsyncDgramSendEx,
};
use crate::net::client::request::{
ComposeRequest, Error, GetResponse, SendRequest,
};
use crate::utils::config::DefMinMax;
use bytes::Bytes;
use core::fmt;
use octseq::OctetsInto;
use std::boxed::Box;
use std::future::Future;
use std::pin::Pin;
use std::sync::Arc;
use std::vec::Vec;
use std::{error, io};
use tokio::sync::Semaphore;
use tokio::time::{timeout_at, Duration, Instant};
use tracing::trace;
const MAX_PARALLEL: DefMinMax<usize> = DefMinMax::new(100, 1, 1000);
const READ_TIMEOUT: DefMinMax<Duration> = DefMinMax::new(
Duration::from_secs(5),
Duration::from_millis(1),
Duration::from_secs(60),
);
const MAX_RETRIES: DefMinMax<u8> = DefMinMax::new(5, 0, 100);
const DEF_UDP_PAYLOAD_SIZE: u16 = 1232;
const DEF_RECV_SIZE: usize = 2000;
#[derive(Clone, Debug)]
pub struct Config {
max_parallel: usize,
read_timeout: Duration,
max_retries: u8,
udp_payload_size: Option<u16>,
recv_size: usize,
}
impl Config {
pub fn new() -> Self {
Default::default()
}
pub fn set_max_parallel(&mut self, value: usize) {
self.max_parallel = MAX_PARALLEL.limit(value)
}
pub fn max_parallel(&self) -> usize {
self.max_parallel
}
pub fn set_read_timeout(&mut self, value: Duration) {
self.read_timeout = READ_TIMEOUT.limit(value)
}
pub fn read_timeout(&self) -> Duration {
self.read_timeout
}
pub fn set_max_retries(&mut self, value: u8) {
self.max_retries = MAX_RETRIES.limit(value)
}
pub fn max_retries(&self) -> u8 {
self.max_retries
}
pub fn set_udp_payload_size(&mut self, value: Option<u16>) {
self.udp_payload_size = value;
}
pub fn udp_payload_size(&self) -> Option<u16> {
self.udp_payload_size
}
pub fn set_recv_size(&mut self, size: usize) {
self.recv_size = size
}
pub fn recv_size(&self) -> usize {
self.recv_size
}
}
impl Default for Config {
fn default() -> Self {
Self {
max_parallel: MAX_PARALLEL.default(),
read_timeout: READ_TIMEOUT.default(),
max_retries: MAX_RETRIES.default(),
udp_payload_size: Some(DEF_UDP_PAYLOAD_SIZE),
recv_size: DEF_RECV_SIZE,
}
}
}
#[derive(Clone, Debug)]
pub struct Connection<S> {
state: Arc<ConnectionState<S>>,
}
#[derive(Debug)]
struct ConnectionState<S> {
config: Config,
connect: S,
semaphore: Semaphore,
}
impl<S> Connection<S> {
pub fn new(connect: S) -> Self {
Self::with_config(connect, Default::default())
}
pub fn with_config(connect: S, config: Config) -> Self {
Self {
state: Arc::new(ConnectionState {
semaphore: Semaphore::new(config.max_parallel),
config,
connect,
}),
}
}
}
impl<S> Connection<S>
where
S: AsyncConnect,
S::Connection: AsyncDgramRecv + AsyncDgramSend + Unpin,
{
async fn handle_request_impl<Req: ComposeRequest>(
self,
mut request: Req,
) -> Result<Message<Bytes>, Error> {
let _permit = self
.state
.semaphore
.acquire()
.await
.expect("semaphore closed");
let mut buf = Vec::new();
for _ in 0..1 + self.state.config.max_retries {
let mut sock = self
.state
.connect
.connect()
.await
.map_err(QueryError::connect)?;
request.header_mut().set_random_id();
if let Some(size) = self.state.config.udp_payload_size {
request.set_udp_payload_size(size)
}
let request_msg = request.to_message()?;
let dgram = request_msg.as_slice();
let sent = sock.send(dgram).await.map_err(QueryError::send)?;
if sent != dgram.len() {
return Err(QueryError::short_send().into());
}
let deadline = Instant::now() + self.state.config.read_timeout;
while deadline > Instant::now() {
buf.resize(self.state.config.recv_size, 0);
let len =
match timeout_at(deadline, sock.recv(&mut buf)).await {
Ok(Ok(len)) => len,
Ok(Err(err)) => {
return Err(QueryError::receive(err).into());
}
Err(_) => {
trace!("Receive timed out");
break;
}
};
trace!("Received {len} bytes of message");
buf.truncate(len);
let answer = match Message::try_from_octets(buf) {
Ok(answer) => answer,
Err(old_buf) => {
trace!("Received bytes were garbage, reading more");
buf = old_buf;
continue;
}
};
if !request.is_answer(answer.for_slice()) {
trace!("Received message is not the answer we were waiting for, reading more");
buf = answer.into_octets();
continue;
}
trace!("Received message is accepted");
return Ok(answer.octets_into());
}
}
Err(QueryError::timeout().into())
}
}
impl<S, Req> SendRequest<Req> for Connection<S>
where
S: AsyncConnect + Clone + Send + Sync + 'static,
S::Connection:
AsyncDgramRecv + AsyncDgramSend + Send + Sync + Unpin + 'static,
Req: ComposeRequest + Send + Sync + 'static,
{
fn send_request(
&self,
request_msg: Req,
) -> Box<dyn GetResponse + Send + Sync> {
Box::new(Request {
fut: Box::pin(self.clone().handle_request_impl(request_msg)),
})
}
}
pub struct Request {
fut: Pin<
Box<dyn Future<Output = Result<Message<Bytes>, Error>> + Send + Sync>,
>,
}
impl Request {
async fn get_response_impl(&mut self) -> Result<Message<Bytes>, Error> {
(&mut self.fut).await
}
}
impl fmt::Debug for Request {
fn fmt(&self, _: &mut fmt::Formatter<'_>) -> Result<(), fmt::Error> {
todo!()
}
}
impl GetResponse for Request {
fn get_response(
&mut self,
) -> Pin<
Box<
dyn Future<Output = Result<Message<Bytes>, Error>>
+ Send
+ Sync
+ '_,
>,
> {
Box::pin(self.get_response_impl())
}
}
#[derive(Debug)]
pub struct QueryError {
kind: QueryErrorKind,
io: std::io::Error,
}
impl QueryError {
fn new(kind: QueryErrorKind, io: io::Error) -> Self {
Self { kind, io }
}
fn connect(io: io::Error) -> Self {
Self::new(QueryErrorKind::Connect, io)
}
fn send(io: io::Error) -> Self {
Self::new(QueryErrorKind::Send, io)
}
fn short_send() -> Self {
Self::new(
QueryErrorKind::Send,
io::Error::other("short request sent"),
)
}
fn timeout() -> Self {
Self::new(
QueryErrorKind::Timeout,
io::Error::new(io::ErrorKind::TimedOut, "timeout expired"),
)
}
fn receive(io: io::Error) -> Self {
Self::new(QueryErrorKind::Receive, io)
}
}
impl QueryError {
pub fn kind(&self) -> QueryErrorKind {
self.kind
}
pub fn io_error(self) -> std::io::Error {
self.io
}
}
impl From<QueryError> for std::io::Error {
fn from(err: QueryError) -> std::io::Error {
err.io
}
}
impl fmt::Display for QueryError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "{}: {}", self.kind.error_str(), self.io)
}
}
impl error::Error for QueryError {}
#[derive(Copy, Clone, Debug)]
pub enum QueryErrorKind {
Connect,
Send,
Timeout,
Receive,
}
impl QueryErrorKind {
fn error_str(self) -> &'static str {
match self {
Self::Connect => "connecting failed",
Self::Send => "sending request failed",
Self::Timeout | Self::Receive => "reading response failed",
}
}
}
impl fmt::Display for QueryErrorKind {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.write_str(match self {
Self::Connect => "connecting failed",
Self::Send => "sending request failed",
Self::Timeout => "request timeout",
Self::Receive => "reading response failed",
})
}
}