#[cfg(feature = "bincode")]
use bincode;
#[cfg(not(feature = "bincode"))]
use serde_json;
use log::{error, info, trace};
use serde::{Deserialize, Serialize};
use std::io::{BufReader, Read, Write};
use std::marker::PhantomData;
use std::os::unix::net::{UnixListener, UnixStream};
use std::path::PathBuf;
use std::sync::mpsc::{self, Sender};
pub use log;
fn serialize_to_vec<T: Serialize>(value: &T) -> Result<Vec<u8>, String> {
#[cfg(feature = "bincode")]
{
bincode::serialize(value).map_err(|e| e.to_string())
}
#[cfg(not(feature = "bincode"))]
{
serde_json::to_vec(value).map_err(|e| e.to_string())
}
}
fn deserialize_from_slice<'de, T>(bytes: &'de [u8]) -> Result<T, String>
where
T: Deserialize<'de>,
{
#[cfg(feature = "bincode")]
{
bincode::deserialize(bytes).map_err(|e| e.to_string())
}
#[cfg(not(feature = "bincode"))]
{
serde_json::from_slice(bytes).map_err(|e| e.to_string())
}
}
#[derive(Serialize, Deserialize, Debug)]
pub enum StreamResponse<T> {
Data(T),
EndOfStream,
}
pub(crate) fn process<Res, Req>(mut stream: UnixStream) -> Option<(Req, Sender<Res>)>
where
Req: for<'de> Deserialize<'de> + Send + 'static + std::fmt::Debug,
Res: Serialize + Send + 'static + std::fmt::Debug,
{
let mut len_buf = [0u8; 4];
if stream.read_exact(&mut len_buf).is_err() {
error!("Failed to read request length");
return None;
}
trace!("Length bytes: {:?}", len_buf);
let req_len = u32::from_le_bytes(len_buf) as usize;
let mut buf = vec![0u8; req_len];
if stream.read_exact(&mut buf).is_err() {
error!("Failed to read request data");
return None;
}
if let Ok(req) = deserialize_from_slice::<Req>(&buf) {
info!("Received request: {:?}", req);
let (tx, rx) = mpsc::channel();
std::thread::spawn(move || {
for response in rx {
trace!("Sending response: {:?}", response);
match serialize_to_vec(&StreamResponse::Data(response)) {
Ok(resp_buf) => {
let len_bytes = (resp_buf.len() as u32).to_le_bytes();
if stream.write_all(&len_bytes).is_err() {
error!("Failed to send response length");
break;
}
if stream.write_all(&resp_buf).is_err() {
error!("Failed to send response data");
break;
}
if stream.flush().is_err() {
error!("Failed to flush stream");
break;
}
}
Err(e) => {
error!("Failed to serialize response: {}", e);
break;
}
}
}
match serialize_to_vec(&StreamResponse::<Res>::EndOfStream) {
Ok(end_buf) => {
let len_bytes = (end_buf.len() as u32).to_le_bytes();
if stream.write_all(&len_bytes).is_err() {
error!("Failed to send EndOfStream length");
}
if stream.write_all(&end_buf).is_err() {
error!("Failed to send EndOfStream data");
}
let _ = stream.flush();
info!("Stream ended successfully");
}
Err(e) => error!("Failed to serialize EndOfStream: {}", e),
}
});
Some((req, tx))
} else {
error!("Failed to deserialize request");
None
}
}
pub fn start_server<Req, Res, F>(socket_name: impl ToString, handler: F) -> std::io::Result<()>
where
Req: for<'de> Deserialize<'de> + Send + 'static + std::fmt::Debug,
Res: Serialize + Send + 'static + std::fmt::Debug,
F: Fn(Req, Sender<Res>) + Send + Sync + Clone + 'static,
{
let socket_path = PathBuf::from(format!("/tmp/{}.sock", socket_name.to_string()));
let _ = std::fs::remove_file(&socket_path);
let listener = UnixListener::bind(&socket_path)?;
info!("Server started on {:?}", socket_path);
for stream in listener.incoming() {
match stream {
Ok(stream) => {
info!("Accepted connection");
if let Some((req, tx)) = process(stream) {
(handler.clone())(req, tx);
}
}
Err(e) => {
error!("Failed to accept connection: {}", e);
continue;
}
}
}
Ok(())
}
pub struct RequestStream<Req, Res> {
unix_stream: tokio::net::UnixListener,
_req: PhantomData<Req>,
_res: PhantomData<Res>,
}
impl<Req, Res> RequestStream<Req, Res>
where
Req: for<'de> Deserialize<'de> + Send + 'static + std::fmt::Debug,
Res: Serialize + Send + 'static + std::fmt::Debug,
{
pub async fn new(app: impl ToString) -> std::io::Result<Self> {
let socket_path = PathBuf::from(format!("/tmp/{}.sock", app.to_string()));
let _ = std::fs::remove_file(&socket_path);
let listener = tokio::net::UnixListener::bind(&socket_path)?;
info!("Server started on {:?}", socket_path);
Ok(Self { unix_stream: listener, _req: PhantomData, _res: PhantomData })
}
}
impl<Req, Res> futures::Stream for RequestStream<Req, Res>
where
Req: for<'de> Deserialize<'de> + Send + 'static + std::fmt::Debug,
Res: Serialize + Send + 'static + std::fmt::Debug,
{
type Item = std::io::Result<(Req, Sender<Res>)>;
fn poll_next(self: std::pin::Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> std::task::Poll<Option<Self::Item>> {
match self.unix_stream.poll_accept(cx) {
std::task::Poll::Ready(v) => std::task::Poll::Ready(v.ok().map(|(stream, _)| {
match process(stream.into_std().expect("Failed to construct stdio UnixStream")) {
Some((req, tx)) => Ok((req, tx)),
None => Err(std::io::Error::new(std::io::ErrorKind::InvalidData, "Handshake error")),
}
})),
std::task::Poll::Pending => std::task::Poll::Pending,
}
}
}
pub async fn start_stream<Req, Res, F>(socket_name: impl ToString) -> std::io::Result<RequestStream<Req, Res>>
where
Req: for<'de> Deserialize<'de> + Send + 'static + std::fmt::Debug,
Res: Serialize + Send + 'static + std::fmt::Debug,
{
RequestStream::new(socket_name).await
}
pub fn send_command<App, Req, Res, H>(app: App, command: &Req, handler: Option<H>) -> std::io::Result<()>
where
App: ToString,
Req: Serialize,
Res: for<'de> Deserialize<'de> + std::fmt::Debug, H: Fn(Res) + Send + 'static,
{
let socket_path = PathBuf::from(format!("/tmp/{}.sock", app.to_string()));
info!("Connecting to server at {:?}", socket_path);
let mut stream = UnixStream::connect(&socket_path)?;
let data = serialize_to_vec(command).expect("Serialization failed");
let len_bytes = (data.len() as u32).to_le_bytes();
stream.write_all(&len_bytes)?;
stream.write_all(&data)?;
stream.flush()?;
info!("Command sent");
let mut reader = BufReader::new(stream);
loop {
let mut len_buf = [0u8; 4];
if reader.read_exact(&mut len_buf).is_err() {
error!("Failed to read response length");
break;
}
let response_len = u32::from_le_bytes(len_buf) as usize;
let mut buf = vec![0u8; response_len];
if reader.read_exact(&mut buf).is_err() {
error!("Failed to read response data");
break;
}
match deserialize_from_slice::<StreamResponse<Res>>(&buf) {
Ok(StreamResponse::Data(response)) => {
info!("Received response: {:?}", response);
if let Some(ref handler) = handler {
handler(response);
}
}
Ok(StreamResponse::EndOfStream) => {
info!("End of stream received");
break;
}
Err(e) => {
error!("Failed to deserialize response: {}", e);
break;
}
}
}
Ok(())
}