use std::collections::HashMap;
use std::net::SocketAddr;
use std::sync::Arc;
use std::sync::atomic::AtomicU32;
use log::{debug, info};
use rand::Rng;
use rand::seq::SliceRandom;
use crate::bytes::{ByteBuffer, ByteBufferMut, DynamicByteBuffer};
use crate::cache::SharedValue;
use crate::certificate::{CertificateError, ClientCertificate};
use crate::crypto::{ClientCryptoTool, KEY_LENGTH, PAYLOAD_CRYPTO_OVERHEAD};
use crate::flow::client::ClientFlowManager;
use crate::flow::decoy::{DecoyFactory, random_decoy_factory};
use crate::flow::probe::ProbeFactory;
use crate::flow::{FlowConfig, FlowControllerError};
use crate::session::{ClientSessionManager, SessionManager};
use crate::settings::{Settings, keys};
use crate::socket::error::ClientSocketError;
use crate::tailer::{ClientConnectionHandler, IdentityType, Tailer};
use crate::utils::random::{SupportRng, get_rng, jittered_chunk_size};
use crate::utils::socket::Socket;
use crate::utils::sync::{AsyncExecutor, Mutex, NotifyQueueReceiver, assert_runtime, create_notify_queue};
pub struct ClientSocketBuilder<T: IdentityType + Clone, AE: AsyncExecutor + 'static, CC: ClientConnectionHandler> {
settings: Option<Arc<Settings<AE>>>,
flow_overrides: HashMap<SocketAddr, FlowConfig>,
certificate: ClientCertificate,
initial_data_generator: CC,
decoy_factory: DecoyFactory<T, AE>,
probe_factory: Option<ProbeFactory<AE>>,
}
impl<T: IdentityType + Clone + 'static, AE: AsyncExecutor + 'static, CC: ClientConnectionHandler + 'static> ClientSocketBuilder<T, AE, CC> {
pub fn new(certificate: ClientCertificate, initial_data_generator: CC) -> Self {
Self {
settings: None,
flow_overrides: HashMap::new(),
certificate,
initial_data_generator,
decoy_factory: random_decoy_factory(),
probe_factory: None,
}
}
pub fn with_settings(mut self, settings: Arc<Settings<AE>>) -> Self {
self.settings = Some(settings);
self
}
pub fn with_decoy_factory(mut self, factory: DecoyFactory<T, AE>) -> Self {
self.decoy_factory = factory;
self
}
pub fn with_decoy<DP: crate::flow::decoy::DecoyCommunicationMode<T, AE> + 'static>(mut self) -> Self {
self.decoy_factory = crate::flow::decoy::decoy_factory::<T, AE, DP>();
self
}
pub fn with_probe_factory(mut self, factory: ProbeFactory<AE>) -> Self {
self.probe_factory = Some(factory);
self
}
pub fn with_probe<PM: crate::flow::probe::ActiveProbeHandler<AE> + Default + 'static>(mut self) -> Self {
self.probe_factory = Some(crate::flow::probe::probe_factory::<AE, PM>());
self
}
pub fn with_flow_config(mut self, addr: SocketAddr, config: FlowConfig) -> Self {
self.flow_overrides.insert(addr, config);
self
}
pub async fn build(mut self) -> Result<ClientSocket<T, AE, CC>, ClientSocketError> {
assert_runtime().map_err(ClientSocketError::UnsupportedRuntime)?;
let cert_addrs = self.certificate.addresses();
if cert_addrs.is_empty() {
return Err(ClientSocketError::CertificateError(CertificateError::NoAddresses));
}
let settings = self.settings.take().unwrap_or_else(|| Arc::new(Settings::default()));
for addr in self.flow_overrides.keys() {
if !cert_addrs.contains(addr) {
return Err(ClientSocketError::AddressNotInCertificate(*addr));
}
}
let addr_configs: Vec<(SocketAddr, FlowConfig)> = if self.flow_overrides.is_empty() {
let mut rng = get_rng();
let n = rng.gen_range(1..=cert_addrs.len());
cert_addrs.choose_multiple(&mut rng, n).map(|&addr| (addr, FlowConfig::random(&settings))).collect()
} else {
self.flow_overrides.drain().collect()
};
let identity_bytes = T::from_bytes(self.initial_data_generator.version(T::length()).slice());
let static_key = get_rng().random_byte_buffer::<KEY_LENGTH>();
let cipher = SharedValue::new(ClientCryptoTool::new(self.certificate.clone(), identity_bytes, &static_key));
let tailer_wire_len = Tailer::<T>::encrypted_len_c2s();
let mut max_data_payload = usize::MAX;
let counter = Arc::new(AtomicU32::new(0));
let mut flows = Vec::with_capacity(addr_configs.len());
for (addr, config) in addr_configs {
config.assert(settings.mtu()).map_err(ClientSocketError::FlowError)?;
max_data_payload = max_data_payload.min(config.max_user_payload(settings.mtu(), PAYLOAD_CRYPTO_OVERHEAD, tailer_wire_len));
let sock = Socket::new(addr, None).await.map_err(ClientSocketError::SocketError)?;
let cipher_cache = cipher.create_cache();
let flow = ClientFlowManager::new(config, cipher_cache, settings.clone(), sock, self.probe_factory.as_ref(), &self.decoy_factory, Arc::clone(&counter), addr).await.map_err(ClientSocketError::FlowError)?;
flows.push(flow);
}
let max_data_payload = if max_data_payload == usize::MAX {
settings.mtu()
} else {
max_data_payload
};
if max_data_payload == 0 {
return Err(ClientSocketError::FlowError(FlowControllerError::AssertionFailed {
message: "flow configuration leaves no room for user data (max_data_payload = 0); reduce fake-body constant length or increase MTU".to_string(),
}));
}
info!("client socket built: max_data_payload={}B (mtu={}B, {} flow(s))", max_data_payload, settings.mtu(), flows.len());
let session = ClientSessionManager::new(cipher, flows, settings.clone(), counter, self.initial_data_generator).map_err(ClientSocketError::SessionError)?;
let (incoming_tx, incoming_rx) = create_notify_queue::<DynamicByteBuffer>();
let receive_session = session.clone();
settings.executor().spawn(async move {
loop {
match receive_session.receive_packet().await {
Ok(buffer) => {
incoming_tx.push(buffer);
}
Err(err) => {
debug!("client bg-recv: terminated: {err}");
break;
}
}
}
});
session.start().await.map_err(ClientSocketError::SessionError)?;
Ok(ClientSocket {
session,
incoming_rx: Mutex::new(incoming_rx),
max_data_payload,
settings,
})
}
}
pub struct ClientSocket<T: IdentityType + Clone + 'static, AE: AsyncExecutor + 'static, CC: ClientConnectionHandler + 'static> {
session: Arc<ClientSessionManager<T, AE, Arc<ClientFlowManager<T, AE>>, CC>>,
incoming_rx: Mutex<NotifyQueueReceiver<DynamicByteBuffer>>,
max_data_payload: usize,
settings: Arc<Settings<AE>>,
}
impl<T: IdentityType + Clone + 'static, AE: AsyncExecutor + 'static, CC: ClientConnectionHandler + 'static> ClientSocket<T, AE, CC> {
pub async fn send(&self, packet: DynamicByteBuffer) -> Result<(), ClientSocketError> {
self.session.send_packet(packet, false).await.map_err(ClientSocketError::SessionError)
}
pub async fn send_bytes(&self, data: &[u8]) -> Result<(), ClientSocketError> {
let jitter = self.settings.get(&keys::SEND_BYTES_JITTER);
let chunk = self.settings.get(&keys::SEND_BYTES_CHUNK) as usize;
let mut offset = 0;
while offset < data.len() {
let remaining = data.len() - offset;
let chunk_size = if remaining <= self.max_data_payload {
remaining
} else {
jittered_chunk_size(self.max_data_payload, chunk, jitter)
};
let buffer = self.settings.pool().allocate(Some(chunk_size));
buffer.slice_mut().copy_from_slice(&data[offset..offset + chunk_size]);
self.send(buffer).await?;
offset += chunk_size;
}
Ok(())
}
pub fn max_data_payload(&self) -> usize {
self.max_data_payload
}
pub async fn receive(&self) -> Result<DynamicByteBuffer, ClientSocketError> {
let buf = self.incoming_rx.lock().await.recv().await.ok_or(ClientSocketError::ChannelClosed)?;
Ok(buf)
}
pub async fn receive_bytes(&self) -> Result<Vec<u8>, ClientSocketError> {
let buffer = self.receive().await?;
Ok(buffer.slice().to_vec())
}
}