use std::fmt::Debug;
use std::sync::Arc;
use tonic::async_trait;
use crate::StatusError;
use crate::attributes::Attributes;
use crate::credentials::SecurityLevel;
use crate::metadata::MetadataMap;
pub struct CallDetails {
service_url: String,
method_name: String,
}
impl CallDetails {
pub fn new(service_url: impl Into<String>, method_name: impl Into<String>) -> Self {
Self {
service_url: service_url.into(),
method_name: method_name.into(),
}
}
pub fn service_url(&self) -> &str {
&self.service_url
}
pub fn method_name(&self) -> &str {
&self.method_name
}
}
pub struct ClientConnectionSecurityInfo {
security_protocol: &'static str,
security_level: SecurityLevel,
attributes: Attributes,
}
impl ClientConnectionSecurityInfo {
pub fn new(
security_protocol: &'static str,
security_level: SecurityLevel,
attributes: Attributes,
) -> Self {
Self {
security_protocol,
security_level,
attributes,
}
}
pub fn security_protocol(&self) -> &'static str {
self.security_protocol
}
pub fn security_level(&self) -> SecurityLevel {
self.security_level
}
pub fn attributes(&self) -> &Attributes {
&self.attributes
}
}
#[async_trait]
pub trait CallCredentials: Send + Sync + Debug {
async fn get_metadata(
&self,
call_details: &CallDetails,
auth_info: &ClientConnectionSecurityInfo,
metadata: &mut MetadataMap,
) -> Result<(), StatusError>;
fn minimum_channel_security_level(&self) -> SecurityLevel {
SecurityLevel::PrivacyAndIntegrity
}
}
#[derive(Debug)]
pub struct CompositeCallCredentials {
creds: Vec<Arc<dyn CallCredentials>>,
}
impl CompositeCallCredentials {
pub fn new(first: Arc<dyn CallCredentials>, second: Arc<dyn CallCredentials>) -> Self {
Self {
creds: vec![first, second],
}
}
pub fn with_call_credentials(mut self, creds: Arc<dyn CallCredentials>) -> Self {
self.creds.push(creds);
self
}
}
#[async_trait]
impl CallCredentials for CompositeCallCredentials {
async fn get_metadata(
&self,
call_details: &CallDetails,
auth_info: &ClientConnectionSecurityInfo,
metadata: &mut MetadataMap,
) -> Result<(), StatusError> {
for cred in &self.creds {
cred.get_metadata(call_details, auth_info, metadata).await?;
}
Ok(())
}
fn minimum_channel_security_level(&self) -> SecurityLevel {
self.creds
.iter()
.map(|c| c.minimum_channel_security_level())
.max()
.expect("CompositeCallCredentials must hold at least two children.")
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::metadata::AsciiMetadataKey;
use crate::metadata::AsciiMetadataValue;
#[derive(Debug)]
struct MockCallCredentials {
key: String,
value: String,
security_level: SecurityLevel,
}
#[async_trait]
impl CallCredentials for MockCallCredentials {
async fn get_metadata(
&self,
_call_details: &CallDetails,
_auth_info: &ClientConnectionSecurityInfo,
metadata: &mut MetadataMap,
) -> Result<(), StatusError> {
metadata.insert(
self.key.parse::<AsciiMetadataKey>().unwrap(),
AsciiMetadataValue::try_from(&self.value).unwrap(),
);
Ok(())
}
fn minimum_channel_security_level(&self) -> SecurityLevel {
self.security_level
}
}
#[tokio::test]
async fn test_composite_call_credentials() {
let cred1 = Arc::new(MockCallCredentials {
key: "key1".to_string(),
value: "value1".to_string(),
security_level: SecurityLevel::IntegrityOnly,
});
let cred2 = Arc::new(MockCallCredentials {
key: "key2".to_string(),
value: "value2".to_string(),
security_level: SecurityLevel::PrivacyAndIntegrity,
});
let composite = CompositeCallCredentials::new(cred1, cred2);
let call_details = CallDetails {
service_url: "url".to_string(),
method_name: "method".to_string(),
};
let auth_info = ClientConnectionSecurityInfo::new(
"test",
SecurityLevel::PrivacyAndIntegrity,
Attributes::new(),
);
let mut metadata = MetadataMap::new();
composite
.get_metadata(&call_details, &auth_info, &mut metadata)
.await
.unwrap();
assert_eq!(metadata.get("key1").unwrap(), "value1");
assert_eq!(metadata.get("key2").unwrap(), "value2");
assert_eq!(
composite.minimum_channel_security_level(),
SecurityLevel::PrivacyAndIntegrity
);
}
}