use crate::bundle::spiffebundle;
use crate::bundle::x509bundle;
use crate::spiffeid::{self, TrustDomain};
use crate::spiffetls::tlsconfig;
use crate::workloadapi::Context;
use hyper::body::Body;
use hyper::service::Service;
use hyper::{Request, Response, StatusCode};
use rustls::{Certificate, ClientConfig, RootCertStore, ServerName};
use std::io::{Read, Write};
use std::net::{IpAddr, TcpStream};
use std::sync::Arc;
use std::task::{Context as TaskContext, Poll};
use std::time::Duration;
use url::Url;
#[derive(Debug, Clone)]
pub struct Error(String);
impl std::fmt::Display for Error {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
self.0.fmt(f)
}
}
impl std::error::Error for Error {}
pub type Result<T> = std::result::Result<T, Error>;
fn wrap_error(message: impl std::fmt::Display) -> Error {
Error(format!("federation: {}", message))
}
pub trait FetchOption {
fn apply(&self, options: &mut FetchOptions) -> Result<()>;
}
pub fn with_spiffe_auth(
bundle_source: Arc<dyn x509bundle::Source + Send + Sync>,
endpoint_id: spiffeid::ID,
) -> impl FetchOption {
FetchOptionFn(move |options: &mut FetchOptions| {
if !matches!(options.auth_method, AuthMethod::Default) {
return Err(wrap_error(
"cannot use both SPIFFE and Web PKI authentication",
));
}
options.auth_method = AuthMethod::Spiffe {
bundle_source: bundle_source.clone(),
endpoint_id: endpoint_id.clone(),
};
Ok(())
})
}
pub fn with_web_pki_roots(roots: RootCertStore) -> impl FetchOption {
FetchOptionFn(move |options: &mut FetchOptions| {
if !matches!(options.auth_method, AuthMethod::Default) {
return Err(wrap_error(
"cannot use both SPIFFE and Web PKI authentication",
));
}
options.auth_method = AuthMethod::WebPki { roots: roots.clone() };
Ok(())
})
}
pub fn fetch_bundle(
trust_domain: TrustDomain,
url: &str,
options: &[Box<dyn FetchOption>],
) -> Result<spiffebundle::Bundle> {
let mut opts = FetchOptions::default();
for option in options {
option.apply(&mut opts)?;
}
let parsed = Url::parse(url).map_err(|err| wrap_error(format!("invalid URL: {}", err)))?;
let body = fetch_url(&parsed, &opts)?;
spiffebundle::Bundle::parse(trust_domain, &body).map_err(|err| wrap_error(err))
}
pub trait BundleWatcher: Send + Sync {
fn next_refresh(&self, refresh_hint: Duration) -> Duration;
fn on_update(&self, bundle: spiffebundle::Bundle);
fn on_error(&self, err: Error);
}
pub async fn watch_bundle(
ctx: &Context,
trust_domain: TrustDomain,
url: &str,
watcher: Arc<dyn BundleWatcher>,
options: Vec<Box<dyn FetchOption>>,
) -> Result<()> {
let mut latest: Option<spiffebundle::Bundle> = None;
loop {
match fetch_bundle(trust_domain.clone(), url, &options) {
Ok(bundle) => {
let changed = latest.as_ref().map(|b| !b.equal(&bundle)).unwrap_or(true);
if changed {
watcher.on_update(bundle.clone_bundle());
latest = Some(bundle);
}
}
Err(err) => watcher.on_error(err),
}
let refresh_hint = latest
.as_ref()
.and_then(|b| b.refresh_hint())
.unwrap_or_default();
let next = watcher.next_refresh(refresh_hint);
tokio::select! {
_ = tokio::time::sleep(next) => {},
_ = ctx.cancelled() => return Err(wrap_error("context canceled")),
}
}
}
#[derive(Clone)]
enum AuthMethod {
Default,
Spiffe {
bundle_source: Arc<dyn x509bundle::Source + Send + Sync>,
endpoint_id: spiffeid::ID,
},
WebPki {
roots: RootCertStore,
},
}
#[doc(hidden)]
pub struct FetchOptions {
auth_method: AuthMethod,
}
impl Default for FetchOptions {
fn default() -> Self {
Self {
auth_method: AuthMethod::Default,
}
}
}
struct FetchOptionFn<F>(F);
impl<F> FetchOption for FetchOptionFn<F>
where
F: Fn(&mut FetchOptions) -> Result<()> + Send + Sync,
{
fn apply(&self, options: &mut FetchOptions) -> Result<()> {
(self.0)(options)
}
}
pub trait HandlerOption {
fn apply(&self, config: &mut HandlerConfig) -> Result<()>;
}
pub fn with_handler_logger(log: crate::workloadapi::LoggerRef) -> Box<dyn HandlerOption> {
Box::new(HandlerOptionFn(move |config: &mut HandlerConfig| {
config.log = log.clone();
Ok(())
}))
}
pub fn new_handler(
trust_domain: TrustDomain,
source: Arc<dyn spiffebundle::Source + Send + Sync>,
options: Vec<Box<dyn HandlerOption>>,
) -> Result<BundleHandler> {
let mut config = HandlerConfig {
log: Arc::new(crate::logger::null_logger()),
};
for opt in options {
opt.apply(&mut config)?;
}
Ok(BundleHandler {
trust_domain,
source,
log: config.log,
})
}
pub struct BundleHandler {
trust_domain: TrustDomain,
source: Arc<dyn spiffebundle::Source + Send + Sync>,
log: crate::workloadapi::LoggerRef,
}
impl Service<Request<Body>> for BundleHandler {
type Response = Response<Body>;
type Error = hyper::Error;
type Future = std::future::Ready<std::result::Result<Self::Response, Self::Error>>;
fn poll_ready(&mut self, _cx: &mut TaskContext<'_>) -> Poll<std::result::Result<(), Self::Error>> {
Poll::Ready(Ok(()))
}
fn call(&mut self, req: Request<Body>) -> Self::Future {
let response = if req.method() != hyper::Method::GET {
Response::builder()
.status(StatusCode::METHOD_NOT_ALLOWED)
.body(Body::from("method is not allowed"))
.unwrap()
} else {
match self.source.get_bundle_for_trust_domain(self.trust_domain.clone()) {
Ok(bundle) => match bundle.marshal() {
Ok(body) => Response::builder()
.status(StatusCode::OK)
.header("Content-Type", "application/json")
.body(Body::from(body))
.unwrap(),
Err(err) => {
self.log.errorf(format_args!(
"unable to marshal bundle for trust domain {:?}: {}",
self.trust_domain, err
));
Response::builder()
.status(StatusCode::INTERNAL_SERVER_ERROR)
.body(Body::from(format!(
"unable to serve bundle for {:?}",
self.trust_domain
)))
.unwrap()
}
},
Err(err) => {
self.log.errorf(format_args!(
"unable to get bundle for trust domain {:?}: {}",
self.trust_domain, err
));
Response::builder()
.status(StatusCode::INTERNAL_SERVER_ERROR)
.body(Body::from(format!(
"unable to serve bundle for {:?}",
self.trust_domain
)))
.unwrap()
}
}
};
std::future::ready(Ok(response))
}
}
#[doc(hidden)]
pub struct HandlerConfig {
log: crate::workloadapi::LoggerRef,
}
struct HandlerOptionFn<F>(F);
impl<F> HandlerOption for HandlerOptionFn<F>
where
F: Fn(&mut HandlerConfig) -> Result<()> + Send + Sync,
{
fn apply(&self, config: &mut HandlerConfig) -> Result<()> {
(self.0)(config)
}
}
fn fetch_url(url: &Url, options: &FetchOptions) -> Result<Vec<u8>> {
let host = url
.host_str()
.ok_or_else(|| wrap_error("URL missing host"))?;
let port = url
.port_or_known_default()
.ok_or_else(|| wrap_error("URL missing port"))?;
let addr = format!("{}:{}", host, port);
let mut stream = match url.scheme() {
"https" => {
let server_name = server_name_for_host(host)?;
let tls_config = tls_config_for_auth(options)?;
let tcp = TcpStream::connect(&addr).map_err(|err| wrap_error(err))?;
let conn = rustls::ClientConnection::new(Arc::new(tls_config), server_name)
.map_err(|err| wrap_error(format!("unable to create TLS connection: {}", err)))?;
HttpStream::Tls(rustls::StreamOwned::new(conn, tcp))
}
"http" => HttpStream::Plain(
TcpStream::connect(&addr).map_err(|err| wrap_error(err))?,
),
scheme => {
return Err(wrap_error(format!("unsupported URL scheme: {}", scheme)));
}
};
let path = match (url.path(), url.query()) {
("", None) => "/".to_string(),
(path, None) => path.to_string(),
(path, Some(query)) => format!("{}?{}", path, query),
};
let request = format!(
"GET {} HTTP/1.1\r\nHost: {}\r\nConnection: close\r\nAccept: */*\r\n\r\n",
path, host
);
stream
.write_all(request.as_bytes())
.map_err(|err| wrap_error(err))?;
stream.flush().map_err(|err| wrap_error(err))?;
let mut response = Vec::new();
stream.read_to_end(&mut response).map_err(|err| wrap_error(err))?;
parse_http_body(&response)
}
fn tls_config_for_auth(options: &FetchOptions) -> Result<ClientConfig> {
match &options.auth_method {
AuthMethod::Default => tlsconfig::webpki_client_config(Some(system_roots()?)),
AuthMethod::WebPki { roots } => tlsconfig::webpki_client_config(Some(roots.clone())),
AuthMethod::Spiffe {
bundle_source,
endpoint_id,
} => tlsconfig::tls_client_config(
bundle_source.clone(),
tlsconfig::authorize_id(endpoint_id.clone()),
),
}
.map_err(|err| wrap_error(err))
}
fn system_roots() -> Result<RootCertStore> {
let mut roots = RootCertStore::empty();
for cert in rustls_native_certs::load_native_certs()
.map_err(|err| wrap_error(format!("unable to load native certs: {}", err)))?
{
roots
.add(&Certificate(cert.as_ref().to_vec()))
.map_err(|err| wrap_error(format!("unable to add root cert: {}", err)))?;
}
Ok(roots)
}
fn parse_http_body(response: &[u8]) -> Result<Vec<u8>> {
let split = response
.windows(4)
.position(|window| window == b"\r\n\r\n")
.ok_or_else(|| wrap_error("invalid HTTP response"))?;
let header = &response[..split];
let body = response[split + 4..].to_vec();
let status_line = header
.split(|byte| *byte == b'\n')
.next()
.ok_or_else(|| wrap_error("invalid HTTP response"))?;
let status_line = String::from_utf8_lossy(status_line).trim().to_string();
let mut parts = status_line.split_whitespace();
let _proto = parts.next().ok_or_else(|| wrap_error("invalid HTTP response"))?;
let status = parts.next().ok_or_else(|| wrap_error("invalid HTTP response"))?;
if status != "200" {
return Err(wrap_error(format!("unexpected HTTP status {}", status)));
}
Ok(body)
}
fn server_name_for_host(host: &str) -> Result<ServerName> {
if let Ok(ip) = host.parse::<IpAddr>() {
return Ok(ServerName::IpAddress(ip));
}
ServerName::try_from(host).map_err(|err| wrap_error(format!("invalid server name: {}", err)))
}
enum HttpStream {
Plain(TcpStream),
Tls(rustls::StreamOwned<rustls::ClientConnection, TcpStream>),
}
impl Read for HttpStream {
fn read(&mut self, buf: &mut [u8]) -> std::io::Result<usize> {
match self {
HttpStream::Plain(stream) => stream.read(buf),
HttpStream::Tls(stream) => stream.read(buf),
}
}
}
impl Write for HttpStream {
fn write(&mut self, buf: &[u8]) -> std::io::Result<usize> {
match self {
HttpStream::Plain(stream) => stream.write(buf),
HttpStream::Tls(stream) => stream.write(buf),
}
}
fn flush(&mut self) -> std::io::Result<()> {
match self {
HttpStream::Plain(stream) => stream.flush(),
HttpStream::Tls(stream) => stream.flush(),
}
}
}