use crate::box_error::BoxError;
use crate::client::runtime_components::sealed::ValidateConfig;
use crate::client::runtime_components::{RuntimeComponents, RuntimeComponentsBuilder};
use crate::impl_shared_conversions;
use aws_smithy_types::config_bag::ConfigBag;
use aws_smithy_types::type_erasure::TypeErasedBox;
use std::any::{Any, TypeId};
use std::collections::HashMap;
use std::fmt;
use std::fmt::Debug;
use std::sync::atomic::{AtomicUsize, Ordering};
use std::sync::Arc;
use std::time::SystemTime;
#[cfg(feature = "http-auth")]
pub mod http;
new_type_future! {
#[doc = "Future for [`IdentityResolver::resolve_identity`]."]
pub struct IdentityFuture<'a, Identity, BoxError>;
}
static NEXT_CACHE_PARTITION: AtomicUsize = AtomicUsize::new(0);
#[derive(Copy, Clone, Debug, Eq, PartialEq, Hash, Ord, PartialOrd)]
pub struct IdentityCachePartition(usize);
impl IdentityCachePartition {
pub fn new() -> Self {
Self(NEXT_CACHE_PARTITION.fetch_add(1, Ordering::Relaxed))
}
#[cfg(feature = "test-util")]
pub fn new_for_tests(value: usize) -> IdentityCachePartition {
Self(value)
}
}
pub trait ResolveCachedIdentity: fmt::Debug + Send + Sync {
fn resolve_cached_identity<'a>(
&'a self,
resolver: SharedIdentityResolver,
runtime_components: &'a RuntimeComponents,
config_bag: &'a ConfigBag,
) -> IdentityFuture<'a>;
#[doc = include_str!("../../rustdoc/validate_base_client_config.md")]
fn validate_base_client_config(
&self,
runtime_components: &RuntimeComponentsBuilder,
cfg: &ConfigBag,
) -> Result<(), BoxError> {
let _ = (runtime_components, cfg);
Ok(())
}
#[doc = include_str!("../../rustdoc/validate_final_config.md")]
fn validate_final_config(
&self,
runtime_components: &RuntimeComponents,
cfg: &ConfigBag,
) -> Result<(), BoxError> {
let _ = (runtime_components, cfg);
Ok(())
}
}
#[derive(Clone, Debug)]
pub struct SharedIdentityCache(Arc<dyn ResolveCachedIdentity>);
impl SharedIdentityCache {
pub fn new(cache: impl ResolveCachedIdentity + 'static) -> Self {
Self(Arc::new(cache))
}
}
impl ResolveCachedIdentity for SharedIdentityCache {
fn resolve_cached_identity<'a>(
&'a self,
resolver: SharedIdentityResolver,
runtime_components: &'a RuntimeComponents,
config_bag: &'a ConfigBag,
) -> IdentityFuture<'a> {
self.0
.resolve_cached_identity(resolver, runtime_components, config_bag)
}
}
impl ValidateConfig for SharedIdentityResolver {}
impl ValidateConfig for SharedIdentityCache {
fn validate_base_client_config(
&self,
runtime_components: &RuntimeComponentsBuilder,
cfg: &ConfigBag,
) -> Result<(), BoxError> {
self.0.validate_base_client_config(runtime_components, cfg)
}
fn validate_final_config(
&self,
runtime_components: &RuntimeComponents,
cfg: &ConfigBag,
) -> Result<(), BoxError> {
self.0.validate_final_config(runtime_components, cfg)
}
}
impl_shared_conversions!(convert SharedIdentityCache from ResolveCachedIdentity using SharedIdentityCache::new);
pub trait ResolveIdentity: Send + Sync + Debug {
fn resolve_identity<'a>(
&'a self,
runtime_components: &'a RuntimeComponents,
config_bag: &'a ConfigBag,
) -> IdentityFuture<'a>;
fn fallback_on_interrupt(&self) -> Option<Identity> {
None
}
fn cache_location(&self) -> IdentityCacheLocation {
IdentityCacheLocation::RuntimeComponents
}
fn cache_partition(&self) -> Option<IdentityCachePartition> {
None
}
}
#[non_exhaustive]
#[derive(Copy, Clone, Debug, Eq, PartialEq)]
pub enum IdentityCacheLocation {
RuntimeComponents,
IdentityResolver,
}
#[derive(Clone, Debug)]
pub struct SharedIdentityResolver {
inner: Arc<dyn ResolveIdentity>,
cache_partition: IdentityCachePartition,
}
impl SharedIdentityResolver {
pub fn new(resolver: impl ResolveIdentity + 'static) -> Self {
let partition = match resolver.cache_partition() {
Some(p) => p,
None => IdentityCachePartition::new(),
};
Self {
inner: Arc::new(resolver),
cache_partition: partition,
}
}
pub fn cache_partition(&self) -> IdentityCachePartition {
self.cache_partition
}
}
impl ResolveIdentity for SharedIdentityResolver {
fn resolve_identity<'a>(
&'a self,
runtime_components: &'a RuntimeComponents,
config_bag: &'a ConfigBag,
) -> IdentityFuture<'a> {
self.inner.resolve_identity(runtime_components, config_bag)
}
fn cache_location(&self) -> IdentityCacheLocation {
self.inner.cache_location()
}
fn cache_partition(&self) -> Option<IdentityCachePartition> {
Some(self.cache_partition())
}
}
impl_shared_conversions!(convert SharedIdentityResolver from ResolveIdentity using SharedIdentityResolver::new);
type DataDebug = Arc<dyn (Fn(&Arc<dyn Any + Send + Sync>) -> &dyn Debug) + Send + Sync>;
#[derive(Clone)]
pub struct Identity {
data: Arc<dyn Any + Send + Sync>,
data_debug: DataDebug,
expiration: Option<SystemTime>,
properties: HashMap<TypeId, Arc<TypeErasedBox>>,
}
impl Identity {
pub fn new<T>(data: T, expiration: Option<SystemTime>) -> Self
where
T: Any + Debug + Send + Sync,
{
Self {
data: Arc::new(data),
data_debug: Arc::new(|d| d.downcast_ref::<T>().expect("type-checked") as _),
expiration,
properties: HashMap::default(),
}
}
pub fn builder() -> Builder {
Builder::default()
}
pub fn data<T: Any + Debug + Send + Sync + 'static>(&self) -> Option<&T> {
self.data.downcast_ref()
}
pub fn expiration(&self) -> Option<SystemTime> {
self.expiration
}
pub fn property<T: Any + Debug + Send + Sync + 'static>(&self) -> Option<&T> {
self.properties
.get(&TypeId::of::<T>())
.and_then(|b| b.downcast_ref())
}
}
impl Debug for Identity {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
let mut debug_struct = f.debug_struct("Identity");
debug_struct
.field("data", (self.data_debug)(&self.data))
.field("expiration", &self.expiration);
for (i, prop) in self.properties.values().enumerate() {
debug_struct.field(&format!("property_{i}"), prop);
}
debug_struct.finish()
}
}
impl ResolveIdentity for Identity {
fn resolve_identity<'a>(
&'a self,
_runtime_components: &'a RuntimeComponents,
_config_bag: &'a ConfigBag,
) -> IdentityFuture<'a> {
IdentityFuture::ready(Ok(self.clone()))
}
}
#[derive(Debug)]
enum ErrorKind {
MissingRequiredField(&'static str),
}
#[derive(Debug)]
pub struct BuildError {
kind: ErrorKind,
}
impl BuildError {
fn missing_required_field(field_name: &'static str) -> Self {
BuildError {
kind: ErrorKind::MissingRequiredField(field_name),
}
}
}
impl fmt::Display for BuildError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> std::fmt::Result {
use ErrorKind::*;
match self.kind {
MissingRequiredField(field_name) => write!(f, "missing required field: `{field_name}`"),
}
}
}
impl std::error::Error for BuildError {}
#[derive(Default)]
pub struct Builder {
data: Option<Arc<dyn Any + Send + Sync>>,
data_debug: Option<DataDebug>,
expiration: Option<SystemTime>,
properties: HashMap<TypeId, Arc<TypeErasedBox>>,
}
impl Builder {
pub fn data<T: Any + Debug + Send + Sync + 'static>(mut self, data: T) -> Self {
self.set_data(data);
self
}
pub fn set_data<T: Any + Debug + Send + Sync + 'static>(&mut self, data: T) {
self.data = Some(Arc::new(data));
self.data_debug = Some(Arc::new(|d| {
d.downcast_ref::<T>().expect("type-checked") as _
}));
}
pub fn expiration(mut self, expiration: SystemTime) -> Self {
self.set_expiration(Some(expiration));
self
}
pub fn set_expiration(&mut self, expiration: Option<SystemTime>) {
self.expiration = expiration;
}
pub fn property<T: Any + Debug + Send + Sync + 'static>(mut self, prop: T) -> Self {
self.set_property(prop);
self
}
pub fn set_property<T: Any + Debug + Send + Sync + 'static>(&mut self, prop: T) {
self.properties
.insert(TypeId::of::<T>(), Arc::new(TypeErasedBox::new(prop)));
}
pub fn build(self) -> Result<Identity, BuildError> {
Ok(Identity {
data: self
.data
.ok_or_else(|| BuildError::missing_required_field("data"))?,
data_debug: self
.data_debug
.expect("should always be set when `data` is set"),
expiration: self.expiration,
properties: self.properties,
})
}
}
#[cfg(test)]
mod tests {
use super::*;
use aws_smithy_async::time::{SystemTimeSource, TimeSource};
#[test]
fn check_send_sync() {
fn is_send_sync<T: Send + Sync>(_: T) {}
is_send_sync(Identity::new("foo", None));
}
#[test]
fn create_retrieve_identity() {
#[derive(Debug)]
struct MyIdentityData {
first: String,
last: String,
}
let ts = SystemTimeSource::new();
let expiration = ts.now();
let identity = Identity::new(
MyIdentityData {
first: "foo".into(),
last: "bar".into(),
},
Some(expiration),
);
assert_eq!("foo", identity.data::<MyIdentityData>().unwrap().first);
assert_eq!("bar", identity.data::<MyIdentityData>().unwrap().last);
assert_eq!(Some(expiration), identity.expiration());
}
#[test]
fn insert_get_identity_properties() {
#[derive(Debug)]
struct MyIdentityData {
first: String,
last: String,
}
#[derive(Debug)]
struct PropertyAlpha;
#[derive(Debug)]
struct PropertyBeta;
let ts = SystemTimeSource::new();
let expiration = ts.now();
let identity = Identity::builder()
.data(MyIdentityData {
first: "foo".into(),
last: "bar".into(),
})
.expiration(expiration)
.property(PropertyAlpha)
.property(PropertyBeta)
.build()
.unwrap();
assert_eq!("foo", identity.data::<MyIdentityData>().unwrap().first);
assert_eq!("bar", identity.data::<MyIdentityData>().unwrap().last);
assert_eq!(Some(expiration), identity.expiration());
assert!(identity.property::<PropertyAlpha>().is_some());
assert!(identity.property::<PropertyBeta>().is_some());
}
}