#[cfg(feature = "bincode")]
use bincode;
#[cfg(not(feature = "bincode"))]
use serde_json;
use log::{error, info, trace};
use serde::{Deserialize, Serialize};
use std::io::{BufReader, Read, Write};
use std::marker::PhantomData;
use std::os::unix::net::{UnixListener, UnixStream};
use std::path::PathBuf;
use std::sync::mpsc::{self, Sender};
pub use log;
fn serialize_to_vec<T: Serialize>(value: &T) -> Result<Vec<u8>, String> {
#[cfg(feature = "bincode")]
{
bincode::serialize(value).map_err(|e| e.to_string())
}
#[cfg(not(feature = "bincode"))]
{
serde_json::to_vec(value).map_err(|e| e.to_string())
}
}
fn deserialize_from_slice<'de, T>(bytes: &'de [u8]) -> Result<T, String>
where
T: Deserialize<'de>,
{
#[cfg(feature = "bincode")]
{
bincode::deserialize(bytes).map_err(|e| e.to_string())
}
#[cfg(not(feature = "bincode"))]
{
serde_json::from_slice(bytes).map_err(|e| e.to_string())
}
}
#[derive(Serialize, Deserialize, Debug)]
pub enum StreamResponse<T> {
Data(T),
EndOfStream,
}
pub(crate) fn handle<Res, Req>(mut stream: UnixStream) -> Option<(Req, Sender<Res>)>
where
Req: for<'de> Deserialize<'de> + Send + 'static + std::fmt::Debug,
Res: Serialize + Send + 'static + std::fmt::Debug,
{
let mut len_buf = [0u8; 4];
if stream.read_exact(&mut len_buf).is_err() {
error!("Failed to read request length");
return None;
}
let req_len = u32::from_le_bytes(len_buf) as usize;
let mut buf = vec![0u8; req_len];
if stream.read_exact(&mut buf).is_err() {
error!("Failed to read request data");
return None;
}
if let Ok(req) = deserialize_from_slice::<Req>(&buf) {
info!("Received request: {:?}", req);
let (tx, rx) = mpsc::channel();
std::thread::spawn(move || {
for response in rx {
trace!("Sending response: {:?}", response);
match serialize_to_vec(&StreamResponse::Data(response)) {
Ok(resp_buf) => {
let len_bytes = (resp_buf.len() as u32).to_le_bytes();
if stream.write_all(&len_bytes).is_err() {
error!("Failed to send response length");
break;
}
if stream.write_all(&resp_buf).is_err() {
error!("Failed to send response data");
break;
}
if stream.flush().is_err() {
error!("Failed to flush stream");
break;
}
}
Err(e) => {
error!("Failed to serialize response: {}", e);
break;
}
}
}
match serialize_to_vec(&StreamResponse::<Res>::EndOfStream) {
Ok(end_buf) => {
let len_bytes = (end_buf.len() as u32).to_le_bytes();
if stream.write_all(&len_bytes).is_err() {
error!("Failed to send EndOfStream length");
}
if stream.write_all(&end_buf).is_err() {
error!("Failed to send EndOfStream data");
}
let _ = stream.flush();
info!("Stream ended successfully");
}
Err(e) => error!("Failed to serialize EndOfStream: {}", e),
}
});
Some((req, tx))
} else {
error!("Failed to deserialize request");
None
}
}
pub fn start_server<Req, Res, F>(socket_name: impl ToString, handler: F) -> std::io::Result<()>
where
Req: for<'de> Deserialize<'de> + Send + 'static + std::fmt::Debug,
Res: Serialize + Send + 'static + std::fmt::Debug,
F: Fn(Req, Sender<Res>) + Send + Sync + Clone + 'static,
{
let socket_path = PathBuf::from(format!("/tmp/{}.sock", socket_name.to_string()));
let _ = std::fs::remove_file(&socket_path);
let listener = UnixListener::bind(&socket_path)?;
info!("Server started on {:?}", socket_path);
for stream in listener.incoming() {
match stream {
Ok(stream) => {
info!("Accepted connection");
let (req, tx) = handle(stream).ok_or(std::io::Error::new(
std::io::ErrorKind::InvalidData,
"Handshake error",
))?;
(handler.clone())(req, tx);
}
Err(e) => {
error!("Failed to accept connection: {}", e);
continue;
}
}
}
Ok(())
}
pub struct RequestStream<Req, Res> {
unix_stream: tokio::net::UnixListener,
_req: PhantomData<Req>,
_res: PhantomData<Res>,
}
impl<Req, Res> RequestStream<Req, Res>
where
Req: for<'de> Deserialize<'de> + Send + 'static + std::fmt::Debug,
Res: Serialize + Send + 'static + std::fmt::Debug,
{
pub async fn new(app: impl ToString) -> std::io::Result<Self> {
let socket_path = PathBuf::from(format!("/tmp/{}.sock", app.to_string()));
let _ = std::fs::remove_file(&socket_path);
let listener = tokio::net::UnixListener::bind(&socket_path)?;
info!("Server started on {:?}", socket_path);
Ok(Self {
unix_stream: listener,
_req: PhantomData,
_res: PhantomData,
})
}
}
impl<Req, Res> futures::Stream for RequestStream<Req, Res>
where
Req: for<'de> Deserialize<'de> + Send + 'static + std::fmt::Debug,
Res: Serialize + Send + 'static + std::fmt::Debug,
{
type Item = std::io::Result<(Req, Sender<Res>)>;
fn poll_next(
self: std::pin::Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> std::task::Poll<Option<Self::Item>> {
match self.unix_stream.poll_accept(cx) {
std::task::Poll::Ready(v) => std::task::Poll::Ready(
v.ok()
.map(|(stream, _)| {
Some(
match handle(
stream
.into_std()
.expect("Failed to construct stdio UnixStream"),
) {
Some((req, tx)) => Ok((req, tx)),
None => Err(std::io::Error::new(
std::io::ErrorKind::InvalidData,
"Handshake error",
)),
},
)
})
.flatten(),
),
std::task::Poll::Pending => std::task::Poll::Pending,
}
}
}
pub async fn start_stream<Req, Res, F>(
socket_name: impl ToString,
) -> std::io::Result<RequestStream<Req, Res>>
where
Req: for<'de> Deserialize<'de> + Send + 'static + std::fmt::Debug,
Res: Serialize + Send + 'static + std::fmt::Debug,
{
RequestStream::new(socket_name).await
}
pub fn send_command<App, Req, Res, H>(
app: App,
command: &Req,
handler: Option<H>,
) -> std::io::Result<()>
where
App: ToString,
Req: Serialize,
Res: for<'de> Deserialize<'de> + std::fmt::Debug, H: Fn(Res) + Send + 'static,
{
let socket_path = PathBuf::from(format!("/tmp/{}.sock", app.to_string()));
info!("Connecting to server at {:?}", socket_path);
let mut stream = UnixStream::connect(&socket_path)?;
let data = serialize_to_vec(command).expect("Serialization failed");
let len_bytes = (data.len() as u32).to_le_bytes();
stream.write_all(&len_bytes)?;
stream.write_all(&data)?;
stream.flush()?;
info!("Command sent");
let mut reader = BufReader::new(stream);
loop {
let mut len_buf = [0u8; 4];
if let Err(_) = reader.read_exact(&mut len_buf) {
error!("Failed to read response length");
break;
}
let response_len = u32::from_le_bytes(len_buf) as usize;
let mut buf = vec![0u8; response_len];
if let Err(_) = reader.read_exact(&mut buf) {
error!("Failed to read response data");
break;
}
match deserialize_from_slice::<StreamResponse<Res>>(&buf) {
Ok(StreamResponse::Data(response)) => {
info!("Received response: {:?}", response);
if let Some(ref handler) = handler {
handler(response);
}
}
Ok(StreamResponse::EndOfStream) => {
info!("End of stream received");
break;
}
Err(e) => {
error!("Failed to deserialize response: {}", e);
break;
}
}
}
Ok(())
}