#[cfg(all(test, feature = "tokio", feature = "server", feature = "client"))]
#[path = "../../tests/session/server.rs"]
mod tests;
use std::hash::Hash;
use std::marker::PhantomData;
use std::sync::atomic::{AtomicU32, Ordering};
use std::sync::{Arc, Weak as StdWeak};
use async_trait::async_trait;
use log::{debug, warn};
use rand::Rng;
use crate::bytes::{ByteBuffer, ByteBufferMut, DynamicByteBuffer};
use crate::cache::{CachedMapEntryTemplate, SharedMap};
use crate::crypto::{UserCryptoState, UserServerState};
use crate::session::error::SessionControllerError;
use crate::session::server_health::ServerHealthProvider;
use crate::settings::{Settings, keys};
use crate::tailer::{IdentityType, PacketFlags, ReturnCode, Tailer};
use crate::utils::bitset::AtomicBitSet;
use crate::utils::random::get_rng;
use crate::utils::sync::{AsyncExecutor, NotifyQueueSender};
use crate::utils::unix_timestamp_ms;
#[async_trait]
pub trait OutgoingRouter<T: Send + Sync>: Send + Sync {
async fn route_packet(&self, packet: DynamicByteBuffer, identity: &T) -> bool;
async fn remove_session(&self, identity: &T);
}
pub struct IncomingPacket<T: IdentityType> {
pub body: DynamicByteBuffer,
pub tailer: Tailer<T>,
}
pub struct ServerSessionManager<T: IdentityType + Clone + Eq + Hash + Send + ToString + 'static, AE: AsyncExecutor + 'static> {
crypto_send: CachedMapEntryTemplate<T, UserServerState>,
crypto_recv: CachedMapEntryTemplate<T, UserServerState>,
identity: T,
active_flows: AtomicBitSet,
counter: Arc<AtomicU32>,
incoming_tx: NotifyQueueSender<DynamicByteBuffer>,
health_provider: ServerHealthProvider,
_phantom: PhantomData<AE>,
}
impl<T: IdentityType + Clone + Eq + Hash + Send + ToString, AE: AsyncExecutor> ServerSessionManager<T, AE> {
pub(crate) async fn assemble_session(crypto_state: UserCryptoState, response_body: DynamicByteBuffer, handshake_tailer: Tailer<T>, identity: T, users: &mut SharedMap<T, UserServerState>, incoming_tx: NotifyQueueSender<DynamicByteBuffer>, router: StdWeak<dyn OutgoingRouter<T>>, num_flows: usize, settings: Arc<Settings<AE>>) -> Result<(Arc<Self>, DynamicByteBuffer), SessionControllerError> {
let user_state = UserServerState::new(crypto_state);
users.insert(identity.clone(), user_state).await;
let crypto_send = users.create_cache_for(identity.clone());
let crypto_recv = users.create_cache_for(identity.clone());
let server_next_in = get_rng().gen_range(settings.get(&keys::HEALTH_CHECK_NEXT_IN_MIN)..=settings.get(&keys::HEALTH_CHECK_NEXT_IN_MAX)) as u32;
let response_body_len = response_body.len();
let tailer_buf = response_body.expand_end(T::length()).rebuffer_start(response_body_len);
let _response_tailer = Tailer::handshake(tailer_buf, &identity, ReturnCode::Success.into(), server_next_in, handshake_tailer.packet_number(), response_body_len as u16);
let response_packet = response_body.expand_end(Tailer::<T>::len());
let health_provider = ServerHealthProvider::new(router, identity.clone(), settings, server_next_in);
let session = Arc::new(Self {
crypto_send,
crypto_recv,
identity,
active_flows: AtomicBitSet::new(num_flows),
counter: Arc::new(AtomicU32::new(0)),
incoming_tx,
health_provider,
_phantom: PhantomData,
});
Ok((session, response_packet))
}
pub fn counter(&self) -> Arc<AtomicU32> {
Arc::clone(&self.counter)
}
pub fn note_active_flow(&self, flow_index: usize) {
self.active_flows.set(flow_index);
}
pub fn select_active_flow(&self, num_flows: usize) -> usize {
self.active_flows.random_set_index(num_flows)
}
pub async fn prepare_outgoing(&self, packet: DynamicByteBuffer, generated: bool) -> Result<DynamicByteBuffer, SessionControllerError> {
if generated {
return Ok(packet);
}
let mut entry = self.crypto_send.create_entry();
let user_state = entry.get_mut().await.map_err(SessionControllerError::MissingCache)?;
let encrypted_payload = user_state.crypto_mut().encrypt_payload(packet, None).map_err(SessionControllerError::CryptoError)?;
let payload_length = encrypted_payload.len() as u16;
drop(entry);
let packet_number = self.next_packet_number();
let encrypted_payload_len = encrypted_payload.len();
let tailer_buf = encrypted_payload.expand_end(T::length()).rebuffer_start(encrypted_payload_len);
let _tailer = Tailer::data(tailer_buf, &self.identity, payload_length, packet_number);
let assembled = encrypted_payload.expand_end(Tailer::<T>::len());
debug!("server session [{}]: sending data packet", self.identity.to_string());
Ok(assembled)
}
fn next_packet_number(&self) -> u64 {
let counter = self.counter.fetch_add(1, Ordering::Relaxed).wrapping_add(1);
let timestamp = (unix_timestamp_ms() / 1000) as u32;
((timestamp as u64) << 32) | (counter as u64)
}
pub async fn process_incoming(&self, incoming: IncomingPacket<T>) -> Result<(), SessionControllerError> {
let IncomingPacket {
body,
tailer,
} = incoming;
debug!("server session [{}]: received {:?} packet", self.identity.to_string(), tailer.flags());
if tailer.flags().is_termination() {
debug!("server session [{}]: connection terminated by client (code={})", self.identity.to_string(), tailer.code());
return Err(SessionControllerError::ConnectionTerminated(tailer.code()));
}
if tailer.flags().contains(PacketFlags::HEALTH_CHECK) && !tailer.flags().has_payload() {
self.health_provider.feed_health_check(tailer.time(), tailer.packet_number());
}
if tailer.flags().has_payload() {
let payload_len = tailer.payload_length() as usize;
let encrypted_payload = body.rebuffer_start(body.len() - payload_len);
let decrypt_result = {
let mut entry = self.crypto_recv.create_entry();
let user_state = entry.get_mut().await.map_err(SessionControllerError::MissingCache)?;
user_state.crypto_mut().decrypt_payload(encrypted_payload, None)
};
match decrypt_result {
Ok(decrypted) => {
if tailer.flags().contains(PacketFlags::HEALTH_CHECK) {
self.health_provider.feed_health_check(tailer.time(), tailer.packet_number());
}
self.incoming_tx.push(decrypted);
}
Err(err) => {
warn!("server session [{}]: payload decryption failed: {}", self.identity.to_string(), err);
}
}
}
Ok(())
}
}