use bytes::Bytes;
use std::io;
use std::pin::Pin;
use std::task::{Context, Poll};
use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
use tokio::sync::mpsc;
pub const DEFAULT_MAX_PAYLOAD: usize = 1_048_576;
pub struct Channel {
nc: async_nats::Client,
output_subject: String,
request: serde_json::Value,
read_rx: mpsc::Receiver<Bytes>,
read_buf: Vec<u8>,
read_pos: usize,
closed: bool,
max_payload: usize,
}
impl Channel {
pub(crate) fn new(
nc: async_nats::Client,
output_subject: String,
request: serde_json::Value,
read_rx: mpsc::Receiver<Bytes>,
max_payload: usize,
) -> Self {
Self {
nc,
output_subject,
request,
read_rx,
read_buf: Vec::new(),
read_pos: 0,
closed: false,
max_payload,
}
}
pub fn request(&self) -> &serde_json::Value {
&self.request
}
pub async fn close(&mut self) -> io::Result<()> {
if self.closed {
return Ok(());
}
self.closed = true;
let _ = self
.nc
.publish(
format!("{}.done", self.output_subject),
Bytes::from_static(b"{}"),
)
.await;
Ok(())
}
pub fn is_closed(&self) -> bool {
self.closed
}
}
impl AsyncRead for Channel {
fn poll_read(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &mut ReadBuf<'_>,
) -> Poll<io::Result<()>> {
if self.read_pos < self.read_buf.len() {
let remaining = &self.read_buf[self.read_pos..];
let n = std::cmp::min(remaining.len(), buf.remaining());
buf.put_slice(&remaining[..n]);
self.read_pos += n;
if self.read_pos >= self.read_buf.len() {
self.read_buf.clear();
self.read_pos = 0;
}
return Poll::Ready(Ok(()));
}
match self.read_rx.poll_recv(cx) {
Poll::Ready(Some(data)) => {
let n = std::cmp::min(data.len(), buf.remaining());
buf.put_slice(&data[..n]);
if n < data.len() {
self.read_buf = data[n..].to_vec();
self.read_pos = 0;
}
Poll::Ready(Ok(()))
}
Poll::Ready(None) => Poll::Ready(Ok(())),
Poll::Pending => Poll::Pending,
}
}
}
impl AsyncWrite for Channel {
fn poll_write(
self: Pin<&mut Self>,
_cx: &mut Context<'_>,
buf: &[u8],
) -> Poll<io::Result<usize>> {
if self.closed {
return Poll::Ready(Err(io::Error::new(
io::ErrorKind::BrokenPipe,
"channel closed",
)));
}
if buf.len() > self.max_payload {
return Poll::Ready(Err(io::Error::new(
io::ErrorKind::InvalidInput,
format!(
"payload size {} exceeds max {} bytes",
buf.len(),
self.max_payload
),
)));
}
let nc = self.nc.clone();
let subject = self.output_subject.clone();
let data = Bytes::copy_from_slice(buf);
let len = buf.len();
tokio::spawn(async move {
let _ = nc.publish(subject, data).await;
});
Poll::Ready(Ok(len))
}
fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<io::Result<()>> {
Poll::Ready(Ok(()))
}
fn poll_shutdown(mut self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<io::Result<()>> {
self.closed = true;
Poll::Ready(Ok(()))
}
}