use crate::protocol::{Role, perform_server_handshake, read_join_frame, read_raw_frame_into};
use crate::session::Sessions;
use crate::{BUFFER_SIZE, error, info, trace, warn};
use anyhow::Result;
use bytes::BytesMut;
use std::net::SocketAddr;
use std::sync::Arc;
use std::sync::atomic::{AtomicUsize, Ordering};
use tokio::io::{AsyncReadExt, AsyncWriteExt, BufReader};
use tokio::net::{TcpListener, TcpStream};
use tokio::sync::broadcast::error::RecvError;
pub struct HydraServer {
listener: TcpListener,
sessions: Arc<Sessions>,
connections: Arc<AtomicUsize>,
max_connections: usize,
max_payload_length: usize,
broadcast_capacity: usize,
}
impl HydraServer {
pub async fn bind_default() -> Result<(Self, SocketAddr)> {
let addr = &"127.0.0.1:0".parse::<SocketAddr>()?;
let server = HydraServer::bind(addr, 64 * 1024 * 1024, 32, 256).await?;
let local_addr = server.listener.local_addr()?;
Ok((server, local_addr))
}
pub async fn bind(
addr: &SocketAddr,
max_payload_length: usize,
max_connections: usize,
broadcast_capacity: usize,
) -> Result<Self> {
let listener = TcpListener::bind(addr).await?;
Ok(Self {
listener,
sessions: Arc::new(Sessions::init()),
connections: Arc::new(AtomicUsize::new(0)),
max_payload_length,
max_connections,
broadcast_capacity,
})
}
pub async fn run(self, connections_timeout_ms: u64) -> Result<()> {
loop {
if self.connections.fetch_add(1, Ordering::Relaxed) >= self.max_connections {
self.connections.fetch_sub(1, Ordering::Relaxed);
warn!(
"Max connections reached: {}, waiting {} ms before retrying",
self.max_connections, connections_timeout_ms
);
tokio::time::sleep(std::time::Duration::from_millis(connections_timeout_ms)).await;
continue;
}
match self.listener.accept().await {
Ok((stream, peer_addr)) => {
stream.set_nodelay(true).ok();
let sessions = Arc::clone(&self.sessions);
let connections = Arc::clone(&self.connections);
tokio::spawn(async move {
trace!("Accepted connection from: {}", peer_addr);
if let Err(e) = Self::handle_connection(
stream,
sessions,
self.max_payload_length,
self.broadcast_capacity,
)
.await
{
error!("Connection handling error: {} from: {}", e, peer_addr);
}
connections.fetch_sub(1, Ordering::Release);
});
}
Err(e) => {
self.connections.fetch_sub(1, Ordering::Release);
error!("Connection accepting error: {}", e);
}
}
}
}
async fn handle_connection(
mut stream: TcpStream,
sessions: Arc<Sessions>,
max_payload_length: usize,
broadcast_capacity: usize,
) -> Result<()> {
stream.set_nodelay(true)?;
let mut mem_pool = BytesMut::with_capacity(max_payload_length + 4); let peer_addr = stream.peer_addr()?;
let (read_h, mut writer_raw) = stream.split();
let mut reader = BufReader::with_capacity(BUFFER_SIZE, read_h);
let transport_key = perform_server_handshake(&mut reader, &mut writer_raw).await?;
let (role, session_id) =
read_join_frame(&mut reader, &transport_key, &mut mem_pool).await?;
match role {
Role::Producer => {
info!(
"Producer addr: {} joined session: {}",
peer_addr,
hex::encode(session_id)
);
Self::run_producer(
&mut reader,
sessions,
session_id,
&peer_addr,
mem_pool,
max_payload_length,
broadcast_capacity,
)
.await
}
Role::Consumer => {
info!(
"Consumer addr: {} joined session: {}",
peer_addr,
hex::encode(session_id)
);
Self::run_consumer(
&mut reader,
&mut writer_raw,
sessions,
session_id,
&peer_addr,
)
.await
}
Role::Admin => Ok(()), }
}
async fn run_producer<R: AsyncReadExt + Unpin>(
reader: &mut R,
sessions: Arc<Sessions>,
session_id: [u8; 64],
client_addr: &SocketAddr,
mut mem_pool: BytesMut,
max_payload_length: usize,
broadcast_capacity: usize,
) -> Result<()> {
let tx = sessions.try_register_producer(session_id, broadcast_capacity)?;
loop {
let n = match read_raw_frame_into(reader, &mut mem_pool, max_payload_length).await {
Ok(n) => n,
Err(e) => {
tx.closed().await;
error!(
"Producer addr: {} session: {} read: {e}",
client_addr,
hex::encode(session_id)
);
break;
}
};
if let Err(e) = tx.send(mem_pool.split_to(n).freeze()) {
tx.closed().await; warn!(
"Producer addr: {} session: {} broadcast: {e}",
client_addr,
hex::encode(session_id)
);
break;
}
}
sessions.unregister_producer(session_id);
Ok(())
}
async fn run_consumer<R: AsyncReadExt + Unpin, W: AsyncWriteExt + Unpin>(
reader: &mut R,
writer: &mut W,
sessions: Arc<Sessions>,
session_id: [u8; 64],
client_addr: &SocketAddr,
) -> Result<()> {
let tx = sessions
.get_for_consumer(session_id)
.ok_or_else(|| anyhow::anyhow!("Session not found"))?;
let mut rx = tx.subscribe();
let mut peek = [0u8; 1];
loop {
tokio::select! {
result = rx.recv() => {
match result {
Ok(data) => {
if let Err(e) = writer.write_all(&data).await {
let _ = writer.shutdown().await;
error!("Consumer addr: {} session: {} write: {e}", client_addr, hex::encode(session_id));
break;
}
}
Err(RecvError::Lagged(n)) => {
let _ = writer.flush().await; let _ = writer.shutdown().await;
warn!("Consumer addr: {} session: {} lagged by {n} messages", client_addr, hex::encode(session_id));
break;
}
Err(RecvError::Closed) => {
let _ = writer.flush().await; let _ = writer.shutdown().await;
info!("Producer for session: {} closed, consumer addr: {}", hex::encode(session_id), client_addr);
break;
},
}
}
result = reader.read(&mut peek) => {
match result {
Ok(0) => break, Err(e) => {
error!("Consumer addr: {} session: {} read: {e}", client_addr, hex::encode(session_id));
break;
}
_ => {}
}
}
}
}
Ok(())
}
}