1use std::net::SocketAddr;
2use std::task::{Context, Poll};
3use std::time::Duration;
4
5use hyper::Uri;
6use tower::util::BoxCloneSyncService;
7use tower::{Service, ServiceExt};
8
9use crate::{BoxFuture, CallContext, WireError};
10
11#[derive(Debug, Clone)]
12pub struct ConnectionInfo {
13 pub id: crate::ConnectionId,
14 pub remote_addr: Option<SocketAddr>,
15 pub local_addr: Option<SocketAddr>,
16 pub tls: bool,
17}
18
19#[derive(Debug, Clone, Default, PartialEq, Eq)]
20pub struct CoalescingInfo {
21 pub verified_server_names: Vec<String>,
22}
23
24impl CoalescingInfo {
25 pub fn new(verified_server_names: Vec<String>) -> Self {
26 Self {
27 verified_server_names,
28 }
29 }
30
31 pub fn is_empty(&self) -> bool {
32 self.verified_server_names.is_empty()
33 }
34}
35
36#[derive(Debug, Clone, Default)]
37pub struct Connected {
38 info: Option<ConnectionInfo>,
39 coalescing: CoalescingInfo,
40 proxied: bool,
41 negotiated_h2: bool,
42}
43
44impl Connected {
45 pub fn new() -> Self {
46 Self::default()
47 }
48
49 pub fn info(mut self, info: ConnectionInfo) -> Self {
50 self.info = Some(info);
51 self
52 }
53
54 pub fn coalescing(mut self, coalescing: CoalescingInfo) -> Self {
55 self.coalescing = coalescing;
56 self
57 }
58
59 pub fn proxy(mut self, proxied: bool) -> Self {
60 self.proxied = proxied;
61 self
62 }
63
64 pub fn negotiated_h2(mut self, negotiated_h2: bool) -> Self {
65 self.negotiated_h2 = negotiated_h2;
66 self
67 }
68
69 pub fn is_proxied(&self) -> bool {
70 self.proxied
71 }
72
73 pub fn is_negotiated_h2(&self) -> bool {
74 self.negotiated_h2
75 }
76
77 pub fn connection_info(&self) -> Option<&ConnectionInfo> {
78 self.info.as_ref()
79 }
80
81 pub fn connection_info_or_default(&self) -> ConnectionInfo {
82 self.info.clone().unwrap_or_else(|| ConnectionInfo {
83 id: crate::next_connection_id(),
84 remote_addr: None,
85 local_addr: None,
86 tls: false,
87 })
88 }
89
90 pub fn coalescing_info(&self) -> &CoalescingInfo {
91 &self.coalescing
92 }
93}
94
95pub trait Connection {
96 fn connected(&self) -> Connected;
97}
98
99pub trait ConnectionIo:
100 hyper::rt::Read + hyper::rt::Write + Connection + Unpin + Send + 'static
101{
102}
103
104impl<T> ConnectionIo for T where
105 T: hyper::rt::Read + hyper::rt::Write + Connection + Unpin + Send + 'static
106{
107}
108
109pub type BoxConnection = Box<dyn ConnectionIo>;
110
111impl Connection for BoxConnection {
112 fn connected(&self) -> Connected {
113 (**self).connected()
114 }
115}
116
117#[derive(Clone, Debug)]
118pub struct DnsRequest {
119 pub ctx: CallContext,
120 pub host: String,
121 pub port: u16,
122}
123
124impl DnsRequest {
125 pub fn new(ctx: CallContext, host: String, port: u16) -> Self {
126 Self { ctx, host, port }
127 }
128}
129
130#[derive(Clone, Debug)]
131pub struct TcpConnectRequest {
132 pub ctx: CallContext,
133 pub addr: SocketAddr,
134 pub timeout: Option<Duration>,
135}
136
137impl TcpConnectRequest {
138 pub fn new(ctx: CallContext, addr: SocketAddr, timeout: Option<Duration>) -> Self {
139 Self { ctx, addr, timeout }
140 }
141}
142
143#[derive(Clone, Copy, Debug, Default, PartialEq, Eq)]
149pub enum TlsAlpnPreference {
150 #[default]
151 Auto,
152 Http1Only,
153}
154
155pub struct TlsConnectRequest {
156 pub ctx: CallContext,
157 pub uri: Uri,
158 pub stream: BoxConnection,
159}
160
161impl TlsConnectRequest {
162 pub fn new(ctx: CallContext, uri: Uri, stream: BoxConnection) -> Self {
163 Self { ctx, uri, stream }
164 }
165}
166
167pub type BoxDnsService = BoxCloneSyncService<DnsRequest, Vec<SocketAddr>, WireError>;
168pub type BoxTcpService = BoxCloneSyncService<TcpConnectRequest, BoxConnection, WireError>;
169pub type BoxTlsService = BoxCloneSyncService<TlsConnectRequest, BoxConnection, WireError>;
170
171pub trait DnsResolver: Send + Sync + 'static {
172 fn resolve(
173 &self,
174 ctx: CallContext,
175 host: String,
176 port: u16,
177 ) -> BoxFuture<Result<Vec<SocketAddr>, WireError>>;
178}
179
180pub trait TcpConnector: Send + Sync + 'static {
181 fn connect(
182 &self,
183 ctx: CallContext,
184 addr: SocketAddr,
185 timeout: Option<Duration>,
186 ) -> BoxFuture<Result<BoxConnection, WireError>>;
187}
188
189pub trait TlsConnector: Send + Sync + 'static {
190 fn connect(
191 &self,
192 ctx: CallContext,
193 uri: Uri,
194 stream: BoxConnection,
195 ) -> BoxFuture<Result<BoxConnection, WireError>>;
196}
197
198#[derive(Clone)]
199pub struct TowerDnsResolver(BoxDnsService);
200
201impl TowerDnsResolver {
202 pub fn new<S>(service: S) -> Self
203 where
204 S: Service<DnsRequest, Response = Vec<SocketAddr>, Error = WireError>
205 + Clone
206 + Send
207 + Sync
208 + 'static,
209 S::Future: Send + 'static,
210 {
211 Self(BoxCloneSyncService::new(service))
212 }
213
214 pub fn service(&self) -> BoxDnsService {
215 self.0.clone()
216 }
217}
218
219impl DnsResolver for TowerDnsResolver {
220 fn resolve(
221 &self,
222 ctx: CallContext,
223 host: String,
224 port: u16,
225 ) -> BoxFuture<Result<Vec<SocketAddr>, WireError>> {
226 let service = self.0.clone();
227 Box::pin(async move { service.oneshot(DnsRequest::new(ctx, host, port)).await })
228 }
229}
230
231#[derive(Clone)]
232pub struct TowerTcpConnector(BoxTcpService);
233
234impl TowerTcpConnector {
235 pub fn new<S>(service: S) -> Self
236 where
237 S: Service<TcpConnectRequest, Response = BoxConnection, Error = WireError>
238 + Clone
239 + Send
240 + Sync
241 + 'static,
242 S::Future: Send + 'static,
243 {
244 Self(BoxCloneSyncService::new(service))
245 }
246
247 pub fn service(&self) -> BoxTcpService {
248 self.0.clone()
249 }
250}
251
252impl TcpConnector for TowerTcpConnector {
253 fn connect(
254 &self,
255 ctx: CallContext,
256 addr: SocketAddr,
257 timeout: Option<Duration>,
258 ) -> BoxFuture<Result<BoxConnection, WireError>> {
259 let service = self.0.clone();
260 Box::pin(async move {
261 service
262 .oneshot(TcpConnectRequest::new(ctx, addr, timeout))
263 .await
264 })
265 }
266}
267
268#[derive(Clone)]
269pub struct TowerTlsConnector(BoxTlsService);
270
271impl TowerTlsConnector {
272 pub fn new<S>(service: S) -> Self
273 where
274 S: Service<TlsConnectRequest, Response = BoxConnection, Error = WireError>
275 + Clone
276 + Send
277 + Sync
278 + 'static,
279 S::Future: Send + 'static,
280 {
281 Self(BoxCloneSyncService::new(service))
282 }
283
284 pub fn service(&self) -> BoxTlsService {
285 self.0.clone()
286 }
287}
288
289impl TlsConnector for TowerTlsConnector {
290 fn connect(
291 &self,
292 ctx: CallContext,
293 uri: Uri,
294 stream: BoxConnection,
295 ) -> BoxFuture<Result<BoxConnection, WireError>> {
296 let service = self.0.clone();
297 Box::pin(async move {
298 service
299 .oneshot(TlsConnectRequest::new(ctx, uri, stream))
300 .await
301 })
302 }
303}
304
305#[derive(Clone)]
306pub struct DnsResolverService<R> {
307 resolver: R,
308}
309
310impl<R> DnsResolverService<R> {
311 pub fn new(resolver: R) -> Self {
312 Self { resolver }
313 }
314
315 pub fn into_inner(self) -> R {
316 self.resolver
317 }
318}
319
320impl<R> Service<DnsRequest> for DnsResolverService<R>
321where
322 R: DnsResolver,
323{
324 type Response = Vec<SocketAddr>;
325 type Error = WireError;
326 type Future = BoxFuture<Result<Self::Response, Self::Error>>;
327
328 fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
329 Poll::Ready(Ok(()))
330 }
331
332 fn call(&mut self, request: DnsRequest) -> Self::Future {
333 self.resolver
334 .resolve(request.ctx, request.host, request.port)
335 }
336}
337
338#[derive(Clone)]
339pub struct TcpConnectorService<C> {
340 connector: C,
341}
342
343impl<C> TcpConnectorService<C> {
344 pub fn new(connector: C) -> Self {
345 Self { connector }
346 }
347
348 pub fn into_inner(self) -> C {
349 self.connector
350 }
351}
352
353impl<C> Service<TcpConnectRequest> for TcpConnectorService<C>
354where
355 C: TcpConnector,
356{
357 type Response = BoxConnection;
358 type Error = WireError;
359 type Future = BoxFuture<Result<Self::Response, Self::Error>>;
360
361 fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
362 Poll::Ready(Ok(()))
363 }
364
365 fn call(&mut self, request: TcpConnectRequest) -> Self::Future {
366 self.connector
367 .connect(request.ctx, request.addr, request.timeout)
368 }
369}
370
371#[derive(Clone)]
372pub struct TlsConnectorService<C> {
373 connector: C,
374}
375
376impl<C> TlsConnectorService<C> {
377 pub fn new(connector: C) -> Self {
378 Self { connector }
379 }
380
381 pub fn into_inner(self) -> C {
382 self.connector
383 }
384}
385
386impl<C> Service<TlsConnectRequest> for TlsConnectorService<C>
387where
388 C: TlsConnector,
389{
390 type Response = BoxConnection;
391 type Error = WireError;
392 type Future = BoxFuture<Result<Self::Response, Self::Error>>;
393
394 fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
395 Poll::Ready(Ok(()))
396 }
397
398 fn call(&mut self, request: TlsConnectRequest) -> Self::Future {
399 self.connector
400 .connect(request.ctx, request.uri, request.stream)
401 }
402}
403
404#[cfg(test)]
405mod tests {
406 use std::io;
407 use std::pin::Pin;
408 use std::sync::{Arc, Mutex};
409 use std::task::{Context, Poll};
410 use std::time::Duration;
411
412 use http::Request;
413 use hyper::Uri;
414 use tower::{service_fn, Service, ServiceExt};
415
416 use super::{
417 BoxConnection, Connected, Connection, ConnectionInfo, DnsRequest, DnsResolver,
418 DnsResolverService, TcpConnectRequest, TcpConnector, TcpConnectorService,
419 TlsConnectRequest, TlsConnector, TlsConnectorService, TowerDnsResolver, TowerTcpConnector,
420 TowerTlsConnector,
421 };
422 use crate::{CallContext, NoopEventListenerFactory, RequestBody, WireError};
423
424 fn make_call_context() -> CallContext {
425 let request = Request::builder()
426 .uri("http://example.com/")
427 .body(RequestBody::absent())
428 .expect("request");
429 let factory = Arc::new(NoopEventListenerFactory) as crate::SharedEventListenerFactory;
430 CallContext::from_factory(&factory, &request, None)
431 }
432
433 fn dummy_connection() -> BoxConnection {
434 Box::new(NoopConnection)
435 }
436
437 struct NoopConnection;
438
439 impl hyper::rt::Read for NoopConnection {
440 fn poll_read(
441 self: Pin<&mut Self>,
442 _cx: &mut Context<'_>,
443 _buf: hyper::rt::ReadBufCursor<'_>,
444 ) -> Poll<Result<(), io::Error>> {
445 Poll::Ready(Ok(()))
446 }
447 }
448
449 impl hyper::rt::Write for NoopConnection {
450 fn poll_write(
451 self: Pin<&mut Self>,
452 _cx: &mut Context<'_>,
453 buf: &[u8],
454 ) -> Poll<Result<usize, io::Error>> {
455 Poll::Ready(Ok(buf.len()))
456 }
457
458 fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Result<(), io::Error>> {
459 Poll::Ready(Ok(()))
460 }
461
462 fn poll_shutdown(
463 self: Pin<&mut Self>,
464 _cx: &mut Context<'_>,
465 ) -> Poll<Result<(), io::Error>> {
466 Poll::Ready(Ok(()))
467 }
468 }
469
470 impl Connection for NoopConnection {
471 fn connected(&self) -> Connected {
472 Connected::new()
473 }
474 }
475
476 #[test]
477 fn connected_negotiated_h2_tracks_requested_state() {
478 assert!(Connected::new().negotiated_h2(true).is_negotiated_h2());
479 assert!(!Connected::new().negotiated_h2(false).is_negotiated_h2());
480 }
481
482 #[test]
483 fn connected_connection_info_or_default_preserves_explicit_metadata() {
484 let info = ConnectionInfo {
485 id: crate::next_connection_id(),
486 remote_addr: Some(([192, 0, 2, 10], 443).into()),
487 local_addr: Some(([192, 0, 2, 20], 50000).into()),
488 tls: true,
489 };
490
491 let actual = Connected::new()
492 .info(info.clone())
493 .connection_info_or_default();
494
495 assert_eq!(actual.id, info.id);
496 assert_eq!(actual.remote_addr, info.remote_addr);
497 assert_eq!(actual.local_addr, info.local_addr);
498 assert_eq!(actual.tls, info.tls);
499 }
500
501 #[test]
502 fn connected_connection_info_or_default_falls_back_to_placeholder() {
503 let actual = Connected::new().connection_info_or_default();
504
505 assert_eq!(actual.remote_addr, None);
506 assert_eq!(actual.local_addr, None);
507 assert!(!actual.tls);
508 }
509
510 #[tokio::test]
511 async fn tower_dns_resolver_calls_service() {
512 let calls = Arc::new(Mutex::new(Vec::new()));
513 let resolver = TowerDnsResolver::new(service_fn({
514 let calls = calls.clone();
515 move |request: DnsRequest| {
516 let calls = calls.clone();
517 async move {
518 calls
519 .lock()
520 .expect("dns calls")
521 .push((request.host, request.port));
522 Ok::<_, WireError>(vec![std::net::SocketAddr::from(([127, 0, 0, 1], 443))])
523 }
524 }
525 }));
526
527 let resolved = resolver
528 .resolve(make_call_context(), "example.com".to_owned(), 443)
529 .await
530 .expect("resolved addrs");
531
532 assert_eq!(
533 calls.lock().expect("dns calls").as_slice(),
534 &[("example.com".to_owned(), 443)]
535 );
536 assert_eq!(
537 resolved,
538 vec![std::net::SocketAddr::from(([127, 0, 0, 1], 443))]
539 );
540 }
541
542 struct StaticDnsResolver;
543
544 impl DnsResolver for StaticDnsResolver {
545 fn resolve(
546 &self,
547 _ctx: CallContext,
548 host: String,
549 port: u16,
550 ) -> crate::BoxFuture<Result<Vec<std::net::SocketAddr>, WireError>> {
551 Box::pin(async move {
552 assert_eq!(host, "resolver.test");
553 assert_eq!(port, 8443);
554 Ok(vec![std::net::SocketAddr::from(([192, 0, 2, 10], 8443))])
555 })
556 }
557 }
558
559 #[tokio::test]
560 async fn dns_resolver_service_calls_resolver() {
561 let mut service = DnsResolverService::new(StaticDnsResolver);
562 let resolved = service
563 .ready()
564 .await
565 .expect("service ready")
566 .call(DnsRequest::new(
567 make_call_context(),
568 "resolver.test".to_owned(),
569 8443,
570 ))
571 .await
572 .expect("resolved addrs");
573
574 assert_eq!(
575 resolved,
576 vec![std::net::SocketAddr::from(([192, 0, 2, 10], 8443))]
577 );
578 }
579
580 #[tokio::test]
581 async fn tower_tcp_connector_calls_service() {
582 let calls = Arc::new(Mutex::new(Vec::new()));
583 let connector = TowerTcpConnector::new(service_fn({
584 let calls = calls.clone();
585 move |request: TcpConnectRequest| {
586 let calls = calls.clone();
587 async move {
588 calls
589 .lock()
590 .expect("tcp calls")
591 .push((request.addr, request.timeout));
592 Ok::<_, WireError>(dummy_connection())
593 }
594 }
595 }));
596
597 connector
598 .connect(
599 make_call_context(),
600 std::net::SocketAddr::from(([127, 0, 0, 1], 8080)),
601 Some(Duration::from_secs(5)),
602 )
603 .await
604 .expect("tcp stream");
605
606 assert_eq!(
607 calls.lock().expect("tcp calls").as_slice(),
608 &[(
609 std::net::SocketAddr::from(([127, 0, 0, 1], 8080)),
610 Some(Duration::from_secs(5))
611 )]
612 );
613 }
614
615 struct StaticTcpConnector;
616
617 impl TcpConnector for StaticTcpConnector {
618 fn connect(
619 &self,
620 _ctx: CallContext,
621 addr: std::net::SocketAddr,
622 timeout: Option<Duration>,
623 ) -> crate::BoxFuture<Result<BoxConnection, WireError>> {
624 Box::pin(async move {
625 assert_eq!(addr, std::net::SocketAddr::from(([192, 0, 2, 20], 80)));
626 assert_eq!(timeout, Some(Duration::from_secs(1)));
627 Ok(dummy_connection())
628 })
629 }
630 }
631
632 #[tokio::test]
633 async fn tcp_connector_service_calls_connector() {
634 let mut service = TcpConnectorService::new(StaticTcpConnector);
635 service
636 .ready()
637 .await
638 .expect("service ready")
639 .call(TcpConnectRequest::new(
640 make_call_context(),
641 std::net::SocketAddr::from(([192, 0, 2, 20], 80)),
642 Some(Duration::from_secs(1)),
643 ))
644 .await
645 .expect("tcp stream");
646 }
647
648 #[tokio::test]
649 async fn tower_tls_connector_calls_service() {
650 let calls = Arc::new(Mutex::new(Vec::new()));
651 let connector = TowerTlsConnector::new(service_fn({
652 let calls = calls.clone();
653 move |request: TlsConnectRequest| {
654 let calls = calls.clone();
655 async move {
656 calls
657 .lock()
658 .expect("tls calls")
659 .push(request.uri.to_string());
660 Ok::<_, WireError>(request.stream)
661 }
662 }
663 }));
664
665 connector
666 .connect(
667 make_call_context(),
668 "https://tls.test/".parse().expect("uri"),
669 dummy_connection(),
670 )
671 .await
672 .expect("tls stream");
673
674 assert_eq!(
675 calls.lock().expect("tls calls").as_slice(),
676 &["https://tls.test/".to_owned()]
677 );
678 }
679
680 struct StaticTlsConnector;
681
682 impl TlsConnector for StaticTlsConnector {
683 fn connect(
684 &self,
685 _ctx: CallContext,
686 uri: Uri,
687 stream: BoxConnection,
688 ) -> crate::BoxFuture<Result<BoxConnection, WireError>> {
689 Box::pin(async move {
690 assert_eq!(uri, "https://service.test/".parse::<Uri>().expect("uri"));
691 Ok(stream)
692 })
693 }
694 }
695
696 #[tokio::test]
697 async fn tls_connector_service_calls_connector() {
698 let mut service = TlsConnectorService::new(StaticTlsConnector);
699 service
700 .ready()
701 .await
702 .expect("service ready")
703 .call(TlsConnectRequest::new(
704 make_call_context(),
705 "https://service.test/".parse().expect("uri"),
706 dummy_connection(),
707 ))
708 .await
709 .expect("tls stream");
710 }
711}