use std::{future::Future, ops::Deref, sync::Arc};
use tokio::{
sync::{OwnedSemaphorePermit, Semaphore},
task::JoinHandle,
};
use crate::io::{IOError, Listener, Stream, StreamPool};
#[derive(Debug)]
pub enum SocketServerError {
IOError(IOError),
PoolInvalid,
}
impl From<IOError> for SocketServerError {
fn from(err: IOError) -> Self {
Self::IOError(err)
}
}
pub trait RequestProcessor: Send + Sync {
fn process(
&self,
request: &[u8],
) -> impl std::future::Future<Output = Vec<u8>> + Send;
}
impl<T, U> RequestProcessor for T
where
T: Deref<Target = U> + Send + Sync,
U: RequestProcessor + 'static,
{
fn process(&self, request: &[u8]) -> impl Future<Output = Vec<u8>> + Send {
self.deref().process(request)
}
}
pub struct SocketServer {
pub pool: StreamPool,
pub tasks: Vec<JoinHandle<Result<(), SocketServerError>>>,
pub max_connections: usize,
}
impl SocketServer {
pub fn listen_all<P>(
pool: StreamPool,
processor: P,
max_connections: usize,
) -> Result<Self, SocketServerError>
where
P: RequestProcessor + Clone + 'static,
{
println!("`SocketServer` listening on pool size {}", pool.len());
let listeners = pool.listen()?;
let tasks = Self::spawn_tasks_for_listeners(
listeners,
processor,
max_connections,
);
Ok(Self { pool, tasks, max_connections })
}
pub fn listen_to<P>(
&mut self,
pool_size: u8,
processor: P,
) -> Result<(), IOError>
where
P: RequestProcessor + Clone + 'static,
{
let listeners = self.pool.listen_to(pool_size)?;
let tasks = Self::spawn_tasks_for_listeners(
listeners,
processor,
self.max_connections,
);
self.tasks.extend(tasks);
Ok(())
}
fn spawn_tasks_for_listeners<P>(
listeners: Vec<Listener>,
processor: P,
max_connections: usize,
) -> Vec<JoinHandle<Result<(), SocketServerError>>>
where
P: RequestProcessor + Clone + 'static,
{
let mut tasks = Vec::new();
for listener in listeners {
let p = processor.clone();
let task = tokio::spawn(async move {
accept_loop(listener, &p, max_connections).await
});
tasks.push(task);
}
tasks
}
}
impl Drop for SocketServer {
fn drop(&mut self) {
for task in &self.tasks {
task.abort();
}
}
}
pub struct PermittedStream {
_permit: OwnedSemaphorePermit,
stream: Stream,
}
impl PermittedStream {
pub async fn accept(
listener: &Listener,
connections: Arc<Semaphore>,
) -> Result<Self, IOError> {
let permit = connections
.acquire_owned()
.await
.map_err(|_| IOError::UnknownError)?;
let stream = listener.accept().await?;
Ok(PermittedStream { _permit: permit, stream })
}
pub async fn send(&mut self, value: &[u8]) -> Result<(), IOError> {
self.stream.send(value).await
}
pub async fn recv(&mut self) -> Result<Vec<u8>, IOError> {
self.stream.recv().await
}
pub fn stream(&mut self) -> &mut Stream {
&mut self.stream
}
}
async fn accept_loop<P>(
listener: Listener,
processor: &P,
max_connections: usize,
) -> Result<(), SocketServerError>
where
P: RequestProcessor + Clone + 'static,
{
let connections = Arc::new(Semaphore::const_new(max_connections));
loop {
let mut stream =
match PermittedStream::accept(&listener, connections.clone()).await
{
Ok(stream) => stream,
Err(err) => {
eprintln!("SocketServer: error on accept {err:?}");
continue;
}
};
let processor = processor.clone();
tokio::spawn(async move {
loop {
match stream.recv().await {
Ok(payload) => {
let response = processor.process(&payload).await;
match stream.send(&response).await {
Ok(()) => {}
Err(err) => {
eprintln!(
"SocketServer: error sending reply {err:?}, re-accepting"
);
break;
}
}
}
Err(IOError::RecvConnectionClosed) => break,
Err(err) => {
eprintln!(
"SocketServer: error receiving request {err:?}, re-accepting"
);
break;
}
}
}
});
}
}