use tonic::{metadata::KeyAndValueRef, Request, Response, Status};
use tracing::info;
use agp_config_grpc::client::ClientConfig;
use agp_config_grpc::testutils::helloworld::greeter_server::Greeter;
use agp_config_grpc::testutils::helloworld::{HelloReply, HelloRequest};
#[derive(Default)]
pub struct TestGreeter {
config: ClientConfig,
}
impl TestGreeter {
pub fn new(config: ClientConfig) -> Self {
Self { config }
}
}
#[tonic::async_trait]
impl Greeter for TestGreeter {
async fn say_hello(
&self,
request: Request<HelloRequest>,
) -> Result<Response<HelloReply>, Status> {
info!("Got a request from {:?}", request.remote_addr());
for key_and_value in request.metadata().iter() {
match key_and_value {
KeyAndValueRef::Ascii(ref key, ref value) => {
info!("Ascii: {:?}: {:?}", key, value)
}
KeyAndValueRef::Binary(ref key, ref value) => {
info!("Binary: {:?}: {:?}", key, value)
}
}
}
for (key, value) in self.config.headers.iter() {
let header = request.metadata().get(key);
assert!(header.is_some());
let header = header.unwrap();
assert_eq!(header.to_str().unwrap(), value);
}
let reply = HelloReply {
message: format!("Hello {}!", request.into_inner().name),
};
Ok(Response::new(reply))
}
}
#[cfg(test)]
mod tests {
use std::collections::HashMap;
use super::*;
use agp_config_auth::basic::Config as BasicAuthConfig;
use agp_config_auth::bearer::Config as BearerAuthConfig;
use agp_config_grpc::client::AuthenticationConfig as ClientAuthenticationConfig;
use agp_config_grpc::server::AuthenticationConfig as ServerAuthenticationConfig;
use agp_config_tls::client::TlsClientConfig;
use agp_config_tls::server::TlsServerConfig;
use tracing::debug;
use tracing::info;
use tracing_test::traced_test;
use agp_config_grpc::testutils::helloworld::greeter_client::GreeterClient;
use agp_config_grpc::testutils::helloworld::greeter_server::GreeterServer;
use agp_config_grpc::testutils::helloworld::HelloRequest;
use agp_config_grpc::{client::ClientConfig, server::ServerConfig};
static TEST_DATA_PATH: &str = concat!(env!("CARGO_MANIFEST_DIR"), "/data");
async fn run_server(
client_config: ClientConfig,
server_config: ServerConfig,
) -> Result<(), Box<dyn std::error::Error>> {
info!("GreeterServer listening on {}", server_config.endpoint);
let greeter = TestGreeter::new(client_config);
let ret = server_config.to_server_future(&[GreeterServer::new(greeter)]);
assert!(ret.is_ok(), "error: {:?}", ret.err());
let server_future = ret.unwrap();
server_future.await?;
Ok(())
}
async fn setup_client_and_server(client_config: ClientConfig, server_config: ServerConfig) {
let _result = rustls::crypto::aws_lc_rs::default_provider().install_default();
let client_config_clone = client_config.clone();
let _server = tokio::spawn(async move {
run_server(client_config_clone, server_config)
.await
.unwrap();
});
let channel_result = client_config.to_channel();
assert!(channel_result.is_ok(), "error: {:?}", channel_result.err());
let channel = channel_result.unwrap();
let mut client = GreeterClient::new(channel);
let request = tonic::Request::new(HelloRequest {
name: "Gateway".into(),
});
let response = client.say_hello(request).await;
assert!(response.is_ok(), "error: {:?}", response.err());
debug!("RESPONSE={:?}", response);
}
#[tokio::test]
#[traced_test]
async fn test_grpc_configuration() {
let client_config = ClientConfig::with_endpoint("http://[::1]:50051")
.with_headers(HashMap::from([(
"x-custom-header".to_string(),
"custom-value".to_string(),
)]))
.with_tls_setting(TlsClientConfig::new().with_insecure(true));
let server_config = ServerConfig::with_endpoint("[::1]:50051")
.with_tls_settings(TlsServerConfig::new().with_insecure(true));
setup_client_and_server(client_config, server_config).await
}
#[tokio::test]
#[traced_test]
async fn test_tls_grpc_configuration() {
let client_config = ClientConfig::with_endpoint("https://[::1]:50052")
.with_headers(HashMap::from([(
"x-custom-header".to_string(),
"custom-value".to_string(),
)]))
.with_tls_setting(
TlsClientConfig::new()
.with_insecure(false)
.with_insecure_skip_verify(true)
.with_tls_version("tls1.3")
.with_ca_file(&(TEST_DATA_PATH.to_string() + "/tls/ca.crt")),
);
let data_dir = std::path::PathBuf::from_iter([TEST_DATA_PATH]);
let cert = std::fs::read_to_string(data_dir.join("tls/server.crt")).unwrap();
let key = std::fs::read_to_string(data_dir.join("tls/server.key")).unwrap();
let server_config = ServerConfig::with_endpoint("[::1]:50052").with_tls_settings(
TlsServerConfig::new()
.with_insecure(false)
.with_cert_pem(&cert)
.with_key_pem(&key),
);
setup_client_and_server(client_config, server_config).await
}
#[tokio::test]
#[traced_test]
async fn test_tls_auth_grpc_configuration() {
let client_config = ClientConfig::with_endpoint("https://[::1]:50053")
.with_headers(HashMap::from([(
"x-custom-header".to_string(),
"custom-value".to_string(),
)]))
.with_tls_setting(
TlsClientConfig::new()
.with_insecure(false)
.with_insecure_skip_verify(true)
.with_tls_version("tls1.3")
.with_ca_file(&(TEST_DATA_PATH.to_string() + "/tls/ca.crt")),
)
.with_auth(ClientAuthenticationConfig::Basic(BasicAuthConfig::new(
"user", "password",
)));
let data_dir = std::path::PathBuf::from_iter([TEST_DATA_PATH]);
let cert = std::fs::read_to_string(data_dir.join("tls/server.crt")).unwrap();
let key = std::fs::read_to_string(data_dir.join("tls/server.key")).unwrap();
let server_config = ServerConfig::with_endpoint("[::1]:50053")
.with_tls_settings(
TlsServerConfig::new()
.with_insecure(false)
.with_cert_pem(&cert)
.with_key_pem(&key),
)
.with_auth(ServerAuthenticationConfig::Basic(BasicAuthConfig::new(
"user", "password",
)));
setup_client_and_server(client_config.clone(), server_config).await;
let channel = client_config
.with_auth(ClientAuthenticationConfig::Basic(BasicAuthConfig::new(
"user", "wrong",
)))
.to_channel()
.unwrap();
let mut client = GreeterClient::new(channel);
let request = tonic::Request::new(HelloRequest { name: "wee".into() });
let response = client.say_hello(request).await;
assert!(response.is_err(), "error: {:?}", response.err());
}
#[tokio::test]
#[traced_test]
async fn test_tls_bearer_auth_grpc_configuration() {
let client_config = ClientConfig::with_endpoint("https://[::1]:50054")
.with_headers(HashMap::from([(
"x-custom-header".to_string(),
"custom-value".to_string(),
)]))
.with_tls_setting(
TlsClientConfig::new()
.with_insecure(false)
.with_insecure_skip_verify(true)
.with_tls_version("tls1.3")
.with_ca_file(&(TEST_DATA_PATH.to_string() + "/tls/ca.crt")),
)
.with_auth(ClientAuthenticationConfig::Bearer(BearerAuthConfig::new(
"token",
)));
let data_dir = std::path::PathBuf::from_iter([TEST_DATA_PATH]);
let cert = std::fs::read_to_string(data_dir.join("tls/server.crt")).unwrap();
let key = std::fs::read_to_string(data_dir.join("tls/server.key")).unwrap();
let server_config = ServerConfig::with_endpoint("[::1]:50054")
.with_tls_settings(
TlsServerConfig::new()
.with_insecure(false)
.with_cert_pem(&cert)
.with_key_pem(&key),
)
.with_auth(ServerAuthenticationConfig::Bearer(BearerAuthConfig::new(
"token",
)));
setup_client_and_server(client_config.clone(), server_config).await;
let channel = client_config
.with_auth(ClientAuthenticationConfig::Bearer(BearerAuthConfig::new(
"wrong",
)))
.to_channel()
.unwrap();
let mut client = GreeterClient::new(channel);
let request = tonic::Request::new(HelloRequest { name: "wee".into() });
let response = client.say_hello(request).await;
assert!(response.is_err(), "error: {:?}", response.err());
}
}