use std::io::Result as IoResult;
use std::io::ErrorKind;
use std::path::{
Path,
PathBuf
};
use std::env;
use data_encoding::BASE64URL_NOPAD;
use futures::stream::{
FuturesUnordered,
StreamExt
};
use rand::Rng;
use tokio::net;
use tokio::task;
use tokio::time::Duration;
use tracing::{
info_span,
trace_span,
Instrument,
Span
};
use super::{
AnyRequest,
Command,
Response,
command::request
};
use crate::util::*;
const BUFFER_SIZE: usize = 65536;
const CLIENT_RECEIVE_TIMEOUT: Duration = Duration::from_secs(3);
pub struct ControlSocket {
path: PathBuf,
sock: net::UnixDatagram,
receive_buf: Vec<u8>,
response_tasks: FuturesUnordered<ResponseAwaiter>,
sender: Sender<AnyRequest>,
span: Span
}
#[derive(Debug)]
pub struct ClientSocket {
sock: net::UnixDatagram,
path: PathBuf,
span: Span
}
type ResponseAwaiter = task::JoinHandle<(UnixPathAddr, Vec<u8>)>;
struct UnixPathAddr {
inner: net::unix::SocketAddr
}
impl UnixPathAddr {
pub fn new(addr: net::unix::SocketAddr) -> Option<Self> {
if addr.as_pathname().is_some() {
Some(Self{inner: addr})
}
else if addr.is_unnamed() {
warn!(?addr, "invalid anonymous address");
None
}
else {
warn!(?addr, "invalid address");
None
}
}
}
impl AsRef<Path> for UnixPathAddr {
fn as_ref(&self) -> &Path {
self.inner.as_pathname().unwrap()
}
}
impl std::fmt::Display for UnixPathAddr {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
std::fmt::Display::fmt(
&self.inner.as_pathname().unwrap().display(),
f
)
}
}
impl Drop for ControlSocket {
fn drop(&mut self) {
let remove_result = std::fs::remove_file(&self.path);
debug!(@self, path = ?self.path, ?remove_result);
}
}
impl Drop for ClientSocket {
fn drop(&mut self) {
let remove_result = std::fs::remove_file(&self.path);
debug!(@self, path = ?self.path, ?remove_result);
}
}
impl ControlSocket {
pub fn default_path() -> PathBuf {
env::temp_dir().join("cnsprcy_control.sock")
}
pub fn open(
sender: Sender<AnyRequest>,
path: PathBuf,
parent_span: &Span)
-> Option<Self>
{
if path.exists() {
warn!(?path, "removing existing socket");
std::fs::remove_file(&path)
.map_err(|err| error!(%err, "failed to remove existing socket"))
.ok()?;
}
let sock = net::UnixDatagram::bind(&path)
.map_err(|err| {
error!(?path, %err, "failed to open socket");
let _ = std::fs::remove_file(&path);
})
.ok()?;
info!(?path, "opened socket");
Some(Self {
sock,
path,
receive_buf: vec![0u8; BUFFER_SIZE],
response_tasks: FuturesUnordered::new(),
sender,
span: parent_span.in_scope(|| info_span!("control_socket"))
})
}
pub async fn run(mut self) {
info!(@self, "listening");
loop { tokio::select! {
r = self.sock.readable() => match r.and_then(|()| self.receive()) {
Ok(Some(handle)) => self.response_tasks.push(handle),
Ok(None) => {continue},
Err(ref e) if e.kind() == ErrorKind::WouldBlock => {
trace!(@self, "socket readable() false positive");
continue
},
Err(io_error) => {
error!(@self, %io_error, "socket error");
break;
}
},
Some(resp) = self.response_tasks.next() => match resp {
Ok((to, data)) => {self.send(&to, &data).await;},
Err(err) => debug!(@self, %err, "response task panicked")
},
_ = self.sender.closed() => {
info!(@self, "received termination signal, shutting down!");
break;
}
}}
while let Some(resp) = self.response_tasks.next().await {
match resp {
Ok((to, data)) => {self.send(&to, &data).await;},
Err(err) => debug!(@self, %err, "response task panicked")
}
}
info!(@self, "stopped");
}
fn parse(&self, amt: usize) -> Option<Command> {
serde_json::from_slice::<Command>(&self.receive_buf[..amt])
.map_err(|error| warn!(@self, %error, "unable to decode datagram"))
.ok()
}
async fn send(&self, to: &UnixPathAddr, data: &[u8]) -> bool {
match self.sock.send_to(data, &to).await {
Ok(amt) => {
trace!(@self, %to, amt, "sent datagram");
true
},
Err(err) => {
warn!(@self, %err, "failed to send datagram");
false
}
}
}
fn make_query(&self, from: UnixPathAddr, cmd: Command) -> ResponseAwaiter {
let span = self.span.in_scope(||
trace_span!("query_daemon", ?cmd)
);
let response_fut = request(&self.sender, cmd);
task::spawn(async move {
let response = response_fut
.in_current_span()
.await;
let json = serde_json::to_vec(&response)
.map_err(|json_error|
error!(:span, %json_error, ?response, "encoding failed")
)
.unwrap();
#[cfg(debug_assertions)]
trace!(:span, to = %from, ?response, "sending response along");
(from, json)
})
}
fn receive(&mut self) -> IoResult<Option<ResponseAwaiter>> {
let _g = self.span.enter();
let (amt, src) = self.sock.try_recv_from(&mut self.receive_buf)?;
Ok(if let Some(from) = UnixPathAddr::new(src) {
if let Some(cmd) = self.parse(amt) {
trace!(%from, amt, "received datagram");
Some(self.make_query(from, cmd))
}
else {
warn!(%from, amt, "received invalid datagram");
None
}
}
else {
warn!(amt, "received datagram from invalid address");
None
})
}
}
impl ClientSocket {
fn new_path() -> PathBuf {
env::temp_dir().join(format!(
"cnsprcy.{}.sock",
BASE64URL_NOPAD.encode(&rand::rng().random::<[u8; 8]>())
))
}
pub fn connect(server_path: &Path) -> Option<Self> {
let path = Self::new_path();
let span = info_span!("client_socket", ?path, ?server_path);
let _g = span.enter();
if path.exists() {
warn!(?path, "removing existing socket");
std::fs::remove_file(&path)
.map_err(|err| error!(%err, "failed to remove existing socket"))
.ok()?;
}
let sock = net::UnixDatagram::bind(&path)
.map_err(|err| {
error!(?path, %err, "failed to open socket");
let _ = std::fs::remove_file(&path);
})
.ok()?;
drop(_g);
sock
.connect(server_path)
.map_err(|socket_error|
error!(:span, %socket_error, "failed to connect to server")
)
.map(|()| Self {sock, path, span})
.ok()
}
pub async fn send(&self, cmd: &Command) -> Option<usize> {
let data = serde_json::to_vec(cmd)
.map_err(|error| error!(@self, %error, ?cmd, "JSON encoding error"))
.ok()?;
self.sock.send(&data).await
.inspect(|amt| trace!(@self, amt, "sent command"))
.map_err(|error| error!(@self, %error, ?cmd, "failed to send"))
.ok()
}
pub async fn receive(&self) -> Option<Response> {
let mut buf: [u8; BUFFER_SIZE] = [0; BUFFER_SIZE];
let amt = self.sock
.recv(&mut buf)
.timeout(CLIENT_RECEIVE_TIMEOUT)
.in_current_span()
.await
.map_err(|timeout| error!(@self, %timeout, "no response received"))
.ok()?
.map_err(|io_error| error!(@self, %io_error, "socket IO error"))
.ok()?;
serde_json::from_slice(&buf[..amt])
.map_err(|error| error!(@self, %error, "JSON decoding error"))
.ok()
}
}