use std::path::Path;
use tokio::io::{AsyncReadExt, AsyncWriteExt};
use tokio::net::UnixListener;
use tokio::net::UnixStream;
pub async fn write_json<T>(stream: &mut UnixStream, value: &T) -> std::io::Result<()>
where
T: serde::Serialize,
{
let bytes = serde_json::to_vec(value).unwrap();
stream.write_all(&bytes).await
}
pub async fn read_json<'a, T>(
stream: &mut UnixStream,
buffer: &'a mut Vec<u8>,
) -> std::io::Result<T>
where
T: serde::Deserialize<'a>,
{
buffer.clear();
let n = stream.read_buf(buffer).await?;
buffer.truncate(n);
serde_json::from_slice(buffer)
.map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidInput, e))
}
fn other_error<T>(msg: String) -> std::io::Result<T> {
use std::io::{Error, ErrorKind};
Err(Error::new(ErrorKind::Other, msg))
}
pub fn create_unix_socket(path: &Path) -> std::io::Result<UnixListener> {
if path.exists() {
use std::os::unix::fs::FileTypeExt;
let meta = std::fs::metadata(path)?;
if !meta.file_type().is_socket() {
return other_error(format!("path {path:?} exists but is not a socket"));
}
std::fs::remove_file(path)?;
}
let error = match UnixListener::bind(path) {
Ok(listener) => return Ok(listener),
Err(e) => e,
};
if let Some(parent) = path.parent() {
if !parent.exists() {
let msg = format!(
r"Could not create observe socket at {:?} because its parent directory does not exist",
&path
);
return other_error(msg);
}
}
let msg = format!(
"Could not create observe socket at {:?}: {:?}",
&path, error
);
other_error(msg)
}
#[cfg(test)]
mod tests {
use tokio::net::UnixListener;
use super::*;
#[tokio::test]
async fn write_then_read_is_identity() {
let path = std::env::temp_dir().join("ntp-test-stream-1");
if path.exists() {
std::fs::remove_file(&path).unwrap();
}
let listener = UnixListener::bind(&path).unwrap();
let mut writer = UnixStream::connect(&path).await.unwrap();
let (mut reader, _) = listener.accept().await.unwrap();
let object = vec![0usize, 10];
write_json(&mut writer, &object).await.unwrap();
let mut buf = Vec::new();
let output = read_json::<Vec<usize>>(&mut reader, &mut buf)
.await
.unwrap();
assert_eq!(object, output);
assert!(!buf.is_empty());
}
#[tokio::test]
async fn invalid_input_is_io_error() {
let path = std::env::temp_dir().join("ntp-test-stream-5");
if path.exists() {
std::fs::remove_file(&path).unwrap();
}
let listener = UnixListener::bind(&path).unwrap();
let mut writer = UnixStream::connect(&path).await.unwrap();
let (mut reader, _) = listener.accept().await.unwrap();
let data = [0; 24];
writer.write_all(&data).await.unwrap();
let mut buf = Vec::new();
let output = read_json::<Vec<usize>>(&mut reader, &mut buf)
.await
.unwrap_err();
assert_eq!(output.kind(), std::io::ErrorKind::InvalidInput);
assert!(!buf.is_empty());
}
}