use serde::{de::DeserializeOwned, Serialize};
use std::io;
use std::marker::PhantomData;
use std::pin::Pin;
use tokio_util::bytes::{Bytes, BytesMut};
#[derive(Debug)]
pub struct MessagePackNamed<Item, SinkItem> {
_item: PhantomData<fn() -> Item>,
_sink_item: PhantomData<fn(SinkItem)>,
}
impl<Item, SinkItem> Default for MessagePackNamed<Item, SinkItem> {
fn default() -> Self {
Self {
_item: PhantomData,
_sink_item: PhantomData,
}
}
}
impl<Item, SinkItem> Clone for MessagePackNamed<Item, SinkItem> {
fn clone(&self) -> Self {
Self::default()
}
}
impl<Item, SinkItem> tokio_serde::Deserializer<Item> for MessagePackNamed<Item, SinkItem>
where
Item: DeserializeOwned,
{
type Error = io::Error;
fn deserialize(self: Pin<&mut Self>, src: &BytesMut) -> Result<Item, Self::Error> {
rmp_serde::from_slice(src).map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e))
}
}
impl<Item, SinkItem> tokio_serde::Serializer<SinkItem> for MessagePackNamed<Item, SinkItem>
where
SinkItem: Serialize,
{
type Error = io::Error;
fn serialize(self: Pin<&mut Self>, item: &SinkItem) -> Result<Bytes, Self::Error> {
rmp_serde::to_vec_named(item)
.map(Into::into)
.map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e))
}
}
use tokio::io::{AsyncRead, AsyncWrite};
use tokio_util::codec::LengthDelimitedCodec;
fn framed<T: AsyncRead + AsyncWrite>(
stream: T,
) -> tokio_util::codec::Framed<T, LengthDelimitedCodec> {
LengthDelimitedCodec::builder().new_framed(stream)
}
pub fn create_client_transport<T: AsyncRead + AsyncWrite + Unpin>(
stream: T,
) -> impl futures::Stream<
Item = Result<tarpc::Response<crate::service::RyoServiceResponse>, std::io::Error>,
> + futures::Sink<
tarpc::ClientMessage<crate::service::RyoServiceRequest>,
Error = std::io::Error,
> {
tarpc::serde_transport::new(framed(stream), MessagePackNamed::default())
}
pub fn create_server_transport<T: AsyncRead + AsyncWrite + Unpin>(
stream: T,
) -> impl futures::Stream<
Item = Result<tarpc::ClientMessage<crate::service::RyoServiceRequest>, std::io::Error>,
> + futures::Sink<tarpc::Response<crate::service::RyoServiceResponse>, Error = std::io::Error> {
tarpc::serde_transport::new(framed(stream), MessagePackNamed::default())
}
#[cfg(test)]
mod tests {
#[test]
fn test_messagepack_named_roundtrip() {
use serde::{Deserialize, Serialize};
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
struct TestStruct {
name: String,
#[serde(default, skip_serializing_if = "Vec::is_empty")]
items: Vec<String>,
#[serde(default)]
count: usize,
}
let original = TestStruct {
name: "test".to_string(),
items: vec![], count: 42,
};
let encoded = rmp_serde::to_vec_named(&original).unwrap();
let decoded: TestStruct = rmp_serde::from_slice(&encoded).unwrap();
assert_eq!(original, decoded);
}
#[test]
fn test_skip_serializing_if_with_named() {
use serde::{Deserialize, Serialize};
#[derive(Debug, Serialize, Deserialize)]
struct Response {
#[serde(default, skip_serializing_if = "Vec::is_empty")]
patterns: Vec<String>,
#[serde(default)]
applied: bool,
#[serde(default)]
files_modified: usize,
}
let response = Response {
patterns: vec!["pattern1".to_string()],
applied: false,
files_modified: 0,
};
let encoded = rmp_serde::to_vec_named(&response).unwrap();
let _decoded: Response = rmp_serde::from_slice(&encoded).unwrap();
let empty_response = Response {
patterns: vec![],
applied: false,
files_modified: 0,
};
let encoded = rmp_serde::to_vec_named(&empty_response).unwrap();
let decoded: Response = rmp_serde::from_slice(&encoded).unwrap();
assert!(decoded.patterns.is_empty());
}
}