mod grpc_raft_service;
pub(crate) mod grpc_transport;
#[cfg(test)]
mod grpc_raft_service_test;
#[cfg(test)]
mod grpc_transport_test;
use std::net::SocketAddr;
use std::path::Path;
use std::sync::Arc;
use std::time::Duration;
use d_engine_core::RaftNodeConfig;
use d_engine_core::Result;
use d_engine_core::SystemError;
use d_engine_core::TlsConfig;
use d_engine_core::TypeConfig;
use d_engine_proto::client::raft_client_service_server::RaftClientServiceServer;
use d_engine_proto::server::cluster::cluster_management_service_server::ClusterManagementServiceServer;
use d_engine_proto::server::election::raft_election_service_server::RaftElectionServiceServer;
use d_engine_proto::server::replication::raft_replication_service_server::RaftReplicationServiceServer;
use d_engine_proto::server::storage::snapshot_service_server::SnapshotServiceServer;
use futures::FutureExt;
use rcgen::CertifiedKey;
use rcgen::generate_simple_self_signed;
use tokio::sync::watch;
use tonic::codec::CompressionEncoding;
use tonic::transport::Certificate;
use tonic::transport::Identity;
use tonic::transport::ServerTlsConfig;
use tonic_health::server::health_reporter;
use tracing::debug;
use tracing::error;
use tracing::info;
use tracing::warn;
use crate::Node;
pub(crate) async fn start_rpc_server<T>(
node: Arc<Node<T>>,
listen_address: SocketAddr,
config: RaftNodeConfig,
mut shutdown_signal: watch::Receiver<()>,
) -> Result<()>
where
T: TypeConfig,
{
let (mut health_reporter, health_service) = health_reporter();
health_reporter.set_serving::<RaftClientServiceServer<Node<T>>>().await;
health_reporter.set_serving::<RaftElectionServiceServer<Node<T>>>().await;
health_reporter.set_serving::<RaftReplicationServiceServer<Node<T>>>().await;
health_reporter.set_serving::<ClusterManagementServiceServer<Node<T>>>().await;
health_reporter.set_serving::<SnapshotServiceServer<Node<T>>>().await;
let control_config = &config.network.control;
let data_config = &config.network.data;
let mut server_builder = tonic::transport::Server::builder()
.timeout(Duration::from_millis(control_config.request_timeout_in_ms))
.concurrency_limit_per_connection(control_config.concurrency_limit)
.max_concurrent_streams(control_config.max_concurrent_streams)
.tcp_keepalive(Some(Duration::from_secs(
control_config.tcp_keepalive_in_secs,
)))
.http2_keepalive_interval(Some(Duration::from_secs(
control_config.http2_keep_alive_interval_in_secs,
)))
.http2_keepalive_timeout(Some(Duration::from_secs(
control_config.http2_keep_alive_timeout_in_secs,
)))
.initial_stream_window_size(data_config.stream_window_size)
.initial_connection_window_size(data_config.connection_window_size)
.http2_adaptive_window(Some(data_config.adaptive_window))
.max_frame_size(Some(data_config.max_frame_size))
.tcp_nodelay(config.network.tcp_nodelay);
if config.tls.enable_tls {
if config.tls.generate_self_signed_certificates {
if Path::new(&config.tls.certificate_authority_root_path).exists() {
warn!(
"Certificate authority root already exists, remove the file if you want to generate new certificates. Skipping self signed certificates generation."
);
} else {
info!("Generating self signed certificates");
generate_self_signed_certificates(config.tls.clone());
}
}
let cert = std::fs::read_to_string(config.tls.server_certificate_path.clone())
.expect("error, failed to read server certificat");
let key = std::fs::read_to_string(config.tls.server_private_key_path.clone())
.expect("error, failed to read server private key");
let server_identity = Identity::from_pem(cert, key);
let tls = ServerTlsConfig::new().identity(server_identity);
if config.tls.enable_mtls {
let client_ca_cert =
std::fs::read_to_string(config.tls.client_certificate_authority_root_path.clone())
.expect("error, failed to read client certificate authority root");
let client_ca_cert = Certificate::from_pem(client_ca_cert);
let tls = tls.client_ca_root(client_ca_cert);
server_builder = server_builder.tls_config(tls).expect("error, failed to setup mTLS");
info!("gRPC mTLS enabled");
} else {
server_builder = server_builder.tls_config(tls).expect("error, failed to setup TLS");
info!("gRPC TLS enabled");
}
}
if let Err(e) = server_builder
.add_service(health_service)
.add_service({
let server = RaftClientServiceServer::from_arc(node.clone())
.accept_compressed(CompressionEncoding::Gzip);
if config.raft.rpc_compression.client_response {
server.send_compressed(CompressionEncoding::Gzip)
} else {
server
}
})
.add_service({
let server = RaftElectionServiceServer::from_arc(node.clone())
.accept_compressed(CompressionEncoding::Gzip);
if config.raft.rpc_compression.election_response {
server.send_compressed(CompressionEncoding::Gzip)
} else {
server
}
})
.add_service({
let server = RaftReplicationServiceServer::from_arc(node.clone())
.accept_compressed(CompressionEncoding::Gzip);
if config.raft.rpc_compression.replication_response {
server.send_compressed(CompressionEncoding::Gzip)
} else {
server
}
})
.add_service({
let server = ClusterManagementServiceServer::from_arc(node.clone())
.accept_compressed(CompressionEncoding::Gzip);
if config.raft.rpc_compression.cluster_response {
server.send_compressed(CompressionEncoding::Gzip)
} else {
server
}
})
.add_service({
let server =
SnapshotServiceServer::from_arc(node).accept_compressed(CompressionEncoding::Gzip);
if config.raft.rpc_compression.snapshot_response {
server.send_compressed(CompressionEncoding::Gzip)
} else {
server
}
})
.serve_with_shutdown(
listen_address,
shutdown_signal.changed().map(|_s| {
info!("Stopping RPC server. {}", listen_address);
}),
)
.await
{
error!("error to start internal rpc server :{:?}.", e);
return Err(SystemError::ServerUnavailable.into());
}
debug!("rpc service finished!");
Ok(())
}
fn generate_self_signed_certificates(config: TlsConfig) {
let subject_alt_names = vec!["localhost".to_string()];
let CertifiedKey { cert, key_pair } =
generate_simple_self_signed(subject_alt_names).expect("Certificate generation failed");
std::fs::write(&config.server_certificate_path, cert.pem())
.expect("Should succeed to write server certificate");
std::fs::write(&config.server_private_key_path, key_pair.serialize_pem())
.expect("Should succeed to write server private key");
}