use futures::{SinkExt, StreamExt};
use network_protocol::core::packet::Packet;
use network_protocol::error::{ProtocolError, Result};
use network_protocol::protocol::message::Message;
use network_protocol::service::secure::SecureConnection;
use network_protocol::transport::remote;
use rand_core::{OsRng, RngCore};
use sha2::{Digest, Sha256};
use x25519_dalek::{EphemeralSecret, PublicKey};
pub struct BenchmarkClient {
conn: SecureConnection,
}
impl BenchmarkClient {
pub async fn connect(addr: &str) -> Result<Self> {
let mut framed = remote::connect(addr).await?;
let client_secret = EphemeralSecret::random_from_rng(OsRng);
let client_public = PublicKey::from(&client_secret);
let mut client_nonce = [0u8; 16];
RngCore::fill_bytes(&mut OsRng, &mut client_nonce);
#[allow(clippy::expect_used)]
let timestamp = std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.expect("Time went backwards")
.as_millis() as u64;
let init_msg = Message::SecureHandshakeInit {
pub_key: client_public.to_bytes(),
nonce: client_nonce,
timestamp,
};
let init_bytes = bincode::serialize(&init_msg)?;
framed
.send(Packet {
version: 1,
payload: init_bytes,
})
.await?;
let packet = framed.next().await.ok_or(ProtocolError::Timeout)??;
let response: Message = bincode::deserialize(&packet.payload)?;
let (server_pub_key, server_nonce, _nonce_verification) = match response {
Message::SecureHandshakeResponse {
pub_key,
nonce,
nonce_verification,
} => (pub_key, nonce, nonce_verification),
_ => {
return Err(ProtocolError::HandshakeError(
"Invalid server response message type".into(),
))
}
};
println!("[benchmark_client] Received server response");
let expected_client_nonce_hash = {
let mut hasher = Sha256::new();
hasher.update(client_nonce);
hasher.finalize().to_vec()
};
println!("[benchmark_client] Server nonce verification: {_nonce_verification:?}");
println!("[benchmark_client] Expected client nonce hash: {expected_client_nonce_hash:?}");
let server_nonce_hash = {
let mut hasher = Sha256::new();
hasher.update(server_nonce);
let mut result = [0u8; 32];
result.copy_from_slice(&hasher.finalize()[..]);
result
};
let verify_msg = Message::SecureHandshakeConfirm {
nonce_verification: server_nonce_hash,
};
let verify_bytes = bincode::serialize(&verify_msg)?;
framed
.send(Packet {
version: 1,
payload: verify_bytes,
})
.await?;
println!("[benchmark_client] Sent client verification");
let server_public = PublicKey::from(server_pub_key);
let shared_secret = client_secret.diffie_hellman(&server_public);
let key = {
let mut hasher = Sha256::new();
hasher.update(shared_secret.as_bytes());
hasher.update(client_nonce);
hasher.update(server_nonce);
let mut result = [0u8; 32];
result.copy_from_slice(&hasher.finalize()[..]);
result
};
println!("[benchmark_client] Derived session key");
let conn = SecureConnection::new(framed, key);
Ok(Self { conn })
}
pub async fn send(&mut self, msg: Message) -> Result<()> {
self.conn.secure_send(msg).await
}
pub async fn recv(&mut self) -> Result<Message> {
self.conn.secure_recv().await
}
}