use super::server_client;
use crate::implementation::server;
use crate::{socket, steady_millis, trace};
use hyprwire_core::message::wire::{fatal_protocol_error, roundtrip_done};
use polling::AsSource;
use std::os::fd;
use std::os::fd::AsRawFd;
use std::os::unix::net;
use std::sync::atomic;
use std::{fs, io, path, sync, time};
const LISTENER_KEY: usize = 0;
pub struct ServerSocket {
poller: polling::Poller,
server: Option<net::UnixListener>,
impls: sync::Arc<sync::RwLock<Vec<Box<dyn server::ProtocolImplementations>>>>,
clients: Vec<sync::Arc<server_client::ServerClientState>>,
next_client_id: u32,
}
impl ServerSocket {
pub fn bind<P>(path: &P) -> crate::Result<Self>
where
P: AsRef<path::Path>,
{
let poller = polling::Poller::new()?;
if fs::exists(path)? {
match net::UnixStream::connect(path) {
Ok(_) => {
return Err(io::Error::new(io::ErrorKind::AddrInUse, "socket is alive").into());
}
Err(e) if e.kind() != io::ErrorKind::ConnectionRefused => {
return Err(e.into());
}
_ => fs::remove_file(path)?,
}
}
let listener = net::UnixListener::bind(path)?;
listener.set_nonblocking(true)?;
unsafe { poller.add(&listener, polling::Event::readable(LISTENER_KEY))? };
Ok(Self {
poller,
server: Some(listener),
impls: sync::Arc::default(),
clients: Vec::new(),
next_client_id: 1,
})
}
pub fn detached() -> crate::Result<Self> {
Ok(Self {
poller: polling::Poller::new()?,
server: None,
impls: sync::Arc::default(),
clients: Vec::new(),
next_client_id: 1,
})
}
pub fn add_implementation<I, H>(&mut self, handler: &mut H, version: u32)
where
I: server::Construct<H> + 'static,
{
let implementation = I::new(version, handler);
self.impls.write().unwrap().push(Box::new(implementation));
}
fn dispatch_client<D: 'static>(
client: &sync::Arc<server_client::ServerClientState>,
dispatch: &mut D,
) {
let state = sync::Arc::clone(&client.state);
let mut data = {
if let Ok(d) = socket::SocketRawParsedMessage::read_from_socket(&state.stream) {
d
} else {
state.send_message(&fatal_protocol_error::FatalProtocolError::new(
0,
u32::MAX,
"fatal: invalid message on wire",
));
state.error.store(true, atomic::Ordering::Relaxed);
let _ = state.stream.shutdown(std::net::Shutdown::Both);
return;
}
};
if data.data.is_empty() {
state.error.store(true, atomic::Ordering::Relaxed);
let _ = state.stream.shutdown(std::net::Shutdown::Both);
return;
}
if client.handle_message(&mut data, dispatch).is_err() {
state.send_message(&fatal_protocol_error::FatalProtocolError::new(
0,
u32::MAX,
"fatal: failed to handle message on wire",
));
state.error.store(true, atomic::Ordering::Relaxed);
let _ = state.stream.shutdown(std::net::Shutdown::Both);
return;
}
let scheduled_seq = client
.scheduled_roundtrip_seq
.load(atomic::Ordering::Relaxed);
if scheduled_seq > 0 {
state.send_message(&roundtrip_done::RoundtripDone::new(scheduled_seq));
client
.scheduled_roundtrip_seq
.store(0, atomic::Ordering::Relaxed);
}
}
fn accept_one(&mut self) -> crate::Result<bool> {
let Some(server) = self.server.as_ref() else {
return Ok(false);
};
let (stream, _addr) = match server.accept() {
Ok(conn) => conn,
Err(e) => {
crate::log_error!("failed to accept connection: {e}");
return Ok(false);
}
};
if stream.set_nonblocking(true).is_err() {
return Ok(false);
}
let state = sync::Arc::new(crate::ConnectionState::new(stream));
let client_id = self.next_client_id;
let client =
server_client::ServerClientState::new(client_id, state, sync::Arc::clone(&self.impls));
unsafe {
self.poller.add(
&client.state.stream,
polling::Event::readable(client_id as usize),
)?;
}
self.next_client_id += 1;
self.clients.push(client);
Ok(true)
}
fn dispatch_pending<D: 'static>(
&mut self,
dispatch: &mut D,
block: bool,
) -> crate::Result<bool> {
let mut events = polling::Events::new();
let timeout = if block {
None
} else {
Some(time::Duration::ZERO)
};
self.poller.wait(&mut events, timeout)?;
if events.is_empty() {
return Ok(false);
}
let mut dead: Vec<u32> = Vec::new();
for ev in events.iter() {
if ev.key == LISTENER_KEY {
let _ = self.accept_one()?;
if let Some(server) = self.server.as_ref() {
self.poller
.modify(server, polling::Event::readable(LISTENER_KEY))?;
}
continue;
}
let id = ev.key as u32;
let Some(client) = self
.clients
.iter()
.find(|c| c.id == id)
.map(sync::Arc::clone)
else {
continue;
};
Self::dispatch_client(&client, dispatch);
if client.state.error.load(atomic::Ordering::Relaxed) {
dead.push(id);
} else {
self.poller
.modify(&client.state.stream, polling::Event::readable(id as usize))?;
}
}
for id in dead {
let Some(idx) = self.clients.iter().position(|c| c.id == id) else {
continue;
};
let client = self.clients.remove(idx);
client.destroy_objects_for_disconnect(dispatch);
let _ = self.poller.delete(&client.state.stream);
trace! {
crate::log_debug!(
"[hw] trace: [{} @ {:.3}] Dropping client",
client.state.stream.as_raw_fd(),
steady_millis(),
)
}
}
Ok(true)
}
pub fn dispatch_events<D: 'static>(&mut self, state: &mut D, block: bool) -> crate::Result<()> {
let mut first = true;
loop {
let do_block = block && first;
let any = self.dispatch_pending(state, do_block)?;
first = false;
if !any {
break;
}
}
Ok(())
}
pub fn add_client<F>(&mut self, fd: F) -> crate::Result<server_client::ServerClient>
where
F: Into<fd::OwnedFd>,
{
let stream = net::UnixStream::from(fd.into());
_ = stream.set_nonblocking(true);
let state = sync::Arc::new(crate::ConnectionState::new(stream));
let client_id = self.next_client_id;
let client =
server_client::ServerClientState::new(client_id, state, sync::Arc::clone(&self.impls));
if let Err(e) = unsafe {
self.poller.add(
&client.state.stream,
polling::Event::readable(client_id as usize),
)
} {
return Err(e.into());
}
self.next_client_id += 1;
self.clients.push(sync::Arc::clone(&client));
Ok(server_client::ServerClient {
id: client_id,
creds: client.creds.clone(),
})
}
pub fn remove_client<D: 'static>(
&mut self,
client: &server_client::ServerClient,
dispatch: &mut D,
) -> crate::Result<bool> {
for state in self.clients.iter().filter(|c| c.id == client.id()) {
state.state.error.store(true, atomic::Ordering::Relaxed);
let _ = state.state.stream.shutdown(std::net::Shutdown::Both);
state.destroy_objects_for_disconnect(dispatch);
let _ = self.poller.delete(&state.state.stream);
}
let before = self.clients.len();
self.clients.retain(|c| c.id != client.id());
Ok(self.clients.len() < before)
}
pub fn extract_loop_fd(&self) -> fd::BorrowedFd<'_> {
self.poller.source()
}
}