use crate::pb::spire::api::agent::delegatedidentity::v1::delegated_identity_client::DelegatedIdentityClient as DelegatedIdentityApiClient;
use crate::pb::spire::api::agent::delegatedidentity::v1::{
FetchJwtsviDsRequest, SubscribeToJwtBundlesRequest, SubscribeToJwtBundlesResponse,
SubscribeToX509BundlesRequest, SubscribeToX509BundlesResponse, SubscribeToX509sviDsRequest,
SubscribeToX509sviDsResponse,
};
use crate::pb::spire::api::types::Jwtsvid as ProtoJwtSvid;
use crate::selectors::Selector;
use spiffe::constants::DEFAULT_SVID;
use spiffe::transport::{Endpoint, TransportError};
use spiffe::{
JwtBundle, JwtBundleError, JwtBundleSet, JwtSvid, JwtSvidError, SpiffeIdError, TrustDomain,
X509Bundle, X509BundleError, X509BundleSet, X509Svid, X509SvidError,
};
use std::str::FromStr as _;
use futures::{Stream, StreamExt as _};
pub const ADMIN_SOCKET_ENV: &str = "SPIRE_ADMIN_ENDPOINT_SOCKET";
#[derive(Debug, thiserror::Error)]
#[non_exhaustive]
pub enum DelegatedIdentityError {
#[error("missing admin endpoint socket path ({ADMIN_SOCKET_ENV})")]
MissingEndpointSocket,
#[error("admin endpoint socket path is not a valid UTF-8 string: {}", .0.display())]
NotUnicodeEndpointSocket(std::ffi::OsString),
#[error("invalid endpoint: {0}")]
Endpoint(#[from] spiffe::transport::EndpointError),
#[error(transparent)]
Transport(#[from] TransportError),
#[error("empty response")]
EmptyResponse,
#[error("JWT SVID error: {0}")]
JwtSvid(#[from] JwtSvidError),
#[error("X.509 bundle error: {0}")]
X509Bundle(#[from] X509BundleError),
#[error("X.509 SVID error: {0}")]
X509Svid(#[from] X509SvidError),
#[error("JWT bundle error: {0}")]
JwtBundle(#[from] JwtBundleError),
#[error("SPIFFE ID error: {0}")]
SpiffeId(#[from] SpiffeIdError),
}
pub fn admin_endpoint_from_env() -> Result<Endpoint, DelegatedIdentityError> {
let raw =
std::env::var_os(ADMIN_SOCKET_ENV).ok_or(DelegatedIdentityError::MissingEndpointSocket)?;
if let Some(raw) = raw.to_str() {
Ok(Endpoint::parse(raw)?)
} else {
Err(DelegatedIdentityError::NotUnicodeEndpointSocket(raw))
}
}
#[derive(Debug, Clone)]
pub struct DelegatedIdentityClient {
client: DelegatedIdentityApiClient<tonic::transport::Channel>,
}
#[derive(Debug, Clone)]
pub enum DelegateAttestationRequest {
Pid(i32),
Selectors(Vec<Selector>),
}
impl DelegatedIdentityClient {
pub async fn connect_to(endpoint: impl AsRef<str>) -> Result<Self, DelegatedIdentityError> {
let endpoint = Endpoint::parse(endpoint.as_ref())?;
Self::connect(endpoint).await
}
pub async fn connect_env() -> Result<Self, DelegatedIdentityError> {
let endpoint = admin_endpoint_from_env()?;
Self::connect(endpoint).await
}
pub async fn connect(endpoint: Endpoint) -> Result<Self, DelegatedIdentityError> {
let channel = spiffe::transport::connector::connect(&endpoint).await?;
Ok(Self {
client: DelegatedIdentityApiClient::new(channel),
})
}
pub fn new(conn: tonic::transport::Channel) -> Result<Self, DelegatedIdentityError> {
Ok(Self {
client: DelegatedIdentityApiClient::new(conn),
})
}
}
impl DelegatedIdentityClient {
pub async fn fetch_x509_svid(
&self,
attest_type: DelegateAttestationRequest,
) -> Result<X509Svid, DelegatedIdentityError> {
let request = make_x509svid_request(attest_type);
self.client
.clone()
.subscribe_to_x509svi_ds(request)
.await?
.into_inner()
.message()
.await?
.ok_or(DelegatedIdentityError::EmptyResponse)
.and_then(|resp| Self::parse_x509_svid_from_grpc_response(&resp))
}
pub async fn stream_x509_svids(
&self,
attest_type: DelegateAttestationRequest,
) -> Result<
impl Stream<Item = Result<X509Svid, DelegatedIdentityError>> + Send + '_,
DelegatedIdentityError,
> {
let request = match attest_type {
DelegateAttestationRequest::Selectors(selectors) => SubscribeToX509sviDsRequest {
selectors: selectors.into_iter().map(Into::into).collect(),
pid: 0,
},
DelegateAttestationRequest::Pid(pid) => SubscribeToX509sviDsRequest {
selectors: Vec::new(),
pid,
},
};
let response = self.client.clone().subscribe_to_x509svi_ds(request).await?;
let stream = response.into_inner().map(|message| {
let resp = message.map_err(DelegatedIdentityError::from)?;
Self::parse_x509_svid_from_grpc_response(&resp)
});
Ok(stream)
}
pub async fn fetch_x509_bundles(&self) -> Result<X509BundleSet, DelegatedIdentityError> {
let request = SubscribeToX509BundlesRequest::default();
let response = self
.client
.clone()
.subscribe_to_x509_bundles(request)
.await?;
let initial = response
.into_inner()
.message()
.await?
.ok_or(DelegatedIdentityError::EmptyResponse)?;
Self::parse_x509_bundle_set_from_grpc_response(initial)
}
pub async fn stream_x509_bundles(
&self,
) -> Result<
impl Stream<Item = Result<X509BundleSet, DelegatedIdentityError>> + Send + 'static + use<>,
DelegatedIdentityError,
> {
let request = SubscribeToX509BundlesRequest::default();
let response = self
.client
.clone()
.subscribe_to_x509_bundles(request)
.await?;
Ok(response.into_inner().map(|msg| {
msg.map_err(DelegatedIdentityError::from)
.and_then(Self::parse_x509_bundle_set_from_grpc_response)
}))
}
pub async fn fetch_jwt_svids<T: AsRef<str> + Sync + ToString>(
&self,
audience: &[T],
attest_type: DelegateAttestationRequest,
) -> Result<Vec<JwtSvid>, DelegatedIdentityError> {
let request = make_jwtsvid_request(audience, attest_type);
let resp = self
.client
.clone()
.fetch_jwtsvi_ds(request)
.await?
.into_inner()
.svids;
Self::parse_jwt_svid_from_grpc_response(resp)
}
pub async fn stream_jwt_bundles(
&self,
) -> Result<
impl Stream<Item = Result<JwtBundleSet, DelegatedIdentityError>> + Send + 'static + use<>,
DelegatedIdentityError,
> {
let request = SubscribeToJwtBundlesRequest::default();
let response = self
.client
.clone()
.subscribe_to_jwt_bundles(request)
.await?;
Ok(response.into_inner().map(|msg| {
msg.map_err(DelegatedIdentityError::from)
.and_then(Self::parse_jwt_bundle_set_from_grpc_response)
}))
}
pub async fn fetch_jwt_bundles(&self) -> Result<JwtBundleSet, DelegatedIdentityError> {
let request = SubscribeToJwtBundlesRequest::default();
let response = self
.client
.clone()
.subscribe_to_jwt_bundles(request)
.await?;
let initial = response
.into_inner()
.message()
.await?
.ok_or(DelegatedIdentityError::EmptyResponse)?;
Self::parse_jwt_bundle_set_from_grpc_response(initial)
}
}
impl DelegatedIdentityClient {
fn parse_x509_svid_from_grpc_response(
response: &SubscribeToX509sviDsResponse,
) -> Result<X509Svid, DelegatedIdentityError> {
let svid = response
.x509_svids
.get(DEFAULT_SVID)
.ok_or(DelegatedIdentityError::EmptyResponse)?;
let x509_svid = svid
.x509_svid
.as_ref()
.ok_or(DelegatedIdentityError::EmptyResponse)?;
let total_length: usize = x509_svid
.cert_chain
.iter()
.map(prost::bytes::Bytes::len)
.sum();
let mut cert_chain_bytes = Vec::with_capacity(total_length);
for c in &x509_svid.cert_chain {
cert_chain_bytes.extend_from_slice(c);
}
X509Svid::parse_from_der(&cert_chain_bytes, svid.x509_svid_key.as_ref()).map_err(Into::into)
}
fn parse_jwt_svid_from_grpc_response(
svids: Vec<ProtoJwtSvid>,
) -> Result<Vec<JwtSvid>, DelegatedIdentityError> {
svids
.into_iter()
.map(|r| JwtSvid::from_str(&r.token).map_err(DelegatedIdentityError::from))
.collect()
}
fn parse_jwt_bundle_set_from_grpc_response(
response: SubscribeToJwtBundlesResponse,
) -> Result<JwtBundleSet, DelegatedIdentityError> {
let mut bundle_set = JwtBundleSet::new();
for (td, bundle_data) in response.bundles {
let trust_domain = TrustDomain::try_from(td)?;
let bundle = JwtBundle::from_jwt_authorities(trust_domain, &bundle_data)
.map_err(DelegatedIdentityError::from)?;
bundle_set.add_bundle(bundle);
}
Ok(bundle_set)
}
fn parse_x509_bundle_set_from_grpc_response(
response: SubscribeToX509BundlesResponse,
) -> Result<X509BundleSet, DelegatedIdentityError> {
let mut bundle_set = X509BundleSet::new();
for (td, bundle) in response.ca_certificates {
let trust_domain = TrustDomain::try_from(td)?;
let parsed = X509Bundle::parse_from_der(trust_domain, &bundle)
.map_err(DelegatedIdentityError::from)?;
bundle_set.add_bundle(parsed);
}
Ok(bundle_set)
}
}
impl From<tonic::Status> for DelegatedIdentityError {
fn from(status: tonic::Status) -> Self {
Self::Transport(TransportError::Status(status))
}
}
impl From<tonic::transport::Error> for DelegatedIdentityError {
fn from(err: tonic::transport::Error) -> Self {
Self::Transport(TransportError::Tonic(err))
}
}
fn make_x509svid_request(attest_type: DelegateAttestationRequest) -> SubscribeToX509sviDsRequest {
match attest_type {
DelegateAttestationRequest::Selectors(selectors) => SubscribeToX509sviDsRequest {
selectors: selectors.into_iter().map(Into::into).collect(),
pid: 0,
},
DelegateAttestationRequest::Pid(pid) => SubscribeToX509sviDsRequest {
selectors: Vec::new(),
pid,
},
}
}
fn make_jwtsvid_request<T: AsRef<str> + ToString>(
audience: &[T],
attest_type: DelegateAttestationRequest,
) -> FetchJwtsviDsRequest {
let audience = audience.iter().map(ToString::to_string).collect();
match attest_type {
DelegateAttestationRequest::Selectors(selectors) => FetchJwtsviDsRequest {
audience,
selectors: selectors.into_iter().map(Into::into).collect(),
pid: 0,
},
DelegateAttestationRequest::Pid(pid) => FetchJwtsviDsRequest {
audience,
selectors: Vec::new(),
pid,
},
}
}