use std::collections::VecDeque;
use std::net;
use std::sync::{Arc, RwLock};
use std::time;
use super::endpoint;
use super::frame;
use super::socket;
use super::ErrorKind;
use super::SendMode;
const FRAME_SIZE_MAX: usize = 1472;
const HANDSHAKE_TIMEOUT_DEFAULT_MS: u64 = 10_000;
const HANDSHAKE_TIMEOUT_MIN_MS: u64 = 2_000;
const HANDSHAKE_RESEND_TIMEOUT_MS: u64 = 2_000;
const CONNECTION_TIMEOUT_DEFAULT_MS: u64 = 10_000;
const CONNECTION_TIMEOUT_MIN_MS: u64 = 2_000;
#[derive(Clone)]
pub struct Config {
pub handshake_timeout_ms: u64,
pub connection_timeout_ms: u64,
pub channel_balance: endpoint::ChannelBalanceConfig,
}
impl Default for Config {
fn default() -> Self {
Self {
handshake_timeout_ms: HANDSHAKE_TIMEOUT_DEFAULT_MS,
connection_timeout_ms: CONNECTION_TIMEOUT_DEFAULT_MS,
channel_balance: Default::default(),
}
}
}
impl Config {
fn validate(&self) {
assert!(
self.handshake_timeout_ms >= HANDSHAKE_TIMEOUT_MIN_MS,
"invalid client configuration: handshake_timeout_ms < {HANDSHAKE_TIMEOUT_MIN_MS}"
);
assert!(
self.connection_timeout_ms >= CONNECTION_TIMEOUT_MIN_MS,
"invalid client configuration: connection_timeout_ms < {CONNECTION_TIMEOUT_MIN_MS}"
);
self.channel_balance.validate();
}
}
type EndpointRef = Arc<RwLock<endpoint::Endpoint>>;
enum HandshakePhase {
Alpha,
Beta,
}
struct HandshakeState {
phase: HandshakePhase,
timeout_time_ms: u64,
timeout_ms: u64,
packet_buffer: Vec<(Box<[u8]>, SendMode)>,
frame: Box<[u8]>,
local_nonce: u32,
}
struct ActiveState {
endpoint_ref: EndpointRef,
}
enum State {
Handshake(HandshakeState),
Active(ActiveState),
Quiescent,
}
struct Timer {
timeout_ms: Option<u64>,
}
#[derive(Debug)]
pub enum Event {
Connect,
Disconnect,
Receive(Box<[u8]>),
Error(ErrorKind),
}
struct EndpointContext<'a> {
client: &'a mut ClientCore,
}
struct ClientCore {
config: Config,
time_ref: time::Instant,
socket_tx: socket::ConnectedSocketTx,
state: State,
rto_timer: Timer,
events: VecDeque<Event>,
}
pub struct Client {
core: ClientCore,
socket_rx: socket::ConnectedSocketRx,
}
impl<'a> EndpointContext<'a> {
fn new(client: &'a mut ClientCore) -> Self {
Self { client }
}
}
impl<'a> endpoint::HostContext for EndpointContext<'a> {
fn send_frame(&mut self, frame_bytes: &[u8]) {
self.client.socket_tx.send(frame_bytes);
}
fn set_rto_timer(&mut self, time_ms: u64) {
self.client.rto_timer.timeout_ms = Some(time_ms);
}
fn unset_rto_timer(&mut self) {
self.client.rto_timer.timeout_ms = None;
}
fn on_connect(&mut self) {
self.client.events.push_back(Event::Connect);
}
fn on_disconnect(&mut self) {
self.client.events.push_back(Event::Disconnect);
}
fn on_receive(&mut self, packet_bytes: Box<[u8]>) {
self.client.events.push_back(Event::Receive(packet_bytes))
}
fn on_timeout(&mut self) {
self.client
.events
.push_back(Event::Error(ErrorKind::Timeout));
}
}
fn connection_params_compatible(a: &frame::ConnectionParams, b: &frame::ConnectionParams) -> bool {
a.packet_size_in_max >= b.packet_size_out_max && b.packet_size_in_max >= a.packet_size_out_max
}
impl ClientCore {
fn time_now_ms(&self) -> u64 {
(time::Instant::now() - self.time_ref).as_millis() as u64
}
pub fn next_timer_timeout(&self) -> Option<time::Duration> {
let now_ms = self.time_now_ms();
if let Some(t_ms) = self.rto_timer.timeout_ms {
let remaining_ms = t_ms.saturating_sub(now_ms);
return Some(time::Duration::from_millis(remaining_ms));
}
None
}
fn process_timeout(&mut self, now_ms: u64) {
match &mut self.state {
State::Handshake(state) => {
if now_ms >= state.timeout_time_ms {
self.events.push_back(Event::Error(ErrorKind::Timeout));
self.state = State::Quiescent;
} else {
self.rto_timer.timeout_ms = Some(now_ms + HANDSHAKE_RESEND_TIMEOUT_MS);
match state.phase {
HandshakePhase::Alpha => {
self.socket_tx.send(&state.frame);
}
HandshakePhase::Beta => {
self.socket_tx.send(&state.frame);
}
}
}
}
State::Active(state) => {
let endpoint_ref = Arc::clone(&state.endpoint_ref);
let mut endpoint = endpoint_ref.write().unwrap();
let host_ctx = &mut EndpointContext::new(self);
match endpoint.handle_rto_timer(now_ms, host_ctx) {
endpoint::TimeoutAction::Continue => (),
endpoint::TimeoutAction::Terminate => {
self.state = State::Quiescent;
}
}
}
State::Quiescent => {}
}
}
pub fn process_timeouts(&mut self) {
let now_ms = self.time_now_ms();
let timer = &mut self.rto_timer;
if let Some(timeout_ms) = timer.timeout_ms {
if now_ms >= timeout_ms {
timer.timeout_ms = None;
self.process_timeout(now_ms);
}
}
}
fn handle_handshake_alpha_ack(&mut self, payload_bytes: &[u8], now_ms: u64) {
use frame::serial::SimpleFrameWrite;
use frame::serial::SimplePayloadRead;
if let State::Handshake(ref mut state) = self.state { if let HandshakePhase::Alpha = state.phase {
if let Some(frame) = frame::HandshakeAlphaAckFrame::read(payload_bytes) {
let client_params = frame::ConnectionParams {
packet_size_in_max: u32::MAX,
packet_size_out_max: u32::MAX,
};
let nonce_valid = frame.client_nonce == state.local_nonce;
let params_compatible =
connection_params_compatible(&client_params, &frame.server_params);
if nonce_valid && params_compatible {
state.phase = HandshakePhase::Beta;
state.timeout_time_ms = now_ms + state.timeout_ms;
self.rto_timer.timeout_ms = Some(now_ms + HANDSHAKE_RESEND_TIMEOUT_MS);
state.frame = frame::HandshakeBetaFrame {
client_params,
client_nonce: frame.client_nonce,
server_nonce: frame.server_nonce,
server_timestamp: frame.server_timestamp,
server_mac: frame.server_mac,
}
.write_boxed();
self.socket_tx.send(&state.frame);
}
}
} }
}
fn handle_handshake_beta_ack(&mut self, payload_bytes: &[u8], now_ms: u64) {
use frame::serial::SimplePayloadRead;
if let State::Handshake(ref mut state) = self.state { if let HandshakePhase::Beta = state.phase {
if let Some(frame) = frame::HandshakeBetaAckFrame::read(payload_bytes) {
if frame.client_nonce == state.local_nonce {
match frame.error {
None => {
self.rto_timer.timeout_ms = None;
let packet_buffer = std::mem::take(&mut state.packet_buffer);
let endpoint_config = endpoint::Config {
timeout_time_ms: self.config.connection_timeout_ms,
prio_config: self.config.channel_balance.clone().into(),
};
let endpoint = endpoint::Endpoint::new(endpoint_config);
let endpoint_ref = Arc::new(RwLock::new(endpoint));
self.state = State::Active(ActiveState {
endpoint_ref: Arc::clone(&endpoint_ref),
});
let mut endpoint = endpoint_ref.write().unwrap();
let host_ctx = &mut EndpointContext::new(self);
endpoint.init(now_ms, host_ctx);
for (packet, mode) in packet_buffer.into_iter() {
endpoint.enqueue(packet, mode, now_ms);
}
endpoint.flush(now_ms, host_ctx);
}
Some(kind) => {
let kind = match kind {
frame::HandshakeErrorKind::Capacity => ErrorKind::Capacity,
frame::HandshakeErrorKind::Parameter => {
ErrorKind::Parameter
}
};
self.events.push_back(Event::Error(kind));
self.state = State::Quiescent;
}
}
}
}
} }
}
fn handle_frame_other(
&mut self,
frame_type: frame::FrameType,
payload_bytes: &[u8],
now_ms: u64,
) {
if let State::Active(ref mut state) = self.state {
let endpoint_ref = Arc::clone(&state.endpoint_ref);
let mut endpoint = endpoint_ref.write().unwrap();
let host_ctx = &mut EndpointContext::new(self);
endpoint.handle_frame(frame_type, payload_bytes, now_ms, host_ctx);
}
}
fn handle_frame(&mut self, frame_bytes: &[u8]) {
if !frame::serial::verify_minimum_size(frame_bytes) {
return;
}
if let Some(frame_type) = frame::serial::read_type(frame_bytes) {
if !frame::serial::verify_crc(frame_bytes) {
return;
}
let payload_bytes = frame::serial::payload(frame_bytes);
let now_ms = self.time_now_ms();
match frame_type {
frame::FrameType::HandshakeAlphaAck => {
self.handle_handshake_alpha_ack(payload_bytes, now_ms);
}
frame::FrameType::HandshakeBetaAck => {
self.handle_handshake_beta_ack(payload_bytes, now_ms);
}
_ => {
self.handle_frame_other(frame_type, payload_bytes, now_ms);
}
}
}
}
pub fn handle_frames(&mut self, socket_rx: &mut socket::ConnectedSocketRx) {
while let Ok(Some(frame_bytes)) = socket_rx.try_read_frame() {
self.handle_frame(frame_bytes);
}
}
pub fn handle_frames_wait(
&mut self,
socket_rx: &mut socket::ConnectedSocketRx,
wait_timeout: Option<time::Duration>,
) {
if let Ok(Some(frame_bytes)) = socket_rx.wait_for_frame(wait_timeout) {
self.handle_frame(frame_bytes);
self.handle_frames(socket_rx)
}
}
pub fn send(&mut self, packet_bytes: Box<[u8]>, mode: SendMode) {
match &mut self.state {
State::Handshake(state) => {
state.packet_buffer.push((packet_bytes, mode));
}
State::Active(state) => {
let endpoint_ref = Arc::clone(&state.endpoint_ref);
let mut endpoint = endpoint_ref.write().unwrap();
let now_ms = self.time_now_ms();
let host_ctx = &mut EndpointContext::new(self);
endpoint.enqueue(packet_bytes, mode, now_ms);
endpoint.flush(now_ms, host_ctx);
}
State::Quiescent => {}
}
}
pub fn enqueue(&mut self, packet_bytes: Box<[u8]>, mode: SendMode) {
match &mut self.state {
State::Handshake(state) => {
state.packet_buffer.push((packet_bytes, mode));
}
State::Active(state) => {
let endpoint_ref = Arc::clone(&state.endpoint_ref);
let mut endpoint = endpoint_ref.write().unwrap();
let now_ms = self.time_now_ms();
endpoint.enqueue(packet_bytes, mode, now_ms);
}
State::Quiescent => {}
}
}
pub fn flush(&mut self) {
match &mut self.state {
State::Handshake(_) => {}
State::Active(state) => {
let endpoint_ref = Arc::clone(&state.endpoint_ref);
let mut endpoint = endpoint_ref.write().unwrap();
let now_ms = self.time_now_ms();
let host_ctx = &mut EndpointContext::new(self);
endpoint.flush(now_ms, host_ctx);
}
State::Quiescent => {}
}
}
pub fn disconnect(&mut self) {
match &mut self.state {
State::Handshake(_) => {}
State::Active(state) => {
let endpoint_ref = Arc::clone(&state.endpoint_ref);
let mut endpoint = endpoint_ref.write().unwrap();
let now_ms = self.time_now_ms();
let host_ctx = &mut EndpointContext::new(self);
endpoint.disconnect(now_ms, host_ctx);
}
State::Quiescent => {}
}
}
}
impl Client {
pub fn connect<A>(server_addr: A) -> std::io::Result<Self>
where
A: net::ToSocketAddrs,
{
Self::connect_with_config(server_addr, Default::default())
}
pub fn connect_with_config<A>(server_addr: A, config: Config) -> std::io::Result<Self>
where
A: net::ToSocketAddrs,
{
config.validate();
let bind_address = (std::net::Ipv4Addr::UNSPECIFIED, 0);
let (socket_tx, socket_rx) =
socket::new_connected(bind_address, server_addr, FRAME_SIZE_MAX)?;
let local_nonce = rand::random::<u32>();
let handshake_frame = frame::HandshakeAlphaFrame {
protocol_id: frame::serial::PROTOCOL_ID,
client_nonce: local_nonce,
};
let handshake_frame =
frame::serial::write_handshake_alpha(&handshake_frame, FRAME_SIZE_MAX);
socket_tx.send(&handshake_frame);
let state = State::Handshake(HandshakeState {
phase: HandshakePhase::Alpha,
timeout_time_ms: config.handshake_timeout_ms,
timeout_ms: config.handshake_timeout_ms,
packet_buffer: Vec::new(),
frame: handshake_frame,
local_nonce,
});
let rto_timer = Timer {
timeout_ms: Some(HANDSHAKE_RESEND_TIMEOUT_MS),
};
let core = ClientCore {
config,
time_ref: time::Instant::now(),
socket_tx,
state,
rto_timer,
events: VecDeque::new(),
};
Ok(Self { core, socket_rx })
}
pub fn poll_event(&mut self) -> Option<Event> {
let core = &mut self.core;
if core.events.is_empty() {
core.handle_frames(&mut self.socket_rx);
core.process_timeouts();
}
core.events.pop_front()
}
pub fn wait_event(&mut self) -> Event {
let core = &mut self.core;
loop {
let wait_timeout = core.next_timer_timeout();
core.handle_frames_wait(&mut self.socket_rx, wait_timeout);
core.process_timeouts();
if let Some(event) = core.events.pop_front() {
return event;
}
}
}
pub fn wait_event_timeout(&mut self, timeout: time::Duration) -> Option<Event> {
let core = &mut self.core;
if core.events.is_empty() {
let mut remaining_timeout = timeout;
let mut wait_begin = time::Instant::now();
loop {
let wait_timeout = if let Some(timer_timeout) = core.next_timer_timeout() {
remaining_timeout.min(timer_timeout)
} else {
remaining_timeout
};
core.handle_frames_wait(&mut self.socket_rx, Some(wait_timeout));
core.process_timeouts();
if !core.events.is_empty() {
break;
}
let now = time::Instant::now();
let elapsed_time = now - wait_begin;
if elapsed_time >= remaining_timeout {
break;
}
remaining_timeout -= elapsed_time;
wait_begin = now;
}
}
core.events.pop_front()
}
pub fn send(&mut self, packet_bytes: Box<[u8]>, mode: SendMode) {
self.core.send(packet_bytes, mode);
}
pub fn enqueue(&mut self, packet_bytes: Box<[u8]>, mode: SendMode) {
self.core.enqueue(packet_bytes, mode);
}
pub fn flush(&mut self) {
self.core.flush();
}
pub fn disconnect(&mut self) {
self.core.disconnect();
}
pub fn local_addr(&self) -> net::SocketAddr {
self.socket_rx.local_addr()
}
pub fn server_addr(&self) -> net::SocketAddr {
self.socket_rx.peer_addr()
}
}