#![warn(missing_docs)]
use core::fmt;
use std::boxed::Box;
use std::future::Future;
use std::pin::Pin;
use std::sync::Arc;
use bytes::Bytes;
use octseq::OctetsInto;
use tracing::trace;
use crate::base::Message;
use crate::net::client::protocol::{
AsyncConnect, AsyncDgramRecv, AsyncDgramRecvEx, AsyncDgramSend,
AsyncDgramSendEx,
};
use crate::net::client::request::{
ComposeRequest, Error, GetResponse, SendRequest,
};
const DEF_RECV_SIZE: usize = 2000;
#[derive(Clone, Debug)]
pub struct Config {
recv_size: usize,
}
impl Config {
#[allow(dead_code)]
pub fn new() -> Self {
Default::default()
}
#[allow(dead_code)]
pub fn set_recv_size(&mut self, size: usize) {
self.recv_size = size
}
#[allow(dead_code)]
pub fn recv_size(&self) -> usize {
self.recv_size
}
}
impl Default for Config {
fn default() -> Self {
Self {
recv_size: DEF_RECV_SIZE,
}
}
}
#[derive(Clone, Debug)]
pub struct Connection<S> {
state: Arc<ConnectionState<S>>,
}
#[derive(Debug)]
struct ConnectionState<S> {
config: Config,
connect: S,
}
impl<S> Connection<S> {
pub fn new(connect: S) -> Self {
Self::with_config(connect, Default::default())
}
pub fn with_config(connect: S, config: Config) -> Self {
Self {
state: Arc::new(ConnectionState { config, connect }),
}
}
}
impl<S> Connection<S>
where
S: AsyncConnect,
S::Connection: AsyncDgramRecv + AsyncDgramSend + Unpin,
{
async fn handle_request_impl<Req: ComposeRequest>(
self,
mut request: Req,
) -> Result<Message<Bytes>, Error> {
let mut reuse_buf = None;
let mut sock = self
.state
.connect
.connect()
.await
.map_err(|_| Error::ConnectionClosed)?;
request.header_mut().set_id(0);
let request_msg = request.to_message()?;
let dgram = request_msg.as_slice();
let sent = sock
.send(dgram)
.await
.map_err(|err| Error::StreamWriteError(Arc::new(err)))?;
if sent != dgram.len() {
return Err(Error::ShortMessage);
}
let mut buf = reuse_buf.take().unwrap_or_else(|| {
vec![0; self.state.config.recv_size]
});
let len = sock
.recv(&mut buf)
.await
.map_err(|err| Error::StreamReadError(Arc::new(err)))?;
trace!("Received {len} bytes of message");
buf.truncate(len);
let answer = Message::try_from_octets(buf)
.expect("Response could not be parsed");
trace!("Received message is accepted");
Ok(answer.octets_into())
}
}
impl<S, Req> SendRequest<Req> for Connection<S>
where
S: AsyncConnect + Clone + Send + Sync + 'static,
S::Connection:
AsyncDgramRecv + AsyncDgramSend + Send + Sync + Unpin + 'static,
Req: ComposeRequest + Send + Sync + 'static,
{
fn send_request(
&self,
request_msg: Req,
) -> Box<dyn GetResponse + Send + Sync> {
Box::new(Request {
fut: Box::pin(self.clone().handle_request_impl(request_msg)),
})
}
}
pub struct Request {
fut: Pin<
Box<dyn Future<Output = Result<Message<Bytes>, Error>> + Send + Sync>,
>,
}
impl Request {
async fn get_response_impl(&mut self) -> Result<Message<Bytes>, Error> {
(&mut self.fut).await
}
}
impl fmt::Debug for Request {
fn fmt(&self, _: &mut fmt::Formatter<'_>) -> Result<(), fmt::Error> {
todo!()
}
}
impl GetResponse for Request {
fn get_response(
&mut self,
) -> Pin<
Box<
dyn Future<Output = Result<Message<Bytes>, Error>>
+ Send
+ Sync
+ '_,
>,
> {
Box::pin(self.get_response_impl())
}
}