ryo-app 0.1.0

[preview] Application layer for RYO - Project management, Intent handling, API
Documentation
//! tarpc transport codec for RYO RPC communication.
//!
//! # Why MessagePackNamed?
//!
//! RYO uses MessagePack for RPC serialization via tarpc. There are two serialization modes:
//!
//! | Mode | Serialization | `skip_serializing_if` |
//! |------|--------------|----------------------|
//! | Array-based (default) | `[value1, value2, ...]` | **Incompatible** |
//! | Named (map-based) | `{"field1": value1, ...}` | Compatible |
//!
//! Many response types use `#[serde(skip_serializing_if = "...")]` to reduce payload size.
//! This requires **named serialization** where fields are identified by name, not position.
//!
//! Using array-based serialization with `skip_serializing_if` causes deserialization failures:
//! ```text
//! invalid type: boolean `false`, expected a sequence
//! ```
//!
//! # Usage
//!
//! Always use the helper functions to create transports:
//!
//! ```ignore
//! use ryo_app::codec::create_client_transport;
//! use tokio::net::UnixStream;
//!
//! let stream = UnixStream::connect(socket_path).await?;
//! let transport = create_client_transport(stream);
//! let client = RyoServiceClient::new(config, transport).spawn();
//! ```
//!
//! # Important
//!
//! **DO NOT** use `tokio_serde::formats::MessagePack::default()` directly.
//! It uses array-based serialization which is incompatible with `skip_serializing_if`.

use serde::{de::DeserializeOwned, Serialize};
use std::io;
use std::marker::PhantomData;
use std::pin::Pin;
use tokio_util::bytes::{Bytes, BytesMut};

/// MessagePack codec with named (map-based) serialization.
///
/// This codec uses `rmp_serde::to_vec_named` for serialization, which produces
/// map-based output compatible with `skip_serializing_if` attributes.
///
/// # Example
///
/// ```ignore
/// let transport = tarpc::serde_transport::new(
///     tokio_util::codec::LengthDelimitedCodec::builder().new_framed(stream),
///     MessagePackNamed::default(),
/// );
/// ```
#[derive(Debug)]
pub struct MessagePackNamed<Item, SinkItem> {
    _item: PhantomData<fn() -> Item>,
    _sink_item: PhantomData<fn(SinkItem)>,
}

impl<Item, SinkItem> Default for MessagePackNamed<Item, SinkItem> {
    fn default() -> Self {
        Self {
            _item: PhantomData,
            _sink_item: PhantomData,
        }
    }
}

impl<Item, SinkItem> Clone for MessagePackNamed<Item, SinkItem> {
    fn clone(&self) -> Self {
        Self::default()
    }
}

impl<Item, SinkItem> tokio_serde::Deserializer<Item> for MessagePackNamed<Item, SinkItem>
where
    Item: DeserializeOwned,
{
    type Error = io::Error;

    fn deserialize(self: Pin<&mut Self>, src: &BytesMut) -> Result<Item, Self::Error> {
        rmp_serde::from_slice(src).map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e))
    }
}

impl<Item, SinkItem> tokio_serde::Serializer<SinkItem> for MessagePackNamed<Item, SinkItem>
where
    SinkItem: Serialize,
{
    type Error = io::Error;

    fn serialize(self: Pin<&mut Self>, item: &SinkItem) -> Result<Bytes, Self::Error> {
        rmp_serde::to_vec_named(item)
            .map(Into::into)
            .map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e))
    }
}

// ============================================================================
// Transport Factory Functions
// ============================================================================

use tokio::io::{AsyncRead, AsyncWrite};
use tokio_util::codec::LengthDelimitedCodec;

/// Create a framed transport with LengthDelimitedCodec.
///
/// Helper to reduce boilerplate when creating transports.
fn framed<T: AsyncRead + AsyncWrite>(
    stream: T,
) -> tokio_util::codec::Framed<T, LengthDelimitedCodec> {
    LengthDelimitedCodec::builder().new_framed(stream)
}

/// Create a tarpc transport with the correct codec for client-side use.
///
/// This is the recommended way to create a client transport. It ensures:
/// - Named MessagePack serialization (compatible with `skip_serializing_if`)
/// - Proper framing with `LengthDelimitedCodec`
///
/// # Example
///
/// ```ignore
/// use ryo_app::codec::create_client_transport;
/// use tokio::net::UnixStream;
///
/// let stream = UnixStream::connect(socket_path).await?;
/// let transport = create_client_transport(stream);
/// let client = RyoServiceClient::new(config, transport).spawn();
/// ```
pub fn create_client_transport<T: AsyncRead + AsyncWrite + Unpin>(
    stream: T,
) -> impl futures::Stream<
    Item = Result<tarpc::Response<crate::service::RyoServiceResponse>, std::io::Error>,
> + futures::Sink<
    tarpc::ClientMessage<crate::service::RyoServiceRequest>,
    Error = std::io::Error,
> {
    tarpc::serde_transport::new(framed(stream), MessagePackNamed::default())
}

/// Create a tarpc transport with the correct codec for server-side use.
///
/// This is the recommended way to create a server transport. It ensures:
/// - Named MessagePack serialization (compatible with `skip_serializing_if`)
/// - Proper framing with `LengthDelimitedCodec`
///
/// # Example
///
/// ```ignore
/// use ryo_app::codec::create_server_transport;
/// use tokio::net::UnixListener;
///
/// let (stream, _) = listener.accept().await?;
/// let transport = create_server_transport(stream);
/// let channel = tarpc::server::BaseChannel::with_defaults(transport);
/// ```
pub fn create_server_transport<T: AsyncRead + AsyncWrite + Unpin>(
    stream: T,
) -> impl futures::Stream<
    Item = Result<tarpc::ClientMessage<crate::service::RyoServiceRequest>, std::io::Error>,
> + futures::Sink<tarpc::Response<crate::service::RyoServiceResponse>, Error = std::io::Error> {
    tarpc::serde_transport::new(framed(stream), MessagePackNamed::default())
}

#[cfg(test)]
mod tests {

    #[test]
    fn test_messagepack_named_roundtrip() {
        use serde::{Deserialize, Serialize};

        #[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
        struct TestStruct {
            name: String,
            #[serde(default, skip_serializing_if = "Vec::is_empty")]
            items: Vec<String>,
            #[serde(default)]
            count: usize,
        }

        let original = TestStruct {
            name: "test".to_string(),
            items: vec![], // Will be skipped during serialization
            count: 42,
        };

        // Serialize with named codec
        let encoded = rmp_serde::to_vec_named(&original).unwrap();

        // Deserialize
        let decoded: TestStruct = rmp_serde::from_slice(&encoded).unwrap();

        assert_eq!(original, decoded);
    }

    #[test]
    fn test_skip_serializing_if_with_named() {
        use serde::{Deserialize, Serialize};

        #[derive(Debug, Serialize, Deserialize)]
        struct Response {
            #[serde(default, skip_serializing_if = "Vec::is_empty")]
            patterns: Vec<String>,
            #[serde(default)]
            applied: bool,
            #[serde(default)]
            files_modified: usize,
        }

        // Simulate SuggestGenerateResponse with list=true returning patterns
        let response = Response {
            patterns: vec!["pattern1".to_string()],
            applied: false,
            files_modified: 0,
        };

        let encoded = rmp_serde::to_vec_named(&response).unwrap();
        let _decoded: Response = rmp_serde::from_slice(&encoded).unwrap();

        // Simulate empty response (patterns skipped)
        let empty_response = Response {
            patterns: vec![],
            applied: false,
            files_modified: 0,
        };

        let encoded = rmp_serde::to_vec_named(&empty_response).unwrap();
        let decoded: Response = rmp_serde::from_slice(&encoded).unwrap();
        assert!(decoded.patterns.is_empty());
    }
}