use std::{
collections::HashMap,
future::Future,
net::SocketAddr,
sync::{Arc, Mutex},
task::{Poll, Waker},
time::Duration,
};
use tokio::time::timeout;
use crate::{
core::{context::BluefinHost, error::BluefinError, packet::BluefinPacket},
utils::common::BluefinResult,
worker::{reader::ReaderRxChannel, writer::WriterHandler},
};
use super::{
build_and_start_ack_consumer_workers, build_and_start_conn_reader_tx_channels,
get_connected_udp_socket,
ordered_bytes::{ConsumeResult, OrderedBytes},
AckBuffer, ConnectionManagedBuffers,
};
pub const MAX_BUFFER_SIZE: usize = 2000;
pub const MAX_BUFFER_CONSUME: usize = 1000;
#[derive(Clone)]
pub(crate) struct HandshakeConnectionBuffer {
conn_buff: Arc<Mutex<ConnectionBuffer>>,
}
impl HandshakeConnectionBuffer {
pub(crate) fn new(conn_buff: Arc<Mutex<ConnectionBuffer>>) -> Self {
Self { conn_buff }
}
#[inline]
pub(crate) async fn read(&self) -> (BluefinPacket, SocketAddr) {
self.clone().await
}
#[inline]
pub(crate) async fn read_with_timeout(
&self,
timeout_duration: Duration,
) -> BluefinResult<(BluefinPacket, SocketAddr)> {
if let Ok(res) = timeout(timeout_duration, self.clone()).await {
return Ok(res);
}
return Err(BluefinError::TimedOut(format!(
"Failed to read from handshake connection buffer after {:?}",
timeout_duration
)));
}
}
impl Future for HandshakeConnectionBuffer {
type Output = (BluefinPacket, SocketAddr);
fn poll(
self: std::pin::Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> std::task::Poll<Self::Output> {
let mut guard = self.conn_buff.lock().unwrap();
if let (Some(packet), Some(addr)) = (guard.packet.take(), guard.addr) {
return Poll::Ready((packet, addr));
}
guard.set_waker(cx.waker().clone());
drop(guard);
Poll::Pending
}
}
#[derive(Clone)]
pub(crate) struct ConnectionBuffer {
ordered_bytes: OrderedBytes,
addr: Option<SocketAddr>,
waker: Option<Waker>,
packet: Option<BluefinPacket>,
dst_conn_id: u32,
host_type: BluefinHost,
set_start_packet_number: bool,
}
impl ConnectionBuffer {
pub(crate) fn new(src_conn_id: u32, host_type: BluefinHost) -> Self {
Self {
ordered_bytes: OrderedBytes::new(src_conn_id, 0x0),
addr: None,
waker: None,
packet: None,
dst_conn_id: 0,
host_type,
set_start_packet_number: false,
}
}
#[inline]
pub(crate) fn set_dst_conn_id(&mut self, dst_conn_id: u32) {
self.dst_conn_id = dst_conn_id;
}
#[inline]
pub(crate) fn buffer_in_addr(&mut self, addr: SocketAddr) -> BluefinResult<()> {
if let Some(_) = self.addr {
return Err(BluefinError::Unexpected(
"Address already exists".to_string(),
));
}
self.addr = Some(addr);
Ok(())
}
#[inline]
pub(crate) fn buffer_in_bytes(&mut self, packet: BluefinPacket) -> BluefinResult<()> {
self.ordered_bytes.buffer_in_packet(packet)
}
#[inline]
pub(crate) fn buffer_in_packet(&mut self, packet: BluefinPacket) -> BluefinResult<()> {
if self.packet.is_some() {
return Err(BluefinError::BufferFullError(
"Buffer already contains a packet. Could not buffer another packet.".to_string(),
));
}
let packet_num = packet.header.packet_number;
self.packet = Some(packet);
if !self.set_start_packet_number {
if self.host_type == BluefinHost::PackLeader {
self.ordered_bytes.set_start_packet_number(packet_num + 2);
} else if self.host_type == BluefinHost::Client {
self.ordered_bytes.set_start_packet_number(packet_num + 1);
}
self.set_start_packet_number = true;
}
Ok(())
}
#[inline]
pub(crate) fn consume(
&mut self,
bytes_to_read: usize,
buf: &mut [u8],
) -> BluefinResult<(ConsumeResult, SocketAddr)> {
if self.addr.is_none() {
return Err(BluefinError::Unexpected(
"Cannot consume buffer because addr is field is none".to_string(),
));
}
let consume_res = self.ordered_bytes.consume(bytes_to_read, buf)?;
Ok((consume_res, self.addr.unwrap()))
}
pub(crate) fn peek(&self) -> BluefinResult<()> {
if self.addr.is_none() {
return Err(BluefinError::Unexpected(
"Cannot consume buffer because addr is field is none".to_string(),
));
}
self.ordered_bytes.peek()
}
#[inline]
pub(crate) fn get_waker(&self) -> Option<&Waker> {
self.waker.as_ref()
}
#[inline]
pub(crate) fn set_waker(&mut self, waker: Waker) {
self.waker = Some(waker);
}
}
pub(crate) struct ConnectionManager {
map: HashMap<String, ConnectionManagedBuffers>,
}
impl ConnectionManager {
pub(crate) fn new() -> Self {
Self {
map: HashMap::new(),
}
}
#[inline]
pub(crate) fn insert(
&mut self,
key: &str,
element: ConnectionManagedBuffers,
) -> BluefinResult<()> {
if self.map.contains_key(key) {
return Err(BluefinError::ConnectionAlreadyExists);
}
self.map.insert(key.to_string(), element);
Ok(())
}
#[inline]
pub(crate) fn get(&self, key: &str) -> Option<ConnectionManagedBuffers> {
self.map.get(key).cloned()
}
#[inline]
pub(crate) fn remove(&mut self, key: &str) -> BluefinResult<()> {
if self.map.remove(key).is_none() {
return Err(BluefinError::NoSuchConnectionError);
}
Ok(())
}
}
#[derive(Clone)]
pub struct BluefinConnection {
pub src_conn_id: u32,
pub dst_conn_id: u32,
reader_rx: ReaderRxChannel,
writer_handler: WriterHandler,
}
impl BluefinConnection {
pub(crate) fn new(
src_conn_id: u32,
dst_conn_id: u32,
next_send_packet_num: u64,
conn_buffer: Arc<Mutex<ConnectionBuffer>>,
ack_buffer: Arc<Mutex<AckBuffer>>,
dst_addr: SocketAddr,
src_addr: SocketAddr,
) -> Self {
build_and_start_ack_consumer_workers(1, Arc::clone(&ack_buffer));
let s = get_connected_udp_socket(src_addr, dst_addr);
if let Err(e) = s {
panic!("Failed to get connected sockets due to error: {:?}", e);
}
let conn_socket = Arc::new(s.unwrap());
let mut writer_handler = WriterHandler::new(
Arc::clone(&conn_socket),
next_send_packet_num,
src_conn_id,
dst_conn_id,
);
if let Err(e) = writer_handler.start() {
panic!("Cannot start connection due to error: {:?}", e);
}
let conn_bufs = Arc::new(ConnectionManagedBuffers {
conn_buff: Arc::clone(&conn_buffer),
ack_buff: Arc::clone(&ack_buffer),
});
let _ = build_and_start_conn_reader_tx_channels(Arc::clone(&conn_socket), conn_bufs);
let reader_rx = ReaderRxChannel::new(Arc::clone(&conn_buffer), writer_handler.clone());
Self {
src_conn_id,
dst_conn_id,
reader_rx,
writer_handler,
}
}
#[inline]
pub async fn recv(&mut self, buf: &mut [u8], len: usize) -> BluefinResult<usize> {
let (size, _) = self.reader_rx.read(len, buf).await?;
Ok(size as usize)
}
#[inline]
pub fn send(&mut self, buf: &[u8]) -> BluefinResult<usize> {
self.writer_handler.send_data(buf)
}
}