use std::sync::Arc;
use crate::attributes::Attributes;
use crate::credentials::ChannelCredentials;
use crate::credentials::ProtocolInfo;
use crate::credentials::SecurityLevel;
use crate::credentials::call::CallCredentials;
use crate::credentials::call::CompositeCallCredentials;
use crate::credentials::common::Authority;
use crate::private;
use crate::rt::GrpcEndpoint;
use crate::rt::GrpcRuntime;
pub struct HandshakeOutput<T, C: ClientConnectionSecurityContext> {
pub endpoint: T,
pub security: ClientConnectionSecurityInfo<C>,
}
pub trait ClientConnectionSecurityContext: Send + Sync + 'static {
fn validate_authority(&self, authority: &Authority) -> bool {
false
}
}
impl ClientConnectionSecurityContext for Box<dyn ClientConnectionSecurityContext> {
fn validate_authority(&self, authority: &Authority) -> bool {
(**self).validate_authority(authority)
}
}
pub struct ClientConnectionSecurityInfo<C> {
security_protocol: &'static str,
security_level: SecurityLevel,
security_context: C,
attributes: Attributes,
}
pub type DynClientConnectionSecurityInfo =
ClientConnectionSecurityInfo<Box<dyn ClientConnectionSecurityContext>>;
impl<C> ClientConnectionSecurityInfo<C> {
pub fn new(
security_protocol: &'static str,
security_level: SecurityLevel,
security_context: C,
attributes: Attributes,
) -> Self {
Self {
security_protocol,
security_level,
security_context,
attributes,
}
}
pub fn security_protocol(&self) -> &'static str {
self.security_protocol
}
pub fn security_level(&self) -> SecurityLevel {
self.security_level
}
pub fn security_context(&self) -> &C {
&self.security_context
}
pub fn attributes(&self) -> &Attributes {
&self.attributes
}
pub fn into_boxed(self) -> DynClientConnectionSecurityInfo
where
C: ClientConnectionSecurityContext + 'static,
{
ClientConnectionSecurityInfo {
security_protocol: self.security_protocol,
security_level: self.security_level,
security_context: Box::new(self.security_context),
attributes: self.attributes,
}
}
}
#[derive(Default, Clone)]
pub struct ClientHandshakeInfo {
attributes: Attributes,
}
impl ClientHandshakeInfo {
pub fn new(attributes: Attributes) -> Self {
Self { attributes }
}
pub fn attributes(&self) -> &Attributes {
&self.attributes
}
}
pub struct CompositeChannelCredentials<T> {
channel_creds: T,
call_creds: Arc<dyn CallCredentials>,
}
impl<T: ChannelCredentials> CompositeChannelCredentials<T> {
pub fn new(channel_creds: T, call_creds: Arc<dyn CallCredentials>) -> Self {
let combined_call_creds =
if let Some(existing) = channel_creds.get_call_credentials(private::Internal) {
let composite_creds = CompositeCallCredentials::new(existing.clone(), call_creds);
Arc::new(composite_creds)
} else {
call_creds
};
Self {
channel_creds,
call_creds: combined_call_creds,
}
}
}
impl<T: ChannelCredentials> ChannelCredentials for CompositeChannelCredentials<T> {
type ContextType = T::ContextType;
type Output<I> = T::Output<I>;
async fn connect<Input: GrpcEndpoint>(
&self,
authority: &Authority,
source: Input,
info: &ClientHandshakeInfo,
runtime: &GrpcRuntime,
token: private::Internal,
) -> Result<HandshakeOutput<Self::Output<Input>, Self::ContextType>, String> {
self.channel_creds
.connect(authority, source, info, runtime, token)
.await
}
fn info(&self) -> &ProtocolInfo {
self.channel_creds.info()
}
fn get_call_credentials(&self, _: private::Internal) -> Option<&Arc<dyn CallCredentials>> {
Some(&self.call_creds)
}
}
#[cfg(test)]
mod tests {
use tokio::net::TcpListener;
use tonic::async_trait;
use super::*;
use crate::StatusError;
use crate::credentials::call::CallCredentials;
use crate::credentials::call::CallDetails;
use crate::credentials::call::ClientConnectionSecurityInfo;
use crate::credentials::local::LocalChannelCredentials;
use crate::metadata::AsciiMetadataKey;
use crate::metadata::AsciiMetadataValue;
use crate::metadata::MetadataMap;
use crate::rt;
use crate::rt::TcpOptions;
#[derive(Debug)]
struct MockCallCredentials {
key: &'static str,
value: &'static str,
min_security_level: SecurityLevel,
}
#[async_trait]
impl CallCredentials for MockCallCredentials {
async fn get_metadata(
&self,
_call_details: &CallDetails,
_auth_info: &ClientConnectionSecurityInfo,
metadata: &mut MetadataMap,
) -> Result<(), StatusError> {
metadata.insert(
self.key.parse::<AsciiMetadataKey>().unwrap(),
AsciiMetadataValue::try_from(self.value).unwrap(),
);
Ok(())
}
fn minimum_channel_security_level(&self) -> SecurityLevel {
self.min_security_level
}
}
#[tokio::test]
async fn test_multiple_composition() {
let channel_creds = LocalChannelCredentials::new();
let call_creds1 = Arc::new(MockCallCredentials {
key: "auth1",
value: "val1",
min_security_level: SecurityLevel::IntegrityOnly,
});
let call_creds2 = Arc::new(MockCallCredentials {
key: "auth2",
value: "val2",
min_security_level: SecurityLevel::PrivacyAndIntegrity,
});
let composite1 = CompositeChannelCredentials::new(channel_creds, call_creds1);
let composite2 = CompositeChannelCredentials::new(composite1, call_creds2);
let combined_call_creds = composite2.get_call_credentials(private::Internal).unwrap();
let call_details = CallDetails::new("service".to_string(), "method".to_string());
let auth_info = ClientConnectionSecurityInfo::new(
"local",
SecurityLevel::NoSecurity,
Attributes::new(),
);
let mut metadata = MetadataMap::new();
combined_call_creds
.get_metadata(&call_details, &auth_info, &mut metadata)
.await
.unwrap();
assert_eq!(metadata.get("auth1").unwrap(), "val1");
assert_eq!(metadata.get("auth2").unwrap(), "val2");
assert_eq!(
combined_call_creds.minimum_channel_security_level(),
SecurityLevel::PrivacyAndIntegrity
);
let addr = "127.0.0.1:0";
let listener = TcpListener::bind(addr).await.unwrap();
let server_addr = listener.local_addr().unwrap();
let authority = Authority::new("localhost".to_string(), Some(server_addr.port()));
let runtime = rt::default_runtime();
let endpoint = runtime
.tcp_stream(server_addr, TcpOptions::default())
.await
.unwrap();
let output = composite2
.connect(
&authority,
endpoint,
&ClientHandshakeInfo::default(),
&runtime,
private::Internal,
)
.await
.unwrap();
assert_eq!(output.security.security_level(), SecurityLevel::NoSecurity);
assert_eq!(output.security.security_protocol(), "local");
}
}