datacake_rpc/
request.rs

1use std::borrow::Cow;
2use std::fmt::{Debug, Formatter};
3use std::net::SocketAddr;
4use std::ops::Deref;
5
6use async_trait::async_trait;
7use rkyv::validation::validators::DefaultValidator;
8use rkyv::{Archive, CheckBytes, Deserialize, Serialize};
9
10use crate::view::DataView;
11use crate::{Body, Status};
12
13#[async_trait]
14/// The deserializer trait for converting the request body into
15/// the desired type specified by [Self::Content].
16///
17/// This trait is automatically implemented for the [Body] type
18/// and any type implementing [rkyv]'s (de)serializer traits.
19pub trait RequestContents {
20    /// The deserialized message type.
21    type Content: Send + Sized + 'static;
22
23    async fn from_body(body: Body) -> Result<Self::Content, Status>;
24}
25
26#[async_trait]
27impl RequestContents for Body {
28    type Content = Self;
29
30    async fn from_body(body: Body) -> Result<Self, Status> {
31        Ok(body)
32    }
33}
34
35#[async_trait]
36impl<Msg> RequestContents for Msg
37where
38    Msg: Archive + Send + Sync + 'static,
39    Msg::Archived: CheckBytes<DefaultValidator<'static>> + Send + Sync + 'static,
40{
41    type Content = DataView<Self>;
42
43    async fn from_body(body: Body) -> Result<Self::Content, Status> {
44        let bytes = crate::utils::to_aligned(body.0)
45            .await
46            .map_err(Status::internal)?;
47
48        DataView::using(bytes).map_err(|_| Status::invalid())
49    }
50}
51
52#[repr(C)]
53#[derive(Serialize, Deserialize, Archive, PartialEq)]
54#[cfg_attr(test, derive(Debug))]
55#[archive(check_bytes)]
56pub struct MessageMetadata {
57    #[with(rkyv::with::AsOwned)]
58    /// The name of the service being targeted.
59    pub(crate) service_name: Cow<'static, str>,
60    #[with(rkyv::with::AsOwned)]
61    /// The message name/path.
62    pub(crate) path: Cow<'static, str>,
63}
64
65/// A zero-copy view of the message data and any additional metadata provided
66/// by the RPC system.
67///
68/// The request contains the original request buffer which is used to create
69/// the 'view' of the given message type.
70pub struct Request<Msg>
71where
72    Msg: RequestContents,
73{
74    pub(crate) remote_addr: SocketAddr,
75
76    // A small hack to stop linters miss-guiding users
77    // into thinking their messages are `!Sized` when in fact they are.
78    // We don't want to box in release mode however.
79    #[cfg(debug_assertions)]
80    pub(crate) view: Box<Msg::Content>,
81    #[cfg(not(debug_assertions))]
82    pub(crate) view: Msg::Content,
83}
84
85impl<Msg> Debug for Request<Msg>
86where
87    Msg: RequestContents,
88    Msg::Content: Debug,
89{
90    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
91        f.debug_struct("Request")
92            .field("view", &self.view)
93            .field("remote_addr", &self.remote_addr)
94            .finish()
95    }
96}
97
98impl<Msg> Deref for Request<Msg>
99where
100    Msg: RequestContents,
101{
102    type Target = Msg::Content;
103
104    fn deref(&self) -> &Self::Target {
105        &self.view
106    }
107}
108
109impl<Msg> Request<Msg>
110where
111    Msg: RequestContents,
112{
113    pub(crate) fn new(remote_addr: SocketAddr, view: Msg::Content) -> Self {
114        Self {
115            remote_addr,
116            #[cfg(debug_assertions)]
117            view: Box::new(view),
118            #[cfg(not(debug_assertions))]
119            view,
120        }
121    }
122
123    #[cfg(debug_assertions)]
124    /// Consumes the request into the value of the message.
125    pub fn into_inner(self) -> Msg::Content {
126        *self.view
127    }
128
129    #[cfg(not(debug_assertions))]
130    /// Consumes the request into the value of the message.
131    pub fn into_inner(self) -> Msg::Content {
132        self.view
133    }
134
135    /// The remote address of the incoming message.
136    pub fn remote_addr(&self) -> SocketAddr {
137        self.remote_addr
138    }
139}
140
141#[cfg(feature = "test-utils")]
142impl<Msg> Request<Msg>
143where
144    Msg: RequestContents
145        + rkyv::Serialize<
146            rkyv::ser::serializers::AllocSerializer<{ crate::SCRATCH_SPACE }>,
147        >,
148{
149    /// A test utility for creating a mocked request.
150    ///
151    /// This takes the owned value of the msg and acts like the target request.
152    ///
153    /// This should be used for testing only.
154    pub async fn using_owned(msg: Msg) -> Self {
155        let bytes = rkyv::to_bytes(&msg).unwrap();
156        let contents = Msg::from_body(Body::from(bytes.to_vec())).await.unwrap();
157
158        use std::net::{Ipv4Addr, SocketAddrV4};
159
160        let addr = SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::from([127, 0, 0, 1]), 80));
161        Self::new(addr, contents)
162    }
163}
164
165#[cfg(test)]
166mod tests {
167    use super::*;
168
169    #[test]
170    fn test_metadata() {
171        let meta = MessageMetadata {
172            service_name: Cow::Borrowed("test"),
173            path: Cow::Borrowed("demo"),
174        };
175
176        let bytes = rkyv::to_bytes::<_, 1024>(&meta).expect("Serialize");
177        let copy: MessageMetadata = rkyv::from_bytes(&bytes).expect("Deserialize");
178        assert_eq!(meta, copy, "Deserialized value should match");
179    }
180
181    #[test]
182    fn test_request() {
183        let msg = MessageMetadata {
184            service_name: Cow::Borrowed("test"),
185            path: Cow::Borrowed("demo"),
186        };
187
188        let addr = "127.0.0.1:8000".parse().unwrap();
189        let bytes = rkyv::to_bytes::<_, 1024>(&msg).expect("Serialize");
190        let view: DataView<MessageMetadata, _> =
191            DataView::using(bytes).expect("Create view");
192        let req = Request::<MessageMetadata>::new(addr, view);
193        assert_eq!(req.remote_addr(), addr, "Remote addr should match.");
194        assert_eq!(
195            req.to_owned().unwrap(),
196            msg,
197            "Deserialized value should match."
198        );
199    }
200}