use std::sync::Arc;
use tonic::async_trait;
use crate::credentials::ChannelCredentials;
use crate::credentials::ProtocolInfo;
use crate::credentials::ServerCredentials;
use crate::credentials::call::CallCredentials;
use crate::credentials::client::ClientConnectionSecurityContext;
use crate::credentials::client::ClientHandshakeInfo;
use crate::credentials::client::HandshakeOutput;
use crate::credentials::common::Authority;
use crate::credentials::server::HandshakeOutput as ServerHandshakeOutput;
use crate::private;
use crate::rt::GrpcEndpoint;
use crate::rt::GrpcRuntime;
use crate::send_future::SendFuture;
type BoxEndpoint = Box<dyn GrpcEndpoint>;
#[async_trait]
pub(crate) trait DynChannelCredentials: Send + Sync {
async fn dyn_connect(
&self,
authority: &Authority,
source: BoxEndpoint,
info: &ClientHandshakeInfo,
runtime: &GrpcRuntime,
) -> Result<HandshakeOutput<BoxEndpoint, Box<dyn ClientConnectionSecurityContext>>, String>;
fn info(&self) -> &ProtocolInfo;
fn get_call_credentials(&self) -> Option<&Arc<dyn CallCredentials>>;
}
#[async_trait]
impl<T> DynChannelCredentials for T
where
T: ChannelCredentials,
T::Output<BoxEndpoint>: GrpcEndpoint,
{
async fn dyn_connect(
&self,
authority: &Authority,
source: BoxEndpoint,
info: &ClientHandshakeInfo,
runtime: &GrpcRuntime,
) -> Result<HandshakeOutput<BoxEndpoint, Box<dyn ClientConnectionSecurityContext>>, String>
{
let output = self
.connect(authority, source, info, runtime, private::Internal)
.make_send()
.await?;
let stream = output.endpoint;
let sec_info = output.security;
Ok(HandshakeOutput {
endpoint: Box::new(stream),
security: sec_info.into_boxed(),
})
}
fn info(&self) -> &ProtocolInfo {
self.info()
}
fn get_call_credentials(&self) -> Option<&Arc<dyn CallCredentials>> {
self.get_call_credentials(private::Internal)
}
}
impl ChannelCredentials for Arc<dyn DynChannelCredentials> {
type ContextType = Box<dyn ClientConnectionSecurityContext>;
type Output<I> = BoxEndpoint;
async fn connect<Input: GrpcEndpoint>(
&self,
authority: &Authority,
source: Input,
info: &ClientHandshakeInfo,
runtime: &GrpcRuntime,
_token: private::Internal,
) -> Result<HandshakeOutput<Self::Output<Input>, Self::ContextType>, String> {
(**self)
.dyn_connect(authority, Box::new(source), info, runtime)
.await
}
fn get_call_credentials(&self, _: private::Internal) -> Option<&Arc<dyn CallCredentials>> {
(**self).get_call_credentials()
}
fn info(&self) -> &ProtocolInfo {
(**self).info()
}
}
#[async_trait]
pub(crate) trait DynServerCredentials: Send + Sync {
async fn dyn_accept(
&self,
source: BoxEndpoint,
runtime: GrpcRuntime,
) -> Result<ServerHandshakeOutput<BoxEndpoint>, String>;
fn info(&self) -> &ProtocolInfo;
}
#[async_trait]
impl<T> DynServerCredentials for T
where
T: ServerCredentials,
T::Output<BoxEndpoint>: GrpcEndpoint,
{
async fn dyn_accept(
&self,
source: BoxEndpoint,
runtime: GrpcRuntime,
) -> Result<ServerHandshakeOutput<BoxEndpoint>, String> {
let output = SendFuture::make_send(self.accept(source, runtime, private::Internal)).await?;
Ok(ServerHandshakeOutput {
endpoint: Box::new(output.endpoint),
security: output.security,
})
}
fn info(&self) -> &ProtocolInfo {
self.info()
}
}
#[cfg(test)]
mod tests {
use tokio::io::AsyncReadExt;
use tokio::io::AsyncWriteExt;
use tokio::net::TcpListener;
use tokio::net::TcpStream;
use super::*;
use crate::credentials::LocalChannelCredentials;
use crate::credentials::LocalServerCredentials;
use crate::credentials::SecurityLevel;
use crate::credentials::client::ClientHandshakeInfo;
use crate::rt::AsyncIoAdapter;
use crate::rt::TcpOptions;
use crate::rt::tokio::TokioIoStream;
use crate::rt::{self};
#[tokio::test]
async fn test_dyn_client_credential_dispatch() {
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
let addr = listener.local_addr().unwrap();
let dyn_creds = LocalChannelCredentials::new_arc() as Arc<dyn DynChannelCredentials>;
let authority = Authority::new("localhost".to_string(), Some(addr.port()));
let runtime = crate::rt::default_runtime();
let source = runtime
.tcp_stream(addr, TcpOptions::default())
.await
.unwrap();
let info = ClientHandshakeInfo::default();
let result = dyn_creds
.dyn_connect(&authority, source, &info, &runtime)
.await;
assert!(result.is_ok());
let output = result.unwrap();
let endpoint = output.endpoint;
let security_info = output.security;
assert!(!endpoint.get_local_address().is_empty());
assert_eq!(security_info.security_protocol(), "local");
assert_eq!(security_info.security_level(), SecurityLevel::NoSecurity);
let (mut server_stream, _) = listener.accept().await.unwrap();
assert_eq!(
endpoint.get_local_address(),
&server_stream.peer_addr().unwrap().to_string()
);
let test_data = b"hello dynamic grpc";
server_stream.write_all(test_data).await.unwrap();
let mut buf = vec![0u8; test_data.len()];
AsyncIoAdapter::new(endpoint)
.read_exact(&mut buf)
.await
.unwrap();
assert_eq!(buf, test_data);
assert!(
security_info
.security_context()
.validate_authority(&authority)
);
}
#[tokio::test]
async fn test_dyn_server_credential_dispatch() {
let creds = LocalServerCredentials::new();
let dyn_creds: Box<dyn DynServerCredentials> = Box::new(creds);
let info = dyn_creds.info();
assert_eq!(info.security_protocol, "local");
let addr = "127.0.0.1:0";
let runtime = rt::default_runtime();
let listener = TcpListener::bind(addr).await.unwrap();
let server_addr = listener.local_addr().unwrap();
let client_handle = tokio::spawn(async move {
let mut stream = TcpStream::connect(server_addr).await.unwrap();
let data = b"hello dynamic grpc server";
stream.write_all(data).await.unwrap();
let mut buf = vec![0u8; 1];
let _ = stream.read(&mut buf).await;
});
let (stream, _) = listener.accept().await.unwrap();
let server_stream = TokioIoStream::new_from_tcp(stream).unwrap();
let result = dyn_creds
.dyn_accept(Box::new(server_stream) as Box<dyn GrpcEndpoint>, runtime)
.await;
assert!(result.is_ok());
let output = result.unwrap();
let endpoint = output.endpoint;
let security_info = output.security;
assert_eq!(security_info.security_protocol(), "local");
assert_eq!(security_info.security_level(), SecurityLevel::NoSecurity);
let mut buf = vec![0u8; 25];
AsyncIoAdapter::new(endpoint)
.read_exact(&mut buf)
.await
.unwrap();
assert_eq!(&buf[..], b"hello dynamic grpc server");
client_handle.abort();
}
}