1use rama_core::{
4 Context, Service,
5 error::{BoxError, ErrorContext},
6 telemetry::tracing,
7};
8use rama_net::client::EstablishedClientConnection;
9
10use crate::{ClientUnixSocketInfo, UnixSocketInfo, UnixStream};
11use std::{convert::Infallible, path::PathBuf, sync::Arc};
12
13pub struct UnixConnector<ConnectorFactory = (), T = UnixTarget> {
15 connector_factory: ConnectorFactory,
16 target: T,
17}
18
19impl<ConnectorFactory: std::fmt::Debug, UnixTarget: std::fmt::Debug> std::fmt::Debug
20 for UnixConnector<ConnectorFactory, UnixTarget>
21{
22 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
23 f.debug_struct("UnixConnector")
24 .field("connector_factory", &self.connector_factory)
25 .field("target", &self.target)
26 .finish()
27 }
28}
29
30impl<ConnectorFactory: Clone, UnixTarget: Clone> Clone
31 for UnixConnector<ConnectorFactory, UnixTarget>
32{
33 fn clone(&self) -> Self {
34 Self {
35 connector_factory: self.connector_factory.clone(),
36 target: self.target.clone(),
37 }
38 }
39}
40
41#[derive(Debug, Clone)]
42pub struct UnixTarget(PathBuf);
44
45impl UnixConnector {
46 pub fn fixed(path: impl Into<PathBuf>) -> Self {
52 Self {
53 target: UnixTarget(path.into()),
54 connector_factory: (),
55 }
56 }
57}
58
59impl<T> UnixConnector<(), T> {
60 pub fn with_connector<Connector>(
62 self,
63 connector: Connector,
64 ) -> UnixConnector<UnixStreamConnectorCloneFactory<Connector>, T>
65where {
66 UnixConnector {
67 connector_factory: UnixStreamConnectorCloneFactory(connector),
68 target: self.target,
69 }
70 }
71
72 pub fn with_connector_factory<Factory>(self, factory: Factory) -> UnixConnector<Factory, T>
74where {
75 UnixConnector {
76 connector_factory: factory,
77 target: self.target,
78 }
79 }
80}
81
82impl<State, Request, ConnectorFactory> Service<State, Request> for UnixConnector<ConnectorFactory>
83where
84 State: Clone + Send + Sync + 'static,
85 Request: Send + 'static,
86 ConnectorFactory: UnixStreamConnectorFactory<
87 State,
88 Connector: UnixStreamConnector<Error: Into<BoxError> + Send + 'static>,
89 Error: Into<BoxError> + Send + 'static,
90 > + Clone,
91{
92 type Response = EstablishedClientConnection<UnixStream, State, Request>;
93 type Error = BoxError;
94
95 async fn serve(
96 &self,
97 ctx: Context<State>,
98 req: Request,
99 ) -> Result<Self::Response, Self::Error> {
100 let CreatedUnixStreamConnector { mut ctx, connector } = self
101 .connector_factory
102 .make_connector(ctx)
103 .await
104 .map_err(Into::into)?;
105
106 let conn = connector
107 .connect(self.target.0.clone())
108 .await
109 .map_err(Into::into)?;
110
111 ctx.insert(ClientUnixSocketInfo(UnixSocketInfo::new(
112 conn.local_addr()
113 .inspect_err(|err| {
114 tracing::debug!(
115 "failed to receive local addr of established connection: {err:?}"
116 )
117 })
118 .ok(),
119 conn.peer_addr()
120 .context("failed to retrieve peer address of established connection")?,
121 )));
122
123 Ok(EstablishedClientConnection { ctx, req, conn })
124 }
125}
126
127pub trait UnixStreamConnector: Send + Sync + 'static {
130 type Error;
132
133 fn connect(
135 &self,
136 path: PathBuf,
137 ) -> impl Future<Output = Result<UnixStream, Self::Error>> + Send + '_;
138}
139
140impl UnixStreamConnector for () {
141 type Error = std::io::Error;
142
143 fn connect(
144 &self,
145 path: PathBuf,
146 ) -> impl Future<Output = Result<UnixStream, Self::Error>> + Send + '_ {
147 UnixStream::connect(path)
148 }
149}
150
151impl<T: UnixStreamConnector> UnixStreamConnector for Arc<T> {
152 type Error = T::Error;
153
154 fn connect(
155 &self,
156 path: PathBuf,
157 ) -> impl Future<Output = Result<UnixStream, Self::Error>> + Send + '_ {
158 (**self).connect(path)
159 }
160}
161
162macro_rules! impl_stream_connector_either {
163 ($id:ident, $($param:ident),+ $(,)?) => {
164 impl<$($param),+> UnixStreamConnector for ::rama_core::combinators::$id<$($param),+>
165 where
166 $(
167 $param: UnixStreamConnector<Error: Into<BoxError>>,
168 )+
169 {
170 type Error = BoxError;
171
172 async fn connect(
173 &self,
174 path: PathBuf,
175 ) -> Result<UnixStream, Self::Error> {
176 match self {
177 $(
178 ::rama_core::combinators::$id::$param(s) => s.connect(path).await.map_err(Into::into),
179 )+
180 }
181 }
182 }
183 };
184}
185
186::rama_core::combinators::impl_either!(impl_stream_connector_either);
187
188pub struct CreatedUnixStreamConnector<State, Connector> {
191 pub ctx: Context<State>,
192 pub connector: Connector,
193}
194
195impl<State, Connector> std::fmt::Debug for CreatedUnixStreamConnector<State, Connector>
196where
197 State: std::fmt::Debug,
198 Connector: std::fmt::Debug,
199{
200 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
201 f.debug_struct("CreatedUnixStreamConnector")
202 .field("ctx", &self.ctx)
203 .field("connector", &self.connector)
204 .finish()
205 }
206}
207
208impl<State, Connector> Clone for CreatedUnixStreamConnector<State, Connector>
209where
210 State: Clone,
211 Connector: Clone,
212{
213 fn clone(&self) -> Self {
214 Self {
215 ctx: self.ctx.clone(),
216 connector: self.connector.clone(),
217 }
218 }
219}
220
221pub trait UnixStreamConnectorFactory<State>: Send + Sync + 'static {
229 type Connector: UnixStreamConnector;
231 type Error;
234
235 fn make_connector(
237 &self,
238 ctx: Context<State>,
239 ) -> impl Future<
240 Output = Result<CreatedUnixStreamConnector<State, Self::Connector>, Self::Error>,
241 > + Send
242 + '_;
243}
244
245impl<State: Send + Sync + 'static> UnixStreamConnectorFactory<State> for () {
246 type Connector = ();
247 type Error = Infallible;
248
249 fn make_connector(
250 &self,
251 ctx: Context<State>,
252 ) -> impl Future<
253 Output = Result<CreatedUnixStreamConnector<State, Self::Connector>, Self::Error>,
254 > + Send
255 + '_ {
256 std::future::ready(Ok(CreatedUnixStreamConnector { ctx, connector: () }))
257 }
258}
259
260pub struct UnixStreamConnectorCloneFactory<C>(pub(super) C);
267
268impl<C> std::fmt::Debug for UnixStreamConnectorCloneFactory<C>
269where
270 C: std::fmt::Debug,
271{
272 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
273 f.debug_tuple("UnixStreamConnectorCloneFactory")
274 .field(&self.0)
275 .finish()
276 }
277}
278
279impl<C> Clone for UnixStreamConnectorCloneFactory<C>
280where
281 C: Clone,
282{
283 fn clone(&self) -> Self {
284 Self(self.0.clone())
285 }
286}
287
288impl<State, C> UnixStreamConnectorFactory<State> for UnixStreamConnectorCloneFactory<C>
289where
290 C: UnixStreamConnector + Clone,
291 State: Send + Sync + 'static,
292{
293 type Connector = C;
294 type Error = Infallible;
295
296 fn make_connector(
297 &self,
298 ctx: Context<State>,
299 ) -> impl Future<
300 Output = Result<CreatedUnixStreamConnector<State, Self::Connector>, Self::Error>,
301 > + Send
302 + '_ {
303 std::future::ready(Ok(CreatedUnixStreamConnector {
304 ctx,
305 connector: self.0.clone(),
306 }))
307 }
308}
309
310impl<State, F> UnixStreamConnectorFactory<State> for Arc<F>
311where
312 F: UnixStreamConnectorFactory<State>,
313 State: Send + Sync + 'static,
314{
315 type Connector = F::Connector;
316 type Error = F::Error;
317
318 fn make_connector(
319 &self,
320 ctx: Context<State>,
321 ) -> impl Future<
322 Output = Result<CreatedUnixStreamConnector<State, Self::Connector>, Self::Error>,
323 > + Send
324 + '_ {
325 (**self).make_connector(ctx)
326 }
327}
328
329macro_rules! impl_stream_connector_factory_either {
330 ($id:ident, $($param:ident),+ $(,)?) => {
331 impl<State, $($param),+> UnixStreamConnectorFactory<State> for ::rama_core::combinators::$id<$($param),+>
332 where
333 State: Send + Sync + 'static,
334 $(
335 $param: UnixStreamConnectorFactory<State, Connector: UnixStreamConnector<Error: Into<BoxError>>, Error: Into<BoxError>>,
336 )+
337 {
338 type Connector = ::rama_core::combinators::$id<$($param::Connector),+>;
339 type Error = BoxError;
340
341 async fn make_connector(
342 &self,
343 ctx: Context<State>,
344 ) -> Result<CreatedUnixStreamConnector<State, Self::Connector>, Self::Error> {
345 match self {
346 $(
347 ::rama_core::combinators::$id::$param(s) => match s.make_connector(ctx).await {
348 Err(e) => Err(e.into()),
349 Ok(CreatedUnixStreamConnector{ ctx, connector }) => Ok(CreatedUnixStreamConnector{
350 ctx,
351 connector: ::rama_core::combinators::$id::$param(connector),
352 }),
353 },
354 )+
355 }
356 }
357 }
358 };
359}
360
361::rama_core::combinators::impl_either!(impl_stream_connector_factory_either);