1use rama_core::{
4    Context, Service,
5    error::{BoxError, ErrorContext},
6    telemetry::tracing,
7};
8use rama_net::client::EstablishedClientConnection;
9
10use crate::{ClientUnixSocketInfo, UnixSocketInfo, UnixStream};
11use std::{convert::Infallible, path::PathBuf, sync::Arc};
12
13pub struct UnixConnector<ConnectorFactory = (), T = UnixTarget> {
15    connector_factory: ConnectorFactory,
16    target: T,
17}
18
19impl<ConnectorFactory: std::fmt::Debug, UnixTarget: std::fmt::Debug> std::fmt::Debug
20    for UnixConnector<ConnectorFactory, UnixTarget>
21{
22    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
23        f.debug_struct("UnixConnector")
24            .field("connector_factory", &self.connector_factory)
25            .field("target", &self.target)
26            .finish()
27    }
28}
29
30impl<ConnectorFactory: Clone, UnixTarget: Clone> Clone
31    for UnixConnector<ConnectorFactory, UnixTarget>
32{
33    fn clone(&self) -> Self {
34        Self {
35            connector_factory: self.connector_factory.clone(),
36            target: self.target.clone(),
37        }
38    }
39}
40
41#[derive(Debug, Clone)]
42pub struct UnixTarget(PathBuf);
44
45impl UnixConnector {
46    pub fn fixed(path: impl Into<PathBuf>) -> Self {
52        Self {
53            target: UnixTarget(path.into()),
54            connector_factory: (),
55        }
56    }
57}
58
59impl<T> UnixConnector<(), T> {
60    pub fn with_connector<Connector>(
62        self,
63        connector: Connector,
64    ) -> UnixConnector<UnixStreamConnectorCloneFactory<Connector>, T>
65where {
66        UnixConnector {
67            connector_factory: UnixStreamConnectorCloneFactory(connector),
68            target: self.target,
69        }
70    }
71
72    pub fn with_connector_factory<Factory>(self, factory: Factory) -> UnixConnector<Factory, T>
74where {
75        UnixConnector {
76            connector_factory: factory,
77            target: self.target,
78        }
79    }
80}
81
82impl<State, Request, ConnectorFactory> Service<State, Request> for UnixConnector<ConnectorFactory>
83where
84    State: Clone + Send + Sync + 'static,
85    Request: Send + 'static,
86    ConnectorFactory: UnixStreamConnectorFactory<
87            State,
88            Connector: UnixStreamConnector<Error: Into<BoxError> + Send + 'static>,
89            Error: Into<BoxError> + Send + 'static,
90        > + Clone,
91{
92    type Response = EstablishedClientConnection<UnixStream, State, Request>;
93    type Error = BoxError;
94
95    async fn serve(
96        &self,
97        ctx: Context<State>,
98        req: Request,
99    ) -> Result<Self::Response, Self::Error> {
100        let CreatedUnixStreamConnector { mut ctx, connector } = self
101            .connector_factory
102            .make_connector(ctx)
103            .await
104            .map_err(Into::into)?;
105
106        let conn = connector
107            .connect(self.target.0.clone())
108            .await
109            .map_err(Into::into)?;
110
111        ctx.insert(ClientUnixSocketInfo(UnixSocketInfo::new(
112            conn.local_addr()
113                .inspect_err(|err| {
114                    tracing::debug!(
115                        "failed to receive local addr of established connection: {err:?}"
116                    )
117                })
118                .ok(),
119            conn.peer_addr()
120                .context("failed to retrieve peer address of established connection")?,
121        )));
122
123        Ok(EstablishedClientConnection { ctx, req, conn })
124    }
125}
126
127pub trait UnixStreamConnector: Send + Sync + 'static {
130    type Error;
132
133    fn connect(
135        &self,
136        path: PathBuf,
137    ) -> impl Future<Output = Result<UnixStream, Self::Error>> + Send + '_;
138}
139
140impl UnixStreamConnector for () {
141    type Error = std::io::Error;
142
143    fn connect(
144        &self,
145        path: PathBuf,
146    ) -> impl Future<Output = Result<UnixStream, Self::Error>> + Send + '_ {
147        UnixStream::connect(path)
148    }
149}
150
151impl<T: UnixStreamConnector> UnixStreamConnector for Arc<T> {
152    type Error = T::Error;
153
154    fn connect(
155        &self,
156        path: PathBuf,
157    ) -> impl Future<Output = Result<UnixStream, Self::Error>> + Send + '_ {
158        (**self).connect(path)
159    }
160}
161
162macro_rules! impl_stream_connector_either {
163    ($id:ident, $($param:ident),+ $(,)?) => {
164        impl<$($param),+> UnixStreamConnector for ::rama_core::combinators::$id<$($param),+>
165        where
166            $(
167                $param: UnixStreamConnector<Error: Into<BoxError>>,
168            )+
169        {
170            type Error = BoxError;
171
172            async fn connect(
173                &self,
174                path: PathBuf,
175            ) -> Result<UnixStream, Self::Error> {
176                match self {
177                    $(
178                        ::rama_core::combinators::$id::$param(s) => s.connect(path).await.map_err(Into::into),
179                    )+
180                }
181            }
182        }
183    };
184}
185
186::rama_core::combinators::impl_either!(impl_stream_connector_either);
187
188pub struct CreatedUnixStreamConnector<State, Connector> {
191    pub ctx: Context<State>,
192    pub connector: Connector,
193}
194
195impl<State, Connector> std::fmt::Debug for CreatedUnixStreamConnector<State, Connector>
196where
197    State: std::fmt::Debug,
198    Connector: std::fmt::Debug,
199{
200    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
201        f.debug_struct("CreatedUnixStreamConnector")
202            .field("ctx", &self.ctx)
203            .field("connector", &self.connector)
204            .finish()
205    }
206}
207
208impl<State, Connector> Clone for CreatedUnixStreamConnector<State, Connector>
209where
210    State: Clone,
211    Connector: Clone,
212{
213    fn clone(&self) -> Self {
214        Self {
215            ctx: self.ctx.clone(),
216            connector: self.connector.clone(),
217        }
218    }
219}
220
221pub trait UnixStreamConnectorFactory<State>: Send + Sync + 'static {
229    type Connector: UnixStreamConnector;
231    type Error;
234
235    fn make_connector(
237        &self,
238        ctx: Context<State>,
239    ) -> impl Future<
240        Output = Result<CreatedUnixStreamConnector<State, Self::Connector>, Self::Error>,
241    > + Send
242    + '_;
243}
244
245impl<State: Send + Sync + 'static> UnixStreamConnectorFactory<State> for () {
246    type Connector = ();
247    type Error = Infallible;
248
249    fn make_connector(
250        &self,
251        ctx: Context<State>,
252    ) -> impl Future<
253        Output = Result<CreatedUnixStreamConnector<State, Self::Connector>, Self::Error>,
254    > + Send
255    + '_ {
256        std::future::ready(Ok(CreatedUnixStreamConnector { ctx, connector: () }))
257    }
258}
259
260pub struct UnixStreamConnectorCloneFactory<C>(pub(super) C);
267
268impl<C> std::fmt::Debug for UnixStreamConnectorCloneFactory<C>
269where
270    C: std::fmt::Debug,
271{
272    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
273        f.debug_tuple("UnixStreamConnectorCloneFactory")
274            .field(&self.0)
275            .finish()
276    }
277}
278
279impl<C> Clone for UnixStreamConnectorCloneFactory<C>
280where
281    C: Clone,
282{
283    fn clone(&self) -> Self {
284        Self(self.0.clone())
285    }
286}
287
288impl<State, C> UnixStreamConnectorFactory<State> for UnixStreamConnectorCloneFactory<C>
289where
290    C: UnixStreamConnector + Clone,
291    State: Send + Sync + 'static,
292{
293    type Connector = C;
294    type Error = Infallible;
295
296    fn make_connector(
297        &self,
298        ctx: Context<State>,
299    ) -> impl Future<
300        Output = Result<CreatedUnixStreamConnector<State, Self::Connector>, Self::Error>,
301    > + Send
302    + '_ {
303        std::future::ready(Ok(CreatedUnixStreamConnector {
304            ctx,
305            connector: self.0.clone(),
306        }))
307    }
308}
309
310impl<State, F> UnixStreamConnectorFactory<State> for Arc<F>
311where
312    F: UnixStreamConnectorFactory<State>,
313    State: Send + Sync + 'static,
314{
315    type Connector = F::Connector;
316    type Error = F::Error;
317
318    fn make_connector(
319        &self,
320        ctx: Context<State>,
321    ) -> impl Future<
322        Output = Result<CreatedUnixStreamConnector<State, Self::Connector>, Self::Error>,
323    > + Send
324    + '_ {
325        (**self).make_connector(ctx)
326    }
327}
328
329macro_rules! impl_stream_connector_factory_either {
330    ($id:ident, $($param:ident),+ $(,)?) => {
331        impl<State, $($param),+> UnixStreamConnectorFactory<State> for ::rama_core::combinators::$id<$($param),+>
332        where
333            State: Send + Sync + 'static,
334            $(
335                $param: UnixStreamConnectorFactory<State, Connector: UnixStreamConnector<Error: Into<BoxError>>, Error: Into<BoxError>>,
336            )+
337        {
338            type Connector = ::rama_core::combinators::$id<$($param::Connector),+>;
339            type Error = BoxError;
340
341            async fn make_connector(
342                &self,
343                ctx: Context<State>,
344            ) -> Result<CreatedUnixStreamConnector<State, Self::Connector>, Self::Error> {
345                match self {
346                    $(
347                        ::rama_core::combinators::$id::$param(s) => match s.make_connector(ctx).await {
348                            Err(e) => Err(e.into()),
349                            Ok(CreatedUnixStreamConnector{ ctx, connector }) => Ok(CreatedUnixStreamConnector{
350                                ctx,
351                                connector: ::rama_core::combinators::$id::$param(connector),
352                            }),
353                        },
354                    )+
355                }
356            }
357        }
358    };
359}
360
361::rama_core::combinators::impl_either!(impl_stream_connector_factory_either);