use std::fs;
use std::path::PathBuf;
use std::pin::Pin;
use std::result::Result;
use std::sync::Arc;
use std::sync::Once;
use std::time::Duration;
use bytes::Buf;
use bytes::Bytes;
use http::HeaderMap;
use http::HeaderName;
use http::HeaderValue;
use tokio::net::TcpListener;
use tokio::sync::Notify;
use tokio::sync::oneshot;
use tokio::time::timeout;
use tokio_stream::Stream;
use tokio_stream::StreamExt;
use tokio_stream::wrappers::ReceiverStream;
use tokio_stream::wrappers::TcpListenerStream;
use tonic::Response;
use tonic::async_trait;
use tonic::metadata::MetadataMap as TonicMetadata;
use tonic::transport::Server;
use tonic_prost::prost::Message as ProstMessage;
use crate::StatusCodeError;
use crate::StatusError;
use crate::client::CallOptions;
use crate::client::Channel;
use crate::client::Invoke as _;
use crate::client::RecvStream as _;
use crate::client::ResponseStreamItem;
use crate::client::SendOptions;
use crate::client::SendStream as _;
use crate::client::name_resolution::TCP_IP_NETWORK_TYPE;
use crate::client::transport::SecurityOpts;
use crate::client::transport::TransportOptions;
use crate::client::transport::registry::GLOBAL_TRANSPORT_REGISTRY;
use crate::core::RecvMessage;
use crate::core::RequestHeaders;
use crate::core::ResponseHeaders;
use crate::core::SendMessage;
use crate::core::Trailers;
use crate::credentials::CompositeChannelCredentials;
use crate::credentials::LocalChannelCredentials;
use crate::credentials::SecurityLevel;
use crate::credentials::call::CallCredentials;
use crate::credentials::call::CallDetails;
use crate::credentials::call::ClientConnectionSecurityInfo;
use crate::credentials::client::ClientHandshakeInfo;
use crate::credentials::common::Authority;
use crate::credentials::rustls::RootCertificates;
use crate::credentials::rustls::StaticProvider;
use crate::credentials::rustls::client::ClientTlsConfig;
use crate::credentials::rustls::client::RustlsChannelCredendials;
use crate::echo_pb::EchoRequest;
use crate::echo_pb::EchoResponse;
use crate::echo_pb::echo_server::Echo;
use crate::echo_pb::echo_server::EchoServer;
use crate::metadata::AsciiMetadataKey;
use crate::metadata::MetadataMap;
use crate::rt::GrpcRuntime;
use crate::rt::tokio::TokioRuntime;
#[derive(Debug)]
struct MockCallCredentials {
metadata: Vec<(&'static str, &'static str)>,
min_security_level: SecurityLevel,
should_fail: Option<crate::StatusError>,
}
#[async_trait]
impl CallCredentials for MockCallCredentials {
async fn get_metadata(
&self,
_call_details: &CallDetails,
_auth_info: &ClientConnectionSecurityInfo,
metadata: &mut MetadataMap,
) -> Result<(), crate::StatusError> {
if let Some(status) = &self.should_fail {
return Err(status.clone());
}
for (key, val) in &self.metadata {
metadata.insert(
key.parse::<AsciiMetadataKey>().unwrap(),
val.parse().unwrap(),
);
}
Ok(())
}
fn minimum_channel_security_level(&self) -> SecurityLevel {
self.min_security_level
}
}
const DEFAULT_TEST_DURATION: Duration = Duration::from_secs(10);
const DEFAULT_TEST_SHORT_DURATION: Duration = Duration::from_millis(10);
#[tokio::test]
pub(crate) async fn tonic_transport_rpc() {
super::reg();
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
let addr = listener.local_addr().unwrap(); let shutdown_notify = Arc::new(Notify::new());
let shutdown_notify_copy = shutdown_notify.clone();
println!("EchoServer listening on: {addr}");
let server_handle = tokio::spawn(async move {
let echo_server = EchoService {
response_headers: None,
};
let svc = EchoServer::new(echo_server);
let _ = Server::builder()
.add_service(svc)
.serve_with_incoming_shutdown(
TcpListenerStream::new(listener),
shutdown_notify_copy.notified(),
)
.await;
});
let builder = GLOBAL_TRANSPORT_REGISTRY
.get_transport(TCP_IP_NETWORK_TYPE)
.unwrap();
let config = Arc::new(TransportOptions::default());
let securty_opts = SecurityOpts {
credentials: LocalChannelCredentials::new_arc(),
authority: Authority::new("localhost".to_string(), None),
handshake_info: ClientHandshakeInfo::default(),
};
let (conn, _sec_info, mut disconnection_listener) = builder
.dyn_connect(
addr.to_string(),
GrpcRuntime::new(TokioRuntime::default()),
&securty_opts,
&config,
)
.await
.unwrap();
let (mut tx, mut rx) = conn
.dyn_invoke(
RequestHeaders::new()
.with_method_name("/grpc.examples.echo.Echo/BidirectionalStreamingEcho"),
CallOptions::default(),
)
.await;
let client_handle = tokio::spawn(async move {
let mut dummy_msg = WrappedEchoResponse(EchoResponse { message: "".into() });
match rx.recv(&mut dummy_msg).await {
ResponseStreamItem::Headers(_) => {
println!("Got headers");
}
item => panic!("Expected headers, got {:?}", item),
}
for i in 0..5 {
let message = format!("message {i}");
let request = EchoRequest {
message: message.clone(),
};
let req = WrappedEchoRequest(request);
println!("Sent request: {:?}", req.0);
assert!(
tx.send(&req, SendOptions::default()).await.is_ok(),
"Receiver dropped"
);
let mut recv_msg = WrappedEchoResponse(EchoResponse { message: "".into() });
match rx.recv(&mut recv_msg).await {
ResponseStreamItem::Message => {
let echo_response = recv_msg.0;
println!("Got response: {echo_response:?}");
assert_eq!(echo_response.message, message);
}
item => panic!("Expected message, got {:?}", item),
}
}
});
client_handle.await.unwrap();
assert_eq!(
disconnection_listener.try_recv(),
Err(oneshot::error::TryRecvError::Empty),
);
shutdown_notify.notify_waiters();
let res = timeout(DEFAULT_TEST_DURATION, disconnection_listener)
.await
.unwrap()
.unwrap();
assert_eq!(res, Ok(()));
server_handle.await.unwrap();
}
#[tokio::test]
async fn grpc_invoke_tonic_unary() {
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
let addr = listener.local_addr().unwrap();
let shutdown_notify = Arc::new(Notify::new());
let shutdown_notify_copy = shutdown_notify.clone();
let server_handle = tokio::spawn(async move {
let echo_server = EchoService {
response_headers: None,
};
let svc = EchoServer::new(echo_server);
let _ = Server::builder()
.add_service(svc)
.serve_with_incoming_shutdown(
TcpListenerStream::new(listener),
shutdown_notify_copy.notified(),
)
.await;
});
let target = format!("dns:///{}", addr);
let channel = Channel::new(
&target,
LocalChannelCredentials::new_arc(),
Default::default(),
);
let (_, resp, trailers) = perform_unary_echo(&channel, "hello interop").await;
assert_eq!(resp.message, "hello interop");
assert!(
trailers.status().is_ok(),
"RPC failed: {:?}",
trailers.status()
);
shutdown_notify.notify_one();
server_handle.await.unwrap();
}
#[cfg(unix)]
mod unix_tests {
use std::path::Component;
use std::path::Path;
use tempfile::tempdir;
use tokio::net::UnixListener;
use tokio_stream::wrappers::UnixListenerStream;
use super::*;
async fn run_unix_test(bind_path: &PathBuf, target: &str) {
let listener = UnixListener::bind(bind_path).unwrap();
let channel = Channel::new(
target,
LocalChannelCredentials::new_arc(),
Default::default(),
);
let shutdown_notify = Arc::new(Notify::new());
let shutdown_notify_copy = shutdown_notify.clone();
let server_handle = tokio::spawn(async move {
let echo_server = EchoService {
response_headers: None,
};
let svc = EchoServer::new(echo_server);
let _ = Server::builder()
.add_service(svc)
.serve_with_incoming_shutdown(
UnixListenerStream::new(listener),
shutdown_notify_copy.notified(),
)
.await;
});
let payload = "hello unix";
let (_, resp, trailers) = perform_unary_echo(&channel, payload).await;
assert_eq!(resp.message, payload);
assert!(trailers.status().is_ok());
shutdown_notify.notify_one();
server_handle.await.unwrap();
}
#[tokio::test]
async fn unix_absolute_path() {
let dir = tempdir().expect("failed to create temp dir");
let socket_path = dir.path().join("absolute.sock");
let target = format!("unix://{}", socket_path.to_str().unwrap());
run_unix_test(&socket_path, &target).await;
}
#[tokio::test]
async fn unix_relative_path() {
let dir = tempdir().expect("failed to create temp dir");
let socket_name = "relative.sock";
let socket_path = dir.path().join(socket_name);
let current_dir = std::env::current_dir().expect("failed to fetch current directory");
let relative_path = get_relative_path(&socket_path, ¤t_dir).unwrap();
let target = format!("unix:{}", relative_path.display());
run_unix_test(&socket_path, &target).await;
std::env::set_current_dir(current_dir).unwrap();
}
#[cfg(target_os = "linux")]
#[tokio::test]
async fn unix_abstract_socket() {
let abstract_path = format!("grpc-test-abstract-socket-{}", rand::random::<u64>());
let bind_path = format!("\0{}", abstract_path);
let target = format!("unix-abstract:{}", abstract_path);
run_unix_test(&PathBuf::from(bind_path), &target).await;
}
fn get_relative_path(target: &Path, base: &Path) -> Result<PathBuf, String> {
let mut target_components = target.components();
let mut base_components = base.components();
let mut common_components = 0;
loop {
match (
target_components.clone().next(),
base_components.clone().next(),
) {
(Some(t), Some(b)) if t == b => {
target_components.next();
base_components.next();
common_components += 1;
}
_ => break,
}
}
if common_components == 0 {
return Err("no common ancestor".to_owned());
}
let mut relative_path = PathBuf::new();
for _ in base_components {
relative_path.push(Component::ParentDir);
}
for component in target_components {
relative_path.push(component);
}
Ok(relative_path)
}
}
static INIT: Once = Once::new();
fn init_provider() {
INIT.call_once(|| {
let _ = rustls::crypto::ring::default_provider().install_default();
});
}
#[tokio::test]
async fn grpc_invoke_tonic_unary_tls() {
init_provider();
let certs_path = PathBuf::from(env!("CARGO_MANIFEST_DIR"))
.parent()
.unwrap()
.join("examples/data/tls");
let server_cert = fs::read(certs_path.join("server.pem")).expect("failed to read server.pem");
let server_key = fs::read(certs_path.join("server.key")).expect("failed to read server.key");
let ca_cert = fs::read(certs_path.join("ca.pem")).expect("failed to read ca.pem");
let identity = tonic::transport::Identity::from_pem(server_cert, server_key);
let tls_config = tonic::transport::ServerTlsConfig::new().identity(identity);
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
let addr = listener.local_addr().unwrap();
let shutdown_notify = Arc::new(Notify::new());
let shutdown_notify_copy = shutdown_notify.clone();
let server_handle = tokio::spawn(async move {
let echo_server = EchoService {
response_headers: None,
};
let svc = EchoServer::new(echo_server);
let _ = Server::builder()
.tls_config(tls_config)
.expect("failed to set tls config")
.add_service(svc)
.serve_with_incoming_shutdown(
TcpListenerStream::new(listener),
shutdown_notify_copy.notified(),
)
.await;
});
let root_certs = RootCertificates::from_pem(ca_cert);
let root_provider = StaticProvider::new(root_certs);
let config = ClientTlsConfig::new().with_root_certificates_provider(root_provider);
let creds = RustlsChannelCredendials::new(config).unwrap();
let call_creds = Arc::new(MockCallCredentials {
metadata: vec![("x-test-metadata", "test-value")],
min_security_level: SecurityLevel::PrivacyAndIntegrity,
should_fail: None,
});
let composite_creds = CompositeChannelCredentials::new(creds, call_creds);
let target = format!("dns:///{}", addr);
let channel = Channel::new(&target, Arc::new(composite_creds), Default::default());
let (headers, resp, trilers) = perform_unary_echo(&channel, "hello interop tls").await;
assert_eq!(
headers.metadata().get("x-test-metadata-echo").unwrap(),
"test-value"
);
assert_eq!(resp.message, "hello interop tls");
assert!(
trilers.status().is_ok(),
"RPC failed: {:?}",
trilers.status()
);
shutdown_notify.notify_one();
server_handle.await.unwrap();
}
#[tokio::test]
async fn grpc_invoke_failure_cases() {
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
let addr = listener.local_addr().unwrap();
let shutdown_notify = Arc::new(Notify::new());
let shutdown_notify_copy = shutdown_notify.clone();
tokio::spawn(async move {
let echo_server = EchoService {
response_headers: None,
};
let svc = EchoServer::new(echo_server);
let _ = Server::builder()
.add_service(svc)
.serve_with_incoming_shutdown(
TcpListenerStream::new(listener),
shutdown_notify_copy.notified(),
)
.await;
});
let target = format!("dns:///{}", addr);
{
let creds = LocalChannelCredentials::new();
let call_creds = Arc::new(MockCallCredentials {
metadata: vec![],
min_security_level: SecurityLevel::PrivacyAndIntegrity,
should_fail: None,
});
let composite_creds = CompositeChannelCredentials::new(creds, call_creds);
let channel = Channel::new(&target, Arc::new(composite_creds), Default::default());
let trailers = perform_unary_echo_failure(&channel).await;
assert_eq!(
trailers.status().as_ref().unwrap_err().code(),
StatusCodeError::Unauthenticated
);
}
{
let creds = LocalChannelCredentials::new();
let call_creds = Arc::new(MockCallCredentials {
metadata: vec![],
min_security_level: SecurityLevel::NoSecurity,
should_fail: Some(crate::StatusError::new(
StatusCodeError::PermissionDenied,
"test message",
)),
});
let composite_creds = CompositeChannelCredentials::new(creds, call_creds);
let channel = Channel::new(&target, Arc::new(composite_creds), Default::default());
let trailers = perform_unary_echo_failure(&channel).await;
assert_eq!(
trailers.status().as_ref().unwrap_err().code(),
StatusCodeError::PermissionDenied
);
assert!(
trailers
.status()
.as_ref()
.unwrap_err()
.message()
.contains("test message")
);
}
{
let creds = LocalChannelCredentials::new();
let call_creds = Arc::new(MockCallCredentials {
metadata: vec![],
min_security_level: SecurityLevel::NoSecurity,
should_fail: Some(StatusError::new(
StatusCodeError::InvalidArgument,
"test message",
)),
});
let composite_creds = CompositeChannelCredentials::new(creds, call_creds);
let channel = Channel::new(&target, Arc::new(composite_creds), Default::default());
let trailers = perform_unary_echo_failure(&channel).await;
assert_eq!(
trailers.status().as_ref().unwrap_err().code(),
StatusCodeError::Internal
);
assert!(
trailers
.status()
.as_ref()
.unwrap_err()
.message()
.contains("test message")
);
}
shutdown_notify.notify_one();
}
async fn perform_unary_echo(
channel: &Channel,
message: &str,
) -> (ResponseHeaders, EchoResponse, Trailers) {
let (mut tx, mut rx) = channel
.invoke(
RequestHeaders::new().with_method_name("/grpc.examples.echo.Echo/UnaryEcho"),
CallOptions::default(),
)
.await;
let req = WrappedEchoRequest(EchoRequest {
message: message.into(),
});
tx.send(
&req,
SendOptions {
final_msg: true,
..Default::default()
},
)
.await
.unwrap();
let mut resp = WrappedEchoResponse(EchoResponse::default());
let ResponseStreamItem::Headers(headers) = rx.recv(&mut resp).await else {
panic!("Expected Headers first");
};
let ResponseStreamItem::Message = rx.recv(&mut resp).await else {
panic!("Expected Message after Headers");
};
let echo_resp = std::mem::take(&mut resp.0);
let ResponseStreamItem::Trailers(trailers) = rx.recv(&mut resp).await else {
panic!("Expected Trailers, got StreamClosed or other item");
};
(headers, echo_resp, trailers)
}
async fn perform_unary_echo_failure(channel: &Channel) -> Trailers {
let (_tx, mut rx) = channel
.invoke(
RequestHeaders::new().with_method_name("/grpc.examples.echo.Echo/UnaryEcho"),
CallOptions::default(),
)
.await;
let mut resp = WrappedEchoResponse(EchoResponse::default());
let ResponseStreamItem::Trailers(t) = rx.recv(&mut resp).await else {
panic!("Expected Trailers due to failure");
};
t
}
#[tokio::test]
async fn tonic_transport_invalid_base64_headers() {
super::reg();
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
let addr = listener.local_addr().unwrap();
let shutdown_notify = Arc::new(Notify::new());
let shutdown_notify_copy = shutdown_notify.clone();
let mut headers = HeaderMap::new();
headers.insert(
HeaderName::from_static("test-bin"),
HeaderValue::from_static("invalid base64 data"),
);
let response_headers = Some(TonicMetadata::from_headers(headers));
let server_handle = tokio::spawn(async move {
let echo_server = EchoService { response_headers };
let svc = EchoServer::new(echo_server);
let _ = Server::builder()
.add_service(svc)
.serve_with_incoming_shutdown(
TcpListenerStream::new(listener),
shutdown_notify_copy.notified(),
)
.await;
});
let builder = GLOBAL_TRANSPORT_REGISTRY
.get_transport(TCP_IP_NETWORK_TYPE)
.unwrap();
let config = Arc::new(TransportOptions::default());
let securty_opts = SecurityOpts {
credentials: LocalChannelCredentials::new_arc(),
authority: Authority::new("localhost".to_string(), None),
handshake_info: ClientHandshakeInfo::default(),
};
let (conn, _sec_info, _disconnection_listener) = builder
.dyn_connect(
addr.to_string(),
GrpcRuntime::new(TokioRuntime::default()),
&securty_opts,
&config,
)
.await
.unwrap();
let (mut tx, mut rx) = conn
.dyn_invoke(
RequestHeaders::new()
.with_method_name("/grpc.examples.echo.Echo/BidirectionalStreamingEcho"),
CallOptions::default(),
)
.await;
let mut dummy_msg = WrappedEchoResponse(EchoResponse { message: "".into() });
match rx.recv(&mut dummy_msg).await {
ResponseStreamItem::Trailers(trailers) => {
println!("Got trailers as expected due to invalid headers");
let status = trailers.status().as_ref().unwrap_err();
assert_eq!(status.code(), StatusCodeError::Internal);
}
item => panic!("Expected Trailers with error, got {:?}", item),
}
let request = EchoRequest {
message: "hello".into(),
};
let req = WrappedEchoRequest(request);
tokio::time::timeout(DEFAULT_TEST_DURATION, async {
while tx.send(&req, SendOptions::default()).await.is_ok() {}
})
.await
.expect("timed out waiting for stream to close");
shutdown_notify.notify_one();
server_handle.await.unwrap();
}
#[tokio::test]
async fn tonic_transport_recv_drop_cancels_send() {
super::reg();
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
let addr = listener.local_addr().unwrap();
let shutdown_notify = Arc::new(Notify::new());
let shutdown_notify_copy = shutdown_notify.clone();
let server_handle = tokio::spawn(async move {
let echo_server = EchoService {
response_headers: None,
};
let svc = EchoServer::new(echo_server);
let _ = Server::builder()
.add_service(svc)
.serve_with_incoming_shutdown(
TcpListenerStream::new(listener),
shutdown_notify_copy.notified(),
)
.await;
});
let builder = GLOBAL_TRANSPORT_REGISTRY
.get_transport(TCP_IP_NETWORK_TYPE)
.unwrap();
let config = Arc::new(TransportOptions::default());
let securty_opts = SecurityOpts {
credentials: LocalChannelCredentials::new_arc(),
authority: Authority::new("localhost".to_string(), None),
handshake_info: ClientHandshakeInfo::default(),
};
let (conn, _sec_info, _disconnection_listener) = builder
.dyn_connect(
addr.to_string(),
GrpcRuntime::new(TokioRuntime::default()),
&securty_opts,
&config,
)
.await
.unwrap();
let (mut tx, rx) = conn
.dyn_invoke(
RequestHeaders::new()
.with_method_name("/grpc.examples.echo.Echo/BidirectionalStreamingEcho"),
CallOptions::default(),
)
.await;
drop(rx);
let request = EchoRequest {
message: "hello".into(),
};
let req = WrappedEchoRequest(request);
tokio::time::timeout(DEFAULT_TEST_DURATION, async {
while tx.send(&req, SendOptions::default()).await.is_ok() {}
})
.await
.expect("timed out waiting for stream to close");
shutdown_notify.notify_one();
server_handle.await.unwrap();
}
struct WrappedEchoRequest(EchoRequest);
struct WrappedEchoResponse(EchoResponse);
impl SendMessage for WrappedEchoRequest {
fn encode(&self) -> Result<Box<dyn Buf + Send + Sync>, String> {
Ok(Box::new(Bytes::from(self.0.encode_to_vec())))
}
}
impl RecvMessage for WrappedEchoResponse {
fn decode(&mut self, data: &mut dyn Buf) -> Result<(), String> {
let buf = data.copy_to_bytes(data.remaining());
self.0 = EchoResponse::decode(buf).map_err(|e| e.to_string())?;
Ok(())
}
}
#[derive(Debug)]
struct EchoService {
response_headers: Option<TonicMetadata>,
}
#[async_trait]
impl Echo for EchoService {
async fn unary_echo(
&self,
request: tonic::Request<EchoRequest>,
) -> Result<tonic::Response<EchoResponse>, tonic::Status> {
let metadata = request.metadata().clone();
let message = request.into_inner().message;
let mut response = tonic::Response::new(EchoResponse { message });
if let Some(val) = metadata.get("x-test-metadata") {
response
.metadata_mut()
.insert("x-test-metadata-echo", val.clone());
}
Ok(response)
}
type ServerStreamingEchoStream = ReceiverStream<Result<EchoResponse, tonic::Status>>;
async fn server_streaming_echo(
&self,
_: tonic::Request<EchoRequest>,
) -> Result<tonic::Response<Self::ServerStreamingEchoStream>, tonic::Status> {
unimplemented!()
}
async fn client_streaming_echo(
&self,
_: tonic::Request<tonic::Streaming<EchoRequest>>,
) -> Result<tonic::Response<EchoResponse>, tonic::Status> {
unimplemented!()
}
type BidirectionalStreamingEchoStream =
Pin<Box<dyn Stream<Item = Result<EchoResponse, tonic::Status>> + Send + 'static>>;
async fn bidirectional_streaming_echo(
&self,
request: tonic::Request<tonic::Streaming<EchoRequest>>,
) -> Result<tonic::Response<Self::BidirectionalStreamingEchoStream>, tonic::Status> {
let metadata = request.metadata().clone();
if let Some(val) = metadata.get("x-test-metadata")
&& val == "test-value"
{
println!("Server received expected metadata");
}
let mut inbound = request.into_inner();
let outbound = async_stream::try_stream! {
while let Some(req) = inbound.next().await {
let req = req?; let reply = EchoResponse {
message: req.message.clone(),
};
yield reply;
}
println!("Server closing stream");
};
let mut response =
Response::new(Box::pin(outbound) as Self::BidirectionalStreamingEchoStream);
if let Some(headers) = &self.response_headers {
*response.metadata_mut() = headers.clone();
}
Ok(response)
}
}