use std::path::PathBuf;
use std::time::Duration;
use tokio::sync::mpsc;
use tokio::time::sleep;
use network_protocol::error::Result;
use network_protocol::protocol::message::Message;
use network_protocol::service::tls_client::TlsClient;
use network_protocol::transport::tls::{TlsClientConfig, TlsServerConfig};
const TEST_PORT: u16 = 49152; const TEST_PORT_TAMPER: u16 = 49153; const CERT_PATH: &str = "tests/test_cert.pem";
const KEY_PATH: &str = "tests/test_key.pem";
#[allow(clippy::expect_used)]
fn generate_test_certificates() -> Result<(PathBuf, PathBuf)> {
let cert_path = PathBuf::from(CERT_PATH);
let key_path = PathBuf::from(KEY_PATH);
if !cert_path.exists() || !key_path.exists() {
TlsServerConfig::generate_self_signed(&cert_path, &key_path)
.expect("Failed to generate test certificates");
}
Ok((cert_path, key_path))
}
#[tokio::test]
async fn test_tls_communication() -> Result<()> {
let (cert_path, key_path) = generate_test_certificates()?;
let (shutdown_tx, shutdown_rx) = mpsc::channel::<()>(1);
let server_addr = format!("127.0.0.1:{TEST_PORT}");
let server_handle = tokio::spawn(async move {
let config = TlsServerConfig::new(cert_path, key_path);
let _ = network_protocol::service::tls_daemon::start_with_shutdown(
&server_addr,
config,
shutdown_rx,
)
.await;
});
sleep(Duration::from_millis(100)).await;
let config = TlsClientConfig::new("localhost").insecure();
let mut client = TlsClient::connect(&format!("127.0.0.1:{TEST_PORT}"), config).await?;
let response = client.request(Message::Ping).await?;
assert!(matches!(response, Message::Pong));
let test_message = Message::Custom {
command: "ECHO".to_string(),
payload: vec![1, 2, 3, 4],
};
let response = client.request(test_message.clone()).await?;
if let Message::Custom { command, payload } = response {
assert_eq!(command, "ECHO");
assert_eq!(payload, vec![1, 2, 3, 4]);
} else {
#[allow(clippy::panic)]
{
panic!("Expected Custom message, got: {response:?}");
}
}
drop(client);
sleep(Duration::from_millis(100)).await;
let _ = shutdown_tx.send(()).await;
let _ = tokio::time::timeout(Duration::from_secs(2), server_handle).await;
Ok(())
}
#[tokio::test]
async fn test_tls_tampering_protection() -> Result<()> {
let (cert_path, key_path) = generate_test_certificates()?;
let (shutdown_tx, shutdown_rx) = mpsc::channel::<()>(1);
let server_addr = format!("127.0.0.1:{TEST_PORT_TAMPER}");
let server_addr_clone = server_addr.clone();
let server_handle = tokio::spawn(async move {
let config = TlsServerConfig::new(cert_path, key_path);
let _ = network_protocol::service::tls_daemon::start_with_shutdown(
&server_addr_clone,
config,
shutdown_rx,
)
.await;
});
sleep(Duration::from_millis(100)).await;
let config = TlsClientConfig::new("localhost").insecure();
let mut client = TlsClient::connect(&server_addr, config).await?;
let response = client.request(Message::Ping).await?;
assert!(matches!(response, Message::Pong));
drop(client);
sleep(Duration::from_millis(100)).await;
let _ = shutdown_tx.send(()).await;
let _ = tokio::time::timeout(Duration::from_secs(2), server_handle).await;
Ok(())
}