use std::io;
use minarrow::Field;
use tokio::io::{AsyncWrite, AsyncWriteExt};
use crate::models::protocol::codec::LightstreamCodec;
use crate::traits::stream_buffer::StreamBuffer;
pub struct LightstreamWriter<W: AsyncWrite + Unpin + Send, B: StreamBuffer + Unpin = Vec<u8>> {
codec: LightstreamCodec<B>,
dest: W,
}
impl<W: AsyncWrite + Unpin + Send, B: StreamBuffer + Unpin> LightstreamWriter<W, B> {
pub fn new(dest: W) -> Self {
Self {
codec: LightstreamCodec::new(),
dest,
}
}
pub fn register_message(&mut self, name: impl Into<String>) -> u8 {
self.codec.register_message(name)
}
pub fn register_table(&mut self, name: impl Into<String>, schema: Vec<Field>) -> u8 {
self.codec.register_table(name, schema)
}
pub async fn send(&mut self, name: &str, payload: &[u8]) -> io::Result<()> {
let tag = self.codec.tag_by_name(name).ok_or_else(|| {
io::Error::new(
io::ErrorKind::InvalidInput,
format!("unknown type name '{}'", name),
)
})?;
let frame = self.codec.encode_message(tag, payload)?;
self.dest.write_all(frame.as_ref()).await?;
Ok(())
}
pub async fn send_table(&mut self, name: &str, table: &minarrow::Table) -> io::Result<()> {
let tag = self.codec.tag_by_name(name).ok_or_else(|| {
io::Error::new(
io::ErrorKind::InvalidInput,
format!("unknown type name '{}'", name),
)
})?;
let frame = self.codec.encode_table(tag, table)?;
self.dest.write_all(frame.as_ref()).await?;
Ok(())
}
pub async fn flush(&mut self) -> io::Result<()> {
self.dest.flush().await
}
pub async fn shutdown(&mut self) -> io::Result<()> {
self.dest.shutdown().await
}
#[cfg(feature = "protobuf")]
pub async fn send_proto<M: prost::Message>(&mut self, name: &str, msg: &M) -> io::Result<()> {
let bytes = msg.encode_to_vec();
self.send(name, &bytes).await
}
#[cfg(feature = "msgpack")]
pub async fn send_msgpack<M: serde::Serialize>(
&mut self,
name: &str,
msg: &M,
) -> io::Result<()> {
let bytes = encode_msgpack(msg)?;
self.send(name, &bytes).await
}
pub fn codec(&self) -> &LightstreamCodec<B> {
&self.codec
}
}
#[cfg(feature = "msgpack")]
fn encode_msgpack<M: serde::Serialize>(msg: &M) -> io::Result<Vec<u8>> {
let mut buf = Vec::new();
let mut serializer =
rmp_serde::Serializer::new(&mut buf).with_bytes(rmp_serde::config::BytesMode::ForceAll);
msg.serialize(&mut serializer)
.map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e))?;
Ok(buf)
}