kubert 0.25.0

Kubernetes runtime helpers. Based on kube-rs.
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
//! Helpers for configuring and running an HTTPS server, especially for admission controllers and
//! API extensions
//!
//! Unlike a normal `hyper` server, this server reloads its TLS credentials for each connection to
//! support certificate rotation.
//!
//! # TLS Feature Flags
//!
//! The server module requires that one of the [TLS implementation Cargo
//! features](crate#tls-features) be enabled in order to run the server.
//! If neither TLS implementation is selected, running the server will panic.
//! However, this module itself is still enabled if neither TLS feature flag is
//! selected. This is to allow the server module to be used in a library crate
//! which does not require either particular TLS implementation, so that the
//! top-level binary crate may choose which TLS implementation is used.

#![cfg_attr(
    not(any(feature = "rustls-tls", feature = "openssl-tls")),
    allow(dead_code, unused_variables)
)]

use std::{convert::Infallible, net::SocketAddr, path::PathBuf, str::FromStr, sync::Arc};
use thiserror::Error;
use tokio::net::{TcpListener, TcpStream};
use tower::Service;
use tracing::{debug, error, info, info_span, Instrument};

#[cfg(feature = "rustls-tls")]
mod tls_rustls;

#[cfg(feature = "openssl-tls")]
mod tls_openssl;

#[cfg(test)]
mod tests;

/// Command-line arguments used to configure a server
#[derive(Clone, Debug)]
#[cfg_attr(feature = "clap", derive(clap::Parser))]
#[cfg_attr(docsrs, doc(cfg(feature = "server")))]
pub struct ServerArgs {
    /// The server's address
    #[cfg_attr(feature = "clap", clap(long, default_value = "0.0.0.0:443"))]
    pub server_addr: SocketAddr,

    /// The path to the server's TLS key file.
    ///
    /// This should be a PEM-encoded file containing a single PKCS#8 or RSA
    /// private key.
    #[cfg_attr(feature = "clap", clap(long))]
    pub server_tls_key: Option<TlsKeyPath>,

    /// The path to the server's TLS certificate file.
    ///
    /// This should be a PEM-encoded file containing at least one TLS end-entity
    /// certificate.
    #[cfg_attr(feature = "clap", clap(long))]
    pub server_tls_certs: Option<TlsCertPath>,
}

/// A running server
#[derive(Debug)]
#[cfg_attr(docsrs, doc(cfg(feature = "server")))]
pub struct Bound {
    local_addr: SocketAddr,
    tcp: tokio::net::TcpListener,
    tls: Arc<TlsPaths>,
}

/// A running server
#[derive(Debug)]
#[cfg_attr(docsrs, doc(cfg(feature = "server")))]
pub struct SpawnedServer {
    local_addr: SocketAddr,
    task: tokio::task::JoinHandle<()>,
}
/// Describes an error that occurred while initializing a server
#[derive(Debug, Error)]
#[cfg_attr(docsrs, doc(cfg(feature = "server")))]
#[non_exhaustive]
pub enum Error {
    /// No TLS key path was configured
    #[error("--server-tls-key must be set")]
    NoTlsKey,

    /// No TLS certificate path was configured
    #[error("--server-tls-certs must be set")]
    NoTlsCerts,

    /// The configured TLS certificate path could not be read
    #[error("failed to read TLS certificates: {0}")]
    TlsCertsReadError(#[source] std::io::Error),

    /// The configured TLS key path could not be read
    #[error("failed to read TLS key: {0}")]
    TlsKeyReadError(#[source] std::io::Error),

    /// The configured TLS credentials were invalid
    #[error("failed to load TLS credentials: {0}")]
    InvalidTlsCredentials(#[source] Box<dyn std::error::Error + Send + Sync>),

    /// An error occurred while binding a server
    #[error("failed to bind {0:?}: {1}")]
    Bind(SocketAddr, #[source] std::io::Error),

    /// An error occurred while reading a bound server's local address
    #[error("failed to get bound local address: {0}")]
    LocalAddr(#[source] std::io::Error),
}

/// The path to the server's TLS private key
#[derive(Clone, Debug)]
pub struct TlsKeyPath(PathBuf);

/// The path to the server's TLS certificate bundle
#[derive(Clone, Debug)]
pub struct TlsCertPath(PathBuf);

#[derive(Clone, Debug)]
// TLS paths may not be used if TLS is not enabled.
struct TlsPaths {
    key: TlsKeyPath,
    certs: TlsCertPath,
}

// === impl ServerArgs ===

impl ServerArgs {
    /// Attempts to load credentials and bind the server socket
    ///
    /// # Panics
    ///
    /// This method panics if neither of [the "rustls-tls" or "openssl-tls" Cargo
    /// features][tls-features] are enabled. See [the module-level
    /// documentation][tls-doc] for details.
    ///
    /// [tls-features]: crate#tls-features
    /// [tls-doc]: crate::server#tls-feature-flags
    pub async fn bind(self) -> Result<Bound, Error> {
        let tls = {
            let key = self.server_tls_key.ok_or(Error::NoTlsKey)?;
            let certs = self.server_tls_certs.ok_or(Error::NoTlsCerts)?;
            // Ensure the TLS key and certificate files load properly before binding the socket and
            // spawning the server.

            #[cfg(all(not(feature = "rustls-tls"), feature = "openssl-tls"))]
            let _ = tls_openssl::load_tls(&key, &certs).await?;
            #[cfg(feature = "rustls-tls")]
            let _ = tls_rustls::load_tls(&key, &certs).await?;

            Arc::new(TlsPaths { key, certs })
        };

        let tcp = TcpListener::bind(&self.server_addr)
            .await
            .map_err(|e| Error::Bind(self.server_addr, e))?;
        let local_addr = tcp.local_addr().map_err(Error::LocalAddr)?;
        Ok(Bound {
            local_addr,
            tcp,
            tls,
        })
    }
}

impl Bound {
    /// Returns the bound local address of the server
    pub fn local_addr(&self) -> SocketAddr {
        self.local_addr
    }

    /// Bind an HTTPS server to the configured address with the provided service
    ///
    /// The server terminates gracefully when the provided `drain` handle is signaled.
    ///
    /// TLS credentials are read from the configured paths _for each connection_ to support
    /// certificate rotation. As such, it is not recommended to expose this server to the open
    /// internet or to clients that open many short-lived connections. It is primarily intended for
    /// kubernetes admission controllers.
    pub fn spawn<S, B>(self, service: S, drain: drain::Watch) -> SpawnedServer
    where
        S: Service<hyper::Request<hyper::body::Incoming>, Response = hyper::Response<B>>
            + Clone
            + Send
            + 'static,
        S::Error: std::error::Error + Send + Sync,
        S::Future: Send,
        B: hyper::body::Body + Send + 'static,
        B::Data: Send,
        B::Error: std::error::Error + Send + Sync,
    {
        let Self {
            local_addr,
            tcp,
            tls,
        } = self;

        let task = tokio::spawn(
            accept_loop(tcp, drain, service, tls)
                .instrument(info_span!("server", port = %local_addr.port())),
        );

        SpawnedServer { local_addr, task }
    }
}

// === impl SpawnedServer ===

impl SpawnedServer {
    /// Returns the bound local address of the spawned server
    pub fn local_addr(&self) -> SocketAddr {
        self.local_addr
    }

    /// Terminates the server task forcefully
    pub fn abort(&self) {
        self.task.abort();
    }

    /// Waits for the server task to complete
    pub async fn join(self) -> Result<(), tokio::task::JoinError> {
        self.task.await
    }
}

async fn accept_loop<S, B>(tcp: TcpListener, drain: drain::Watch, service: S, tls: Arc<TlsPaths>)
where
    S: Service<hyper::Request<hyper::body::Incoming>, Response = hyper::Response<B>>
        + Clone
        + Send
        + 'static,
    S::Error: std::error::Error + Send + Sync,
    S::Future: Send,
    B: hyper::body::Body + Send + 'static,
    B::Data: Send,
    B::Error: std::error::Error + Send + Sync,
{
    tracing::debug!("listening");
    loop {
        tracing::trace!("accepting");
        // Wait for the shutdown to be signaled or for the next connection to be accepted.
        let socket = tokio::select! {
            biased;

            release = drain.clone().signaled() => {
                drop(release);
                return;
            }

            res = tcp.accept() => match res {
                Ok((socket, _)) => socket,
                Err(error) => {
                    error!(%error, "Failed to accept connection");
                    continue;
                }
            },
        };

        if let Err(error) = socket.set_nodelay(true) {
            error!(%error, "Failed to set TCP_NODELAY");
            continue;
        }

        let client_addr = match socket.peer_addr() {
            Ok(addr) => addr,
            Err(error) => {
                error!(%error, "Failed to get peer address");
                continue;
            }
        };

        tokio::spawn(
            serve_conn(socket, drain.clone(), service.clone(), tls.clone()).instrument(info_span!(
                "conn",
                client.ip = %client_addr.ip(),
                client.port = %client_addr.port(),
            )),
        );
    }
}

async fn serve_conn<S, B>(socket: TcpStream, drain: drain::Watch, service: S, tls: Arc<TlsPaths>)
where
    S: Service<hyper::Request<hyper::body::Incoming>, Response = hyper::Response<B>>
        + Clone
        + Send
        + 'static,
    S::Error: std::error::Error + Send + Sync,
    S::Future: Send,
    B: hyper::body::Body + Send + 'static,
    B::Data: Send,
    B::Error: std::error::Error + Send + Sync,
{
    tracing::debug!("accepted TCP connection");

    let socket = {
        let TlsPaths { ref key, ref certs } = &*tls;
        // Reload the TLS credentials for each connection.

        #[cfg(all(not(feature = "rustls-tls"), feature = "openssl-tls"))]
        let res = tls_openssl::load_tls(key, certs).await;
        #[cfg(feature = "rustls-tls")]
        let res = tls_rustls::load_tls(key, certs).await;
        #[cfg(not(any(feature = "rustls-tls", feature = "openssl-tls")))]
        let res = {
            enum Accept {}
            Err::<Accept, _>(std::io::Error::other("TLS support not enabled"))
        };
        let tls = match res {
            Ok(tls) => tls,
            Err(error) => {
                info!(%error, "Connection failed");
                return;
            }
        };
        tracing::trace!("loaded TLS credentials");

        #[cfg(all(not(feature = "rustls-tls"), feature = "openssl-tls"))]
        let res = tls_openssl::accept(&tls, socket).await;
        #[cfg(feature = "rustls-tls")]
        let res = tls_rustls::accept(&tls, socket).await;
        #[cfg(not(any(feature = "rustls-tls", feature = "openssl-tls")))]
        let res = Err::<TcpStream, _>(std::io::Error::other("TLS support not enabled"));
        let socket = match res {
            Ok(s) => s,
            Err(error) => {
                info!(%error, "TLS handshake failed");
                return;
            }
        };
        tracing::trace!("TLS handshake completed");

        socket
    };

    #[derive(Copy, Clone, Debug)]
    struct Executor;
    impl<F> hyper::rt::Executor<F> for Executor
    where
        F: std::future::Future + Send + 'static,
        F::Output: Send + 'static,
    {
        fn execute(&self, fut: F) {
            tokio::spawn(fut.in_current_span());
        }
    }

    #[cfg(any(feature = "server-brotli", feature = "server-gzip"))]
    let service = tower_http::decompression::Decompression::new(
        tower_http::compression::Compression::new(service),
    );

    // Serve the HTTP connection and wait for the drain signal. If a drain is
    // signaled, tell the HTTP connection to terminate gracefully when in-flight
    // requests have completed.
    let mut builder = hyper_util::server::conn::auto::Builder::new(Executor);
    // Prevent port scanners, etc, from holding connections open.
    builder
        .http1()
        .header_read_timeout(std::time::Duration::from_secs(2))
        .timer(hyper_util::rt::TokioTimer::default());
    let graceful = hyper_util::server::graceful::GracefulShutdown::new();
    let conn = graceful.watch(
        builder
            .serve_connection(
                hyper_util::rt::TokioIo::new(socket),
                hyper_util::service::TowerToHyperService::new(service),
            )
            .into_owned(),
    );
    tokio::spawn(
        async move {
            match conn.await {
                Ok(()) => debug!("Connection closed"),
                Err(error) => info!(%error, "Connection lost"),
            }
        }
        .in_current_span(),
    );

    let latch = drain.signaled().await;
    latch.release_after(graceful.shutdown()).await;
}

// === impl TlsCertPath ===

impl FromStr for TlsCertPath {
    type Err = Infallible;

    fn from_str(s: &str) -> Result<Self, Self::Err> {
        s.parse().map(Self)
    }
}

// === impl TlsKeyPath ===

impl FromStr for TlsKeyPath {
    type Err = Infallible;

    fn from_str(s: &str) -> Result<Self, Self::Err> {
        s.parse().map(Self)
    }
}