use async_std::io::{Read as AsyncRead};
use async_std::prelude::*;
use async_std::task::{ready, Context, Poll};
use std::io;
use std::pin::Pin;
use std::time::Duration;
use std::io::Write;
use flate2::{GzBuilder, Compression};
#[derive(Debug)]
struct WriteBuf {
buf: Vec<u8>,
}
impl Write for WriteBuf {
fn write(&mut self, buf: &[u8]) -> std::io::Result<usize> {
self.buf.extend(buf);
Ok(buf.len())
}
fn flush(&mut self) -> std::io::Result<()> {
Ok(())
}
}
pin_project_lite::pin_project! {
#[derive(Debug)]
pub struct Encoder {
buf: Vec<u8>,
cursor: usize,
#[pin]
receiver: async_channel::Receiver<Vec<u8>>,
gz: flate2::write::GzEncoder<WriteBuf>,
gz_enabled: bool,
}
}
impl AsyncRead for Encoder {
fn poll_read(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &mut [u8],
) -> Poll<io::Result<usize>> {
let mut this = self.project();
if this.buf.len() <= *this.cursor {
match ready!(this.receiver.as_mut().poll_next(cx)) {
Some(mut buf) => {
log::trace!("> Received a new buffer with len {}", buf.len());
if *this.gz_enabled {
this.gz.write_all(&buf)?;
this.gz.flush()?;
let inner = this.gz.get_mut();
std::mem::swap(&mut inner.buf, &mut buf);
inner.buf.clear();
}
*this.buf = buf;
*this.cursor = 0;
}
None => {
log::trace!("> Encoder done reading");
return Poll::Ready(Ok(0));
}
};
}
let local_buf = &this.buf[*this.cursor..];
let max = buf.len().min(local_buf.len());
buf[..max].clone_from_slice(&local_buf[..max]);
*this.cursor += max;
Poll::Ready(Ok(max))
}
}
#[derive(Debug, Clone)]
pub struct Sender(async_channel::Sender<Vec<u8>>);
pub fn encode(is_gzip: bool) -> (Sender, Encoder) {
let (sender, receiver) = async_channel::bounded(1);
let write_buf = WriteBuf { buf: Vec::new() };
let gz = GzBuilder::new()
.write(write_buf, Compression::default());
let encoder = Encoder {
receiver,
buf: Vec::new(),
cursor: 0,
gz: gz,
gz_enabled: is_gzip,
};
(Sender(sender), encoder)
}
impl Sender {
async fn inner_send(&self, bytes: impl Into<Vec<u8>>) -> io::Result<()> {
self.0
.send(bytes.into())
.await
.map_err(|_| io::Error::new(io::ErrorKind::ConnectionAborted, "sse disconnected"))
}
pub async fn send(&self, name: &str, data: &str, id: Option<&str>) -> io::Result<()> {
let id_string: String = if let Some(id) = id {
format!("id:{}\n", id)
} else {
"".into()
};
let msg = format!("event:{}\n{}data:{}\n\n", name, id_string, data);
self.inner_send(msg).await?;
Ok(())
}
#[allow(dead_code)]
pub async fn send_retry(&self, dur: Duration, id: Option<&str>) -> io::Result<()> {
if let Some(id) = id {
self.inner_send(format!("id:{}\n", id)).await?;
}
let dur = dur.as_secs_f64() as u64;
let msg = format!("retry:{}\n\n", dur);
self.inner_send(msg).await?;
Ok(())
}
}