use std::collections::BTreeMap;
use std::os::unix::net; use std::path::PathBuf;
use std::time::{Duration, Instant};
use serde::{Deserialize, Serialize};
#[cfg(feature = "python")]
use pyo3::prelude::*;
use tracing::{info, warn};
use crate::py_json_methods;
use super::*;
use deimos_shared::peripherals::PeripheralId;
#[derive(Serialize, Deserialize, Default)]
#[cfg_attr(feature = "python", pyclass)]
pub struct UnixSocket {
name: String,
#[serde(skip)]
socket: Option<net::UnixDatagram>,
#[serde(skip)]
addrs: BTreeMap<PeripheralId, PathBuf>,
#[serde(skip)]
pids: BTreeMap<PathBuf, PeripheralId>,
#[serde(skip)]
addr_tokens: BTreeMap<PathBuf, SocketAddrToken>,
#[serde(skip)]
token_addrs: BTreeMap<SocketAddrToken, PathBuf>,
#[serde(skip)]
next_addr_token: SocketAddrToken,
#[serde(skip)]
ctx: ControllerCtx,
}
impl UnixSocket {
pub fn new(name: &str) -> Self {
Self {
name: name.to_owned(),
socket: None,
addrs: BTreeMap::new(),
pids: BTreeMap::new(),
addr_tokens: BTreeMap::new(),
token_addrs: BTreeMap::new(),
next_addr_token: 0,
ctx: ControllerCtx::default(),
}
}
pub fn name(&self) -> &str {
&self.name
}
pub fn path(&self) -> PathBuf {
self.ctx.op_dir.join("sock").join(&self.name)
}
pub fn addr(&self) -> Result<net::SocketAddr, String> {
net::SocketAddr::from_pathname(self.path())
.map_err(|e| format!("Unable to form socket address for `{}`: {}", self.name, e))
}
pub fn peripheral_socket_dir(&self) -> PathBuf {
self.ctx.op_dir.join("sock").join("per")
}
}
py_json_methods!(
UnixSocket,
Socket,
#[new]
fn py_new(name: &str) -> PyResult<Self> {
Ok(Self::new(name))
}
);
#[typetag::serde]
impl Socket for UnixSocket {
fn is_open(&self) -> bool {
self.socket.is_some()
}
fn open(&mut self, ctx: &ControllerCtx) -> Result<(), String> {
if self.socket.is_none() {
self.ctx = ctx.clone();
std::fs::create_dir_all(self.ctx.op_dir.join("sock"))
.map_err(|e| format!("Unable to create socket folders: {e}"))?;
std::fs::create_dir_all(self.peripheral_socket_dir())
.map_err(|e| format!("Unable to create socket folders: {e}"))?;
let path = self.path();
if path.exists() {
let _ = std::fs::remove_file(&path);
}
let socket = net::UnixDatagram::bind(&path)
.map_err(|e| format!("Unable to bind unix socket: {e}"))?;
self.socket = Some(socket);
info!("Opened controller unix socket at {path:?}");
} else {
return Err("Controller unix socket already open".to_string());
}
Ok(())
}
fn close(&mut self) {
let path = self.path();
self.socket = None;
self.addrs.clear();
self.pids.clear();
self.addr_tokens.clear();
self.token_addrs.clear();
self.next_addr_token = 0;
self.ctx = ControllerCtx::default();
info!("Closed controller unix socket at {path:?}");
let file_remove_status = std::fs::remove_file(&path);
if file_remove_status.is_err() {
warn!("Failed to remove unix socket file: {file_remove_status:?}");
}
}
fn send(&mut self, id: PeripheralId, msg: &[u8]) -> Result<(), String> {
let addr = self
.addrs
.get(&id)
.ok_or(format!("Peripheral not present in address map: {id:?}"))?;
let sock = self
.socket
.as_mut()
.ok_or("Unable to send before socket is bound".to_string())?;
sock.send_to(msg, addr)
.map_err(|e| format!("Failed to send packet: {e}"))?;
Ok(())
}
fn recv(&mut self, buf: &mut [u8], timeout: Duration) -> Option<SocketPacketMeta> {
let sock = self.socket.as_mut()?;
let timeout = if timeout.is_zero() {
Duration::from_nanos(1)
} else {
timeout
};
let _ = sock.set_read_timeout(Some(timeout));
let (size, src_path, time) = match sock.recv_from(buf) {
Ok((size, addr)) => {
let now = Instant::now();
if let Some(src_path) = addr.as_pathname() {
let src_path = src_path.to_owned();
(size, src_path, now)
} else {
return None;
}
}
Err(err) => match err.kind() {
std::io::ErrorKind::WouldBlock | std::io::ErrorKind::TimedOut => return None,
_ => return None,
},
};
let token = match self.addr_tokens.get(&src_path).copied() {
Some(token) => token,
None => {
let token = self.next_addr_token;
self.next_addr_token = self.next_addr_token.wrapping_add(1);
self.addr_tokens.insert(src_path.clone(), token);
self.token_addrs.insert(token, src_path.clone());
token
}
};
let pid = self.pids.get(&src_path).copied();
Some(SocketPacketMeta {
pid,
token,
time,
size,
})
}
fn broadcast(&mut self, msg: &[u8]) -> Result<(), String> {
let dir = self.peripheral_socket_dir();
let sock = self
.socket
.as_mut()
.ok_or("Unable to send before socket is bound".to_string())?;
if dir.exists() {
let paths = std::fs::read_dir(dir)
.map_err(|e| format!("Unable to read peripheral socket dir: {e}"))?;
let files = paths.filter_map(|entry| {
if let Ok(entry) = entry {
let p = entry.path();
match p.is_dir() || p.is_file() {
true => None,
false => Some(p),
}
} else {
None
}
});
let files: Vec<PathBuf> = files.collect();
info!("Unix socket broadcasting to {files:?}");
for f in files {
sock.send_to(msg, &f)
.map_err(|e| format!("Failed to send unix socket packet: {e}"))?;
}
}
Ok(())
}
fn update_map(&mut self, id: PeripheralId, token: SocketAddrToken) -> Result<(), String> {
if let Some(addr) = self.token_addrs.get(&token) {
self.addrs.insert(id, addr.clone());
self.pids.insert(addr.clone(), id);
if self.addrs.len() != self.pids.len() {
return Err(format!(
"Duplicate addresses or peripheral IDs detected.\nAddress map: {:?}\nPeripheral ID map: {:?}",
&self.addrs, &self.pids
));
}
} else {
return Err(format!("Unknown address token {token}"));
}
Ok(())
}
}