#[cfg(all(test, feature = "tokio", feature = "server"))]
#[path = "../../tests/session/client_health.rs"]
mod tests;
use std::sync::atomic::{AtomicU32, Ordering};
use std::sync::{Arc, Weak};
use std::time::Duration;
use log::{debug, warn};
use rand::Rng;
use crate::bytes::{ByteBuffer, ByteBufferMut, DynamicByteBuffer, FixedByteBuffer};
use crate::cache::SharedValue;
use crate::crypto::{ClientCryptoTool, ClientData};
use crate::session::SessionControllerError;
use crate::session::common::SessionManager;
use crate::settings::Settings;
use crate::settings::keys::*;
use crate::tailer::{ClientConnectionHandler, IdentityType, PacketFlags, ReturnCode, Tailer};
use crate::utils::random::get_rng;
use crate::utils::sync::{AsyncExecutor, FuturePool, Mutex, WatchReceiver, WatchSender, sleep};
use crate::utils::unix_timestamp_ms;
type HealthResponse<T> = (u32, u128, Option<DynamicByteBuffer>, Option<T>);
enum DecaySleepEvent<T: IdentityType> {
Timeout,
Terminated,
ResponseReceived {
server_next_in: u32,
receive_time: u128,
handshake_body: Option<DynamicByteBuffer>,
server_identity: Option<T>,
},
}
enum DecayShadowrideEvent {
Timeout,
Terminated,
Shadowridden,
}
enum SendOutcome {
Sent,
Retry,
Stop,
}
pub(super) struct HealthState<T: IdentityType + Clone, AE: AsyncExecutor, CC: ClientConnectionHandler> {
settings: Arc<Settings<AE>>,
smooth_rtt: Option<f64>,
rtt_variance: Option<f64>,
counter: Arc<AtomicU32>,
retry_count: u64,
last_sent_time: u128,
last_sent_next_in: u32,
current_pn: u64,
shadowride_pending: Option<(u64, u32)>,
crypto_tool: SharedValue<ClientCryptoTool<T>>,
client_data: Option<ClientData>,
initial_data_generator: CC,
response_rx: Option<WatchReceiver<HealthResponse<T>>>,
termination_code: ReturnCode,
}
impl<T: IdentityType + Clone, AE: AsyncExecutor, CC: ClientConnectionHandler> HealthState<T, AE, CC> {
fn new(settings: Arc<Settings<AE>>, crypto_tool: SharedValue<ClientCryptoTool<T>>, counter: Arc<AtomicU32>, initial_data_generator: CC, response_rx: WatchReceiver<HealthResponse<T>>) -> Self {
Self {
settings,
smooth_rtt: None,
rtt_variance: None,
counter,
retry_count: 0,
last_sent_time: 0,
last_sent_next_in: 0,
current_pn: 0,
shadowride_pending: None,
crypto_tool,
client_data: None,
initial_data_generator,
response_rx: Some(response_rx),
termination_code: ReturnCode::Success,
}
}
fn smooth_rtt_or_default(&self) -> f64 {
self.smooth_rtt.unwrap_or(self.settings.get(&RTT_DEFAULT) as f64)
}
fn next_packet_number(&self) -> u64 {
let counter = self.counter.fetch_add(1, Ordering::Relaxed).wrapping_add(1);
let timestamp = (unix_timestamp_ms() / 1000) as u32;
((timestamp as u64) << 32) | (counter as u64)
}
fn compute_next_in(&self) -> u32 {
let min = self.settings.get(&HEALTH_CHECK_NEXT_IN_MIN);
let max = self.settings.get(&HEALTH_CHECK_NEXT_IN_MAX);
get_rng().gen_range(min..=max) as u32
}
fn compute_timeout(&self) -> u64 {
let timeout_min = self.settings.get(&TIMEOUT_MIN);
let timeout_max = self.settings.get(&TIMEOUT_MAX);
match (self.smooth_rtt, self.rtt_variance) {
(Some(srtt), Some(rttvar)) => {
let factor = self.settings.get(&TIMEOUT_RTT_FACTOR);
((srtt + rttvar) * factor) as u64
}
_ => self.settings.get(&TIMEOUT_DEFAULT),
}
.clamp(timeout_min, timeout_max)
}
fn increment_retry(&mut self) -> bool {
self.retry_count += 1;
self.retry_count < self.settings.get(&MAX_RETRIES)
}
fn update_rtt(&mut self, receive_time: u128) {
let packet_rtt = (receive_time as f64) - (self.last_sent_time as f64) - (self.last_sent_next_in as f64);
let rtt_min = self.settings.get(&RTT_MIN) as f64;
let rtt_max = self.settings.get(&RTT_MAX) as f64;
let packet_rtt = packet_rtt.clamp(rtt_min, rtt_max);
match self.smooth_rtt {
None => {
self.smooth_rtt = Some(packet_rtt);
self.rtt_variance = Some(packet_rtt / 2.0);
}
Some(srtt) => {
let alpha = self.settings.get(&RTT_ALPHA);
let beta = self.settings.get(&RTT_BETA);
let new_srtt = (1.0 - alpha) * srtt + alpha * packet_rtt;
let new_rttvar = (1.0 - beta) * self.rtt_variance.unwrap() + beta * (new_srtt - packet_rtt).abs();
self.smooth_rtt = Some(new_srtt.clamp(rtt_min, rtt_max));
self.rtt_variance = Some(new_rttvar);
}
}
}
fn identity_value(&mut self) -> T {
self.crypto_tool.get().identity()
}
fn create_health_check_packet(&mut self, pn: u64, next_in: u32) -> DynamicByteBuffer {
let identity = self.identity_value();
let buf = self.settings.pool().allocate(Some(T::length()));
Tailer::health_check(buf, &identity, next_in, pn).into_buffer()
}
fn create_handshake_packet(&mut self, pn: u64, next_in: u32) -> DynamicByteBuffer {
let settings = self.settings.clone();
let initial_data = self.initial_data_generator.initial_data();
let (identity, client_data, handshake_secret, updated_tool) = {
let crypto = self.crypto_tool.get();
let identity = crypto.identity();
let (client_data, handshake_secret, initial_key) = crypto.create_handshake(settings.pool(), initial_data.slice());
let updated_tool = crypto.with_key(&initial_key);
(identity, client_data, handshake_secret, updated_tool)
};
self.client_data = Some(client_data);
self.crypto_tool.set(updated_tool);
let tailer_buffer = settings.pool().allocate(Some(Tailer::<T>::len()));
let tailer = Tailer::handshake(tailer_buffer, &identity, 0, next_in, pn, handshake_secret.len() as u16);
handshake_secret.append(tailer.buffer().slice())
}
fn process_handshake_response(&mut self, handshake_body: DynamicByteBuffer) -> Option<(FixedByteBuffer<32>, DynamicByteBuffer)> {
let client_data = self.client_data.take()?;
match self.crypto_tool.get().process_handshake_response(client_data, handshake_body, self.settings.pool()) {
Ok((session_key, server_initial_data)) => Some((session_key, server_initial_data)),
Err(err) => {
warn!("health provider: handshake response decryption failed: {err}");
None
}
}
}
}
async fn wait_for_response<T: IdentityType + Clone>(timeout_ms: u64, response_rx: &mut WatchReceiver<HealthResponse<T>>) -> DecaySleepEvent<T> {
let mut pool = FuturePool::new();
pool.add(async {
sleep(Duration::from_millis(timeout_ms)).await;
DecaySleepEvent::Timeout
});
pool.add(async {
match response_rx.recv().await {
Some((ni, time, body, identity)) => DecaySleepEvent::ResponseReceived {
server_next_in: ni,
receive_time: time,
handshake_body: body,
server_identity: identity,
},
None => DecaySleepEvent::Terminated,
}
});
pool.next().await.unwrap_or(DecaySleepEvent::Terminated)
}
pub struct ClientHealthProvider<T: IdentityType + Clone + 'static, AE: AsyncExecutor + 'static, SM: SessionManager + Send + Sync + 'static, CC: ClientConnectionHandler + 'static> {
manager: Weak<SM>,
state: Arc<Mutex<HealthState<T, AE, CC>>>,
settings: Arc<Settings<AE>>,
response_tx: WatchSender<HealthResponse<T>>,
shadowride_tx: WatchSender<()>,
}
impl<T: IdentityType + Clone, AE: AsyncExecutor, SM: SessionManager + Send + Sync, CC: ClientConnectionHandler + 'static> ClientHealthProvider<T, AE, SM, CC> {
pub fn new(manager: Weak<SM>, settings: Arc<Settings<AE>>, state_crypto: SharedValue<ClientCryptoTool<T>>, counter: Arc<AtomicU32>, response_tx: WatchSender<HealthResponse<T>>, shadowride_tx: WatchSender<()>, response_rx: WatchReceiver<HealthResponse<T>>, initial_data_generator: CC) -> Self {
let state = Arc::new(Mutex::new(HealthState::new(settings.clone(), state_crypto, counter, initial_data_generator, response_rx)));
Self {
manager,
state,
settings,
response_tx,
shadowride_tx,
}
}
pub async fn perform_handshake(&self) -> Result<(), SessionControllerError> {
let mut response_rx = self.state.lock().await.response_rx.take().expect("perform_handshake() must be called exactly once");
let handshake_factor = self.settings.get(&HANDSHAKE_NEXT_IN_FACTOR);
let Some(initial_server_next_in) = self.do_handshake(&mut response_rx, handshake_factor).await else {
return Err(SessionControllerError::InitialHandshakeFailed(self.settings.get(&MAX_RETRIES)));
};
let timer_response_rx = self.response_tx.subscribe();
let timer_shadowride_rx = self.shadowride_tx.subscribe();
let manager = self.manager.clone();
let state = self.state.clone();
let executor = self.settings.executor().clone();
executor.spawn(Self::timer_task(manager, state, timer_response_rx, timer_shadowride_rx, Some(initial_server_next_in)));
debug!("health provider: decay cycle started");
Ok(())
}
pub(super) async fn termination_snapshot(&self) -> (T, ReturnCode) {
let mut guard = self.state.lock().await;
let identity = guard.crypto_tool.get().identity();
let code = guard.termination_code;
(identity, code)
}
pub async fn feed_input(&self, tailer: Tailer<T>) -> Result<(), SessionControllerError> {
let pn = tailer.packet_number();
let time = tailer.time();
let state = self.state.lock().await;
if pn != state.current_pn {
debug!("health provider: discarding health check with unexpected PN (got {:#018x}, expected {:#018x})", pn, state.current_pn);
return Ok(());
}
let receive_time = unix_timestamp_ms();
let server_next_in = time.clamp(state.settings.get(&HEALTH_CHECK_NEXT_IN_MIN) as u32, state.settings.get(&HEALTH_CHECK_NEXT_IN_MAX) as u32);
drop(state);
if self.response_tx.send((server_next_in, receive_time, None, None)) {
Ok(())
} else {
Err(SessionControllerError::HealthProviderDied)
}
}
pub async fn feed_handshake_input(&self, tailer: Tailer<T>, body: DynamicByteBuffer) -> Result<(), SessionControllerError> {
let pn = tailer.packet_number();
let time = tailer.time();
let state = self.state.lock().await;
if pn != state.current_pn {
debug!("health provider: discarding handshake response with unexpected PN (got {:#018x}, expected {:#018x})", pn, state.current_pn);
return Ok(());
}
let receive_time = unix_timestamp_ms();
let server_next_in = time.clamp(state.settings.get(&HEALTH_CHECK_NEXT_IN_MIN) as u32, state.settings.get(&HEALTH_CHECK_NEXT_IN_MAX) as u32);
let server_identity = Some(tailer.identity());
drop(state);
if self.response_tx.send((server_next_in, receive_time, Some(body), server_identity)) {
Ok(())
} else {
Err(SessionControllerError::HealthProviderDied)
}
}
pub async fn feed_output(&self, tailer: Tailer<T>) -> Result<(), SessionControllerError> {
if tailer.flags().contains(PacketFlags::HEALTH_CHECK) {
return Ok(());
}
let shadowridden = {
let mut state = self.state.lock().await;
if let Some((pn, next_in)) = state.shadowride_pending.take() {
tailer.set_flags(tailer.flags() | PacketFlags::HEALTH_CHECK);
tailer.set_time(next_in);
tailer.set_packet_number_raw(pn);
state.last_sent_time = unix_timestamp_ms();
debug!("health provider: health check shadowridden onto data packet (PN={pn:#018x})");
true
} else {
false
}
};
if shadowridden && !self.shadowride_tx.send(()) {
return Err(SessionControllerError::HealthProviderDied);
}
Ok(())
}
async fn try_send(&self, packet: DynamicByteBuffer) -> SendOutcome {
Self::try_send_static(&self.manager, packet, &self.state).await
}
async fn try_send_static(manager: &Weak<SM>, packet: DynamicByteBuffer, state: &Arc<Mutex<HealthState<T, AE, CC>>>) -> SendOutcome {
let Some(mgr) = manager.upgrade() else {
warn!("health provider: session manager dropped unexpectedly");
return SendOutcome::Stop;
};
if let Err(err) = mgr.send_packet(packet, true).await {
warn!("health provider: failed to send packet: {err}");
let mut st = state.lock().await;
if st.increment_retry() {
debug!("health provider: retry {}/{}", st.retry_count, st.settings.get(&MAX_RETRIES));
return SendOutcome::Retry;
}
warn!("health provider: max retries ({}) reached, stopping", st.retry_count);
return SendOutcome::Stop;
}
SendOutcome::Sent
}
async fn do_handshake(&self, response_rx: &mut WatchReceiver<HealthResponse<T>>, handshake_factor: f64) -> Option<u32> {
loop {
let (packet, next_in) = {
let mut st = self.state.lock().await;
let pn = st.next_packet_number();
let next_in = st.compute_next_in();
st.current_pn = pn;
st.last_sent_time = unix_timestamp_ms();
let packet = st.create_handshake_packet(pn, next_in);
(packet, next_in)
};
debug!("health provider: sending handshake packet");
match self.try_send(packet).await {
SendOutcome::Sent => {}
SendOutcome::Retry => continue,
SendOutcome::Stop => return None,
}
let timeout_ms = {
let st = self.state.lock().await;
let handshake_delay = (next_in as f64 * handshake_factor) as u64;
handshake_delay + st.compute_timeout()
};
match wait_for_response(timeout_ms, response_rx).await {
DecaySleepEvent::Timeout => {
let mut st = self.state.lock().await;
if st.increment_retry() {
debug!("health provider: handshake timeout, retry {}/{}", st.retry_count, st.settings.get(&MAX_RETRIES));
continue;
}
warn!("health provider: handshake failed after {} retries", st.retry_count);
return None;
}
DecaySleepEvent::Terminated => {
warn!("health provider: response channel closed during handshake");
return None;
}
DecaySleepEvent::ResponseReceived {
server_next_in,
handshake_body,
server_identity,
..
} => {
let mut st = self.state.lock().await;
st.retry_count = 0;
if let Some(body) = handshake_body {
match st.process_handshake_response(body) {
Some((session_key, _server_initial_data)) => {
let updated_tool = {
let tool = st.crypto_tool.get();
if let Some(new_identity) = server_identity {
tool.with_key_and_identity(&session_key, new_identity)
} else {
tool.with_key(&session_key)
}
};
st.crypto_tool.set(updated_tool);
}
None => return None,
}
}
debug!("health provider: handshake completed");
return Some(server_next_in);
}
}
}
}
async fn send_or_shadowride(manager: &Weak<SM>, state: &Arc<Mutex<HealthState<T, AE, CC>>>, shadowride_rx: &mut WatchReceiver<()>, server_next_in: Option<u32>, pn: u64, next_in: u32) -> SendOutcome {
if let Some(srv_ni) = server_next_in {
let rtt = state.lock().await.smooth_rtt_or_default();
let pre_wait = ((srv_ni as f64) - rtt).max(0.0) as u64;
sleep(Duration::from_millis(pre_wait)).await;
{
let mut st = state.lock().await;
st.shadowride_pending = Some((pn, next_in));
}
let shadowride_window = (rtt * 2.0).max(1.0) as u64;
let shadowridden = {
let mut pool = FuturePool::new();
pool.add(async {
sleep(Duration::from_millis(shadowride_window)).await;
DecayShadowrideEvent::Timeout
});
pool.add(async {
match shadowride_rx.recv().await {
Some(()) => DecayShadowrideEvent::Shadowridden,
None => DecayShadowrideEvent::Terminated,
}
});
pool.next().await.unwrap_or(DecayShadowrideEvent::Terminated)
};
match shadowridden {
DecayShadowrideEvent::Timeout => {
let mut st = state.lock().await;
st.shadowride_pending = None;
st.last_sent_time = unix_timestamp_ms();
let packet = st.create_health_check_packet(pn, next_in);
drop(st);
Self::try_send_static(manager, packet, state).await
}
DecayShadowrideEvent::Terminated => {
warn!("health provider: shadowride channel closed unexpectedly");
SendOutcome::Stop
}
DecayShadowrideEvent::Shadowridden => SendOutcome::Sent,
}
} else {
let packet = {
let mut st = state.lock().await;
st.last_sent_time = unix_timestamp_ms();
st.create_health_check_packet(pn, next_in)
};
Self::try_send_static(manager, packet, state).await
}
}
async fn timer_task(manager: Weak<SM>, state: Arc<Mutex<HealthState<T, AE, CC>>>, mut response_rx: WatchReceiver<HealthResponse<T>>, mut shadowride_rx: WatchReceiver<()>, initial_server_next_in: Option<u32>) {
let mut server_next_in = initial_server_next_in;
loop {
let (pn, my_next_in) = {
let mut st = state.lock().await;
let pn = st.next_packet_number();
let my_next_in = st.compute_next_in();
st.current_pn = pn;
st.last_sent_next_in = my_next_in;
(pn, my_next_in)
};
match Self::send_or_shadowride(&manager, &state, &mut shadowride_rx, server_next_in, pn, my_next_in).await {
SendOutcome::Sent => {}
SendOutcome::Retry => {
server_next_in = None;
continue;
}
SendOutcome::Stop => break,
}
let timeout_ms = {
let st = state.lock().await;
(my_next_in as u64) + st.compute_timeout()
};
match wait_for_response(timeout_ms, &mut response_rx).await {
DecaySleepEvent::Timeout => {
let mut st = state.lock().await;
if st.increment_retry() {
debug!("health provider: health check timeout, retry {}/{}", st.retry_count, st.settings.get(&MAX_RETRIES));
server_next_in = None;
continue;
}
warn!("health provider: connection decayed after {} retries", st.retry_count);
st.termination_code = ReturnCode::ConnectionDecayed;
break;
}
DecaySleepEvent::Terminated => {
warn!("health provider: response channel closed unexpectedly");
break;
}
DecaySleepEvent::ResponseReceived {
server_next_in: srv_ni,
receive_time,
..
} => {
let mut st = state.lock().await;
st.update_rtt(receive_time);
st.retry_count = 0;
server_next_in = Some(srv_ni);
debug!("health provider: health check response received");
}
}
}
}
}