use std::net::SocketAddr;
use std::str::FromStr;
use std::sync::Arc;
use crate::attributes::Attributes;
use crate::client::name_resolution::TCP_IP_NETWORK_TYPE;
use crate::client::name_resolution::UNIX_NETWORK_TYPE;
use crate::credentials::ChannelCredentials;
use crate::credentials::ProtocolInfo;
use crate::credentials::SecurityLevel;
use crate::credentials::ServerCredentials;
use crate::credentials::call::CallCredentials;
use crate::credentials::client::ClientConnectionSecurityContext;
use crate::credentials::client::ClientConnectionSecurityInfo;
use crate::credentials::client::ClientHandshakeInfo;
use crate::credentials::client::HandshakeOutput;
use crate::credentials::common::Authority;
use crate::credentials::server;
use crate::credentials::server::ServerConnectionSecurityInfo;
use crate::private;
use crate::rt::GrpcEndpoint;
use crate::rt::GrpcRuntime;
pub const PROTOCOL_NAME: &str = "local";
#[derive(Debug, Clone, Default)]
pub struct LocalChannelCredentials {
_private: (),
}
impl LocalChannelCredentials {
pub fn new() -> Self {
Self { _private: () }
}
pub fn new_arc() -> Arc<Self> {
Arc::new(Self { _private: () })
}
}
#[derive(Debug, Clone)]
pub struct LocalConnectionSecurityContext;
impl ClientConnectionSecurityContext for LocalConnectionSecurityContext {
fn validate_authority(&self, _authority: &Authority) -> bool {
true
}
}
fn security_level_for_endpoint(
peer_addr: &str,
network_type: &str,
) -> Result<SecurityLevel, String> {
if network_type == TCP_IP_NETWORK_TYPE
&& SocketAddr::from_str(peer_addr)
.map_err(|e| e.to_string())?
.ip()
.is_loopback()
{
return Ok(SecurityLevel::NoSecurity);
}
if network_type == UNIX_NETWORK_TYPE {
if peer_addr.starts_with("\0") {
return Ok(SecurityLevel::NoSecurity);
}
return Ok(SecurityLevel::PrivacyAndIntegrity);
}
Err(format!(
"local credentials rejected connection to non-local address {}",
peer_addr
))
}
impl ChannelCredentials for LocalChannelCredentials {
type ContextType = LocalConnectionSecurityContext;
type Output<I> = I;
async fn connect<Input: GrpcEndpoint>(
&self,
_authority: &Authority,
source: Input,
_info: &ClientHandshakeInfo,
_runtime: &GrpcRuntime,
_token: private::Internal,
) -> Result<HandshakeOutput<Self::Output<Input>, Self::ContextType>, String> {
let security_level =
security_level_for_endpoint(source.get_peer_address(), source.get_network_type())?;
Ok(HandshakeOutput {
endpoint: source,
security: ClientConnectionSecurityInfo::new(
PROTOCOL_NAME,
security_level,
LocalConnectionSecurityContext,
Attributes::new(),
),
})
}
fn info(&self) -> &ProtocolInfo {
static INFO: ProtocolInfo = ProtocolInfo::new(PROTOCOL_NAME);
&INFO
}
fn get_call_credentials(&self, _: private::Internal) -> Option<&Arc<dyn CallCredentials>> {
None
}
}
#[derive(Debug, Clone, Default)]
pub struct LocalServerCredentials {
_private: (),
}
impl LocalServerCredentials {
pub fn new() -> Self {
Self { _private: () }
}
}
impl ServerCredentials for LocalServerCredentials {
type Output<I> = I;
async fn accept<Input: GrpcEndpoint>(
&self,
source: Input,
_runtime: GrpcRuntime,
_token: private::Internal,
) -> Result<server::HandshakeOutput<Self::Output<Input>>, String> {
let security_level =
security_level_for_endpoint(source.get_peer_address(), source.get_network_type())?;
Ok(server::HandshakeOutput {
endpoint: source,
security: ServerConnectionSecurityInfo::new(
PROTOCOL_NAME,
security_level,
Attributes::new(),
),
})
}
fn info(&self) -> &ProtocolInfo {
static INFO: ProtocolInfo = ProtocolInfo::new(PROTOCOL_NAME);
&INFO
}
}
#[cfg(test)]
mod test {
use tokio::io::AsyncReadExt;
use tokio::io::AsyncWriteExt;
use tokio::net::TcpListener;
use tokio::net::TcpStream;
use super::*;
use crate::credentials::ChannelCredentials;
use crate::credentials::SecurityLevel;
use crate::credentials::ServerCredentials;
use crate::credentials::client::ClientConnectionSecurityContext;
use crate::credentials::client::ClientHandshakeInfo;
use crate::credentials::common::Authority;
use crate::rt;
use crate::rt::AsyncIoAdapter;
use crate::rt::GrpcEndpoint;
use crate::rt::TcpOptions;
use crate::rt::tokio::TokioIoStream;
#[test]
fn test_security_level_for_endpoint_success() {
assert_eq!(
security_level_for_endpoint("127.0.0.1:8080", TCP_IP_NETWORK_TYPE),
Ok(SecurityLevel::NoSecurity)
);
assert_eq!(
security_level_for_endpoint("[::1]:8080", TCP_IP_NETWORK_TYPE),
Ok(SecurityLevel::NoSecurity)
);
assert_eq!(
security_level_for_endpoint("/file/path/name.sock", UNIX_NETWORK_TYPE),
Ok(SecurityLevel::PrivacyAndIntegrity)
);
assert_eq!(
security_level_for_endpoint("\0abstract-sock", UNIX_NETWORK_TYPE),
Ok(SecurityLevel::NoSecurity)
);
}
#[test]
fn test_security_level_for_endpoint_failure() {
assert!(security_level_for_endpoint("192.168.1.1:8080", TCP_IP_NETWORK_TYPE).is_err());
assert!(security_level_for_endpoint("invalid", TCP_IP_NETWORK_TYPE).is_err());
}
#[tokio::test]
async fn test_local_client_credentials() {
let creds = LocalChannelCredentials::new();
let info = creds.info();
assert_eq!(info.security_protocol(), "local");
let addr = "127.0.0.1:0";
let listener = TcpListener::bind(addr).await.unwrap();
let server_addr = listener.local_addr().unwrap();
let authority = Authority::new("localhost".to_string(), Some(server_addr.port()));
let runtime = rt::default_runtime();
let endpoint = runtime
.tcp_stream(server_addr, TcpOptions::default())
.await
.unwrap();
let handshake_info = ClientHandshakeInfo::default();
let output = creds
.connect(
&authority,
endpoint,
&handshake_info,
&runtime,
private::Internal,
)
.await
.unwrap();
let endpoint = output.endpoint;
let security_info = output.security;
assert_eq!(security_info.security_protocol(), "local");
assert_eq!(security_info.security_level(), SecurityLevel::NoSecurity);
let (mut server_stream, _) = listener.accept().await.unwrap();
assert_eq!(
endpoint.get_local_address(),
&server_stream.peer_addr().unwrap().to_string()
);
let test_data = b"hello grpc";
server_stream.write_all(test_data).await.unwrap();
let mut buf = vec![0u8; test_data.len()];
AsyncIoAdapter::new(endpoint)
.read_exact(&mut buf)
.await
.unwrap();
assert_eq!(buf, test_data);
assert!(
security_info
.security_context()
.validate_authority(&authority)
);
}
#[tokio::test]
async fn test_local_server_credentials() {
let creds = LocalServerCredentials::new();
let info = creds.info();
assert_eq!(info.security_protocol, "local");
let addr = "127.0.0.1:0";
let runtime = rt::default_runtime();
let listener = TcpListener::bind(addr).await.unwrap();
let server_addr = listener.local_addr().unwrap();
let client_handle = tokio::spawn(async move {
let mut stream = TcpStream::connect(server_addr).await.unwrap();
let data = b"hello grpc";
stream.write_all(data).await.unwrap();
let mut buf = vec![0u8; 1];
let _ = stream.read(&mut buf).await;
});
let (stream, _) = listener.accept().await.unwrap();
let server_stream = TokioIoStream::new_from_tcp(stream).unwrap();
let output = creds
.accept(server_stream, runtime, private::Internal)
.await
.unwrap();
let endpoint = output.endpoint;
let security_info = output.security;
assert_eq!(security_info.security_protocol(), "local");
assert_eq!(security_info.security_level(), SecurityLevel::NoSecurity);
let mut buf = vec![0u8; 10];
AsyncIoAdapter::new(endpoint)
.read_exact(&mut buf)
.await
.unwrap();
assert_eq!(&buf[..], b"hello grpc");
client_handle.abort();
}
}