Skip to main content

kube_client/client/
builder.rs

1use bytes::Bytes;
2use http::{Request, Response, header::HeaderMap};
3use hyper::{
4    body::Incoming,
5    rt::{Read, Write},
6};
7use hyper_timeout::TimeoutConnector;
8
9use hyper_util::{
10    client::legacy::connect::{Connection, HttpConnector},
11    rt::TokioExecutor,
12};
13
14use jiff::Timestamp;
15use std::time::Duration;
16use tower::{BoxError, Layer, Service, ServiceBuilder, ServiceExt as _, util::BoxService};
17use tower_http::{ServiceExt as _, classify::ServerErrorsFailureClass, trace::TraceLayer};
18use tracing::Span;
19
20use super::body::Body;
21use crate::{Client, Config, Error, Result, client::ConfigExt};
22
23/// HTTP body of a dynamic backing type.
24///
25/// The suggested implementation type is [`crate::client::Body`].
26pub type DynBody = dyn http_body::Body<Data = Bytes, Error = BoxError> + Send + Unpin;
27
28/// Builder for [`Client`] instances with customized [tower](`Service`) middleware.
29pub struct ClientBuilder<Svc> {
30    service: Svc,
31    default_ns: String,
32    valid_until: Option<Timestamp>,
33}
34
35impl<Svc> ClientBuilder<Svc> {
36    /// Construct a [`ClientBuilder`] from scratch with a fully custom [`Service`] stack.
37    ///
38    /// This method is only intended for advanced use cases, most users will want to use [`ClientBuilder::try_from`] instead,
39    /// which provides a default stack as a starting point.
40    pub fn new(service: Svc, default_namespace: impl Into<String>) -> Self
41    where
42        Svc: Service<Request<Body>>,
43    {
44        Self {
45            service,
46            default_ns: default_namespace.into(),
47            valid_until: None,
48        }
49    }
50
51    /// Add a [`Layer`] to the current [`Service`] stack.
52    pub fn with_layer<L: Layer<Svc>>(self, layer: &L) -> ClientBuilder<L::Service> {
53        let Self {
54            service: stack,
55            default_ns,
56            valid_until,
57        } = self;
58        ClientBuilder {
59            service: layer.layer(stack),
60            default_ns,
61            valid_until,
62        }
63    }
64
65    /// Sets an expiration timestamp for the client.
66    pub fn with_valid_until(self, valid_until: Option<Timestamp>) -> Self {
67        ClientBuilder {
68            service: self.service,
69            default_ns: self.default_ns,
70            valid_until,
71        }
72    }
73
74    /// Build a [`Client`] instance with the current [`Service`] stack.
75    pub fn build<B>(self) -> Client
76    where
77        Svc: Service<Request<Body>, Response = Response<B>> + Send + 'static,
78        Svc::Future: Send + 'static,
79        Svc::Error: Into<BoxError>,
80        B: http_body::Body<Data = bytes::Bytes> + Send + 'static,
81        B::Error: Into<BoxError>,
82    {
83        Client::new(self.service, self.default_ns).with_valid_until(self.valid_until)
84    }
85}
86
87pub type GenericService = BoxService<Request<Body>, Response<Box<DynBody>>, BoxError>;
88
89impl TryFrom<Config> for ClientBuilder<GenericService> {
90    type Error = Error;
91
92    /// Builds a default [`ClientBuilder`] stack from a given configuration
93    fn try_from(config: Config) -> Result<Self> {
94        let mut connector = HttpConnector::new();
95        connector.enforce_http(false);
96
97        #[cfg(all(feature = "aws-lc-rs", feature = "rustls-tls"))]
98        {
99            if rustls::crypto::CryptoProvider::get_default().is_none() {
100                // the only error here is if it's been initialized in between: we can ignore it
101                // since our semantic is only to set the default value if it does not exist.
102                let _ = rustls::crypto::aws_lc_rs::default_provider().install_default();
103            }
104        }
105
106        match config.proxy_url.as_ref() {
107            Some(proxy_url) if proxy_url.scheme_str() == Some("socks5") => {
108                #[cfg(feature = "socks5")]
109                {
110                    let connector = hyper_util::client::legacy::connect::proxy::SocksV5::new(
111                        proxy_url.clone(),
112                        connector,
113                    );
114                    make_generic_builder(connector, config)
115                }
116
117                #[cfg(not(feature = "socks5"))]
118                Err(Error::ProxyProtocolDisabled {
119                    proxy_url: proxy_url.clone(),
120                    protocol_feature: "kube/socks5",
121                })
122            }
123
124            Some(proxy_url) if proxy_url.scheme_str() == Some("http") => {
125                #[cfg(feature = "http-proxy")]
126                {
127                    let connector =
128                        hyper_util::client::legacy::connect::proxy::Tunnel::new(proxy_url.clone(), connector);
129                    make_generic_builder(connector, config)
130                }
131
132                #[cfg(not(feature = "http-proxy"))]
133                Err(Error::ProxyProtocolDisabled {
134                    proxy_url: proxy_url.clone(),
135                    protocol_feature: "kube/http-proxy",
136                })
137            }
138
139            Some(proxy_url) => Err(Error::ProxyProtocolUnsupported {
140                proxy_url: proxy_url.clone(),
141            }),
142
143            None => make_generic_builder(connector, config),
144        }
145    }
146}
147
148/// Helper function for implementation of [`TryFrom<Config>`] for [`ClientBuilder`].
149/// Ignores [`Config::proxy_url`], which at this point is already handled.
150fn make_generic_builder<H>(connector: H, config: Config) -> Result<ClientBuilder<GenericService>, Error>
151where
152    H: 'static + Clone + Send + Sync + Service<http::Uri>,
153    H::Response: 'static + Connection + Read + Write + Send + Unpin,
154    H::Future: 'static + Send,
155    H::Error: 'static + Send + Sync + std::error::Error,
156{
157    let default_ns = config.default_namespace.clone();
158    let auth_layer = config.auth_layer()?;
159
160    let client: hyper_util::client::legacy::Client<_, Body> = {
161        // Current TLS feature precedence when more than one are set:
162        // 1. rustls-tls
163        // 2. openssl-tls
164        // Create a custom client to use something else.
165        // If TLS features are not enabled, http connector will be used.
166        #[cfg(feature = "rustls-tls")]
167        let connector = config.rustls_https_connector_with_connector(connector)?;
168        #[cfg(all(not(feature = "rustls-tls"), feature = "openssl-tls"))]
169        let connector = config.openssl_https_connector_with_connector(connector)?;
170        #[cfg(all(not(feature = "rustls-tls"), not(feature = "openssl-tls")))]
171        if config.cluster_url.scheme() == Some(&http::uri::Scheme::HTTPS) {
172            // no tls stack situation only works with http scheme
173            return Err(Error::TlsRequired);
174        }
175
176        let mut connector = TimeoutConnector::new(connector);
177
178        // Set the timeouts for the client
179        connector.set_connect_timeout(config.connect_timeout);
180        connector.set_read_timeout(config.read_timeout);
181        connector.set_write_timeout(config.write_timeout);
182
183        hyper_util::client::legacy::Builder::new(TokioExecutor::new()).build(connector)
184    };
185
186    let stack = ServiceBuilder::new().layer(config.base_uri_layer()).into_inner();
187    #[cfg(feature = "gzip")]
188    let stack = ServiceBuilder::new()
189        .layer(stack)
190        .layer(
191            tower_http::decompression::DecompressionLayer::new()
192                .no_br()
193                .no_deflate()
194                .no_zstd()
195                .gzip(!config.disable_compression),
196        )
197        .into_inner();
198
199    let service = ServiceBuilder::new()
200        .layer(stack)
201        .option_layer(auth_layer)
202        .layer(config.extra_headers_layer()?)
203        .layer(
204            // Attribute names follow [Semantic Conventions].
205            // [Semantic Conventions]: https://github.com/open-telemetry/opentelemetry-specification/blob/main/specification/trace/semantic_conventions/http.md
206            TraceLayer::new_for_http()
207                .make_span_with(|req: &Request<Body>| {
208                    tracing::debug_span!(
209                        "HTTP",
210                         http.method = %req.method(),
211                         http.url = %req.uri(),
212                         http.status_code = tracing::field::Empty,
213                         otel.name = req.extensions().get::<&'static str>().unwrap_or(&"HTTP"),
214                         otel.kind = "client",
215                         otel.status_code = tracing::field::Empty,
216                    )
217                })
218                .on_request(|_req: &Request<Body>, _span: &Span| {
219                    tracing::debug!("requesting");
220                })
221                .on_response(|res: &Response<Incoming>, _latency: Duration, span: &Span| {
222                    let status = res.status();
223                    span.record("http.status_code", status.as_u16());
224                    if status.is_client_error() || status.is_server_error() {
225                        span.record("otel.status_code", "ERROR");
226                    }
227                })
228                // Explicitly disable `on_body_chunk`. The default does nothing.
229                .on_body_chunk(())
230                .on_eos(|_: Option<&HeaderMap>, _duration: Duration, _span: &Span| {
231                    tracing::debug!("stream closed");
232                })
233                .on_failure(|ec: ServerErrorsFailureClass, _latency: Duration, span: &Span| {
234                    // Called when
235                    // - Calling the inner service errored
236                    // - Polling `Body` errored
237                    // - the response was classified as failure (5xx)
238                    // - End of stream was classified as failure
239                    span.record("otel.status_code", "ERROR");
240                    match ec {
241                        ServerErrorsFailureClass::StatusCode(status) => {
242                            span.record("http.status_code", status.as_u16());
243                            tracing::error!("failed with status {}", status)
244                        }
245                        ServerErrorsFailureClass::Error(err) => {
246                            tracing::error!("failed with error {}", err)
247                        }
248                    }
249                }),
250        )
251        .map_err(BoxError::from)
252        .service(client);
253
254    let (_, expiration) = config.exec_identity_pem();
255
256    let client = ClientBuilder::new(
257        service
258            .map_response_body(|body| {
259                Box::new(http_body_util::BodyExt::map_err(body, BoxError::from)) as Box<DynBody>
260            })
261            .boxed(),
262        default_ns,
263    )
264    .with_valid_until(expiration);
265
266    Ok(client)
267}
268
269#[cfg(test)]
270mod tests {
271    #[cfg(feature = "gzip")] use super::*;
272
273    #[cfg(feature = "gzip")]
274    #[tokio::test]
275    async fn test_no_accept_encoding_header_sent_when_compression_disabled()
276    -> Result<(), Box<dyn std::error::Error>> {
277        use http::Uri;
278        use std::net::SocketAddr;
279        use tokio::net::{TcpListener, TcpStream};
280
281        // setup a server that echoes back any encoding header value
282        let addr: SocketAddr = ([127, 0, 0, 1], 0).into();
283        let listener = TcpListener::bind(addr).await?;
284        let local_addr = listener.local_addr()?;
285        let uri: Uri = format!("http://{}", local_addr).parse()?;
286
287        tokio::spawn(async move {
288            use http_body_util::Full;
289            use hyper::{server::conn::http1, service::service_fn};
290            use hyper_util::rt::{TokioIo, TokioTimer};
291            use std::convert::Infallible;
292
293            loop {
294                let (tcp, _) = listener.accept().await.unwrap();
295                let io: TokioIo<TcpStream> = TokioIo::new(tcp);
296
297                tokio::spawn(async move {
298                    http1::Builder::new()
299                        .timer(TokioTimer::new())
300                        .serve_connection(
301                            io,
302                            service_fn(|req| async move {
303                                let response = req
304                                    .headers()
305                                    .get(http::header::ACCEPT_ENCODING)
306                                    .map(|b| Bytes::copy_from_slice(b.as_bytes()))
307                                    .unwrap_or_default();
308                                Ok::<_, Infallible>(Response::new(Full::new(response)))
309                            }),
310                        )
311                        .await
312                        .unwrap();
313                });
314            }
315        });
316
317        // confirm gzip echoed back with default config
318        let config = Config { ..Config::new(uri) };
319        let client = make_generic_builder(HttpConnector::new(), config.clone())?.build();
320        let response = client.request_text(http::Request::default()).await?;
321        assert_eq!(&response, "gzip");
322
323        // now disable and check empty string echoed back
324        let config = Config {
325            disable_compression: true,
326            ..config
327        };
328        let client = make_generic_builder(HttpConnector::new(), config)?.build();
329        let response = client.request_text(http::Request::default()).await?;
330        assert_eq!(&response, "");
331
332        Ok(())
333    }
334}