#[path = "socket-utils/connect_socket.rs"]
mod connect_socket;
#[path = "socket-utils/event_handler.rs"]
mod event_handler;
#[path = "socket-utils/event_sender.rs"]
mod event_sender;
mod logger;
mod os;
mod socket;
#[path = "socket-utils/socket_async_read.rs"]
mod socket_async_read;
#[path = "socket-utils/socket_async_write.rs"]
mod socket_async_write;
mod target;
pub use os::{OsDetectionError, OsKind};
pub use socket::{socket_init, DetectionError, Socket, SocketError};
pub use target::{TargetDetectionError, TargetKind};
#[cfg(test)]
mod tests {
use std::env;
use super::*;
fn with_env(vars: &[(&str, Option<&str>)]) -> EnvGuard {
let keys: Vec<String> = vars.iter().map(|(k, _)| (*k).to_string()).collect();
for (k, v) in vars {
match v {
Some(s) => unsafe { env::set_var(k, s) },
None => unsafe { env::remove_var(k) },
}
}
EnvGuard { keys }
}
struct EnvGuard {
keys: Vec<String>,
}
impl Drop for EnvGuard {
fn drop(&mut self) {
for k in &self.keys {
unsafe { env::remove_var(k) }
}
}
}
#[test]
fn connect_non_url_with_env_unset_returns_detection_error() {
let rt = tokio::runtime::Builder::new_current_thread()
.enable_all()
.build()
.unwrap();
let _g = with_env(&[("target", None), ("os", None)]);
let result = rt.block_on(socket_init("/tmp/foo.sock"));
let err = match result {
Err(e) => e,
Ok(_) => panic!("expected Detection error, got Ok"),
};
assert!(matches!(err, SocketError::Detection(_)), "expected Detection error");
}
#[test]
fn connect_ws_url_does_not_require_env() {
let rt = tokio::runtime::Builder::new_current_thread()
.enable_all()
.build()
.unwrap();
let _g = with_env(&[("target", None), ("os", None)]);
let result = rt.block_on(socket_init("ws://127.0.0.1:0/"));
if let Err(SocketError::Detection(_)) = result {
panic!("socket_init(ws://...) should not return Detection when env is unset");
}
}
#[test]
fn connect_local_endpoint_with_env_set_does_not_return_detection_error() {
let rt = tokio::runtime::Builder::new_current_thread()
.enable_all()
.build()
.unwrap();
let _g = with_env(&[("target", Some("local")), ("os", Some("windows"))]);
let result = rt.block_on(socket_init("mypipe"));
assert!(
!matches!(result, Err(SocketError::Detection(_))),
"socket_init(local endpoint) with env set should not return Detection"
);
}
#[cfg(unix)]
#[test]
fn polled_unix_socket_returns_connection_lost_after_peer_drops() {
use std::io::ErrorKind;
use tokio::net::UnixListener;
use tokio::sync::mpsc;
let rt = tokio::runtime::Builder::new_current_thread()
.enable_all()
.build()
.unwrap();
rt.block_on(async {
let path = std::env::temp_dir().join("ipcez_poll_test.sock");
let _ = std::fs::remove_file(&path);
let listener = UnixListener::bind(&path).unwrap();
let path_str = path.to_str().unwrap().to_string();
let path_clone = path.clone();
let server = tokio::spawn(async move {
let (stream, _) = listener.accept().await.unwrap();
drop(stream);
});
let _g = with_env(&[("target", Some("local")), ("os", Some("linux"))]);
let client = socket_init(&path_str).await.unwrap();
let (tx, mut rx) = mpsc::unbounded_channel();
client.message_handler(move |result| {
let _ = tx.send(result);
async {}
});
server.await.unwrap();
tokio::time::sleep(std::time::Duration::from_millis(15)).await;
let result = rx.recv().await.expect("message_handler delivers error");
let err = result.expect_err("expected connection lost error");
let io_err = match &err {
SocketError::Io(e) => e,
_ => panic!("expected Io error, got {:?}", err),
};
assert_eq!(io_err.kind(), ErrorKind::ConnectionReset);
assert!(
io_err.to_string().contains("connection lost"),
"expected 'connection lost', got {}",
io_err
);
let _ = std::fs::remove_file(&path_clone);
});
}
#[cfg(unix)]
#[test]
fn local_unix_framed_message_round_trip() {
use tokio::io::{AsyncReadExt, AsyncWriteExt};
use tokio::sync::mpsc;
let rt = tokio::runtime::Builder::new_current_thread()
.enable_all()
.build()
.unwrap();
rt.block_on(async {
let path = std::env::temp_dir().join("ipcez_framed_test.sock");
let _ = std::fs::remove_file(&path);
let listener = tokio::net::UnixListener::bind(&path).unwrap();
let path_str = path.to_str().unwrap().to_string();
let path_clone = path.clone();
let server = tokio::spawn(async move {
let (mut stream, _) = listener.accept().await.unwrap();
let mut len_buf = [0u8; 4];
stream.read_exact(&mut len_buf).await.unwrap();
let n = u32::from_be_bytes(len_buf) as usize;
let mut buf = vec![0u8; n];
stream.read_exact(&mut buf).await.unwrap();
assert_eq!(buf.as_slice(), b"ping");
let reply = b"pong";
stream.write_all(&(reply.len() as u32).to_be_bytes()).await.unwrap();
stream.write_all(reply).await.unwrap();
stream.flush().await.unwrap();
let _ = std::fs::remove_file(&path_clone);
});
let _g = with_env(&[("target", Some("local")), ("os", Some("linux"))]);
let client = socket_init(&path_str).await.unwrap();
let (tx, mut rx) = mpsc::unbounded_channel();
client.message_handler(move |result| {
let _ = tx.send(result);
async {}
});
client.send_message(b"ping").await.unwrap();
let received = rx.recv().await.expect("one message");
let buf = received.expect("no error");
assert_eq!(buf.as_slice(), b"pong");
server.await.unwrap();
});
}
#[cfg(unix)]
#[test]
fn local_unix_send_message_timeout_when_recipient_does_not_ack() {
use tokio::io::AsyncReadExt;
let rt = tokio::runtime::Builder::new_current_thread()
.enable_all()
.build()
.unwrap();
rt.block_on(async {
let path = std::env::temp_dir().join("ipcez_ack_timeout_test.sock");
let _ = std::fs::remove_file(&path);
let listener = tokio::net::UnixListener::bind(&path).unwrap();
let path_str = path.to_str().unwrap().to_string();
let path_clone = path.clone();
let server = tokio::spawn(async move {
let (mut stream, _) = listener.accept().await.unwrap();
let mut len_buf = [0u8; 4];
let _ = stream.read_exact(&mut len_buf).await.unwrap();
let n = u32::from_be_bytes(len_buf) as usize;
let mut buf = vec![0u8; n];
let _ = stream.read_exact(&mut buf).await.unwrap();
let _ = std::fs::remove_file(&path_clone);
});
let _g = with_env(&[("target", Some("local")), ("os", Some("linux"))]);
let client = socket_init(&path_str).await.unwrap();
let result = client.send_message(b"hello").await;
let err = result.expect_err("expected RecipientAckTimeout");
assert!(
matches!(err, SocketError::RecipientAckTimeout),
"expected RecipientAckTimeout, got {:?}",
err
);
server.await.unwrap();
});
}
#[test]
fn connect_local_with_os_unset_returns_detection_error() {
let rt = tokio::runtime::Builder::new_current_thread()
.enable_all()
.build()
.unwrap();
let _g = with_env(&[("os", None), ("target", Some("local"))]);
let result = rt.block_on(socket_init("mypipe"));
let err = match result {
Err(e) => e,
Ok(_) => panic!("expected Detection error"),
};
assert!(matches!(err, SocketError::Detection(_)), "expected Detection error");
assert!(
err.to_string().contains("os"),
"error should notify about missing/invalid 'os'; got: {}",
err
);
}
#[test]
fn connect_local_with_os_invalid_returns_detection_error() {
let rt = tokio::runtime::Builder::new_current_thread()
.enable_all()
.build()
.unwrap();
let _g = with_env(&[("os", Some("macos")), ("target", Some("local"))]);
let result = rt.block_on(socket_init("mypipe"));
let err = match result {
Err(e) => e,
Ok(_) => panic!("expected Detection error"),
};
assert!(matches!(err, SocketError::Detection(_)));
assert!(err.to_string().contains("os") && err.to_string().contains("linux"));
}
#[test]
fn connect_local_with_target_unset_returns_detection_error() {
let rt = tokio::runtime::Builder::new_current_thread()
.enable_all()
.build()
.unwrap();
let _g = with_env(&[("os", Some("windows")), ("target", None)]);
let result = rt.block_on(socket_init("mypipe"));
let err = match result {
Err(e) => e,
Ok(_) => panic!("expected Detection error"),
};
assert!(matches!(err, SocketError::Detection(_)));
assert!(err.to_string().contains("target"));
}
#[test]
fn connect_local_with_target_invalid_returns_detection_error() {
let rt = tokio::runtime::Builder::new_current_thread()
.enable_all()
.build()
.unwrap();
let _g = with_env(&[("target", Some("cloud")), ("os", Some("windows"))]);
let result = rt.block_on(socket_init("mypipe"));
let err = match result {
Err(e) => e,
Ok(_) => panic!("expected Detection error"),
};
assert!(matches!(err, SocketError::Detection(_)));
assert!(err.to_string().contains("target") && err.to_string().contains("local"));
}
#[test]
fn connect_local_with_valid_env_does_not_return_detection() {
let rt = tokio::runtime::Builder::new_current_thread()
.enable_all()
.build()
.unwrap();
let _g = with_env(&[("target", Some("local")), ("os", Some("windows"))]);
let result = rt.block_on(socket_init("mypipe"));
assert!(!matches!(result, Err(SocketError::Detection(_))));
}
}