Skip to main content

openwire_core/
transport.rs

1use std::net::SocketAddr;
2use std::task::{Context, Poll};
3use std::time::Duration;
4
5use hyper::Uri;
6use tower::util::BoxCloneSyncService;
7use tower::{Service, ServiceExt};
8
9use crate::{BoxFuture, CallContext, WireError};
10
11#[derive(Debug, Clone)]
12pub struct ConnectionInfo {
13    pub id: crate::ConnectionId,
14    pub remote_addr: Option<SocketAddr>,
15    pub local_addr: Option<SocketAddr>,
16    pub tls: bool,
17}
18
19#[derive(Debug, Clone, Default, PartialEq, Eq)]
20pub struct CoalescingInfo {
21    pub verified_server_names: Vec<String>,
22}
23
24impl CoalescingInfo {
25    pub fn new(verified_server_names: Vec<String>) -> Self {
26        Self {
27            verified_server_names,
28        }
29    }
30
31    pub fn is_empty(&self) -> bool {
32        self.verified_server_names.is_empty()
33    }
34}
35
36#[derive(Debug, Clone, Default)]
37pub struct Connected {
38    info: Option<ConnectionInfo>,
39    coalescing: CoalescingInfo,
40    proxied: bool,
41    negotiated_h2: bool,
42}
43
44impl Connected {
45    pub fn new() -> Self {
46        Self::default()
47    }
48
49    pub fn info(mut self, info: ConnectionInfo) -> Self {
50        self.info = Some(info);
51        self
52    }
53
54    pub fn coalescing(mut self, coalescing: CoalescingInfo) -> Self {
55        self.coalescing = coalescing;
56        self
57    }
58
59    pub fn proxy(mut self, proxied: bool) -> Self {
60        self.proxied = proxied;
61        self
62    }
63
64    pub fn negotiated_h2(mut self, negotiated_h2: bool) -> Self {
65        self.negotiated_h2 = negotiated_h2;
66        self
67    }
68
69    pub fn is_proxied(&self) -> bool {
70        self.proxied
71    }
72
73    pub fn is_negotiated_h2(&self) -> bool {
74        self.negotiated_h2
75    }
76
77    pub fn connection_info(&self) -> Option<&ConnectionInfo> {
78        self.info.as_ref()
79    }
80
81    pub fn connection_info_or_default(&self) -> ConnectionInfo {
82        self.info.clone().unwrap_or_else(|| ConnectionInfo {
83            id: crate::next_connection_id(),
84            remote_addr: None,
85            local_addr: None,
86            tls: false,
87        })
88    }
89
90    pub fn coalescing_info(&self) -> &CoalescingInfo {
91        &self.coalescing
92    }
93}
94
95pub trait Connection {
96    fn connected(&self) -> Connected;
97}
98
99pub trait ConnectionIo:
100    hyper::rt::Read + hyper::rt::Write + Connection + Unpin + Send + 'static
101{
102}
103
104impl<T> ConnectionIo for T where
105    T: hyper::rt::Read + hyper::rt::Write + Connection + Unpin + Send + 'static
106{
107}
108
109pub type BoxConnection = Box<dyn ConnectionIo>;
110
111impl Connection for BoxConnection {
112    fn connected(&self) -> Connected {
113        (**self).connected()
114    }
115}
116
117#[derive(Clone, Debug)]
118pub struct DnsRequest {
119    pub ctx: CallContext,
120    pub host: String,
121    pub port: u16,
122}
123
124impl DnsRequest {
125    pub fn new(ctx: CallContext, host: String, port: u16) -> Self {
126        Self { ctx, host, port }
127    }
128}
129
130#[derive(Clone, Debug)]
131pub struct TcpConnectRequest {
132    pub ctx: CallContext,
133    pub addr: SocketAddr,
134    pub timeout: Option<Duration>,
135}
136
137impl TcpConnectRequest {
138    pub fn new(ctx: CallContext, addr: SocketAddr, timeout: Option<Duration>) -> Self {
139        Self { ctx, addr, timeout }
140    }
141}
142
143/// Per-call hint for how a TLS connector should advertise ALPN.
144///
145/// `Auto` leaves protocol selection to the connector's default behavior.
146/// `Http1Only` is used by WebSocket-over-TLS calls, which must complete an
147/// HTTP/1.1 upgrade handshake and therefore must not negotiate HTTP/2.
148#[derive(Clone, Copy, Debug, Default, PartialEq, Eq)]
149pub enum TlsAlpnPreference {
150    #[default]
151    Auto,
152    Http1Only,
153}
154
155pub struct TlsConnectRequest {
156    pub ctx: CallContext,
157    pub uri: Uri,
158    pub stream: BoxConnection,
159}
160
161impl TlsConnectRequest {
162    pub fn new(ctx: CallContext, uri: Uri, stream: BoxConnection) -> Self {
163        Self { ctx, uri, stream }
164    }
165}
166
167pub type BoxDnsService = BoxCloneSyncService<DnsRequest, Vec<SocketAddr>, WireError>;
168pub type BoxTcpService = BoxCloneSyncService<TcpConnectRequest, BoxConnection, WireError>;
169pub type BoxTlsService = BoxCloneSyncService<TlsConnectRequest, BoxConnection, WireError>;
170
171pub trait DnsResolver: Send + Sync + 'static {
172    fn resolve(
173        &self,
174        ctx: CallContext,
175        host: String,
176        port: u16,
177    ) -> BoxFuture<Result<Vec<SocketAddr>, WireError>>;
178}
179
180pub trait TcpConnector: Send + Sync + 'static {
181    fn connect(
182        &self,
183        ctx: CallContext,
184        addr: SocketAddr,
185        timeout: Option<Duration>,
186    ) -> BoxFuture<Result<BoxConnection, WireError>>;
187}
188
189pub trait TlsConnector: Send + Sync + 'static {
190    fn connect(
191        &self,
192        ctx: CallContext,
193        uri: Uri,
194        stream: BoxConnection,
195    ) -> BoxFuture<Result<BoxConnection, WireError>>;
196}
197
198#[derive(Clone)]
199pub struct TowerDnsResolver(BoxDnsService);
200
201impl TowerDnsResolver {
202    pub fn new<S>(service: S) -> Self
203    where
204        S: Service<DnsRequest, Response = Vec<SocketAddr>, Error = WireError>
205            + Clone
206            + Send
207            + Sync
208            + 'static,
209        S::Future: Send + 'static,
210    {
211        Self(BoxCloneSyncService::new(service))
212    }
213
214    pub fn service(&self) -> BoxDnsService {
215        self.0.clone()
216    }
217}
218
219impl DnsResolver for TowerDnsResolver {
220    fn resolve(
221        &self,
222        ctx: CallContext,
223        host: String,
224        port: u16,
225    ) -> BoxFuture<Result<Vec<SocketAddr>, WireError>> {
226        let service = self.0.clone();
227        Box::pin(async move { service.oneshot(DnsRequest::new(ctx, host, port)).await })
228    }
229}
230
231#[derive(Clone)]
232pub struct TowerTcpConnector(BoxTcpService);
233
234impl TowerTcpConnector {
235    pub fn new<S>(service: S) -> Self
236    where
237        S: Service<TcpConnectRequest, Response = BoxConnection, Error = WireError>
238            + Clone
239            + Send
240            + Sync
241            + 'static,
242        S::Future: Send + 'static,
243    {
244        Self(BoxCloneSyncService::new(service))
245    }
246
247    pub fn service(&self) -> BoxTcpService {
248        self.0.clone()
249    }
250}
251
252impl TcpConnector for TowerTcpConnector {
253    fn connect(
254        &self,
255        ctx: CallContext,
256        addr: SocketAddr,
257        timeout: Option<Duration>,
258    ) -> BoxFuture<Result<BoxConnection, WireError>> {
259        let service = self.0.clone();
260        Box::pin(async move {
261            service
262                .oneshot(TcpConnectRequest::new(ctx, addr, timeout))
263                .await
264        })
265    }
266}
267
268#[derive(Clone)]
269pub struct TowerTlsConnector(BoxTlsService);
270
271impl TowerTlsConnector {
272    pub fn new<S>(service: S) -> Self
273    where
274        S: Service<TlsConnectRequest, Response = BoxConnection, Error = WireError>
275            + Clone
276            + Send
277            + Sync
278            + 'static,
279        S::Future: Send + 'static,
280    {
281        Self(BoxCloneSyncService::new(service))
282    }
283
284    pub fn service(&self) -> BoxTlsService {
285        self.0.clone()
286    }
287}
288
289impl TlsConnector for TowerTlsConnector {
290    fn connect(
291        &self,
292        ctx: CallContext,
293        uri: Uri,
294        stream: BoxConnection,
295    ) -> BoxFuture<Result<BoxConnection, WireError>> {
296        let service = self.0.clone();
297        Box::pin(async move {
298            service
299                .oneshot(TlsConnectRequest::new(ctx, uri, stream))
300                .await
301        })
302    }
303}
304
305#[derive(Clone)]
306pub struct DnsResolverService<R> {
307    resolver: R,
308}
309
310impl<R> DnsResolverService<R> {
311    pub fn new(resolver: R) -> Self {
312        Self { resolver }
313    }
314
315    pub fn into_inner(self) -> R {
316        self.resolver
317    }
318}
319
320impl<R> Service<DnsRequest> for DnsResolverService<R>
321where
322    R: DnsResolver,
323{
324    type Response = Vec<SocketAddr>;
325    type Error = WireError;
326    type Future = BoxFuture<Result<Self::Response, Self::Error>>;
327
328    fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
329        Poll::Ready(Ok(()))
330    }
331
332    fn call(&mut self, request: DnsRequest) -> Self::Future {
333        self.resolver
334            .resolve(request.ctx, request.host, request.port)
335    }
336}
337
338#[derive(Clone)]
339pub struct TcpConnectorService<C> {
340    connector: C,
341}
342
343impl<C> TcpConnectorService<C> {
344    pub fn new(connector: C) -> Self {
345        Self { connector }
346    }
347
348    pub fn into_inner(self) -> C {
349        self.connector
350    }
351}
352
353impl<C> Service<TcpConnectRequest> for TcpConnectorService<C>
354where
355    C: TcpConnector,
356{
357    type Response = BoxConnection;
358    type Error = WireError;
359    type Future = BoxFuture<Result<Self::Response, Self::Error>>;
360
361    fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
362        Poll::Ready(Ok(()))
363    }
364
365    fn call(&mut self, request: TcpConnectRequest) -> Self::Future {
366        self.connector
367            .connect(request.ctx, request.addr, request.timeout)
368    }
369}
370
371#[derive(Clone)]
372pub struct TlsConnectorService<C> {
373    connector: C,
374}
375
376impl<C> TlsConnectorService<C> {
377    pub fn new(connector: C) -> Self {
378        Self { connector }
379    }
380
381    pub fn into_inner(self) -> C {
382        self.connector
383    }
384}
385
386impl<C> Service<TlsConnectRequest> for TlsConnectorService<C>
387where
388    C: TlsConnector,
389{
390    type Response = BoxConnection;
391    type Error = WireError;
392    type Future = BoxFuture<Result<Self::Response, Self::Error>>;
393
394    fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
395        Poll::Ready(Ok(()))
396    }
397
398    fn call(&mut self, request: TlsConnectRequest) -> Self::Future {
399        self.connector
400            .connect(request.ctx, request.uri, request.stream)
401    }
402}
403
404#[cfg(test)]
405mod tests {
406    use std::io;
407    use std::pin::Pin;
408    use std::sync::{Arc, Mutex};
409    use std::task::{Context, Poll};
410    use std::time::Duration;
411
412    use http::Request;
413    use hyper::Uri;
414    use tower::{service_fn, Service, ServiceExt};
415
416    use super::{
417        BoxConnection, Connected, Connection, ConnectionInfo, DnsRequest, DnsResolver,
418        DnsResolverService, TcpConnectRequest, TcpConnector, TcpConnectorService,
419        TlsConnectRequest, TlsConnector, TlsConnectorService, TowerDnsResolver, TowerTcpConnector,
420        TowerTlsConnector,
421    };
422    use crate::{CallContext, NoopEventListenerFactory, RequestBody, WireError};
423
424    fn make_call_context() -> CallContext {
425        let request = Request::builder()
426            .uri("http://example.com/")
427            .body(RequestBody::absent())
428            .expect("request");
429        let factory = Arc::new(NoopEventListenerFactory) as crate::SharedEventListenerFactory;
430        CallContext::from_factory(&factory, &request, None)
431    }
432
433    fn dummy_connection() -> BoxConnection {
434        Box::new(NoopConnection)
435    }
436
437    struct NoopConnection;
438
439    impl hyper::rt::Read for NoopConnection {
440        fn poll_read(
441            self: Pin<&mut Self>,
442            _cx: &mut Context<'_>,
443            _buf: hyper::rt::ReadBufCursor<'_>,
444        ) -> Poll<Result<(), io::Error>> {
445            Poll::Ready(Ok(()))
446        }
447    }
448
449    impl hyper::rt::Write for NoopConnection {
450        fn poll_write(
451            self: Pin<&mut Self>,
452            _cx: &mut Context<'_>,
453            buf: &[u8],
454        ) -> Poll<Result<usize, io::Error>> {
455            Poll::Ready(Ok(buf.len()))
456        }
457
458        fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Result<(), io::Error>> {
459            Poll::Ready(Ok(()))
460        }
461
462        fn poll_shutdown(
463            self: Pin<&mut Self>,
464            _cx: &mut Context<'_>,
465        ) -> Poll<Result<(), io::Error>> {
466            Poll::Ready(Ok(()))
467        }
468    }
469
470    impl Connection for NoopConnection {
471        fn connected(&self) -> Connected {
472            Connected::new()
473        }
474    }
475
476    #[test]
477    fn connected_negotiated_h2_tracks_requested_state() {
478        assert!(Connected::new().negotiated_h2(true).is_negotiated_h2());
479        assert!(!Connected::new().negotiated_h2(false).is_negotiated_h2());
480    }
481
482    #[test]
483    fn connected_connection_info_or_default_preserves_explicit_metadata() {
484        let info = ConnectionInfo {
485            id: crate::next_connection_id(),
486            remote_addr: Some(([192, 0, 2, 10], 443).into()),
487            local_addr: Some(([192, 0, 2, 20], 50000).into()),
488            tls: true,
489        };
490
491        let actual = Connected::new()
492            .info(info.clone())
493            .connection_info_or_default();
494
495        assert_eq!(actual.id, info.id);
496        assert_eq!(actual.remote_addr, info.remote_addr);
497        assert_eq!(actual.local_addr, info.local_addr);
498        assert_eq!(actual.tls, info.tls);
499    }
500
501    #[test]
502    fn connected_connection_info_or_default_falls_back_to_placeholder() {
503        let actual = Connected::new().connection_info_or_default();
504
505        assert_eq!(actual.remote_addr, None);
506        assert_eq!(actual.local_addr, None);
507        assert!(!actual.tls);
508    }
509
510    #[tokio::test]
511    async fn tower_dns_resolver_calls_service() {
512        let calls = Arc::new(Mutex::new(Vec::new()));
513        let resolver = TowerDnsResolver::new(service_fn({
514            let calls = calls.clone();
515            move |request: DnsRequest| {
516                let calls = calls.clone();
517                async move {
518                    calls
519                        .lock()
520                        .expect("dns calls")
521                        .push((request.host, request.port));
522                    Ok::<_, WireError>(vec![std::net::SocketAddr::from(([127, 0, 0, 1], 443))])
523                }
524            }
525        }));
526
527        let resolved = resolver
528            .resolve(make_call_context(), "example.com".to_owned(), 443)
529            .await
530            .expect("resolved addrs");
531
532        assert_eq!(
533            calls.lock().expect("dns calls").as_slice(),
534            &[("example.com".to_owned(), 443)]
535        );
536        assert_eq!(
537            resolved,
538            vec![std::net::SocketAddr::from(([127, 0, 0, 1], 443))]
539        );
540    }
541
542    struct StaticDnsResolver;
543
544    impl DnsResolver for StaticDnsResolver {
545        fn resolve(
546            &self,
547            _ctx: CallContext,
548            host: String,
549            port: u16,
550        ) -> crate::BoxFuture<Result<Vec<std::net::SocketAddr>, WireError>> {
551            Box::pin(async move {
552                assert_eq!(host, "resolver.test");
553                assert_eq!(port, 8443);
554                Ok(vec![std::net::SocketAddr::from(([192, 0, 2, 10], 8443))])
555            })
556        }
557    }
558
559    #[tokio::test]
560    async fn dns_resolver_service_calls_resolver() {
561        let mut service = DnsResolverService::new(StaticDnsResolver);
562        let resolved = service
563            .ready()
564            .await
565            .expect("service ready")
566            .call(DnsRequest::new(
567                make_call_context(),
568                "resolver.test".to_owned(),
569                8443,
570            ))
571            .await
572            .expect("resolved addrs");
573
574        assert_eq!(
575            resolved,
576            vec![std::net::SocketAddr::from(([192, 0, 2, 10], 8443))]
577        );
578    }
579
580    #[tokio::test]
581    async fn tower_tcp_connector_calls_service() {
582        let calls = Arc::new(Mutex::new(Vec::new()));
583        let connector = TowerTcpConnector::new(service_fn({
584            let calls = calls.clone();
585            move |request: TcpConnectRequest| {
586                let calls = calls.clone();
587                async move {
588                    calls
589                        .lock()
590                        .expect("tcp calls")
591                        .push((request.addr, request.timeout));
592                    Ok::<_, WireError>(dummy_connection())
593                }
594            }
595        }));
596
597        connector
598            .connect(
599                make_call_context(),
600                std::net::SocketAddr::from(([127, 0, 0, 1], 8080)),
601                Some(Duration::from_secs(5)),
602            )
603            .await
604            .expect("tcp stream");
605
606        assert_eq!(
607            calls.lock().expect("tcp calls").as_slice(),
608            &[(
609                std::net::SocketAddr::from(([127, 0, 0, 1], 8080)),
610                Some(Duration::from_secs(5))
611            )]
612        );
613    }
614
615    struct StaticTcpConnector;
616
617    impl TcpConnector for StaticTcpConnector {
618        fn connect(
619            &self,
620            _ctx: CallContext,
621            addr: std::net::SocketAddr,
622            timeout: Option<Duration>,
623        ) -> crate::BoxFuture<Result<BoxConnection, WireError>> {
624            Box::pin(async move {
625                assert_eq!(addr, std::net::SocketAddr::from(([192, 0, 2, 20], 80)));
626                assert_eq!(timeout, Some(Duration::from_secs(1)));
627                Ok(dummy_connection())
628            })
629        }
630    }
631
632    #[tokio::test]
633    async fn tcp_connector_service_calls_connector() {
634        let mut service = TcpConnectorService::new(StaticTcpConnector);
635        service
636            .ready()
637            .await
638            .expect("service ready")
639            .call(TcpConnectRequest::new(
640                make_call_context(),
641                std::net::SocketAddr::from(([192, 0, 2, 20], 80)),
642                Some(Duration::from_secs(1)),
643            ))
644            .await
645            .expect("tcp stream");
646    }
647
648    #[tokio::test]
649    async fn tower_tls_connector_calls_service() {
650        let calls = Arc::new(Mutex::new(Vec::new()));
651        let connector = TowerTlsConnector::new(service_fn({
652            let calls = calls.clone();
653            move |request: TlsConnectRequest| {
654                let calls = calls.clone();
655                async move {
656                    calls
657                        .lock()
658                        .expect("tls calls")
659                        .push(request.uri.to_string());
660                    Ok::<_, WireError>(request.stream)
661                }
662            }
663        }));
664
665        connector
666            .connect(
667                make_call_context(),
668                "https://tls.test/".parse().expect("uri"),
669                dummy_connection(),
670            )
671            .await
672            .expect("tls stream");
673
674        assert_eq!(
675            calls.lock().expect("tls calls").as_slice(),
676            &["https://tls.test/".to_owned()]
677        );
678    }
679
680    struct StaticTlsConnector;
681
682    impl TlsConnector for StaticTlsConnector {
683        fn connect(
684            &self,
685            _ctx: CallContext,
686            uri: Uri,
687            stream: BoxConnection,
688        ) -> crate::BoxFuture<Result<BoxConnection, WireError>> {
689            Box::pin(async move {
690                assert_eq!(uri, "https://service.test/".parse::<Uri>().expect("uri"));
691                Ok(stream)
692            })
693        }
694    }
695
696    #[tokio::test]
697    async fn tls_connector_service_calls_connector() {
698        let mut service = TlsConnectorService::new(StaticTlsConnector);
699        service
700            .ready()
701            .await
702            .expect("service ready")
703            .call(TlsConnectRequest::new(
704                make_call_context(),
705                "https://service.test/".parse().expect("uri"),
706                dummy_connection(),
707            ))
708            .await
709            .expect("tls stream");
710    }
711}