use futures::{SinkExt, StreamExt};
use tokio::time;
use tracing::{debug, info, instrument, warn};
use crate::config::ClientConfig;
use crate::core::packet::Packet;
use crate::protocol::message::Message;
use crate::error::{ProtocolError, Result};
use crate::protocol::handshake::{
client_derive_session_key, client_secure_handshake_init, client_secure_handshake_verify,
};
use crate::protocol::heartbeat::{build_ping, is_pong};
use crate::protocol::keepalive::KeepAliveManager;
use crate::service::secure::SecureConnection;
use crate::transport::remote;
use crate::utils::replay_cache::ReplayCache;
use crate::utils::timeout::with_timeout_error;
pub struct Client {
conn: SecureConnection,
keep_alive: KeepAliveManager,
config: ClientConfig,
#[allow(dead_code)]
replay_cache: ReplayCache,
}
impl Client {
#[instrument(skip(addr), fields(address = %addr))]
pub async fn connect(addr: &str) -> Result<Self> {
let config = ClientConfig {
address: addr.to_string(),
..Default::default()
};
Self::connect_with_config(config).await
}
#[instrument(skip(config), fields(address = %config.address))]
pub async fn connect_with_config(config: ClientConfig) -> Result<Self> {
let mut framed = with_timeout_error(
async { remote::connect(&config.address).await },
config.connection_timeout,
)
.await?;
let (client_state, init_msg) = client_secure_handshake_init()?;
let init_bytes = bincode::serialize(&init_msg)?;
framed
.send(Packet {
version: 1,
payload: init_bytes,
})
.await?;
let response = with_timeout_error(
async {
let packet = framed
.next()
.await
.ok_or(ProtocolError::ConnectionClosed)?
.map_err(|e| ProtocolError::TransportError(e.to_string()))?;
bincode::deserialize::<Message>(&packet.payload)
.map_err(|e| ProtocolError::DeserializeError(e.to_string()))
},
config.connection_timeout,
)
.await?;
let (server_pub_key, server_nonce, nonce_verification) = match response {
Message::SecureHandshakeResponse {
pub_key,
nonce,
nonce_verification,
} => (pub_key, nonce, nonce_verification),
_ => {
return Err(ProtocolError::HandshakeError(
"Invalid server response message type".into(),
))
}
};
let (client_state_verified, verify_msg) = client_secure_handshake_verify(
client_state,
server_pub_key,
server_nonce,
nonce_verification,
&config.address,
&mut ReplayCache::new(),
)?;
let verify_bytes = bincode::serialize(&verify_msg)?;
framed
.send(Packet {
version: 1,
payload: verify_bytes,
})
.await?;
let key = client_derive_session_key(client_state_verified)?;
let conn = SecureConnection::new(framed, key);
let dead_timeout = config.heartbeat_interval.mul_f32(4.0); let keep_alive = KeepAliveManager::with_settings(config.heartbeat_interval, dead_timeout);
info!("Connection established successfully");
Ok(Self {
conn,
keep_alive,
config,
replay_cache: ReplayCache::new(),
})
}
#[instrument(skip(self, msg))]
pub async fn send(&mut self, msg: Message) -> Result<()> {
let result = self.conn.secure_send(msg).await;
if result.is_ok() {
self.keep_alive.update_send();
}
result
}
#[instrument(skip(self))]
pub async fn recv(&mut self) -> Result<Message> {
let result = self.conn.secure_recv().await;
if result.is_ok() {
self.keep_alive.update_recv();
}
result
}
#[instrument(skip(self))]
pub async fn send_keepalive(&mut self) -> Result<()> {
debug!("Sending keep-alive ping");
let ping = build_ping();
self.send(ping).await
}
#[instrument(skip(self))]
pub async fn recv_with_keepalive(
&mut self,
timeout_duration: std::time::Duration,
) -> Result<Message> {
let mut ping_interval = time::interval(self.keep_alive.ping_interval());
let timeout = time::sleep(timeout_duration);
tokio::pin!(timeout);
loop {
tokio::select! {
_ = ping_interval.tick() => {
if self.keep_alive.should_ping() {
self.send_keepalive().await?;
}
if self.keep_alive.is_connection_dead() {
warn!(dead_seconds = ?self.keep_alive.time_since_last_recv().as_secs(),
"Connection appears dead");
return Err(ProtocolError::ConnectionTimeout);
}
}
recv_result = self.conn.secure_recv::<Message>() => {
match recv_result {
Ok(msg) => {
self.keep_alive.update_recv();
if !is_pong(&msg) {
return Ok(msg);
} else {
debug!("Received pong response");
}
}
Err(ProtocolError::Timeout) => {
continue;
}
Err(e) => return Err(e),
}
}
_ = &mut timeout => {
return Err(ProtocolError::Timeout);
}
}
}
}
#[instrument(skip(self, msg))]
pub async fn send_and_wait(&mut self, msg: Message) -> Result<Message> {
self.send(msg).await?;
self.recv_with_keepalive(self.config.response_timeout).await
}
}