rama_unix/unix/client/
connector.rs

1//! Unix (domain) socket client module for Rama.
2
3use rama_core::{
4    Context, Service,
5    error::{BoxError, ErrorContext},
6    telemetry::tracing,
7};
8use rama_net::client::EstablishedClientConnection;
9
10use crate::{ClientUnixSocketInfo, UnixSocketInfo, UnixStream};
11use std::{convert::Infallible, path::PathBuf, sync::Arc};
12
13/// A connector which can be used to establish a Unix connection to a server.
14pub struct UnixConnector<ConnectorFactory = (), T = UnixTarget> {
15    connector_factory: ConnectorFactory,
16    target: T,
17}
18
19impl<ConnectorFactory: std::fmt::Debug, UnixTarget: std::fmt::Debug> std::fmt::Debug
20    for UnixConnector<ConnectorFactory, UnixTarget>
21{
22    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
23        f.debug_struct("UnixConnector")
24            .field("connector_factory", &self.connector_factory)
25            .field("target", &self.target)
26            .finish()
27    }
28}
29
30impl<ConnectorFactory: Clone, UnixTarget: Clone> Clone
31    for UnixConnector<ConnectorFactory, UnixTarget>
32{
33    fn clone(&self) -> Self {
34        Self {
35            connector_factory: self.connector_factory.clone(),
36            target: self.target.clone(),
37        }
38    }
39}
40
41#[derive(Debug, Clone)]
42/// Type of [`UnixConnector`] which connects to a fixed [`Path`].
43pub struct UnixTarget(PathBuf);
44
45impl UnixConnector {
46    /// Create a new [`UnixConnector`], which is used to establish a connection to a server
47    /// at a fixed path.
48    ///
49    /// You can use middleware around the [`UnixConnector`]
50    /// or add connection pools, retry logic and more.
51    pub fn fixed(path: impl Into<PathBuf>) -> Self {
52        Self {
53            target: UnixTarget(path.into()),
54            connector_factory: (),
55        }
56    }
57}
58
59impl<T> UnixConnector<(), T> {
60    /// Consume `self` to attach the given `Connector` (a [`UnixStreamConnector`]) as a new [`UnixConnector`].
61    pub fn with_connector<Connector>(
62        self,
63        connector: Connector,
64    ) -> UnixConnector<UnixStreamConnectorCloneFactory<Connector>, T>
65where {
66        UnixConnector {
67            connector_factory: UnixStreamConnectorCloneFactory(connector),
68            target: self.target,
69        }
70    }
71
72    /// Consume `self` to attach the given `Factory` (a [`UnixStreamConnectorFactory`]) as a new [`UnixConnector`].
73    pub fn with_connector_factory<Factory>(self, factory: Factory) -> UnixConnector<Factory, T>
74where {
75        UnixConnector {
76            connector_factory: factory,
77            target: self.target,
78        }
79    }
80}
81
82impl<State, Request, ConnectorFactory> Service<State, Request> for UnixConnector<ConnectorFactory>
83where
84    State: Clone + Send + Sync + 'static,
85    Request: Send + 'static,
86    ConnectorFactory: UnixStreamConnectorFactory<
87            State,
88            Connector: UnixStreamConnector<Error: Into<BoxError> + Send + 'static>,
89            Error: Into<BoxError> + Send + 'static,
90        > + Clone,
91{
92    type Response = EstablishedClientConnection<UnixStream, State, Request>;
93    type Error = BoxError;
94
95    async fn serve(
96        &self,
97        ctx: Context<State>,
98        req: Request,
99    ) -> Result<Self::Response, Self::Error> {
100        let CreatedUnixStreamConnector { mut ctx, connector } = self
101            .connector_factory
102            .make_connector(ctx)
103            .await
104            .map_err(Into::into)?;
105
106        let conn = connector
107            .connect(self.target.0.clone())
108            .await
109            .map_err(Into::into)?;
110
111        ctx.insert(ClientUnixSocketInfo(UnixSocketInfo::new(
112            conn.local_addr()
113                .inspect_err(|err| {
114                    tracing::debug!(
115                        "failed to receive local addr of established connection: {err:?}"
116                    )
117                })
118                .ok(),
119            conn.peer_addr()
120                .context("failed to retrieve peer address of established connection")?,
121        )));
122
123        Ok(EstablishedClientConnection { ctx, req, conn })
124    }
125}
126
127/// Trait used by the `UnixConnector`
128/// to actually establish the [`UnixStream`].
129pub trait UnixStreamConnector: Send + Sync + 'static {
130    /// Type of error that can occurr when establishing the connection failed.
131    type Error;
132
133    /// Connect to the path and return the established [`UnixStream`].
134    fn connect(
135        &self,
136        path: PathBuf,
137    ) -> impl Future<Output = Result<UnixStream, Self::Error>> + Send + '_;
138}
139
140impl UnixStreamConnector for () {
141    type Error = std::io::Error;
142
143    fn connect(
144        &self,
145        path: PathBuf,
146    ) -> impl Future<Output = Result<UnixStream, Self::Error>> + Send + '_ {
147        UnixStream::connect(path)
148    }
149}
150
151impl<T: UnixStreamConnector> UnixStreamConnector for Arc<T> {
152    type Error = T::Error;
153
154    fn connect(
155        &self,
156        path: PathBuf,
157    ) -> impl Future<Output = Result<UnixStream, Self::Error>> + Send + '_ {
158        (**self).connect(path)
159    }
160}
161
162macro_rules! impl_stream_connector_either {
163    ($id:ident, $($param:ident),+ $(,)?) => {
164        impl<$($param),+> UnixStreamConnector for ::rama_core::combinators::$id<$($param),+>
165        where
166            $(
167                $param: UnixStreamConnector<Error: Into<BoxError>>,
168            )+
169        {
170            type Error = BoxError;
171
172            async fn connect(
173                &self,
174                path: PathBuf,
175            ) -> Result<UnixStream, Self::Error> {
176                match self {
177                    $(
178                        ::rama_core::combinators::$id::$param(s) => s.connect(path).await.map_err(Into::into),
179                    )+
180                }
181            }
182        }
183    };
184}
185
186::rama_core::combinators::impl_either!(impl_stream_connector_either);
187
188/// Contains a `Connector` created by a [`UnixStreamConnectorFactory`],
189/// together with the [`Context`] used to create it in relation to.
190pub struct CreatedUnixStreamConnector<State, Connector> {
191    pub ctx: Context<State>,
192    pub connector: Connector,
193}
194
195impl<State, Connector> std::fmt::Debug for CreatedUnixStreamConnector<State, Connector>
196where
197    State: std::fmt::Debug,
198    Connector: std::fmt::Debug,
199{
200    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
201        f.debug_struct("CreatedUnixStreamConnector")
202            .field("ctx", &self.ctx)
203            .field("connector", &self.connector)
204            .finish()
205    }
206}
207
208impl<State, Connector> Clone for CreatedUnixStreamConnector<State, Connector>
209where
210    State: Clone,
211    Connector: Clone,
212{
213    fn clone(&self) -> Self {
214        Self {
215            ctx: self.ctx.clone(),
216            connector: self.connector.clone(),
217        }
218    }
219}
220
221/// Factory to create a [`UnixStreamConnector`]. This is used by the Unix
222/// stream service to create a stream within a specific [`Context`].
223///
224/// In the most simplest case you use a [`UnixStreamConnectorCloneFactory`]
225/// to use a [`Clone`]able [`UnixStreamConnectorCloneFactory`], but in more
226/// advanced cases you can use variants of [`UnixStreamConnector`] specific
227/// to the given contexts.
228pub trait UnixStreamConnectorFactory<State>: Send + Sync + 'static {
229    /// `UnixStreamConnector` created by this [`UnixStreamConnectorFactory`]
230    type Connector: UnixStreamConnector;
231    /// Error returned in case [`UnixStreamConnectorFactory`] was
232    /// not able to create a [`UnixStreamConnector`].
233    type Error;
234
235    /// Try to create a [`UnixStreamConnector`], and return an error or otherwise.
236    fn make_connector(
237        &self,
238        ctx: Context<State>,
239    ) -> impl Future<
240        Output = Result<CreatedUnixStreamConnector<State, Self::Connector>, Self::Error>,
241    > + Send
242    + '_;
243}
244
245impl<State: Send + Sync + 'static> UnixStreamConnectorFactory<State> for () {
246    type Connector = ();
247    type Error = Infallible;
248
249    fn make_connector(
250        &self,
251        ctx: Context<State>,
252    ) -> impl Future<
253        Output = Result<CreatedUnixStreamConnector<State, Self::Connector>, Self::Error>,
254    > + Send
255    + '_ {
256        std::future::ready(Ok(CreatedUnixStreamConnector { ctx, connector: () }))
257    }
258}
259
260/// Utility implementation of a [`UnixStreamConnectorFactory`] which is implemented
261/// to allow one to use a [`Clone`]able [`UnixStreamConnector`] as a [`UnixStreamConnectorFactory`]
262/// by cloning itself.
263///
264/// This struct cannot be created by third party crates
265/// and instead is to be used via other API's provided by this crate.
266pub struct UnixStreamConnectorCloneFactory<C>(pub(super) C);
267
268impl<C> std::fmt::Debug for UnixStreamConnectorCloneFactory<C>
269where
270    C: std::fmt::Debug,
271{
272    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
273        f.debug_tuple("UnixStreamConnectorCloneFactory")
274            .field(&self.0)
275            .finish()
276    }
277}
278
279impl<C> Clone for UnixStreamConnectorCloneFactory<C>
280where
281    C: Clone,
282{
283    fn clone(&self) -> Self {
284        Self(self.0.clone())
285    }
286}
287
288impl<State, C> UnixStreamConnectorFactory<State> for UnixStreamConnectorCloneFactory<C>
289where
290    C: UnixStreamConnector + Clone,
291    State: Send + Sync + 'static,
292{
293    type Connector = C;
294    type Error = Infallible;
295
296    fn make_connector(
297        &self,
298        ctx: Context<State>,
299    ) -> impl Future<
300        Output = Result<CreatedUnixStreamConnector<State, Self::Connector>, Self::Error>,
301    > + Send
302    + '_ {
303        std::future::ready(Ok(CreatedUnixStreamConnector {
304            ctx,
305            connector: self.0.clone(),
306        }))
307    }
308}
309
310impl<State, F> UnixStreamConnectorFactory<State> for Arc<F>
311where
312    F: UnixStreamConnectorFactory<State>,
313    State: Send + Sync + 'static,
314{
315    type Connector = F::Connector;
316    type Error = F::Error;
317
318    fn make_connector(
319        &self,
320        ctx: Context<State>,
321    ) -> impl Future<
322        Output = Result<CreatedUnixStreamConnector<State, Self::Connector>, Self::Error>,
323    > + Send
324    + '_ {
325        (**self).make_connector(ctx)
326    }
327}
328
329macro_rules! impl_stream_connector_factory_either {
330    ($id:ident, $($param:ident),+ $(,)?) => {
331        impl<State, $($param),+> UnixStreamConnectorFactory<State> for ::rama_core::combinators::$id<$($param),+>
332        where
333            State: Send + Sync + 'static,
334            $(
335                $param: UnixStreamConnectorFactory<State, Connector: UnixStreamConnector<Error: Into<BoxError>>, Error: Into<BoxError>>,
336            )+
337        {
338            type Connector = ::rama_core::combinators::$id<$($param::Connector),+>;
339            type Error = BoxError;
340
341            async fn make_connector(
342                &self,
343                ctx: Context<State>,
344            ) -> Result<CreatedUnixStreamConnector<State, Self::Connector>, Self::Error> {
345                match self {
346                    $(
347                        ::rama_core::combinators::$id::$param(s) => match s.make_connector(ctx).await {
348                            Err(e) => Err(e.into()),
349                            Ok(CreatedUnixStreamConnector{ ctx, connector }) => Ok(CreatedUnixStreamConnector{
350                                ctx,
351                                connector: ::rama_core::combinators::$id::$param(connector),
352                            }),
353                        },
354                    )+
355                }
356            }
357        }
358    };
359}
360
361::rama_core::combinators::impl_either!(impl_stream_connector_factory_either);