1use std::sync::atomic::{AtomicBool, Ordering};
2use std::sync::{Arc, Mutex, Weak};
3use std::time::Duration;
4
5use bytes::Bytes;
6use futures_channel::oneshot;
7use futures_util::future::{select, Either};
8use futures_util::task::AtomicWaker;
9use http::header::PROXY_AUTHORIZATION;
10use http::{Request, Response};
11use http_body::{Body, Frame, SizeHint};
12use http_body_util::BodyExt;
13use hyper::rt::Timer;
14use openwire_core::{
15 Authenticator, BoxTaskHandle, BoxWireService, CallContext, CookieJar, DnsResolver,
16 EventListenerFactory, Exchange, InterceptorLayer, NoopEventListenerFactory, RedirectPolicy,
17 RequestBody, ResponseBody, RetryPolicy, SharedEventListenerFactory, SharedInterceptor,
18 SharedTimer, TcpConnector, TlsConnector, WireError, WireExecutor,
19};
20use openwire_tokio::{SystemDnsResolver, TokioExecutor, TokioTcpConnector, TokioTimer};
21use pin_project_lite::pin_project;
22use tower::layer::Layer;
23use tower::util::BoxCloneSyncService;
24use tower::Service;
25use tracing::instrument::WithSubscriber;
26use tracing::Instrument;
27
28use crate::auth::SharedAuthenticator;
29use crate::bridge::BridgeInterceptor;
30use crate::connection::{
31 Address, CachedAddresses, ConnectionPool, DefaultRoutePlanner, ExchangeFinder, PoolSettings,
32 RequestAdmissionLimiter, RequestAdmissionPermit, ResolvedAddress, RoutePlanner,
33};
34use crate::cookie::SharedCookieJar;
35use crate::policy::{
36 AuthPolicyConfig, FollowUpPolicyService, PolicyConfig, RedirectPolicyConfig, RetryPolicyConfig,
37};
38use crate::proxy::{
39 resolved_proxy_candidates_with_sticky, ProxyRules, ProxySelector, SelectedProxy,
40 SharedProxySelector,
41};
42use crate::sync_util::lock_mutex;
43use crate::transport::{ConnectorStack, TransportService, TransportServiceInit};
44
45#[derive(Clone)]
46pub struct Client {
47 inner: Arc<ClientInner>,
48}
49
50struct ClientInner {
51 event_listener_factory: SharedEventListenerFactory,
52 executor: Arc<dyn WireExecutor>,
53 timer: SharedTimer,
54 request_config: EffectiveRequestConfig,
55 service: BoxWireService,
56 pool_reaper: Arc<PoolReaperController>,
57 #[cfg(feature = "websocket")]
58 connector: ConnectorStack,
59 #[cfg(feature = "websocket")]
60 proxy_selector: SharedProxySelector,
61}
62
63pub struct Call {
64 client: Client,
65 request: Request<RequestBody>,
66 options: CallOptions,
67 state: Arc<CallState>,
68}
69
70#[derive(Clone)]
76pub struct CallHandle {
77 state: Arc<CallState>,
78}
79
80pub struct QueuedCall {
85 handle: CallHandle,
86 receiver: oneshot::Receiver<Result<Response<ResponseBody>, WireError>>,
87 _task: BoxTaskHandle,
88}
89
90#[derive(Clone)]
91pub(crate) struct TransportConfig {
92 pub(crate) connect_timeout: Option<Duration>,
93 pub(crate) pool_idle_timeout: Option<Duration>,
94 pub(crate) pool_max_idle_per_host: usize,
95 pub(crate) http2_keep_alive_interval: Option<Duration>,
96 pub(crate) http2_keep_alive_while_idle: bool,
97 pub(crate) max_connections_total: usize,
98 pub(crate) max_connections_per_host: usize,
99 pub(crate) max_requests_total: usize,
100 pub(crate) max_requests_per_host: usize,
101}
102
103#[derive(Clone, Copy, Debug, Default, PartialEq, Eq)]
104pub struct CallOptions {
105 call_timeout: Option<Duration>,
106 connect_timeout: Option<Duration>,
107 follow_redirects: Option<bool>,
108 max_redirects: Option<usize>,
109 retry_on_connection_failure: Option<bool>,
110 max_retries: Option<usize>,
111 retry_canceled_requests: Option<bool>,
112 allow_insecure_redirects: Option<bool>,
113}
114
115#[derive(Clone, Copy, Debug, PartialEq, Eq)]
116pub(crate) struct EffectiveRequestConfig {
117 pub(crate) call_timeout: Option<Duration>,
118 pub(crate) connect_timeout: Option<Duration>,
119 pub(crate) follow_redirects: bool,
120 pub(crate) max_redirects: usize,
121 pub(crate) retry_on_connection_failure: bool,
122 pub(crate) max_retries: usize,
123 pub(crate) retry_canceled_requests: bool,
124 pub(crate) allow_insecure_redirects: bool,
125}
126
127pub struct ClientBuilder {
128 application_interceptors: Vec<SharedInterceptor>,
129 network_interceptors: Vec<SharedInterceptor>,
130 event_listener_factory: SharedEventListenerFactory,
131 executor: Arc<dyn WireExecutor>,
132 timer: SharedTimer,
133 call_timeout: Option<Duration>,
134 transport: TransportConfig,
135 policy: PolicyConfig,
136 dns_resolver: Arc<dyn DnsResolver>,
137 tcp_connector: Arc<dyn TcpConnector>,
138 tls_connector: Option<Arc<dyn TlsConnector>>,
139 route_planner: Arc<dyn RoutePlanner>,
140 proxy_selector: SharedProxySelector,
141}
142
143const MIN_POOL_REAPER_CADENCE: Duration = Duration::from_secs(5);
144const MAX_POOL_REAPER_CADENCE: Duration = Duration::from_secs(60);
145
146#[derive(Default)]
147struct PoolReaperController {
148 state: Mutex<PoolReaperState>,
149}
150
151#[derive(Default)]
152struct PoolReaperState {
153 handle: Option<BoxTaskHandle>,
154}
155
156#[derive(Default)]
157struct CallState {
158 executed: AtomicBool,
159 canceled: AtomicBool,
160 waker: AtomicWaker,
161}
162
163impl ClientBuilder {
164 pub fn new() -> Self {
165 Self::default()
166 }
167
168 pub fn application_interceptor<I>(mut self, interceptor: I) -> Self
169 where
170 I: openwire_core::Interceptor,
171 {
172 self.application_interceptors.push(Arc::new(interceptor));
173 self
174 }
175
176 pub fn network_interceptor<I>(mut self, interceptor: I) -> Self
177 where
178 I: openwire_core::Interceptor,
179 {
180 self.network_interceptors.push(Arc::new(interceptor));
181 self
182 }
183
184 pub fn event_listener_factory<F>(mut self, factory: F) -> Self
185 where
186 F: EventListenerFactory,
187 {
188 self.event_listener_factory = Arc::new(factory);
189 self
190 }
191
192 pub fn executor<E>(mut self, executor: E) -> Self
193 where
194 E: WireExecutor,
195 {
196 self.executor = Arc::new(executor);
197 self
198 }
199
200 pub fn timer<T>(mut self, timer: T) -> Self
201 where
202 T: Timer + Send + Sync + 'static,
203 {
204 self.timer = SharedTimer::new(timer);
205 self
206 }
207
208 pub fn dns_resolver<R>(mut self, resolver: R) -> Self
209 where
210 R: DnsResolver,
211 {
212 self.dns_resolver = Arc::new(resolver);
213 self
214 }
215
216 pub fn tcp_connector<C>(mut self, connector: C) -> Self
217 where
218 C: TcpConnector,
219 {
220 self.tcp_connector = Arc::new(connector);
221 self
222 }
223
224 pub fn tls_connector<T>(mut self, connector: T) -> Self
225 where
226 T: TlsConnector,
227 {
228 self.tls_connector = Some(Arc::new(connector));
229 self
230 }
231
232 pub fn route_planner<P>(mut self, planner: P) -> Self
233 where
234 P: RoutePlanner,
235 {
236 self.route_planner = Arc::new(planner);
237 self
238 }
239
240 pub fn cookie_jar<J>(mut self, jar: J) -> Self
242 where
243 J: CookieJar,
244 {
245 self.policy.cookie_jar = Some(Arc::new(jar) as SharedCookieJar);
246 self
247 }
248
249 pub fn proxy_selector<S>(mut self, selector: S) -> Self
251 where
252 S: ProxySelector,
253 {
254 self.proxy_selector = Arc::new(selector);
255 self
256 }
257
258 pub fn authenticator<A>(mut self, authenticator: A) -> Self
259 where
260 A: Authenticator,
261 {
262 self.policy.auth.authenticator = Some(Arc::new(authenticator) as SharedAuthenticator);
263 self
264 }
265
266 pub fn proxy_authenticator<A>(mut self, authenticator: A) -> Self
267 where
268 A: Authenticator,
269 {
270 self.policy.auth.proxy_authenticator = Some(Arc::new(authenticator) as SharedAuthenticator);
271 self
272 }
273
274 pub fn max_auth_attempts(mut self, max_auth_attempts: usize) -> Self {
276 self.policy.auth.max_auth_attempts = max_auth_attempts;
277 self
278 }
279
280 pub fn call_timeout(mut self, timeout: Duration) -> Self {
281 self.call_timeout = Some(timeout);
282 self
283 }
284
285 pub fn connect_timeout(mut self, timeout: Duration) -> Self {
286 self.transport.connect_timeout = Some(timeout);
287 self
288 }
289
290 pub fn follow_redirects(mut self, enabled: bool) -> Self {
291 self.policy
292 .redirect
293 .default_mut()
294 .set_follow_redirects(enabled);
295 self
296 }
297
298 pub fn max_redirects(mut self, max_redirects: usize) -> Self {
299 self.policy
300 .redirect
301 .default_mut()
302 .set_max_redirects(max_redirects);
303 self
304 }
305
306 pub fn allow_insecure_redirects(mut self, enabled: bool) -> Self {
307 self.policy
308 .redirect
309 .default_mut()
310 .set_allow_insecure_redirects(enabled);
311 self
312 }
313
314 pub fn retry_on_connection_failure(mut self, enabled: bool) -> Self {
315 self.policy
316 .retry
317 .default_mut()
318 .set_retry_on_connection_failure(enabled);
319 self
320 }
321
322 pub fn max_retries(mut self, max_retries: usize) -> Self {
323 self.policy.retry.default_mut().set_max_retries(max_retries);
324 self
325 }
326
327 pub fn pool_idle_timeout(mut self, timeout: Duration) -> Self {
328 self.transport.pool_idle_timeout = Some(timeout);
329 self
330 }
331
332 pub fn pool_max_idle_per_host(mut self, max_idle: usize) -> Self {
333 self.transport.pool_max_idle_per_host = max_idle;
334 self
335 }
336
337 pub fn http2_keep_alive_interval(mut self, interval: Duration) -> Self {
338 self.transport.http2_keep_alive_interval = Some(interval);
339 self
340 }
341
342 pub fn http2_keep_alive_while_idle(mut self, enabled: bool) -> Self {
343 self.transport.http2_keep_alive_while_idle = enabled;
344 self
345 }
346
347 pub fn max_connections_total(mut self, max_connections: usize) -> Self {
348 self.transport.max_connections_total = max_connections;
349 self
350 }
351
352 pub fn max_connections_per_host(mut self, max_connections: usize) -> Self {
353 self.transport.max_connections_per_host = max_connections;
354 self
355 }
356
357 pub fn max_requests_total(mut self, max_requests: usize) -> Self {
358 self.transport.max_requests_total = max_requests;
359 self
360 }
361
362 pub fn max_requests_per_host(mut self, max_requests: usize) -> Self {
363 self.transport.max_requests_per_host = max_requests;
364 self
365 }
366
367 pub fn retry_canceled_requests(mut self, enabled: bool) -> Self {
368 self.policy
369 .retry
370 .default_mut()
371 .set_retry_canceled_requests(enabled);
372 self
373 }
374
375 pub fn retry_policy<P>(mut self, policy: P) -> Self
376 where
377 P: RetryPolicy,
378 {
379 self.policy.retry.set_custom(policy);
380 self
381 }
382
383 pub fn redirect_policy<P>(mut self, policy: P) -> Self
384 where
385 P: RedirectPolicy,
386 {
387 self.policy.redirect.set_custom(policy);
388 self
389 }
390
391 pub fn build(self) -> Result<Client, WireError> {
392 #[cfg(feature = "tls-rustls")]
393 let tls_connector = match self.tls_connector {
394 Some(tls_connector) => Some(tls_connector),
395 None => Some(
396 Arc::new(openwire_rustls::RustlsTlsConnector::builder().build()?)
397 as Arc<dyn TlsConnector>,
398 ),
399 };
400
401 #[cfg(not(feature = "tls-rustls"))]
402 let tls_connector = self.tls_connector;
403
404 let request_config = EffectiveRequestConfig::from_defaults(
405 self.call_timeout,
406 self.transport.connect_timeout,
407 &self.policy,
408 );
409
410 let pool = Arc::new(ConnectionPool::new(PoolSettings {
411 idle_timeout: self.transport.pool_idle_timeout,
412 max_idle_per_address: self.transport.pool_max_idle_per_host,
413 }));
414 let pool_reaper = Arc::new(PoolReaperController::default());
415 let on_pooled_connection_published = if pool.settings().idle_timeout.is_some() {
416 let reaper = pool_reaper.clone();
417 let executor = self.executor.clone();
418 let timer = self.timer.clone();
419 let weak_pool = Arc::downgrade(&pool);
420 Some(Arc::new(move || {
421 reaper.ensure_started(executor.clone(), timer.clone(), weak_pool.clone());
422 }) as Arc<dyn Fn() + Send + Sync>)
423 } else {
424 None
425 };
426 let proxy_selector = self.proxy_selector;
427 let request_admission = RequestAdmissionLimiter::new(
428 self.transport.max_requests_total,
429 self.transport.max_requests_per_host,
430 );
431 let exchange_finder = Arc::new(ExchangeFinder::new(pool, proxy_selector.clone()));
432 let connector = ConnectorStack {
433 dns_resolver: self.dns_resolver,
434 tcp_connector: self.tcp_connector,
435 tls_connector,
436 connect_timeout: self.transport.connect_timeout,
437 executor: self.executor.clone(),
438 timer: self.timer.clone(),
439 route_planner: self.route_planner,
440 proxy_authenticator: self.policy.auth.proxy_authenticator.clone(),
441 max_proxy_auth_attempts: self.policy.auth.max_auth_attempts,
442 };
443
444 let transport = TransportService::new(TransportServiceInit {
445 connector: connector.clone(),
446 config: self.transport.clone(),
447 executor: self.executor.clone(),
448 timer: self.timer.clone(),
449 exchange_finder,
450 request_admission,
451 proxy_selector: proxy_selector.clone(),
452 on_pooled_connection_published,
453 });
454 let service = build_service_chain(
455 transport,
456 self.application_interceptors,
457 self.network_interceptors,
458 self.policy.clone(),
459 );
460
461 Ok(Client {
462 inner: Arc::new(ClientInner {
463 event_listener_factory: self.event_listener_factory,
464 executor: self.executor,
465 timer: self.timer,
466 request_config,
467 service,
468 pool_reaper,
469 #[cfg(feature = "websocket")]
470 connector,
471 #[cfg(feature = "websocket")]
472 proxy_selector,
473 }),
474 })
475 }
476}
477
478impl Default for ClientBuilder {
479 fn default() -> Self {
480 Self {
481 application_interceptors: Vec::new(),
482 network_interceptors: Vec::new(),
483 event_listener_factory: Arc::new(NoopEventListenerFactory),
484 executor: Arc::new(TokioExecutor::new()),
485 timer: SharedTimer::new(TokioTimer::new()),
486 call_timeout: None,
487 transport: TransportConfig {
488 connect_timeout: None,
489 pool_idle_timeout: Some(Duration::from_secs(300)),
490 pool_max_idle_per_host: 5,
491 http2_keep_alive_interval: None,
492 http2_keep_alive_while_idle: false,
493 max_connections_total: usize::MAX,
494 max_connections_per_host: usize::MAX,
495 max_requests_total: 64,
496 max_requests_per_host: 5,
497 },
498 policy: PolicyConfig {
499 cookie_jar: None,
500 auth: AuthPolicyConfig {
501 authenticator: None,
502 proxy_authenticator: None,
503 max_auth_attempts: 3,
504 },
505 retry: RetryPolicyConfig::default(),
506 redirect: RedirectPolicyConfig::default(),
507 },
508 dns_resolver: Arc::new(SystemDnsResolver),
509 tcp_connector: Arc::new(TokioTcpConnector),
510 tls_connector: None,
511 route_planner: Arc::new(DefaultRoutePlanner::default()),
512 proxy_selector: Arc::new(ProxyRules::new()),
513 }
514 }
515}
516
517impl Client {
518 pub fn builder() -> ClientBuilder {
519 ClientBuilder::new()
520 }
521
522 pub fn new_call(&self, request: Request<RequestBody>) -> Call {
523 Call {
524 client: self.clone(),
525 request,
526 options: CallOptions::default(),
527 state: Arc::new(CallState::default()),
528 }
529 }
530
531 pub async fn execute(
532 &self,
533 request: Request<RequestBody>,
534 ) -> Result<Response<ResponseBody>, WireError> {
535 self.new_call(request).execute().await
536 }
537
538 #[cfg(feature = "websocket")]
539 pub fn new_websocket(&self, request: Request<RequestBody>) -> WebSocketCall<'_> {
540 WebSocketCall {
541 client: self,
542 request,
543 handshake_timeout: None,
544 close_timeout: None,
545 max_frame_size: None,
546 max_message_size: None,
547 send_queue_size: None,
548 ping_interval: None,
549 pong_timeout: None,
550 subprotocols: Vec::new(),
551 deliver_control_frames: false,
552 engine: None,
553 }
554 }
555
556 #[cfg(feature = "websocket")]
557 pub(crate) fn ws_connector(&self) -> &ConnectorStack {
558 &self.inner.connector
559 }
560
561 #[cfg(feature = "websocket")]
562 pub(crate) fn ws_proxy_selector(&self) -> &SharedProxySelector {
563 &self.inner.proxy_selector
564 }
565
566 #[cfg(feature = "websocket")]
567 pub(crate) fn event_listener_factory(&self) -> &SharedEventListenerFactory {
568 &self.inner.event_listener_factory
569 }
570}
571
572#[cfg(feature = "websocket")]
573pub struct WebSocketCall<'a> {
574 pub(crate) client: &'a Client,
575 pub(crate) request: Request<RequestBody>,
576 pub(crate) handshake_timeout: Option<Duration>,
577 pub(crate) close_timeout: Option<Duration>,
578 pub(crate) max_frame_size: Option<usize>,
579 pub(crate) max_message_size: Option<usize>,
580 pub(crate) send_queue_size: Option<usize>,
581 pub(crate) ping_interval: Option<Duration>,
582 pub(crate) pong_timeout: Option<Duration>,
583 pub(crate) subprotocols: Vec<String>,
584 pub(crate) deliver_control_frames: bool,
585 pub(crate) engine: Option<openwire_core::websocket::SharedWebSocketEngine>,
586}
587
588#[cfg(feature = "websocket")]
589impl<'a> WebSocketCall<'a> {
590 pub fn handshake_timeout(mut self, value: Duration) -> Self {
591 self.handshake_timeout = Some(value);
592 self
593 }
594
595 pub fn close_timeout(mut self, value: Duration) -> Self {
596 self.close_timeout = Some(value);
597 self
598 }
599
600 pub fn max_frame_size(mut self, value: usize) -> Self {
601 self.max_frame_size = Some(value);
602 self
603 }
604
605 pub fn max_message_size(mut self, value: usize) -> Self {
606 self.max_message_size = Some(value);
607 self
608 }
609
610 pub fn send_queue_size(mut self, value: usize) -> Self {
611 self.send_queue_size = Some(value);
612 self
613 }
614
615 pub fn ping_interval(mut self, value: Duration) -> Self {
616 self.ping_interval = Some(value);
617 self
618 }
619
620 pub fn pong_timeout(mut self, value: Duration) -> Self {
621 self.pong_timeout = Some(value);
622 self
623 }
624
625 pub fn subprotocols<I, S>(mut self, protocols: I) -> Self
626 where
627 I: IntoIterator<Item = S>,
628 S: Into<String>,
629 {
630 self.subprotocols = protocols.into_iter().map(Into::into).collect();
631 self
632 }
633
634 pub fn deliver_control_frames(mut self, on: bool) -> Self {
635 self.deliver_control_frames = on;
636 self
637 }
638
639 pub fn engine(mut self, engine: openwire_core::websocket::SharedWebSocketEngine) -> Self {
640 self.engine = Some(engine);
641 self
642 }
643
644 pub async fn execute(
645 self,
646 ) -> Result<crate::websocket::WebSocket, openwire_core::websocket::WebSocketError> {
647 crate::websocket::transport::execute(self).await
648 }
649}
650
651impl Drop for ClientInner {
652 fn drop(&mut self) {
653 self.pool_reaper.abort();
654 }
655}
656
657impl PoolReaperController {
658 fn ensure_started(
659 &self,
660 executor: Arc<dyn WireExecutor>,
661 timer: SharedTimer,
662 pool: Weak<ConnectionPool>,
663 ) {
664 let mut state = lock_mutex(&self.state);
665 if state.handle.is_some() {
666 return;
667 }
668
669 let Some(pool) = pool.upgrade() else {
670 return;
671 };
672
673 match spawn_pool_reaper(executor, timer, &pool) {
674 Ok(handle) => state.handle = handle,
675 Err(error) => tracing::warn!(%error, "failed to start pool reaper task"),
676 }
677 }
678
679 fn abort(&self) {
680 if let Some(handle) = lock_mutex(&self.state).handle.take() {
681 handle.abort();
682 }
683 }
684}
685
686impl CallState {
687 fn cancel(&self) {
688 self.canceled.store(true, Ordering::Release);
689 self.waker.wake();
690 }
691
692 fn is_canceled(&self) -> bool {
693 self.canceled.load(Ordering::Acquire)
694 }
695
696 fn is_executed(&self) -> bool {
697 self.executed.load(Ordering::Acquire)
698 }
699
700 fn poll_canceled(&self, cx: &mut std::task::Context<'_>) -> bool {
701 if self.is_canceled() {
702 return true;
703 }
704
705 self.waker.register(cx.waker());
706 self.is_canceled()
707 }
708}
709
710impl Call {
711 pub fn handle(&self) -> CallHandle {
714 CallHandle {
715 state: self.state.clone(),
716 }
717 }
718
719 pub fn cancel(&self) {
721 self.state.cancel();
722 }
723
724 pub fn is_canceled(&self) -> bool {
725 self.state.is_canceled()
726 }
727
728 pub fn is_executed(&self) -> bool {
729 self.state.is_executed()
730 }
731
732 pub fn try_clone(&self) -> Option<Self> {
735 Some(Self {
736 client: self.client.clone(),
737 request: clone_request(&self.request)?,
738 options: self.options,
739 state: Arc::new(CallState::default()),
740 })
741 }
742
743 pub fn options(mut self, options: CallOptions) -> Self {
744 self.options.apply(options);
745 self
746 }
747
748 pub fn call_timeout(mut self, timeout: Duration) -> Self {
749 self.options.call_timeout = Some(timeout);
750 self
751 }
752
753 pub fn connect_timeout(mut self, timeout: Duration) -> Self {
754 self.options.connect_timeout = Some(timeout);
755 self
756 }
757
758 pub fn follow_redirects(mut self, enabled: bool) -> Self {
759 self.options.follow_redirects = Some(enabled);
760 self
761 }
762
763 pub fn max_redirects(mut self, max_redirects: usize) -> Self {
764 self.options.max_redirects = Some(max_redirects);
765 self
766 }
767
768 pub fn retry_on_connection_failure(mut self, enabled: bool) -> Self {
769 self.options.retry_on_connection_failure = Some(enabled);
770 self
771 }
772
773 pub fn max_retries(mut self, max_retries: usize) -> Self {
774 self.options.max_retries = Some(max_retries);
775 self
776 }
777
778 pub fn retry_canceled_requests(mut self, enabled: bool) -> Self {
779 self.options.retry_canceled_requests = Some(enabled);
780 self
781 }
782
783 pub fn allow_insecure_redirects(mut self, enabled: bool) -> Self {
784 self.options.allow_insecure_redirects = Some(enabled);
785 self
786 }
787
788 pub fn enqueue(self) -> Result<QueuedCall, WireError> {
791 self.mark_executed()?;
792 let handle = self.handle();
793 let executor = self.client.inner.executor.clone();
794 let (sender, receiver) = oneshot::channel();
795 let task = executor.spawn(Box::pin(async move {
796 let result = self.execute_marked().await;
797 let _ = sender.send(result);
798 }))?;
799
800 Ok(QueuedCall {
801 handle,
802 receiver,
803 _task: task,
804 })
805 }
806
807 pub async fn execute(self) -> Result<Response<ResponseBody>, WireError> {
808 self.mark_executed()?;
809 self.execute_marked().await
810 }
811
812 fn mark_executed(&self) -> Result<(), WireError> {
813 if self
814 .state
815 .executed
816 .compare_exchange(false, true, Ordering::AcqRel, Ordering::Acquire)
817 .is_ok()
818 {
819 Ok(())
820 } else {
821 Err(WireError::invalid_request("call has already been executed"))
822 }
823 }
824
825 async fn execute_marked(mut self) -> Result<Response<ResponseBody>, WireError> {
826 let request_config = self
827 .client
828 .inner
829 .request_config
830 .with_overrides(self.options);
831 self.request.extensions_mut().insert(request_config);
832 self.request.extensions_mut().insert(self.options);
833 let ctx = CallContext::from_factory(
834 &self.client.inner.event_listener_factory,
835 &self.request,
836 request_config.call_timeout,
837 );
838
839 let span = tracing::info_span!(
840 "openwire.call",
841 call_id = ctx.call_id().as_u64(),
842 method = %self.request.method(),
843 uri = %self.request.uri(),
844 );
845
846 async move {
847 ctx.listener().call_start(&ctx, &self.request);
848
849 let mut service = self.client.inner.service.clone();
850 let execute_ctx = ctx.clone();
851 let execute = async move {
852 tower::ServiceExt::ready(&mut service)
853 .await
854 .map_err(|error| WireError::internal("service chain not ready", error))?;
855 service
856 .call(Exchange::new(self.request, execute_ctx, 1))
857 .await
858 };
859
860 let execute =
861 with_call_deadline(self.client.inner.timer.clone(), ctx.deadline(), execute);
862 let result = with_call_cancellation(self.state.clone(), execute).await;
863
864 match result {
865 Ok(response) => Ok(attach_call_lifecycle(
866 response,
867 ctx.clone(),
868 self.state.clone(),
869 )),
870 Err(error) => {
871 ctx.listener().call_failed(&ctx, &error);
872 Err(error)
873 }
874 }
875 }
876 .instrument(span)
877 .with_current_subscriber()
878 .await
879 }
880}
881
882impl CallHandle {
883 pub fn cancel(&self) {
884 self.state.cancel();
885 }
886
887 pub fn is_canceled(&self) -> bool {
888 self.state.is_canceled()
889 }
890
891 pub fn is_executed(&self) -> bool {
892 self.state.is_executed()
893 }
894}
895
896impl QueuedCall {
897 pub fn handle(&self) -> CallHandle {
898 self.handle.clone()
899 }
900
901 pub fn cancel(&self) {
902 self.handle.cancel();
903 }
904
905 pub fn is_canceled(&self) -> bool {
906 self.handle.is_canceled()
907 }
908
909 pub fn is_executed(&self) -> bool {
910 self.handle.is_executed()
911 }
912
913 pub async fn await_response(self) -> Result<Response<ResponseBody>, WireError> {
914 match self.receiver.await {
915 Ok(result) => result,
916 Err(_closed) => Err(WireError::canceled(
917 "queued call ended before producing a response",
918 )),
919 }
920 }
921}
922
923pub(crate) fn attach_request_admission(
924 response: Response<ResponseBody>,
925 permit: RequestAdmissionPermit,
926) -> Response<ResponseBody> {
927 let (parts, body) = response.into_parts();
928 Response::from_parts(
929 parts,
930 ResponseBody::new(
931 RequestAdmissionBody {
932 inner: body,
933 _permit: Some(permit),
934 }
935 .boxed(),
936 ),
937 )
938}
939
940fn attach_call_lifecycle(
941 response: Response<ResponseBody>,
942 ctx: CallContext,
943 state: Arc<CallState>,
944) -> Response<ResponseBody> {
945 let (parts, body) = response.into_parts();
946 Response::from_parts(
947 parts,
948 ResponseBody::new(CallLifecycleBody::new(body, ctx, state).boxed()),
949 )
950}
951
952fn build_service_chain(
953 transport: TransportService,
954 application_interceptors: Vec<SharedInterceptor>,
955 network_interceptors: Vec<SharedInterceptor>,
956 policy: PolicyConfig,
957) -> BoxWireService {
958 let mut network: BoxWireService = BoxCloneSyncService::new(transport);
959 for interceptor in network_interceptors.iter().rev() {
960 network =
961 BoxCloneSyncService::new(InterceptorLayer::new(interceptor.clone()).layer(network));
962 }
963 network = BoxCloneSyncService::new(
964 InterceptorLayer::new(Arc::new(BridgeInterceptor) as SharedInterceptor).layer(network),
965 );
966
967 let mut service: BoxWireService =
968 BoxCloneSyncService::new(FollowUpPolicyService::new(network, policy));
969 for interceptor in application_interceptors.iter().rev() {
970 service =
971 BoxCloneSyncService::new(InterceptorLayer::new(interceptor.clone()).layer(service));
972 }
973
974 service
975}
976
977impl CallOptions {
978 pub fn new() -> Self {
979 Self::default()
980 }
981
982 pub fn call_timeout(mut self, timeout: Duration) -> Self {
983 self.call_timeout = Some(timeout);
984 self
985 }
986
987 pub fn connect_timeout(mut self, timeout: Duration) -> Self {
988 self.connect_timeout = Some(timeout);
989 self
990 }
991
992 pub fn follow_redirects(mut self, enabled: bool) -> Self {
993 self.follow_redirects = Some(enabled);
994 self
995 }
996
997 pub fn max_redirects(mut self, max_redirects: usize) -> Self {
998 self.max_redirects = Some(max_redirects);
999 self
1000 }
1001
1002 pub fn retry_on_connection_failure(mut self, enabled: bool) -> Self {
1003 self.retry_on_connection_failure = Some(enabled);
1004 self
1005 }
1006
1007 pub fn max_retries(mut self, max_retries: usize) -> Self {
1008 self.max_retries = Some(max_retries);
1009 self
1010 }
1011
1012 pub fn retry_canceled_requests(mut self, enabled: bool) -> Self {
1013 self.retry_canceled_requests = Some(enabled);
1014 self
1015 }
1016
1017 pub fn allow_insecure_redirects(mut self, enabled: bool) -> Self {
1018 self.allow_insecure_redirects = Some(enabled);
1019 self
1020 }
1021
1022 pub(crate) fn has_retry_overrides(self) -> bool {
1023 self.retry_on_connection_failure.is_some()
1024 || self.max_retries.is_some()
1025 || self.retry_canceled_requests.is_some()
1026 }
1027
1028 pub(crate) fn has_redirect_overrides(self) -> bool {
1029 self.follow_redirects.is_some()
1030 || self.max_redirects.is_some()
1031 || self.allow_insecure_redirects.is_some()
1032 }
1033
1034 fn apply(&mut self, other: Self) {
1035 self.call_timeout = other.call_timeout.or(self.call_timeout);
1036 self.connect_timeout = other.connect_timeout.or(self.connect_timeout);
1037 self.follow_redirects = other.follow_redirects.or(self.follow_redirects);
1038 self.max_redirects = other.max_redirects.or(self.max_redirects);
1039 self.retry_on_connection_failure = other
1040 .retry_on_connection_failure
1041 .or(self.retry_on_connection_failure);
1042 self.max_retries = other.max_retries.or(self.max_retries);
1043 self.retry_canceled_requests = other
1044 .retry_canceled_requests
1045 .or(self.retry_canceled_requests);
1046 self.allow_insecure_redirects = other
1047 .allow_insecure_redirects
1048 .or(self.allow_insecure_redirects);
1049 }
1050}
1051
1052impl EffectiveRequestConfig {
1053 fn from_defaults(
1054 call_timeout: Option<Duration>,
1055 connect_timeout: Option<Duration>,
1056 policy: &PolicyConfig,
1057 ) -> Self {
1058 let retry = policy.retry.default_config();
1059 let redirect = policy.redirect.default_config();
1060 Self {
1061 call_timeout,
1062 connect_timeout,
1063 follow_redirects: redirect.follow_redirects(),
1064 max_redirects: redirect.max_redirects(),
1065 retry_on_connection_failure: retry.retry_on_connection_failure(),
1066 max_retries: retry.max_retries(),
1067 retry_canceled_requests: retry.retry_canceled_requests(),
1068 allow_insecure_redirects: redirect.allow_insecure_redirects(),
1069 }
1070 }
1071
1072 fn with_overrides(self, options: CallOptions) -> Self {
1073 Self {
1074 call_timeout: options.call_timeout.or(self.call_timeout),
1075 connect_timeout: options.connect_timeout.or(self.connect_timeout),
1076 follow_redirects: options.follow_redirects.unwrap_or(self.follow_redirects),
1077 max_redirects: options.max_redirects.unwrap_or(self.max_redirects),
1078 retry_on_connection_failure: options
1079 .retry_on_connection_failure
1080 .unwrap_or(self.retry_on_connection_failure),
1081 max_retries: options.max_retries.unwrap_or(self.max_retries),
1082 retry_canceled_requests: options
1083 .retry_canceled_requests
1084 .unwrap_or(self.retry_canceled_requests),
1085 allow_insecure_redirects: options
1086 .allow_insecure_redirects
1087 .unwrap_or(self.allow_insecure_redirects),
1088 }
1089 }
1090}
1091
1092pub(crate) fn cache_request_addresses(
1093 request: &mut Request<RequestBody>,
1094 proxy_selector: &dyn ProxySelector,
1095) -> Result<Arc<[ResolvedAddress]>, WireError> {
1096 let previous_selected_proxy = request.extensions().get::<SelectedProxy>().cloned();
1097 let candidates = resolved_proxy_candidates_with_sticky(
1098 proxy_selector.select(request.uri())?,
1099 previous_selected_proxy.as_ref(),
1100 );
1101 clear_proxy_authorization_if_proxy_dropped_from_candidates(request, &candidates);
1102
1103 let mut addresses = Vec::new();
1104 for candidate in candidates {
1105 let resolved = ResolvedAddress::new(
1106 Address::from_uri(request.uri(), candidate.as_ref())?,
1107 candidate,
1108 );
1109 if !addresses.iter().any(|existing: &ResolvedAddress| {
1110 existing.address() == resolved.address()
1111 && existing.selected_proxy() == resolved.selected_proxy()
1112 }) {
1113 addresses.push(resolved);
1114 }
1115 }
1116 let addresses = Arc::<[ResolvedAddress]>::from(addresses);
1117 let extensions = request.extensions_mut();
1118 extensions.insert(CachedAddresses(addresses.clone()));
1119 Ok(addresses)
1120}
1121
1122pub(crate) fn clear_proxy_authorization_if_proxy_changed(
1123 request: &mut Request<RequestBody>,
1124 selected_proxy: Option<&SelectedProxy>,
1125) {
1126 let previous_selected_proxy = request.extensions().get::<SelectedProxy>().cloned();
1127 if previous_selected_proxy.is_some()
1128 && !same_selected_proxy_endpoint(previous_selected_proxy.as_ref(), selected_proxy)
1129 {
1130 request.headers_mut().remove(PROXY_AUTHORIZATION);
1131 }
1132}
1133
1134fn clear_proxy_authorization_if_proxy_dropped_from_candidates(
1135 request: &mut Request<RequestBody>,
1136 candidates: &[Option<SelectedProxy>],
1137) {
1138 let previous_selected_proxy = request.extensions().get::<SelectedProxy>().cloned();
1139 if previous_selected_proxy.is_some()
1140 && !candidates.iter().any(|candidate| {
1141 same_selected_proxy_endpoint(candidate.as_ref(), previous_selected_proxy.as_ref())
1142 })
1143 {
1144 request.headers_mut().remove(PROXY_AUTHORIZATION);
1145 }
1146}
1147
1148fn same_selected_proxy_endpoint(
1149 left: Option<&SelectedProxy>,
1150 right: Option<&SelectedProxy>,
1151) -> bool {
1152 match (left, right) {
1153 (Some(left), Some(right)) => left.same_endpoint(right),
1154 (None, None) => true,
1155 _ => false,
1156 }
1157}
1158
1159async fn with_call_deadline<F>(
1160 timer: SharedTimer,
1161 deadline: Option<std::time::Instant>,
1162 future: F,
1163) -> Result<Response<ResponseBody>, WireError>
1164where
1165 F: std::future::Future<Output = Result<Response<ResponseBody>, WireError>>,
1166{
1167 let Some(deadline) = deadline else {
1168 return future.await;
1169 };
1170
1171 let timeout = deadline.saturating_duration_since(std::time::Instant::now());
1172 let future = Box::pin(future);
1173 let sleep = timer.sleep(timeout);
1174
1175 match select(future, sleep).await {
1176 Either::Left((result, _sleep)) => result,
1177 Either::Right((_ready, _future)) => Err(WireError::timeout(format!(
1178 "call timed out after {timeout:?}"
1179 ))),
1180 }
1181}
1182
1183async fn with_call_cancellation<F>(
1184 state: Arc<CallState>,
1185 future: F,
1186) -> Result<Response<ResponseBody>, WireError>
1187where
1188 F: std::future::Future<Output = Result<Response<ResponseBody>, WireError>>,
1189{
1190 if state.is_canceled() {
1191 return Err(call_canceled_error());
1192 }
1193
1194 match select(Box::pin(future), Box::pin(CallCanceled { state })).await {
1195 Either::Left((result, _canceled)) => result,
1196 Either::Right((_ready, _future)) => Err(call_canceled_error()),
1197 }
1198}
1199
1200fn clone_request(request: &Request<RequestBody>) -> Option<Request<RequestBody>> {
1201 let mut cloned = Request::builder()
1202 .method(request.method().clone())
1203 .uri(request.uri().clone())
1204 .version(request.version())
1205 .body(request.body().try_clone()?)
1206 .ok()?;
1207 *cloned.headers_mut() = request.headers().clone();
1208 *cloned.extensions_mut() = request.extensions().clone();
1209 Some(cloned)
1210}
1211
1212fn call_canceled_error() -> WireError {
1213 WireError::canceled("call canceled")
1214}
1215
1216struct CallCanceled {
1217 state: Arc<CallState>,
1218}
1219
1220impl std::future::Future for CallCanceled {
1221 type Output = ();
1222
1223 fn poll(
1224 self: std::pin::Pin<&mut Self>,
1225 cx: &mut std::task::Context<'_>,
1226 ) -> std::task::Poll<Self::Output> {
1227 if self.state.poll_canceled(cx) {
1228 std::task::Poll::Ready(())
1229 } else {
1230 std::task::Poll::Pending
1231 }
1232 }
1233}
1234
1235fn spawn_pool_reaper(
1236 executor: Arc<dyn WireExecutor>,
1237 timer: SharedTimer,
1238 pool: &Arc<ConnectionPool>,
1239) -> Result<Option<BoxTaskHandle>, WireError> {
1240 let Some(idle_timeout) = pool.settings().idle_timeout else {
1241 return Ok(None);
1242 };
1243
1244 let cadence = pool_reaper_cadence(idle_timeout);
1245 let weak_pool = Arc::downgrade(pool);
1246 executor
1247 .spawn(Box::pin(async move {
1248 loop {
1249 timer.sleep(cadence).await;
1250 let Some(pool) = weak_pool.upgrade() else {
1251 break;
1252 };
1253 pool.prune_all();
1254 }
1255 }))
1256 .map(Some)
1257}
1258
1259fn pool_reaper_cadence(idle_timeout: Duration) -> Duration {
1260 (idle_timeout / 2).clamp(MIN_POOL_REAPER_CADENCE, MAX_POOL_REAPER_CADENCE)
1261}
1262
1263struct CallLifecycleBody {
1264 inner: Option<ResponseBody>,
1265 ctx: CallContext,
1266 state: Arc<CallState>,
1267 finished: bool,
1268}
1269
1270impl CallLifecycleBody {
1271 fn new(inner: ResponseBody, ctx: CallContext, state: Arc<CallState>) -> Self {
1272 Self {
1273 inner: Some(inner),
1274 ctx,
1275 state,
1276 finished: false,
1277 }
1278 }
1279
1280 fn finish_successfully(&mut self) {
1281 if self.finished {
1282 return;
1283 }
1284 self.finished = true;
1285 let _ = self.inner.take();
1286 self.ctx.listener().call_end(&self.ctx);
1287 }
1288
1289 fn finish_with_error(&mut self, error: &WireError) {
1290 if self.finished {
1291 return;
1292 }
1293 self.finished = true;
1294 let _ = self.inner.take();
1295 self.ctx.listener().call_failed(&self.ctx, error);
1296 }
1297
1298 fn finish_abandoned(&mut self) {
1299 if self.finished {
1300 return;
1301 }
1302 self.finished = true;
1303 drop(self.inner.take());
1304 if self.state.is_canceled() {
1305 self.ctx
1306 .listener()
1307 .call_failed(&self.ctx, &call_canceled_error());
1308 } else {
1309 self.ctx.listener().call_end(&self.ctx);
1310 }
1311 }
1312}
1313
1314impl Drop for CallLifecycleBody {
1315 fn drop(&mut self) {
1316 self.finish_abandoned();
1317 }
1318}
1319
1320impl Body for CallLifecycleBody {
1321 type Data = Bytes;
1322 type Error = WireError;
1323
1324 fn poll_frame(
1325 self: std::pin::Pin<&mut Self>,
1326 cx: &mut std::task::Context<'_>,
1327 ) -> std::task::Poll<Option<Result<Frame<Self::Data>, Self::Error>>> {
1328 let this = self.get_mut();
1329 if this.finished {
1330 return std::task::Poll::Ready(None);
1331 }
1332 if this.state.poll_canceled(cx) {
1333 let error = call_canceled_error();
1334 this.finish_with_error(&error);
1335 return std::task::Poll::Ready(Some(Err(error)));
1336 }
1337 let Some(inner) = this.inner.as_mut() else {
1338 return std::task::Poll::Ready(None);
1339 };
1340
1341 match std::pin::Pin::new(inner).poll_frame(cx) {
1342 std::task::Poll::Ready(Some(Ok(frame))) => std::task::Poll::Ready(Some(Ok(frame))),
1343 std::task::Poll::Ready(Some(Err(error))) => {
1344 this.finish_with_error(&error);
1345 std::task::Poll::Ready(Some(Err(error)))
1346 }
1347 std::task::Poll::Ready(None) => {
1348 this.finish_successfully();
1349 std::task::Poll::Ready(None)
1350 }
1351 std::task::Poll::Pending => std::task::Poll::Pending,
1352 }
1353 }
1354
1355 fn is_end_stream(&self) -> bool {
1356 match self.inner.as_ref() {
1357 Some(inner) => inner.is_end_stream(),
1358 None => true,
1359 }
1360 }
1361
1362 fn size_hint(&self) -> SizeHint {
1363 self.inner
1364 .as_ref()
1365 .map_or_else(SizeHint::default, http_body::Body::size_hint)
1366 }
1367}
1368
1369pin_project! {
1370 struct RequestAdmissionBody {
1371 #[pin]
1372 inner: ResponseBody,
1373 _permit: Option<RequestAdmissionPermit>,
1374 }
1375}
1376
1377impl Body for RequestAdmissionBody {
1378 type Data = Bytes;
1379 type Error = WireError;
1380
1381 fn poll_frame(
1382 self: std::pin::Pin<&mut Self>,
1383 cx: &mut std::task::Context<'_>,
1384 ) -> std::task::Poll<Option<Result<Frame<Self::Data>, Self::Error>>> {
1385 self.project().inner.poll_frame(cx)
1386 }
1387
1388 fn is_end_stream(&self) -> bool {
1389 self.inner.is_end_stream()
1390 }
1391
1392 fn size_hint(&self) -> SizeHint {
1393 self.inner.size_hint()
1394 }
1395}
1396
1397#[cfg(test)]
1398mod tests {
1399 use std::sync::atomic::{AtomicUsize, Ordering};
1400 use std::sync::Arc;
1401 use std::time::Duration;
1402
1403 use bytes::Bytes;
1404 use futures_util::stream;
1405 use http::header::PROXY_AUTHORIZATION;
1406 use http::Request;
1407 use openwire_core::{BoxFuture, RequestBody, TaskHandle, WireError};
1408
1409 use super::{
1410 cache_request_addresses, pool_reaper_cadence, spawn_pool_reaper, CallOptions,
1411 ClientBuilder, ConnectionPool, EffectiveRequestConfig, PoolReaperController, PoolSettings,
1412 };
1413 use crate::connection::CachedAddresses;
1414 use crate::proxy::{Proxy, ProxyRules, ProxySelection, ProxySelector, SelectedProxy};
1415
1416 #[derive(Clone)]
1417 struct StaticProxySelector(ProxySelection);
1418
1419 impl ProxySelector for StaticProxySelector {
1420 fn select(&self, _uri: &http::Uri) -> Result<ProxySelection, WireError> {
1421 Ok(self.0.clone())
1422 }
1423 }
1424 struct CountingTaskHandle {
1425 aborts: Arc<AtomicUsize>,
1426 }
1427
1428 impl TaskHandle for CountingTaskHandle {
1429 fn abort(&self) {
1430 self.aborts.fetch_add(1, Ordering::Relaxed);
1431 }
1432 }
1433
1434 #[derive(Clone, Default)]
1435 struct CountingExecutor {
1436 spawns: Arc<AtomicUsize>,
1437 aborts: Arc<AtomicUsize>,
1438 }
1439
1440 impl CountingExecutor {
1441 fn spawns(&self) -> usize {
1442 self.spawns.load(Ordering::Relaxed)
1443 }
1444
1445 fn aborts(&self) -> usize {
1446 self.aborts.load(Ordering::Relaxed)
1447 }
1448 }
1449
1450 impl openwire_core::WireExecutor for CountingExecutor {
1451 fn spawn(
1452 &self,
1453 _future: BoxFuture<()>,
1454 ) -> Result<openwire_core::BoxTaskHandle, openwire_core::WireError> {
1455 self.spawns.fetch_add(1, Ordering::Relaxed);
1456 Ok(Box::new(CountingTaskHandle {
1457 aborts: self.aborts.clone(),
1458 }))
1459 }
1460 }
1461
1462 #[derive(Clone, Default)]
1463 struct FailOnceExecutor {
1464 spawns: Arc<AtomicUsize>,
1465 aborts: Arc<AtomicUsize>,
1466 }
1467
1468 impl FailOnceExecutor {
1469 fn spawns(&self) -> usize {
1470 self.spawns.load(Ordering::Relaxed)
1471 }
1472
1473 fn aborts(&self) -> usize {
1474 self.aborts.load(Ordering::Relaxed)
1475 }
1476 }
1477
1478 impl openwire_core::WireExecutor for FailOnceExecutor {
1479 fn spawn(&self, _future: BoxFuture<()>) -> Result<openwire_core::BoxTaskHandle, WireError> {
1480 let attempt = self.spawns.fetch_add(1, Ordering::Relaxed);
1481 if attempt == 0 {
1482 return Err(WireError::internal(
1483 "scripted spawn failure",
1484 std::io::Error::other("scripted spawn failure"),
1485 ));
1486 }
1487
1488 Ok(Box::new(CountingTaskHandle {
1489 aborts: self.aborts.clone(),
1490 }))
1491 }
1492 }
1493
1494 #[test]
1495 fn pool_reaper_cadence_is_clamped() {
1496 assert_eq!(
1497 pool_reaper_cadence(Duration::from_secs(2)),
1498 Duration::from_secs(5)
1499 );
1500 assert_eq!(
1501 pool_reaper_cadence(Duration::from_secs(20)),
1502 Duration::from_secs(10)
1503 );
1504 assert_eq!(
1505 pool_reaper_cadence(Duration::from_secs(180)),
1506 Duration::from_secs(60)
1507 );
1508 }
1509
1510 #[test]
1511 fn spawn_pool_reaper_skips_when_idle_timeout_is_disabled() {
1512 let executor = CountingExecutor::default();
1513 let timer = openwire_core::SharedTimer::new(openwire_tokio::TokioTimer::new());
1514 let pool = Arc::new(ConnectionPool::new(PoolSettings {
1515 idle_timeout: None,
1516 max_idle_per_address: usize::MAX,
1517 }));
1518
1519 let handle =
1520 spawn_pool_reaper(Arc::new(executor.clone()), timer, &pool).expect("spawn reaper");
1521
1522 assert!(handle.is_none());
1523 assert_eq!(executor.spawns(), 0);
1524 assert_eq!(executor.aborts(), 0);
1525 }
1526
1527 #[test]
1528 fn dropping_final_client_aborts_pool_reaper_task() {
1529 let executor = CountingExecutor::default();
1530 let timer = openwire_core::SharedTimer::new(openwire_tokio::TokioTimer::new());
1531 let pool = Arc::new(ConnectionPool::new(PoolSettings::default()));
1532 let reaper = PoolReaperController::default();
1533
1534 reaper.ensure_started(Arc::new(executor.clone()), timer, Arc::downgrade(&pool));
1535 assert_eq!(executor.spawns(), 1);
1536 assert_eq!(executor.aborts(), 0);
1537
1538 reaper.abort();
1539 assert_eq!(executor.aborts(), 1);
1540 }
1541
1542 #[test]
1543 fn cache_request_addresses_inserts_cached_extension() {
1544 let mut request = Request::builder()
1545 .uri("http://example.com/resource")
1546 .body(RequestBody::empty())
1547 .expect("request");
1548
1549 let addresses =
1550 cache_request_addresses(&mut request, &ProxyRules::new()).expect("addresses");
1551
1552 assert_eq!(
1553 request
1554 .extensions()
1555 .get::<CachedAddresses>()
1556 .map(|cached| cached.0.clone()),
1557 Some(addresses)
1558 );
1559 }
1560
1561 #[test]
1562 fn cache_request_addresses_prioritize_previously_selected_proxy_when_candidates_still_include_it(
1563 ) {
1564 let fallback = Proxy::http("http://first.test:8080").expect("fallback proxy");
1565 let sticky = Proxy::http("http://sticky.test:8080").expect("sticky proxy");
1566 let mut request = Request::builder()
1567 .uri("http://example.com/resource")
1568 .header(PROXY_AUTHORIZATION, "Basic cHJveHk6b2xk")
1569 .body(RequestBody::empty())
1570 .expect("request");
1571 request
1572 .extensions_mut()
1573 .insert(SelectedProxy::from_proxy(&sticky));
1574
1575 let addresses = cache_request_addresses(
1576 &mut request,
1577 &StaticProxySelector(
1578 ProxySelection::direct()
1579 .push_proxy(fallback.clone())
1580 .push_proxy(sticky.clone()),
1581 ),
1582 )
1583 .expect("addresses");
1584
1585 assert_eq!(
1586 request
1587 .headers()
1588 .get(PROXY_AUTHORIZATION)
1589 .and_then(|value| value.to_str().ok()),
1590 Some("Basic cHJveHk6b2xk")
1591 );
1592 assert_eq!(
1593 addresses
1594 .first()
1595 .and_then(|candidate| candidate.selected_proxy()),
1596 Some(&SelectedProxy::from_proxy(&sticky))
1597 );
1598 assert_eq!(
1599 addresses
1600 .get(1)
1601 .and_then(|candidate| candidate.selected_proxy()),
1602 None
1603 );
1604 assert_eq!(
1605 addresses
1606 .get(2)
1607 .and_then(|candidate| candidate.selected_proxy()),
1608 Some(&SelectedProxy::from_proxy(&fallback))
1609 );
1610 }
1611
1612 #[test]
1613 fn cache_request_addresses_matches_sticky_proxy_by_endpoint_across_scheme_specific_selectors() {
1614 let previous_proxy = Proxy::http("http://proxy.test:8080").expect("previous proxy");
1615 let current_proxy = Proxy::https("http://proxy.test:8080").expect("current proxy");
1616 let mut request = Request::builder()
1617 .uri("https://example.com/resource")
1618 .header(PROXY_AUTHORIZATION, "Basic cHJveHk6b2xk")
1619 .body(RequestBody::empty())
1620 .expect("request");
1621 request
1622 .extensions_mut()
1623 .insert(SelectedProxy::from_proxy(&previous_proxy));
1624
1625 let addresses = cache_request_addresses(
1626 &mut request,
1627 &StaticProxySelector(ProxySelection::direct().push_proxy(current_proxy.clone())),
1628 )
1629 .expect("addresses");
1630
1631 assert_eq!(
1632 request
1633 .headers()
1634 .get(PROXY_AUTHORIZATION)
1635 .and_then(|value| value.to_str().ok()),
1636 Some("Basic cHJveHk6b2xk")
1637 );
1638 assert_eq!(
1639 addresses
1640 .first()
1641 .and_then(|candidate| candidate.selected_proxy()),
1642 Some(&SelectedProxy::from_proxy(¤t_proxy))
1643 );
1644 }
1645
1646 #[test]
1647 fn cache_request_addresses_clear_proxy_authorization_when_current_candidates_drop_proxy() {
1648 let proxy = Proxy::http("http://first.test:8080").expect("proxy");
1649 let mut request = Request::builder()
1650 .uri("http://example.com/resource")
1651 .header(PROXY_AUTHORIZATION, "Basic cHJveHk6b2xk")
1652 .body(RequestBody::empty())
1653 .expect("request");
1654 request
1655 .extensions_mut()
1656 .insert(SelectedProxy::from_proxy(&proxy));
1657
1658 let addresses =
1659 cache_request_addresses(&mut request, &StaticProxySelector(ProxySelection::direct()))
1660 .expect("addresses");
1661
1662 assert!(request.headers().get(PROXY_AUTHORIZATION).is_none());
1663 assert_eq!(
1664 addresses
1665 .first()
1666 .and_then(|candidate| candidate.selected_proxy()),
1667 None
1668 );
1669 }
1670
1671 #[test]
1672 fn client_builder_defaults_use_bounded_pool_and_request_limits() {
1673 let builder = ClientBuilder::default();
1674
1675 assert_eq!(builder.transport.connect_timeout, None);
1676 assert_eq!(
1677 builder.transport.pool_idle_timeout,
1678 Some(Duration::from_secs(300))
1679 );
1680 assert_eq!(builder.transport.pool_max_idle_per_host, 5);
1681 assert_eq!(builder.transport.max_requests_total, 64);
1682 assert_eq!(builder.transport.max_requests_per_host, 5);
1683 }
1684
1685 #[test]
1686 fn call_options_merge_prefers_newly_supplied_overrides() {
1687 let mut options = CallOptions::new()
1688 .call_timeout(Duration::from_millis(50))
1689 .follow_redirects(true)
1690 .max_retries(1);
1691 options.apply(
1692 CallOptions::new()
1693 .call_timeout(Duration::from_millis(25))
1694 .connect_timeout(Duration::from_millis(10))
1695 .max_retries(3),
1696 );
1697
1698 assert_eq!(options.call_timeout, Some(Duration::from_millis(25)));
1699 assert_eq!(options.connect_timeout, Some(Duration::from_millis(10)));
1700 assert_eq!(options.follow_redirects, Some(true));
1701 assert_eq!(options.max_retries, Some(3));
1702 }
1703
1704 #[test]
1705 fn effective_request_config_applies_call_overrides() {
1706 let defaults = EffectiveRequestConfig {
1707 call_timeout: Some(Duration::from_secs(1)),
1708 connect_timeout: Some(Duration::from_millis(250)),
1709 follow_redirects: true,
1710 max_redirects: 10,
1711 retry_on_connection_failure: true,
1712 max_retries: 1,
1713 retry_canceled_requests: false,
1714 allow_insecure_redirects: false,
1715 };
1716
1717 let effective = defaults.with_overrides(
1718 CallOptions::new()
1719 .call_timeout(Duration::from_millis(25))
1720 .follow_redirects(false)
1721 .max_redirects(2)
1722 .retry_on_connection_failure(false)
1723 .max_retries(0)
1724 .retry_canceled_requests(true)
1725 .allow_insecure_redirects(true),
1726 );
1727
1728 assert_eq!(effective.call_timeout, Some(Duration::from_millis(25)));
1729 assert_eq!(effective.connect_timeout, Some(Duration::from_millis(250)));
1730 assert!(!effective.follow_redirects);
1731 assert_eq!(effective.max_redirects, 2);
1732 assert!(!effective.retry_on_connection_failure);
1733 assert_eq!(effective.max_retries, 0);
1734 assert!(effective.retry_canceled_requests);
1735 assert!(effective.allow_insecure_redirects);
1736 }
1737
1738 #[test]
1739 fn call_try_clone_preserves_replayable_request_and_options_with_fresh_state() {
1740 let client = crate::Client::builder().build().expect("client");
1741 let request = Request::builder()
1742 .method("POST")
1743 .uri("http://example.com/resource")
1744 .header("x-test", "yes")
1745 .body(RequestBody::from_static(b"hello"))
1746 .expect("request");
1747 let call = client
1748 .new_call(request)
1749 .call_timeout(Duration::from_secs(1))
1750 .max_retries(0);
1751 call.cancel();
1752
1753 let cloned = call.try_clone().expect("replayable clone");
1754
1755 assert!(call.is_canceled());
1756 assert!(!cloned.is_canceled());
1757 assert!(!cloned.is_executed());
1758 assert_eq!(cloned.options.call_timeout, Some(Duration::from_secs(1)));
1759 assert_eq!(cloned.options.max_retries, Some(0));
1760 assert_eq!(cloned.request.method(), http::Method::POST);
1761 assert_eq!(
1762 cloned
1763 .request
1764 .headers()
1765 .get("x-test")
1766 .and_then(|value| value.to_str().ok()),
1767 Some("yes")
1768 );
1769 }
1770
1771 #[test]
1772 fn call_try_clone_rejects_streaming_request_body() {
1773 let client = crate::Client::builder().build().expect("client");
1774 let request = Request::builder()
1775 .uri("http://example.com/stream")
1776 .body(RequestBody::from_stream(stream::empty::<
1777 Result<Bytes, WireError>,
1778 >()))
1779 .expect("request");
1780 let call = client.new_call(request);
1781
1782 assert!(call.try_clone().is_none());
1783 }
1784
1785 #[test]
1786 fn pool_reaper_retries_after_spawn_failure() {
1787 let executor = FailOnceExecutor::default();
1788 let timer = openwire_core::SharedTimer::new(openwire_tokio::TokioTimer::new());
1789 let pool = Arc::new(ConnectionPool::new(PoolSettings::default()));
1790 let reaper = PoolReaperController::default();
1791
1792 reaper.ensure_started(
1793 Arc::new(executor.clone()),
1794 timer.clone(),
1795 Arc::downgrade(&pool),
1796 );
1797 assert_eq!(executor.spawns(), 1);
1798 assert_eq!(executor.aborts(), 0);
1799
1800 reaper.ensure_started(Arc::new(executor.clone()), timer, Arc::downgrade(&pool));
1801 assert_eq!(executor.spawns(), 2);
1802
1803 reaper.ensure_started(
1804 Arc::new(executor.clone()),
1805 openwire_core::SharedTimer::new(openwire_tokio::TokioTimer::new()),
1806 Arc::downgrade(&pool),
1807 );
1808 assert_eq!(executor.spawns(), 2);
1809
1810 reaper.abort();
1811 assert_eq!(executor.aborts(), 1);
1812 }
1813}