use std::collections::HashMap;
use std::future::Future;
use std::sync::atomic::{AtomicU64, Ordering};
use std::sync::Arc;
use anyhow::anyhow;
use nix::errno::Errno;
use slab::Slab;
use tokio::io::{self, AsyncReadExt, AsyncWrite, AsyncWriteExt, BufWriter};
use tokio::net::tcp::{OwnedReadHalf, OwnedWriteHalf};
use tokio::net::TcpStream;
use tokio::sync::{mpsc, oneshot, Mutex};
use tracing::error;
use crate::Operation;
pub struct Peer<W: AsyncWrite>(Arc<PeerInner<W>>);
struct PeerInner<W: AsyncWrite> {
addr: String,
pendings: Mutex<Slab<Pending<W>>>,
request_chan: mpsc::Sender<Pending<W>>,
}
struct Pending<W: AsyncWrite> {
op: Arc<Operation>,
notify: oneshot::Sender<Response<W>>,
body: W,
}
pub struct Response<W: AsyncWrite> {
pub status: i32,
pub length: u64,
pub body: W,
}
impl<W: 'static + Unpin + Send + AsyncWrite> Peer<W> {
const BUF_SIZE: usize = 1 << 16;
pub async fn new(addr: impl AsRef<str>) -> anyhow::Result<Self> {
let conn = TcpStream::connect(addr.as_ref()).await?;
let (reader, writer) = conn.into_split();
let (tx, rx) = mpsc::channel(1024);
let inner = Arc::new(PeerInner {
addr: addr.as_ref().into(),
pendings: Default::default(),
request_chan: tx,
});
let peer = Self(inner);
let _ = tokio::spawn(peer.request_loop(writer, rx));
let _ = tokio::spawn(peer.response_loop(reader));
Ok(peer)
}
fn request_loop(
&self,
mut writer: OwnedWriteHalf,
mut rx: mpsc::Receiver<Pending<W>>,
) -> impl 'static + Future<Output = ()> {
let peer = self.clone();
async move {
let mut buffer = Vec::<u8>::with_capacity(Self::BUF_SIZE);
loop {
if let Err(err) = peer.fill_buffer(&mut rx, &mut buffer).await {
error!("fill request buffer: {}", err);
continue;
}
if let Err(err) = writer.write_all(&buffer).await {
error!("write buffer into tcp stream: {}", err);
}
}
}
}
async fn fill_buffer(
&self,
rx: &mut mpsc::Receiver<Pending<W>>,
buffer: &mut Vec<u8>,
) -> anyhow::Result<()> {
buffer.clear();
let request = match rx.recv().await {
None => {
return Err(anyhow!("request channel is closed"));
}
Some(r) => r,
};
self.write_request(buffer, request).await?;
if buffer.len() >= Self::BUF_SIZE {
return Ok(());
}
while let Ok(r) = rx.try_recv() {
self.write_request(buffer, r).await?;
if buffer.len() >= Self::BUF_SIZE {
break;
}
}
Ok(())
}
async fn write_request(&self, buffer: &mut Vec<u8>, request: Pending<W>) -> anyhow::Result<()> {
let op = request.op.clone();
let id = self.0.pendings.lock().await.insert(request);
buffer.write_u64(id as u64).await?;
op.encode(buffer).await?;
Ok(())
}
fn response_loop(&self, mut reader: OwnedReadHalf) -> impl 'static + Future<Output = ()> {
let peer = self.clone();
async move {
loop {
if let Err(err) = peer.clone().proccess_resp(&mut reader).await {
error!("fail to process response: {}", err);
break;
}
}
}
}
async fn proccess_resp(self, reader: &mut OwnedReadHalf) -> io::Result<()> {
let id = reader.read_u64().await?;
let mut p = self.0.pendings.lock().await.remove(id as usize);
let status = reader.read_i32().await?;
if status != 0 {
error!("status: {}", Errno::from_i32(status));
return Ok(());
}
let length = reader.read_u64().await?;
if length > 0 {
let mut part = reader.take(length);
io::copy(&mut part, &mut p.body).await?;
}
let _ = p.notify.send(Response {
status: status,
length,
body: p.body,
});
Ok(())
}
pub async fn send_request(&self, op: Operation, w: W) -> anyhow::Result<Response<W>> {
let (tx, rx) = oneshot::channel();
let operation = Arc::new(op);
let pending = Pending {
op: operation.clone(),
notify: tx,
body: w,
};
if let Err(_) = self.0.request_chan.send(pending).await {
Err(anyhow!("request channel is closed"))
} else {
Ok(rx.await.unwrap())
}
}
}
impl<W: AsyncWrite> Clone for Peer<W> {
fn clone(&self) -> Self {
Self(self.0.clone())
}
}