use crate::tools::BytesGatherer;
use crate::transport::ddos::ddos::DdosConnectionGuard;
use bytes::Bytes;
use log::{info, trace, warn};
use std::sync::Arc;
use tokio::sync::{mpsc, oneshot};
use tokio_util::sync::CancellationToken;
pub struct IncomingRequest {
pub caller_address: String,
pub bytes: Bytes,
pub reply: oneshot::Sender<BytesGatherer>,
ddos_connection_guard: Arc<DdosConnectionGuard>,
}
impl IncomingRequest {
pub fn new(caller_address: String, bytes: Bytes, reply: oneshot::Sender<BytesGatherer>, ddos_connection_guard: Arc<DdosConnectionGuard>) -> Self {
Self { caller_address, bytes, reply, ddos_connection_guard }
}
pub fn report_bad_request(&self) {
self.ddos_connection_guard.report_bad_request();
}
}
#[derive(Debug, Clone, Copy, PartialEq)]
pub enum ServerState {
Created,
Listening,
Shutdown,
}
pub trait TransportServerHandler {
async fn handle(&self, bytes: Bytes) -> BytesGatherer;
async fn run(&self, cancellation_token: CancellationToken, mut rx: mpsc::Receiver<IncomingRequest>) -> anyhow::Result<()> {
loop {
tokio::select! {
_ = cancellation_token.cancelled() => {
break;
},
receipt = rx.recv() => {
match receipt {
Some(incoming) => {
info!("received packet from {}: {:?}", incoming.caller_address, incoming.bytes);
let result = self.handle(incoming.bytes.clone()).await;
let result = incoming.reply.send(result);
match result {
Ok(_) => { trace!("sent reply"); },
Err(_) => { warn!("failed to send reply"); },
}
},
None => {
warn!("channel closed");
break;
}
}
},
}
}
Ok(())
}
}
#[async_trait::async_trait]
pub trait TransportServer: Send + Sync {
fn get_address(&self) -> &String;
async fn listen(&self, cancellation_token: CancellationToken, handler: mpsc::Sender<IncomingRequest>) -> anyhow::Result<()>;
}
#[async_trait::async_trait]
pub trait TransportFactory: Send + Sync {
async fn get_bootstrap_addresses(&self) -> Vec<String>;
async fn create_server(&self, base_path: &str, port: u16, force_local_network: bool) -> anyhow::Result<Arc<dyn TransportServer>>;
async fn rpc(&self, address: &str, bytes: Bytes) -> anyhow::Result<Bytes>;
}
#[cfg(any(test, feature = "generic-tests"))]
pub mod tests {
use crate::tools::time::{MILLIS_IN_MILLISECOND, MILLIS_IN_SECOND};
use crate::tools::time_provider::time_provider::{RealTimeProvider, TimeProvider};
use crate::tools::tools::get_temp_dir;
use crate::tools::BytesGatherer;
use crate::transport::transport::{IncomingRequest, TransportFactory, TransportServerHandler};
use bytes::Bytes;
use log::{info, trace};
use std::sync::Arc;
use tokio::join;
use tokio::sync::mpsc;
use tokio_util::sync::CancellationToken;
pub async fn rpc_test(transport_factory: Arc<dyn TransportFactory>) -> anyhow::Result<()> {
let time_provider = Arc::new(RealTimeProvider::default());
let cancellation_token = CancellationToken::new();
let (_, temp_dir_str) = get_temp_dir()?;
let transport_server = transport_factory.create_server(&temp_dir_str, 0u16, true).await?;
let address = transport_server.get_address().clone();
trace!("server address is {}", address);
let (tx, rx) = mpsc::channel::<IncomingRequest>(32);
struct MyHandler {}
impl TransportServerHandler for MyHandler {
async fn handle(&self, _: Bytes) -> BytesGatherer {
BytesGatherer::from_bytes(Bytes::from("here is the reply"))
}
}
let my_handler = MyHandler {};
info!("running server and clients in parallel");
let results = join!(
my_handler.run(cancellation_token.clone(), rx),
transport_server.listen(cancellation_token.clone(), tx),
async {
info!("waiting for server to start");
time_provider.sleep_millis(MILLIS_IN_SECOND.const_mul(1)).await;
for _ in 0..20 {
info!("calling server");
let bytes = Bytes::from("hello");
let response = transport_factory.rpc(&address, bytes).await.unwrap();
assert_eq!(response, Bytes::from("here is the reply"));
time_provider.sleep_millis(MILLIS_IN_MILLISECOND.const_mul(100)).await;
}
info!("shutting down servers");
cancellation_token.cancel();
time_provider.sleep_millis(MILLIS_IN_SECOND.const_mul(1)).await;
Ok::<(), anyhow::Error>(())
}
);
assert!(results.0.is_ok());
assert!(results.1.is_ok());
assert!(results.2.is_ok());
Ok(())
}
pub async fn bind_port_zero_test(transport_factory: Arc<dyn TransportFactory>) -> anyhow::Result<()> {
let time_provider = Arc::new(RealTimeProvider::default());
info!("starting test");
let cancellation_token = CancellationToken::new();
let (_, temp_dir_str) = get_temp_dir()?;
let transport_server_1 = transport_factory.create_server(&temp_dir_str, 0u16, true).await?;
let transport_server_2 = transport_factory.create_server(&temp_dir_str, 0u16, true).await?;
let (tx_1, rx_1) = mpsc::channel::<IncomingRequest>(32);
let (tx_2, rx_2) = mpsc::channel::<IncomingRequest>(32);
struct MyHandler {}
impl TransportServerHandler for MyHandler {
async fn handle(&self, _: Bytes) -> BytesGatherer {
BytesGatherer::from_bytes(Bytes::from("here is the reply"))
}
}
let my_handler = MyHandler {};
info!("running server and clients in parallel");
let results = join!(
my_handler.run(cancellation_token.clone(), rx_1),
my_handler.run(cancellation_token.clone(), rx_2),
transport_server_1.listen(cancellation_token.clone(), tx_1),
transport_server_2.listen(cancellation_token.clone(), tx_2),
async {
info!("waiting for server to start");
time_provider.sleep_millis(MILLIS_IN_SECOND.const_mul(1)).await;
info!("shutting down servers");
cancellation_token.cancel();
time_provider.sleep_millis(MILLIS_IN_SECOND.const_mul(1)).await;
Ok::<(), anyhow::Error>(())
}
);
assert!(results.0.is_ok());
assert!(results.1.is_ok());
assert!(results.2.is_ok());
assert!(results.3.is_ok());
assert!(results.4.is_ok());
Ok(())
}
}