use crate::pb::spire::api::agent::delegatedidentity::v1::delegated_identity_client::DelegatedIdentityClient as DelegatedIdentityApiClient;
use crate::pb::spire::api::agent::delegatedidentity::v1::{
FetchJwtsviDsRequest, SubscribeToJwtBundlesRequest, SubscribeToJwtBundlesResponse,
SubscribeToX509BundlesRequest, SubscribeToX509BundlesResponse, SubscribeToX509sviDsRequest,
SubscribeToX509sviDsResponse,
};
use crate::pb::spire::api::types::Jwtsvid as ProtoJwtSvid;
use crate::selectors::Selector;
use spiffe::constants::DEFAULT_SVID;
use spiffe::error::GrpcClientError;
use spiffe::{
Endpoint, JwtBundle, JwtBundleSet, JwtSvid, TrustDomain, X509Bundle, X509BundleSet, X509Svid,
};
use std::str::FromStr;
use tokio_stream::{Stream, StreamExt};
pub const ADMIN_SOCKET_ENV: &str = "SPIRE_ADMIN_ENDPOINT_SOCKET";
pub fn admin_endpoint_from_env() -> Result<Endpoint, GrpcClientError> {
let raw =
std::env::var(ADMIN_SOCKET_ENV).map_err(|_| GrpcClientError::MissingEndpointSocket)?;
Ok(Endpoint::parse(&raw)?)
}
#[derive(Debug, Clone)]
pub struct DelegatedIdentityClient {
client: DelegatedIdentityApiClient<tonic::transport::Channel>,
}
#[derive(Debug, Clone)]
pub enum DelegateAttestationRequest {
Pid(i32),
Selectors(Vec<Selector>),
}
impl DelegatedIdentityClient {
pub async fn connect_to(endpoint: impl AsRef<str>) -> Result<Self, GrpcClientError> {
let endpoint = Endpoint::parse(endpoint.as_ref())?;
Self::connect(endpoint).await
}
pub async fn connect_env() -> Result<Self, GrpcClientError> {
let endpoint = admin_endpoint_from_env()?;
Self::connect(endpoint).await
}
pub async fn connect(endpoint: Endpoint) -> Result<Self, GrpcClientError> {
let channel = spiffe::grpc::connector::connect(&endpoint).await?;
Ok(Self {
client: DelegatedIdentityApiClient::new(channel),
})
}
pub fn new(conn: tonic::transport::Channel) -> Result<Self, GrpcClientError> {
Ok(DelegatedIdentityClient {
client: DelegatedIdentityApiClient::new(conn),
})
}
}
impl DelegatedIdentityClient {
pub async fn fetch_x509_svid(
&self,
attest_type: DelegateAttestationRequest,
) -> Result<X509Svid, GrpcClientError> {
let request = make_x509svid_request(attest_type);
self.client
.clone()
.subscribe_to_x509svi_ds(request)
.await?
.into_inner()
.message()
.await?
.ok_or(GrpcClientError::EmptyResponse)
.and_then(|resp| Self::parse_x509_svid_from_grpc_response(&resp))
}
pub async fn stream_x509_svids(
&self,
attest_type: DelegateAttestationRequest,
) -> Result<impl Stream<Item = Result<X509Svid, GrpcClientError>> + Send + '_, GrpcClientError>
{
let request = match attest_type {
DelegateAttestationRequest::Selectors(selectors) => SubscribeToX509sviDsRequest {
selectors: selectors.into_iter().map(Into::into).collect(),
pid: 0,
},
DelegateAttestationRequest::Pid(pid) => SubscribeToX509sviDsRequest {
selectors: Vec::new(),
pid,
},
};
let response = self.client.clone().subscribe_to_x509svi_ds(request).await?;
let stream = response.into_inner().map(|message| {
message
.map_err(GrpcClientError::from)
.and_then(|resp| Self::parse_x509_svid_from_grpc_response(&resp))
});
Ok(stream)
}
pub async fn fetch_x509_bundles(&self) -> Result<X509BundleSet, GrpcClientError> {
let request = SubscribeToX509BundlesRequest::default();
let response = self
.client
.clone()
.subscribe_to_x509_bundles(request)
.await?;
let initial = response
.into_inner()
.message()
.await?
.ok_or(GrpcClientError::EmptyResponse)?;
Self::parse_x509_bundle_set_from_grpc_response(initial)
}
pub async fn stream_x509_bundles(
&self,
) -> Result<
impl Stream<Item = Result<X509BundleSet, GrpcClientError>> + Send + 'static,
GrpcClientError,
> {
let request = SubscribeToX509BundlesRequest::default();
let response = self
.client
.clone()
.subscribe_to_x509_bundles(request)
.await?;
Ok(response.into_inner().map(|msg| {
msg.map_err(GrpcClientError::from)
.and_then(Self::parse_x509_bundle_set_from_grpc_response)
}))
}
pub async fn fetch_jwt_svids<T: AsRef<str> + ToString>(
&self,
audience: &[T],
attest_type: DelegateAttestationRequest,
) -> Result<Vec<JwtSvid>, GrpcClientError> {
let request = make_jwtsvid_request(audience, attest_type);
let resp = self
.client
.clone()
.fetch_jwtsvi_ds(request)
.await?
.into_inner()
.svids;
Self::parse_jwt_svid_from_grpc_response(resp)
}
pub async fn stream_jwt_bundles(
&self,
) -> Result<
impl Stream<Item = Result<JwtBundleSet, GrpcClientError>> + Send + 'static,
GrpcClientError,
> {
let request = SubscribeToJwtBundlesRequest::default();
let response = self
.client
.clone()
.subscribe_to_jwt_bundles(request)
.await?;
Ok(response.into_inner().map(|msg| {
msg.map_err(GrpcClientError::from)
.and_then(Self::parse_jwt_bundle_set_from_grpc_response)
}))
}
pub async fn fetch_jwt_bundles(&self) -> Result<JwtBundleSet, GrpcClientError> {
let request = SubscribeToJwtBundlesRequest::default();
let response = self
.client
.clone()
.subscribe_to_jwt_bundles(request)
.await?;
let initial = response
.into_inner()
.message()
.await?
.ok_or(GrpcClientError::EmptyResponse)?;
Self::parse_jwt_bundle_set_from_grpc_response(initial)
}
}
impl DelegatedIdentityClient {
fn parse_x509_svid_from_grpc_response(
response: &SubscribeToX509sviDsResponse,
) -> Result<X509Svid, GrpcClientError> {
let svid = response
.x509_svids
.get(DEFAULT_SVID)
.ok_or(GrpcClientError::EmptyResponse)?;
let x509_svid = svid
.x509_svid
.as_ref()
.ok_or(GrpcClientError::EmptyResponse)?;
let total_length: usize = x509_svid
.cert_chain
.iter()
.map(prost::bytes::Bytes::len)
.sum();
let mut cert_chain_bytes = Vec::with_capacity(total_length);
for c in &x509_svid.cert_chain {
cert_chain_bytes.extend_from_slice(c);
}
X509Svid::parse_from_der(&cert_chain_bytes, svid.x509_svid_key.as_ref()).map_err(Into::into)
}
fn parse_jwt_svid_from_grpc_response(
svids: Vec<ProtoJwtSvid>,
) -> Result<Vec<JwtSvid>, GrpcClientError> {
svids
.into_iter()
.map(|r| JwtSvid::from_str(&r.token).map_err(GrpcClientError::JwtSvid))
.collect()
}
fn parse_jwt_bundle_set_from_grpc_response(
response: SubscribeToJwtBundlesResponse,
) -> Result<JwtBundleSet, GrpcClientError> {
let mut bundle_set = JwtBundleSet::new();
for (td, bundle_data) in response.bundles {
let trust_domain = TrustDomain::try_from(td)?;
let bundle = JwtBundle::from_jwt_authorities(trust_domain, &bundle_data)
.map_err(GrpcClientError::from)?;
bundle_set.add_bundle(bundle);
}
Ok(bundle_set)
}
fn parse_x509_bundle_set_from_grpc_response(
response: SubscribeToX509BundlesResponse,
) -> Result<X509BundleSet, GrpcClientError> {
let mut bundle_set = X509BundleSet::new();
for (td, bundle) in response.ca_certificates {
let trust_domain = TrustDomain::try_from(td)?;
let parsed = X509Bundle::parse_from_der(trust_domain, &bundle)
.map_err(GrpcClientError::X509Bundle)?;
bundle_set.add_bundle(parsed);
}
Ok(bundle_set)
}
}
fn make_x509svid_request(attest_type: DelegateAttestationRequest) -> SubscribeToX509sviDsRequest {
match attest_type {
DelegateAttestationRequest::Selectors(selectors) => SubscribeToX509sviDsRequest {
selectors: selectors.into_iter().map(Into::into).collect(),
pid: 0,
},
DelegateAttestationRequest::Pid(pid) => SubscribeToX509sviDsRequest {
selectors: Vec::new(),
pid,
},
}
}
fn make_jwtsvid_request<T: AsRef<str> + ToString>(
audience: &[T],
attest_type: DelegateAttestationRequest,
) -> FetchJwtsviDsRequest {
let audience = audience
.iter()
.map(std::string::ToString::to_string)
.collect();
match attest_type {
DelegateAttestationRequest::Selectors(selectors) => FetchJwtsviDsRequest {
audience,
selectors: selectors.into_iter().map(Into::into).collect(),
pid: 0,
},
DelegateAttestationRequest::Pid(pid) => FetchJwtsviDsRequest {
audience,
selectors: Vec::new(),
pid,
},
}
}