ryo_app/codec.rs
1//! tarpc transport codec for RYO RPC communication.
2//!
3//! # Why MessagePackNamed?
4//!
5//! RYO uses MessagePack for RPC serialization via tarpc. There are two serialization modes:
6//!
7//! | Mode | Serialization | `skip_serializing_if` |
8//! |------|--------------|----------------------|
9//! | Array-based (default) | `[value1, value2, ...]` | **Incompatible** |
10//! | Named (map-based) | `{"field1": value1, ...}` | Compatible |
11//!
12//! Many response types use `#[serde(skip_serializing_if = "...")]` to reduce payload size.
13//! This requires **named serialization** where fields are identified by name, not position.
14//!
15//! Using array-based serialization with `skip_serializing_if` causes deserialization failures:
16//! ```text
17//! invalid type: boolean `false`, expected a sequence
18//! ```
19//!
20//! # Usage
21//!
22//! Always use the helper functions to create transports:
23//!
24//! ```ignore
25//! use ryo_app::codec::create_client_transport;
26//! use tokio::net::UnixStream;
27//!
28//! let stream = UnixStream::connect(socket_path).await?;
29//! let transport = create_client_transport(stream);
30//! let client = RyoServiceClient::new(config, transport).spawn();
31//! ```
32//!
33//! # Important
34//!
35//! **DO NOT** use `tokio_serde::formats::MessagePack::default()` directly.
36//! It uses array-based serialization which is incompatible with `skip_serializing_if`.
37
38use serde::{de::DeserializeOwned, Serialize};
39use std::io;
40use std::marker::PhantomData;
41use std::pin::Pin;
42use tokio_util::bytes::{Bytes, BytesMut};
43
44/// MessagePack codec with named (map-based) serialization.
45///
46/// This codec uses `rmp_serde::to_vec_named` for serialization, which produces
47/// map-based output compatible with `skip_serializing_if` attributes.
48///
49/// # Example
50///
51/// ```ignore
52/// let transport = tarpc::serde_transport::new(
53/// tokio_util::codec::LengthDelimitedCodec::builder().new_framed(stream),
54/// MessagePackNamed::default(),
55/// );
56/// ```
57#[derive(Debug)]
58pub struct MessagePackNamed<Item, SinkItem> {
59 _item: PhantomData<fn() -> Item>,
60 _sink_item: PhantomData<fn(SinkItem)>,
61}
62
63impl<Item, SinkItem> Default for MessagePackNamed<Item, SinkItem> {
64 fn default() -> Self {
65 Self {
66 _item: PhantomData,
67 _sink_item: PhantomData,
68 }
69 }
70}
71
72impl<Item, SinkItem> Clone for MessagePackNamed<Item, SinkItem> {
73 fn clone(&self) -> Self {
74 Self::default()
75 }
76}
77
78impl<Item, SinkItem> tokio_serde::Deserializer<Item> for MessagePackNamed<Item, SinkItem>
79where
80 Item: DeserializeOwned,
81{
82 type Error = io::Error;
83
84 fn deserialize(self: Pin<&mut Self>, src: &BytesMut) -> Result<Item, Self::Error> {
85 rmp_serde::from_slice(src).map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e))
86 }
87}
88
89impl<Item, SinkItem> tokio_serde::Serializer<SinkItem> for MessagePackNamed<Item, SinkItem>
90where
91 SinkItem: Serialize,
92{
93 type Error = io::Error;
94
95 fn serialize(self: Pin<&mut Self>, item: &SinkItem) -> Result<Bytes, Self::Error> {
96 rmp_serde::to_vec_named(item)
97 .map(Into::into)
98 .map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e))
99 }
100}
101
102// ============================================================================
103// Transport Factory Functions
104// ============================================================================
105
106use tokio::io::{AsyncRead, AsyncWrite};
107use tokio_util::codec::LengthDelimitedCodec;
108
109/// Create a framed transport with LengthDelimitedCodec.
110///
111/// Helper to reduce boilerplate when creating transports.
112fn framed<T: AsyncRead + AsyncWrite>(
113 stream: T,
114) -> tokio_util::codec::Framed<T, LengthDelimitedCodec> {
115 LengthDelimitedCodec::builder().new_framed(stream)
116}
117
118/// Create a tarpc transport with the correct codec for client-side use.
119///
120/// This is the recommended way to create a client transport. It ensures:
121/// - Named MessagePack serialization (compatible with `skip_serializing_if`)
122/// - Proper framing with `LengthDelimitedCodec`
123///
124/// # Example
125///
126/// ```ignore
127/// use ryo_app::codec::create_client_transport;
128/// use tokio::net::UnixStream;
129///
130/// let stream = UnixStream::connect(socket_path).await?;
131/// let transport = create_client_transport(stream);
132/// let client = RyoServiceClient::new(config, transport).spawn();
133/// ```
134pub fn create_client_transport<T: AsyncRead + AsyncWrite + Unpin>(
135 stream: T,
136) -> impl futures::Stream<
137 Item = Result<tarpc::Response<crate::service::RyoServiceResponse>, std::io::Error>,
138> + futures::Sink<
139 tarpc::ClientMessage<crate::service::RyoServiceRequest>,
140 Error = std::io::Error,
141> {
142 tarpc::serde_transport::new(framed(stream), MessagePackNamed::default())
143}
144
145/// Create a tarpc transport with the correct codec for server-side use.
146///
147/// This is the recommended way to create a server transport. It ensures:
148/// - Named MessagePack serialization (compatible with `skip_serializing_if`)
149/// - Proper framing with `LengthDelimitedCodec`
150///
151/// # Example
152///
153/// ```ignore
154/// use ryo_app::codec::create_server_transport;
155/// use tokio::net::UnixListener;
156///
157/// let (stream, _) = listener.accept().await?;
158/// let transport = create_server_transport(stream);
159/// let channel = tarpc::server::BaseChannel::with_defaults(transport);
160/// ```
161pub fn create_server_transport<T: AsyncRead + AsyncWrite + Unpin>(
162 stream: T,
163) -> impl futures::Stream<
164 Item = Result<tarpc::ClientMessage<crate::service::RyoServiceRequest>, std::io::Error>,
165> + futures::Sink<tarpc::Response<crate::service::RyoServiceResponse>, Error = std::io::Error> {
166 tarpc::serde_transport::new(framed(stream), MessagePackNamed::default())
167}
168
169#[cfg(test)]
170mod tests {
171
172 #[test]
173 fn test_messagepack_named_roundtrip() {
174 use serde::{Deserialize, Serialize};
175
176 #[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
177 struct TestStruct {
178 name: String,
179 #[serde(default, skip_serializing_if = "Vec::is_empty")]
180 items: Vec<String>,
181 #[serde(default)]
182 count: usize,
183 }
184
185 let original = TestStruct {
186 name: "test".to_string(),
187 items: vec![], // Will be skipped during serialization
188 count: 42,
189 };
190
191 // Serialize with named codec
192 let encoded = rmp_serde::to_vec_named(&original).unwrap();
193
194 // Deserialize
195 let decoded: TestStruct = rmp_serde::from_slice(&encoded).unwrap();
196
197 assert_eq!(original, decoded);
198 }
199
200 #[test]
201 fn test_skip_serializing_if_with_named() {
202 use serde::{Deserialize, Serialize};
203
204 #[derive(Debug, Serialize, Deserialize)]
205 struct Response {
206 #[serde(default, skip_serializing_if = "Vec::is_empty")]
207 patterns: Vec<String>,
208 #[serde(default)]
209 applied: bool,
210 #[serde(default)]
211 files_modified: usize,
212 }
213
214 // Simulate SuggestGenerateResponse with list=true returning patterns
215 let response = Response {
216 patterns: vec!["pattern1".to_string()],
217 applied: false,
218 files_modified: 0,
219 };
220
221 let encoded = rmp_serde::to_vec_named(&response).unwrap();
222 let _decoded: Response = rmp_serde::from_slice(&encoded).unwrap();
223
224 // Simulate empty response (patterns skipped)
225 let empty_response = Response {
226 patterns: vec![],
227 applied: false,
228 files_modified: 0,
229 };
230
231 let encoded = rmp_serde::to_vec_named(&empty_response).unwrap();
232 let decoded: Response = rmp_serde::from_slice(&encoded).unwrap();
233 assert!(decoded.patterns.is_empty());
234 }
235}