use std::{convert::TryFrom, marker::PhantomData, mem::size_of};
use bytes::{BufMut, Bytes, BytesMut};
use futures_channel::oneshot;
use quinn::{SendStream, VarInt};
use serde::Serialize;
use super::Task;
use crate::{Error, Result};
#[derive(Clone, Debug)]
pub struct Sender<T: Serialize> {
sender: flume::Sender<Bytes>,
_type: PhantomData<T>,
task: Task<Result<()>, Message>,
}
#[derive(Clone, Debug)]
enum Message {
Data(Bytes),
Finish,
Close,
}
impl<T: Serialize> Sender<T> {
pub(super) fn new(mut stream_sender: SendStream) -> Self {
let (sender, receiver) = flume::unbounded();
let (shutdown_sender, shutdown_receiver) = oneshot::channel();
let task = Task::new(
async move {
let mut receiver = receiver.into_stream();
let mut shutdown = shutdown_receiver;
while let Some(message) = allochronic_util::select! {
message: &mut receiver => message.map(Message::Data),
shutdown: &mut shutdown => shutdown.ok(),
} {
match message {
Message::Data(bytes) => stream_sender
.write_chunk(bytes)
.await
.map_err(Error::Write)?,
Message::Finish => {
stream_sender.finish().await.map_err(Error::Finish)?;
break;
}
Message::Close => {
stream_sender
.reset(VarInt::from_u32(0))
.map_err(|_error| Error::AlreadyClosed)?;
break;
}
}
}
Ok(())
},
shutdown_sender,
);
Self {
sender,
_type: PhantomData,
task,
}
}
pub fn send(&self, data: &T) -> Result<()> {
self.send_any(data)
}
#[allow(clippy::panic_in_result_fn, clippy::unwrap_in_result)]
pub(super) fn send_any<A: Serialize>(&self, data: &A) -> Result<()> {
let mut bytes = BytesMut::new();
#[allow(box_pointers)]
let len = bincode::serialized_size(&data).map_err(|error| Error::Serialize(*error))?;
#[allow(clippy::expect_used)]
bytes.reserve(
usize::try_from(len)
.expect("not a 64-bit system")
.checked_add(size_of::<u64>())
.expect("data trying to be sent is too big"),
);
bytes.put_u64_le(len);
let mut bytes = bytes.writer();
#[allow(box_pointers)]
bincode::serialize_into(&mut bytes, &data).map_err(|error| Error::Serialize(*error))?;
let bytes = bytes.into_inner().freeze();
#[allow(clippy::expect_used)]
{
debug_assert_eq!(
u64::try_from(bytes.len()).expect("not a 64-bit system"),
u64::try_from(size_of::<u64>())
.expect("not a 64-bit system")
.checked_add(len)
.expect("message to long")
);
}
self.sender.send(bytes).map_err(|_bytes| Error::Send)
}
pub async fn finish(&self) -> Result<()> {
self.task.close(Message::Finish).await?
}
pub async fn close(&self) -> Result<()> {
self.task.close(Message::Close).await?
}
}