use std::os::unix::fs::PermissionsExt;
use std::path::PathBuf;
use std::sync::atomic::{AtomicI64, AtomicU64, Ordering};
use std::sync::Arc;
use tokio::net::UnixListener;
use tokio::sync::{mpsc, Semaphore};
use tokio_util::sync::CancellationToken;
use tokio_util::task::TaskTracker;
use crate::ipc::message::{IpcFrame, IpcResponse};
use crate::ipc::ResponseRouter;
#[derive(Debug, Clone)]
pub struct IpcServerConfig {
pub socket_path: PathBuf,
pub agent_id: String,
pub max_connections: usize,
pub inbound_channel_capacity: usize,
}
impl IpcServerConfig {
pub fn from_runtime_config(config: &crate::config::RuntimeConfig) -> Self {
Self {
socket_path: PathBuf::from(format!("/tmp/aa-runtime-{}.sock", config.agent_id)),
agent_id: config.agent_id.clone(),
max_connections: config.ipc_max_connections,
inbound_channel_capacity: 256,
}
}
}
pub struct IpcServer {
config: IpcServerConfig,
listener: UnixListener,
}
impl IpcServer {
pub fn bind(config: IpcServerConfig) -> std::io::Result<Self> {
let path = &config.socket_path;
if path.exists() {
std::fs::remove_file(path)?;
tracing::info!(path = %path.display(), "removed stale socket file");
}
let listener = {
let prev_umask = unsafe { libc::umask(0o077) };
let result = UnixListener::bind(path);
unsafe { libc::umask(prev_umask) };
result?
};
std::fs::set_permissions(path, std::fs::Permissions::from_mode(0o600))?;
debug_assert!(
config.agent_id.is_empty() || path.to_string_lossy().contains(config.agent_id.as_str()),
"IPC socket path must contain the configured agent id"
);
tracing::info!(
path = %path.display(),
agent_id = %config.agent_id,
max_connections = config.max_connections,
"IPC server bound"
);
Ok(Self { config, listener })
}
pub async fn run(
self,
tracker: TaskTracker,
token: CancellationToken,
inbound_tx: mpsc::Sender<(u64, IpcFrame)>,
active_connections: Arc<AtomicI64>,
response_router: ResponseRouter,
verified_identities: crate::ipc::VerifiedIdentityStore,
) {
let semaphore = Arc::new(Semaphore::new(self.config.max_connections));
let listener = self.listener;
let socket_path = self.config.socket_path.clone();
let inbound_channel_capacity = self.config.inbound_channel_capacity;
let max_connections = self.config.max_connections;
let next_conn_id = Arc::new(AtomicU64::new(0));
let runtime_uid = crate::ipc::peercred::current_runtime_uid();
let expected_key = crate::ipc::handshake::expected_verifying_key(&self.config.agent_id);
tracing::info!("IPC server accept loop started");
loop {
tokio::select! {
_ = token.cancelled() => {
tracing::info!("IPC server shutting down — cancellation received");
break;
}
result = listener.accept() => {
match result {
Err(e) => {
tracing::error!(error = %e, "accept error");
continue;
}
Ok((stream, _addr)) => {
match stream.peer_cred() {
Ok(cred) => {
let peer_uid = cred.uid();
if !crate::ipc::peercred::peer_uid_is_allowed(peer_uid, runtime_uid) {
tracing::warn!(
peer_uid,
runtime_uid,
"rejecting IPC connection — peer UID does not match runtime UID"
);
drop(stream);
continue;
}
}
Err(e) => {
tracing::warn!(
error = %e,
"rejecting IPC connection — could not read peer credentials"
);
drop(stream);
continue;
}
}
let permit = match Arc::clone(&semaphore).try_acquire_owned() {
Ok(p) => p,
Err(_) => {
tracing::warn!(
max = max_connections,
"connection limit reached — dropping new connection"
);
drop(stream);
continue;
}
};
let connection_id = next_conn_id.fetch_add(1, Ordering::Relaxed);
let frame_tx = inbound_tx.clone();
let conn_token = token.child_token();
let conn_router = Arc::clone(&response_router);
let conn_verified = Arc::clone(&verified_identities);
let conn_active = Arc::clone(&active_connections);
spawn_connection(
&tracker,
stream,
frame_tx,
conn_token,
permit,
conn_active,
connection_id,
conn_router,
conn_verified,
expected_key,
inbound_channel_capacity,
);
}
}
}
}
}
if let Err(e) = std::fs::remove_file(&socket_path) {
tracing::warn!(error = %e, "failed to remove socket file on shutdown");
}
tracing::info!("IPC server accept loop stopped");
}
}
const HANDSHAKE_TIMEOUT: std::time::Duration = std::time::Duration::from_secs(5);
#[allow(clippy::too_many_arguments)]
pub(super) fn spawn_connection(
tracker: &TaskTracker,
stream: tokio::net::UnixStream,
frame_tx: mpsc::Sender<(u64, IpcFrame)>,
token: CancellationToken,
permit: tokio::sync::OwnedSemaphorePermit,
active_connections: Arc<AtomicI64>,
connection_id: u64,
response_router: ResponseRouter,
verified_identities: crate::ipc::VerifiedIdentityStore,
expected_key: ed25519_dalek::VerifyingKey,
inbound_channel_capacity: usize,
) {
let tracker_inner = tracker.clone();
tracker.spawn(async move {
let _permit = permit;
let (mut read_half, mut write_half) = stream.into_split();
let verified_identity = match tokio::time::timeout(
HANDSHAKE_TIMEOUT,
perform_handshake(&mut read_half, &mut write_half, &expected_key),
)
.await
{
Ok(Some(identity)) => {
tracing::debug!(connection_id, "IPC handshake succeeded");
identity
}
Ok(None) => {
tracing::warn!(connection_id, "IPC handshake failed — dropping connection");
return;
}
Err(_) => {
tracing::warn!(connection_id, "IPC handshake timed out — dropping connection");
return;
}
};
let (resp_tx, resp_rx) = mpsc::channel::<IpcResponse>(inbound_channel_capacity);
response_router.write().await.insert(connection_id, resp_tx.clone());
verified_identities
.write()
.await
.insert(connection_id, verified_identity);
active_connections.fetch_add(1, Ordering::Relaxed);
let reader_token = token.clone();
tracker_inner.spawn(async move {
run_reader(
read_half,
frame_tx,
reader_token,
active_connections,
connection_id,
response_router,
verified_identities,
)
.await;
});
let _resp_tx = resp_tx;
tracker_inner.spawn(async move {
run_writer(write_half, resp_rx, token).await;
});
});
}
pub(super) async fn perform_handshake(
read_half: &mut tokio::net::unix::OwnedReadHalf,
write_half: &mut tokio::net::unix::OwnedWriteHalf,
expected_key: &ed25519_dalek::VerifyingKey,
) -> Option<aa_security::sdk_identity::VerifiedSdkIdentity> {
use crate::ipc::handshake;
use aa_proto::assembly::ipc::v1::HandshakeChallenge;
let nonce = handshake::generate_nonce();
let challenge = IpcResponse::HandshakeChallenge(HandshakeChallenge { nonce: nonce.to_vec() });
if let Err(e) = super::codec::write_response(write_half, challenge).await {
tracing::warn!(error = %e, "failed to send handshake challenge");
return None;
}
match super::codec::read_frame(read_half).await {
Ok(IpcFrame::HandshakeProof(proof)) => {
match handshake::verify_proof(&nonce, &proof, expected_key) {
Some(identity) => Some(identity),
None => {
tracing::warn!("handshake proof did not verify against the expected agent key");
None
}
}
}
Ok(other) => {
tracing::warn!(
frame = ?std::mem::discriminant(&other),
"first IPC frame was not a handshake proof — rejecting unauthenticated peer"
);
None
}
Err(e) => {
tracing::warn!(error = %e, "failed to read handshake proof");
None
}
}
}
pub(super) async fn run_reader(
mut stream: tokio::net::unix::OwnedReadHalf,
frame_tx: mpsc::Sender<(u64, IpcFrame)>,
token: CancellationToken,
active_connections: Arc<AtomicI64>,
connection_id: u64,
response_router: ResponseRouter,
verified_identities: crate::ipc::VerifiedIdentityStore,
) {
loop {
tokio::select! {
_ = token.cancelled() => {
tracing::debug!("reader task cancelled");
break;
}
result = super::codec::read_frame(&mut stream) => {
match result {
Ok(frame) => {
if frame_tx.send((connection_id, frame)).await.is_err() {
tracing::debug!("inbound channel closed — reader exiting");
break;
}
}
Err(super::codec::CodecError::Io(e))
if e.kind() == std::io::ErrorKind::UnexpectedEof
|| e.kind() == std::io::ErrorKind::ConnectionReset =>
{
tracing::debug!("SDK client disconnected");
break;
}
Err(e) => {
tracing::warn!(error = %e, "frame decode error — closing connection");
break;
}
}
}
}
}
response_router.write().await.remove(&connection_id);
verified_identities.write().await.remove(&connection_id);
token.cancel(); active_connections.fetch_sub(1, Ordering::Relaxed);
}
pub(super) async fn run_writer(
mut stream: tokio::net::unix::OwnedWriteHalf,
mut resp_rx: mpsc::Receiver<IpcResponse>,
token: CancellationToken,
) {
loop {
tokio::select! {
_ = token.cancelled() => {
tracing::debug!("writer task cancelled");
break;
}
maybe_resp = resp_rx.recv() => {
match maybe_resp {
None => {
tracing::debug!("response channel closed — writer exiting");
break;
}
Some(response) => {
if let Err(e) = super::codec::write_response(&mut stream, response).await {
tracing::warn!(error = %e, "failed to write response — closing connection");
break;
}
}
}
}
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::ipc::codec::{TAG_EVENT_REPORT, TAG_HEARTBEAT, TAG_POLICY_QUERY};
use crate::ipc::message::IpcFrame;
use aa_proto::assembly::audit::v1::AuditEvent;
use aa_proto::assembly::policy::v1::CheckActionRequest;
use prost::Message;
use std::time::Duration;
use tokio::net::UnixStream;
use tokio::sync::mpsc;
fn temp_socket_path(name: &str) -> std::path::PathBuf {
std::path::PathBuf::from(format!("/tmp/aa-runtime-{TEST_AGENT_ID}-{name}.sock"))
}
const TEST_AGENT_ID: &str = "test-agent";
async fn do_client_handshake(stream: &mut UnixStream, agent_id: &str) {
use crate::ipc::codec::{TAG_HANDSHAKE_CHALLENGE, TAG_HANDSHAKE_PROOF};
use aa_proto::assembly::ipc::v1::{HandshakeChallenge, HandshakeProof};
use ed25519_dalek::Signer;
use sha2::{Digest, Sha256};
use tokio::io::{AsyncReadExt, AsyncWriteExt};
let tag = stream.read_u8().await.expect("read challenge tag");
assert_eq!(tag, TAG_HANDSHAKE_CHALLENGE, "expected handshake challenge first");
let len = read_varint_stream(stream).await;
let mut buf = vec![0u8; len];
stream.read_exact(&mut buf).await.expect("read challenge payload");
let challenge = HandshakeChallenge::decode(buf.as_ref()).expect("decode challenge");
let seed: [u8; 32] = Sha256::digest(agent_id.as_bytes()).into();
let sk = ed25519_dalek::SigningKey::from_bytes(&seed);
let sdk_version = String::new();
let mut signed_payload = challenge.nonce.clone();
signed_payload.extend_from_slice(sdk_version.as_bytes());
let sig = sk.sign(&signed_payload);
let proof = HandshakeProof {
agent_did: format!("did:key:{agent_id}"),
public_key: hex::encode(sk.verifying_key().to_bytes()),
signature: sig.to_bytes().to_vec(),
sdk_version,
};
let payload = proof.encode_to_vec();
stream.write_u8(TAG_HANDSHAKE_PROOF).await.unwrap();
write_varint_stream(stream, payload.len() as u64).await;
stream.write_all(&payload).await.unwrap();
stream.flush().await.unwrap();
}
async fn read_varint_stream(stream: &mut UnixStream) -> usize {
use tokio::io::AsyncReadExt;
let mut result: u64 = 0;
let mut shift = 0u32;
loop {
let byte = stream.read_u8().await.unwrap();
result |= ((byte & 0x7F) as u64) << shift;
if byte & 0x80 == 0 {
break;
}
shift += 7;
}
result as usize
}
async fn write_varint_stream(stream: &mut UnixStream, mut value: u64) {
use tokio::io::AsyncWriteExt;
loop {
let byte = (value & 0x7F) as u8;
value >>= 7;
if value == 0 {
stream.write_u8(byte).await.unwrap();
break;
} else {
stream.write_u8(byte | 0x80).await.unwrap();
}
}
}
async fn connect_client(path: &std::path::Path) -> UnixStream {
for _ in 0..20 {
if let Ok(mut stream) = UnixStream::connect(path).await {
do_client_handshake(&mut stream, TEST_AGENT_ID).await;
return stream;
}
tokio::time::sleep(Duration::from_millis(10)).await;
}
panic!("could not connect to test IPC server at {}", path.display());
}
async fn start_server(
socket_path: std::path::PathBuf,
token: CancellationToken,
active_connections: Arc<AtomicI64>,
) -> (
mpsc::Receiver<(u64, IpcFrame)>,
crate::ipc::ResponseRouter,
crate::ipc::VerifiedIdentityStore,
) {
let config = IpcServerConfig {
socket_path,
agent_id: "test-agent".to_string(),
max_connections: 64,
inbound_channel_capacity: 16,
};
let server = IpcServer::bind(config).expect("bind failed");
let (tx, rx) = mpsc::channel(16);
let router = crate::ipc::new_response_router();
let router_clone = Arc::clone(&router);
let verified = crate::ipc::new_verified_identity_store();
let verified_clone = Arc::clone(&verified);
let tracker = TaskTracker::new();
let tracker_clone = tracker.clone();
tracker.spawn(async move {
server
.run(
tracker_clone,
token,
tx,
active_connections,
router_clone,
verified_clone,
)
.await;
});
(rx, router, verified)
}
async fn write_raw_frame(stream: &mut tokio::net::unix::OwnedWriteHalf, tag: u8, payload: &[u8]) {
use tokio::io::AsyncWriteExt;
stream.write_u8(tag).await.unwrap();
let mut len = payload.len() as u64;
loop {
let byte = (len & 0x7F) as u8;
len >>= 7;
if len == 0 {
stream.write_u8(byte).await.unwrap();
break;
} else {
stream.write_u8(byte | 0x80).await.unwrap();
}
}
stream.write_all(payload).await.unwrap();
stream.flush().await.unwrap();
}
#[tokio::test]
async fn bind_creates_socket_with_0600_and_agent_scoped_path() {
let socket_path = temp_socket_path("perm-check");
let _ = std::fs::remove_file(&socket_path);
let config = IpcServerConfig {
socket_path: socket_path.clone(),
agent_id: "perm-check".to_string(),
max_connections: 8,
inbound_channel_capacity: 8,
};
let _server = IpcServer::bind(config).expect("bind failed");
let mode = std::fs::metadata(&socket_path).unwrap().permissions().mode() & 0o777;
assert_eq!(mode, 0o600, "socket must be owner-only (0600), got {mode:o}");
assert!(
socket_path.to_string_lossy().contains("perm-check"),
"socket path must be scoped to the agent id"
);
let _ = std::fs::remove_file(&socket_path);
}
#[tokio::test]
async fn heartbeat_frame_arrives_on_inbound_channel() {
let socket_path = temp_socket_path("heartbeat");
let token = CancellationToken::new();
let counter = Arc::new(AtomicI64::new(0));
let (mut rx, _router, _verified) = start_server(socket_path.clone(), token.clone(), Arc::clone(&counter)).await;
let client = connect_client(&socket_path).await;
let (_, mut write_half) = client.into_split();
use tokio::io::AsyncWriteExt;
write_half.write_u8(TAG_HEARTBEAT).await.unwrap();
write_half.flush().await.unwrap();
let (_conn_id, frame) = tokio::time::timeout(Duration::from_secs(2), rx.recv())
.await
.expect("timed out waiting for frame")
.expect("channel closed");
assert!(matches!(frame, IpcFrame::Heartbeat));
token.cancel();
}
#[tokio::test]
async fn policy_query_arrives_decoded_on_inbound_channel() {
let socket_path = temp_socket_path("policy-query");
let token = CancellationToken::new();
let counter = Arc::new(AtomicI64::new(0));
let (mut rx, _router, _verified) = start_server(socket_path.clone(), token.clone(), Arc::clone(&counter)).await;
let client = connect_client(&socket_path).await;
let (_, mut write_half) = client.into_split();
let request = CheckActionRequest {
trace_id: "trace-xyz".to_string(),
..Default::default()
};
let payload = request.encode_to_vec();
write_raw_frame(&mut write_half, TAG_POLICY_QUERY, &payload).await;
let (_conn_id, frame) = tokio::time::timeout(Duration::from_secs(2), rx.recv())
.await
.expect("timed out")
.expect("channel closed");
match frame {
IpcFrame::PolicyQuery(decoded) => assert_eq!(decoded.trace_id, "trace-xyz"),
other => panic!("expected PolicyQuery, got {other:?}"),
}
token.cancel();
}
#[tokio::test]
async fn event_report_arrives_decoded_on_inbound_channel() {
let socket_path = temp_socket_path("event-report");
let token = CancellationToken::new();
let counter = Arc::new(AtomicI64::new(0));
let (mut rx, _router, _verified) = start_server(socket_path.clone(), token.clone(), Arc::clone(&counter)).await;
let client = connect_client(&socket_path).await;
let (_, mut write_half) = client.into_split();
let event = AuditEvent {
event_id: "evt-456".to_string(),
..Default::default()
};
let payload = event.encode_to_vec();
write_raw_frame(&mut write_half, TAG_EVENT_REPORT, &payload).await;
let (_conn_id, frame) = tokio::time::timeout(Duration::from_secs(2), rx.recv())
.await
.expect("timed out")
.expect("channel closed");
match frame {
IpcFrame::EventReport(decoded) => assert_eq!(decoded.event_id, "evt-456"),
other => panic!("expected EventReport, got {other:?}"),
}
token.cancel();
}
#[tokio::test]
async fn concurrent_connections_up_to_limit() {
let socket_path = temp_socket_path("concurrent");
let token = CancellationToken::new();
let counter = Arc::new(AtomicI64::new(0));
let (_rx, _router, _verified) = start_server(socket_path.clone(), token.clone(), Arc::clone(&counter)).await;
const CONN_COUNT: usize = 5;
let mut clients = Vec::new();
for _ in 0..CONN_COUNT {
clients.push(connect_client(&socket_path).await);
}
assert_eq!(clients.len(), CONN_COUNT);
token.cancel();
}
#[tokio::test]
#[ignore]
async fn round_trip_latency_under_1ms() {
let socket_path = temp_socket_path("latency");
let token = CancellationToken::new();
let counter = Arc::new(AtomicI64::new(0));
let (mut rx, _router, _verified) = start_server(socket_path.clone(), token.clone(), Arc::clone(&counter)).await;
let client = connect_client(&socket_path).await;
let (_, mut write_half) = client.into_split();
const ITERATIONS: u32 = 1000;
let start = std::time::Instant::now();
for _ in 0..ITERATIONS {
use tokio::io::AsyncWriteExt;
write_half.write_u8(TAG_HEARTBEAT).await.unwrap();
write_half.flush().await.unwrap();
tokio::time::timeout(Duration::from_millis(100), rx.recv())
.await
.expect("timed out")
.expect("channel closed"); }
let elapsed = start.elapsed();
let avg_us = elapsed.as_micros() / ITERATIONS as u128;
println!("Average round-trip: {avg_us} µs");
assert!(avg_us < 1000, "average round-trip {avg_us} µs exceeded 1ms threshold");
token.cancel();
}
#[tokio::test]
async fn active_connections_increments_on_accept() {
let socket_path = temp_socket_path("counter-increment");
let token = CancellationToken::new();
let counter = Arc::new(AtomicI64::new(0));
let (_rx, _router, _verified) = start_server(socket_path.clone(), token.clone(), Arc::clone(&counter)).await;
const CONN_COUNT: usize = 3;
let mut clients = Vec::new();
for _ in 0..CONN_COUNT {
clients.push(connect_client(&socket_path).await);
}
let mut observed = 0i64;
for _ in 0..50 {
observed = counter.load(Ordering::Relaxed);
if observed == CONN_COUNT as i64 {
break;
}
tokio::time::sleep(Duration::from_millis(10)).await;
}
assert_eq!(
observed, CONN_COUNT as i64,
"counter should equal number of accepted connections"
);
token.cancel();
drop(clients);
}
#[tokio::test]
async fn active_connections_decrements_on_disconnect() {
let socket_path = temp_socket_path("counter-decrement");
let token = CancellationToken::new();
let counter = Arc::new(AtomicI64::new(0));
let (_rx, _router, _verified) = start_server(socket_path.clone(), token.clone(), Arc::clone(&counter)).await;
let client = connect_client(&socket_path).await;
for _ in 0..50 {
if counter.load(Ordering::Relaxed) == 1 {
break;
}
tokio::time::sleep(Duration::from_millis(10)).await;
}
assert_eq!(
counter.load(Ordering::Relaxed),
1,
"counter should be 1 after one connection"
);
drop(client);
let mut observed = 1i64;
for _ in 0..100 {
observed = counter.load(Ordering::Relaxed);
if observed == 0 {
break;
}
tokio::time::sleep(Duration::from_millis(10)).await;
}
assert_eq!(observed, 0, "counter should return to 0 after client disconnects");
token.cancel();
}
#[tokio::test]
async fn response_router_has_entry_after_accept() {
let socket_path = temp_socket_path("router-insert");
let token = CancellationToken::new();
let counter = Arc::new(AtomicI64::new(0));
let (_rx, router, _verified) = start_server(socket_path.clone(), token.clone(), Arc::clone(&counter)).await;
let _client = connect_client(&socket_path).await;
for _ in 0..50 {
if counter.load(Ordering::Relaxed) == 1 {
break;
}
tokio::time::sleep(Duration::from_millis(10)).await;
}
let map = router.read().await;
assert_eq!(map.len(), 1, "router should contain one entry after one connection");
token.cancel();
}
#[tokio::test]
async fn response_router_entry_removed_after_disconnect() {
let socket_path = temp_socket_path("router-remove");
let token = CancellationToken::new();
let counter = Arc::new(AtomicI64::new(0));
let (_rx, router, _verified) = start_server(socket_path.clone(), token.clone(), Arc::clone(&counter)).await;
let client = connect_client(&socket_path).await;
for _ in 0..50 {
if router.read().await.len() == 1 {
break;
}
tokio::time::sleep(Duration::from_millis(10)).await;
}
assert_eq!(router.read().await.len(), 1);
drop(client);
let mut observed_len = 1usize;
for _ in 0..100 {
observed_len = router.read().await.len();
if observed_len == 0 {
break;
}
tokio::time::sleep(Duration::from_millis(10)).await;
}
assert_eq!(
observed_len, 0,
"router entry should be removed after client disconnects"
);
token.cancel();
}
#[tokio::test]
async fn verified_identity_recorded_after_handshake() {
let socket_path = temp_socket_path("verified-insert");
let token = CancellationToken::new();
let counter = Arc::new(AtomicI64::new(0));
let (_rx, _router, verified) = start_server(socket_path.clone(), token.clone(), Arc::clone(&counter)).await;
let _client = connect_client(&socket_path).await;
for _ in 0..50 {
if counter.load(Ordering::Relaxed) == 1 {
break;
}
tokio::time::sleep(Duration::from_millis(10)).await;
}
let map = verified.read().await;
assert_eq!(
map.len(),
1,
"verified-identity store should hold one entry after a handshake"
);
token.cancel();
}
#[tokio::test]
async fn verified_identity_carries_signed_sdk_version() {
use crate::ipc::codec::{TAG_HANDSHAKE_CHALLENGE, TAG_HANDSHAKE_PROOF};
use aa_proto::assembly::ipc::v1::{HandshakeChallenge, HandshakeProof};
use ed25519_dalek::Signer;
use sha2::{Digest, Sha256};
use tokio::io::{AsyncReadExt, AsyncWriteExt};
let socket_path = temp_socket_path("verified-version");
let token = CancellationToken::new();
let counter = Arc::new(AtomicI64::new(0));
let (_rx, _router, verified) = start_server(socket_path.clone(), token.clone(), Arc::clone(&counter)).await;
let mut stream = {
let mut s = None;
for _ in 0..20 {
if let Ok(st) = UnixStream::connect(&socket_path).await {
s = Some(st);
break;
}
tokio::time::sleep(Duration::from_millis(10)).await;
}
s.expect("connect failed")
};
let tag = stream.read_u8().await.unwrap();
assert_eq!(tag, TAG_HANDSHAKE_CHALLENGE);
let clen = read_varint_stream(&mut stream).await;
let mut cbuf = vec![0u8; clen];
stream.read_exact(&mut cbuf).await.unwrap();
let challenge = HandshakeChallenge::decode(cbuf.as_ref()).unwrap();
let sdk_version = "2.5.0".to_string();
let seed: [u8; 32] = Sha256::digest(TEST_AGENT_ID.as_bytes()).into();
let sk = ed25519_dalek::SigningKey::from_bytes(&seed);
let mut signed_payload = challenge.nonce.clone();
signed_payload.extend_from_slice(sdk_version.as_bytes());
let proof = HandshakeProof {
agent_did: format!("did:key:{TEST_AGENT_ID}"),
public_key: hex::encode(sk.verifying_key().to_bytes()),
signature: sk.sign(&signed_payload).to_bytes().to_vec(),
sdk_version: sdk_version.clone(),
};
let payload = proof.encode_to_vec();
stream.write_u8(TAG_HANDSHAKE_PROOF).await.unwrap();
write_varint_stream(&mut stream, payload.len() as u64).await;
stream.write_all(&payload).await.unwrap();
stream.flush().await.unwrap();
for _ in 0..50 {
if counter.load(Ordering::Relaxed) == 1 {
break;
}
tokio::time::sleep(Duration::from_millis(10)).await;
}
let map = verified.read().await;
let identity = map.values().next().expect("a verified identity must be recorded");
assert_eq!(
identity.version.as_deref(),
Some("2.5.0"),
"the signed SDK version must be recorded as the verified reference"
);
token.cancel();
}
#[tokio::test]
async fn verified_identity_removed_after_disconnect() {
let socket_path = temp_socket_path("verified-remove");
let token = CancellationToken::new();
let counter = Arc::new(AtomicI64::new(0));
let (_rx, _router, verified) = start_server(socket_path.clone(), token.clone(), Arc::clone(&counter)).await;
let client = connect_client(&socket_path).await;
for _ in 0..50 {
if verified.read().await.len() == 1 {
break;
}
tokio::time::sleep(Duration::from_millis(10)).await;
}
assert_eq!(verified.read().await.len(), 1);
drop(client);
let mut observed_len = 1usize;
for _ in 0..100 {
observed_len = verified.read().await.len();
if observed_len == 0 {
break;
}
tokio::time::sleep(Duration::from_millis(10)).await;
}
assert_eq!(
observed_len, 0,
"verified-identity entry should be removed after client disconnects"
);
token.cancel();
}
#[tokio::test]
async fn violation_event_triggers_alert_within_100ms() {
use crate::ipc::codec::{TAG_EVENT_REPORT, TAG_VIOLATION_ALERT};
use crate::pipeline::{PipelineConfig, PipelineMetrics};
use aa_proto::assembly::audit::v1::{audit_event::Detail, PolicyViolation};
use prost::Message;
use std::sync::Arc;
use tokio::io::{AsyncReadExt, AsyncWriteExt};
let socket_path = temp_socket_path("violation-alert");
let token = CancellationToken::new();
let counter = Arc::new(AtomicI64::new(0));
let (inbound_rx, router, verified) =
start_server(socket_path.clone(), token.clone(), Arc::clone(&counter)).await;
let pipeline_config = PipelineConfig {
input_buffer: 64,
batch_size: 100,
flush_interval: std::time::Duration::from_secs(60),
broadcast_capacity: 64,
agent_id: "test-agent".to_string(),
enforcement: crate::pipeline::enforcement::EnforcementConfig::default(),
gateway_fail_closed: true,
min_sdk_version: None,
};
let pipeline_metrics = Arc::new(PipelineMetrics::default());
let (broadcast_tx, _broadcast_rx) = tokio::sync::broadcast::channel::<crate::pipeline::PipelineEvent>(64);
let pipeline_router = Arc::clone(&router);
let pipeline_token = token.clone();
tokio::spawn(crate::pipeline::run(
inbound_rx,
broadcast_tx,
pipeline_config,
pipeline_metrics,
pipeline_token,
Arc::new(crate::policy::PolicyRules::default()),
pipeline_router,
crate::approval::ApprovalQueue::new(),
None,
crate::op_control::OpControlStore::new(),
Arc::new(std::sync::atomic::AtomicU64::new(0)),
Arc::clone(&verified),
));
let client = connect_client(&socket_path).await;
let (mut read_half, mut write_half) = client.into_split();
for _ in 0..50 {
if counter.load(Ordering::Relaxed) == 1 {
break;
}
tokio::time::sleep(Duration::from_millis(10)).await;
}
let violation = PolicyViolation {
policy_rule: "test-rule".to_string(),
blocked_action: "FILE_OPERATION".to_string(),
reason: "blocked".to_string(),
latency_ms: 0,
};
let event = AuditEvent {
detail: Some(Detail::Violation(violation)),
..Default::default()
};
let payload = event.encode_to_vec();
write_half.write_u8(TAG_EVENT_REPORT).await.unwrap();
let mut len = payload.len() as u64;
loop {
let byte = (len & 0x7F) as u8;
len >>= 7;
if len == 0 {
write_half.write_u8(byte).await.unwrap();
break;
} else {
write_half.write_u8(byte | 0x80).await.unwrap();
}
}
write_half.write_all(&payload).await.unwrap();
write_half.flush().await.unwrap();
let tag = tokio::time::timeout(Duration::from_millis(100), read_half.read_u8())
.await
.expect("ViolationAlert did not arrive within 100ms")
.expect("read error");
assert_eq!(tag, TAG_VIOLATION_ALERT, "expected ViolationAlert tag (4)");
token.cancel();
}
#[tokio::test]
async fn normal_event_produces_no_response() {
use crate::ipc::codec::TAG_EVENT_REPORT;
use crate::pipeline::{PipelineConfig, PipelineMetrics};
use prost::Message;
use std::sync::Arc;
use tokio::io::{AsyncReadExt, AsyncWriteExt};
let socket_path = temp_socket_path("no-alert");
let token = CancellationToken::new();
let counter = Arc::new(AtomicI64::new(0));
let (inbound_rx, router, verified) =
start_server(socket_path.clone(), token.clone(), Arc::clone(&counter)).await;
let pipeline_config = PipelineConfig {
input_buffer: 64,
batch_size: 100,
flush_interval: std::time::Duration::from_secs(60),
broadcast_capacity: 64,
agent_id: "test-agent".to_string(),
enforcement: crate::pipeline::enforcement::EnforcementConfig::default(),
gateway_fail_closed: true,
min_sdk_version: None,
};
let pipeline_metrics = Arc::new(PipelineMetrics::default());
let (broadcast_tx, _broadcast_rx) = tokio::sync::broadcast::channel::<crate::pipeline::PipelineEvent>(64);
let pipeline_router = Arc::clone(&router);
let pipeline_token = token.clone();
tokio::spawn(crate::pipeline::run(
inbound_rx,
broadcast_tx,
pipeline_config,
pipeline_metrics,
pipeline_token,
Arc::new(crate::policy::PolicyRules::default()),
pipeline_router,
crate::approval::ApprovalQueue::new(),
None,
crate::op_control::OpControlStore::new(),
Arc::new(std::sync::atomic::AtomicU64::new(0)),
Arc::clone(&verified),
));
let client = connect_client(&socket_path).await;
let (mut read_half, mut write_half) = client.into_split();
for _ in 0..50 {
if counter.load(Ordering::Relaxed) == 1 {
break;
}
tokio::time::sleep(Duration::from_millis(10)).await;
}
let event = AuditEvent::default();
let payload = event.encode_to_vec();
write_half.write_u8(TAG_EVENT_REPORT).await.unwrap();
let mut len = payload.len() as u64;
loop {
let byte = (len & 0x7F) as u8;
len >>= 7;
if len == 0 {
write_half.write_u8(byte).await.unwrap();
break;
} else {
write_half.write_u8(byte | 0x80).await.unwrap();
}
}
write_half.write_all(&payload).await.unwrap();
write_half.flush().await.unwrap();
let result = tokio::time::timeout(Duration::from_millis(100), read_half.read_u8()).await;
assert!(
result.is_err(),
"expected no response for a normal event, but received one"
);
token.cancel();
}
#[tokio::test]
async fn approval_round_trip_over_ipc_socket() {
use crate::approval::ApprovalDecision as RuntimeApprovalDecision;
use crate::ipc::codec::{TAG_APPROVAL_DECISION, TAG_POLICY_QUERY, TAG_POLICY_RESPONSE};
use crate::pipeline::{PipelineConfig, PipelineMetrics};
use crate::policy::{PolicyRule, PolicyRules};
use aa_proto::assembly::common::v1::{ActionType, Decision};
use aa_proto::assembly::event::v1::ApprovalDecision as ProtoApprovalDecision;
use aa_proto::assembly::policy::v1::{CheckActionRequest, CheckActionResponse};
use prost::Message;
use std::sync::Arc;
use tokio::io::AsyncReadExt;
let socket_path = temp_socket_path("approval-roundtrip");
let token = CancellationToken::new();
let counter = Arc::new(AtomicI64::new(0));
let (inbound_rx, router, verified) =
start_server(socket_path.clone(), token.clone(), Arc::clone(&counter)).await;
let policy = Arc::new(PolicyRules {
rules: vec![PolicyRule {
name: "approve-tool".to_string(),
requires_approval_actions: vec![ActionType::ToolCall.as_str_name().to_string()],
approval_timeout_secs: 60,
..Default::default()
}],
});
let approval_queue = crate::approval::ApprovalQueue::new();
let queue_ref = Arc::clone(&approval_queue);
let pipeline_config = PipelineConfig {
input_buffer: 64,
batch_size: 100,
flush_interval: std::time::Duration::from_secs(60),
broadcast_capacity: 64,
agent_id: "test-agent".to_string(),
enforcement: crate::pipeline::enforcement::EnforcementConfig::default(),
gateway_fail_closed: true,
min_sdk_version: None,
};
let pipeline_metrics = Arc::new(PipelineMetrics::default());
let (broadcast_tx, _broadcast_rx) = tokio::sync::broadcast::channel::<crate::pipeline::PipelineEvent>(64);
let pipeline_router = Arc::clone(&router);
let pipeline_token = token.clone();
tokio::spawn(crate::pipeline::run(
inbound_rx,
broadcast_tx,
pipeline_config,
pipeline_metrics,
pipeline_token,
policy,
pipeline_router,
approval_queue,
None,
crate::op_control::OpControlStore::new(),
Arc::new(std::sync::atomic::AtomicU64::new(0)),
Arc::clone(&verified),
));
let client = connect_client(&socket_path).await;
let (mut read_half, mut write_half) = client.into_split();
for _ in 0..50 {
if counter.load(Ordering::Relaxed) == 1 {
break;
}
tokio::time::sleep(Duration::from_millis(10)).await;
}
let request = CheckActionRequest {
action_type: ActionType::ToolCall as i32,
trace_id: "trace-approval-roundtrip".to_string(),
..Default::default()
};
let payload = request.encode_to_vec();
write_raw_frame(&mut write_half, TAG_POLICY_QUERY, &payload).await;
let tag = tokio::time::timeout(Duration::from_millis(200), read_half.read_u8())
.await
.expect("PENDING response timed out")
.expect("read error");
assert_eq!(tag, TAG_POLICY_RESPONSE, "expected PolicyResponse tag");
let mut resp_len: u64 = 0;
let mut shift = 0u32;
loop {
let byte = read_half.read_u8().await.unwrap();
resp_len |= ((byte & 0x7F) as u64) << shift;
if byte & 0x80 == 0 {
break;
}
shift += 7;
}
let mut resp_buf = vec![0u8; resp_len as usize];
read_half.read_exact(&mut resp_buf).await.unwrap();
let pending_resp = CheckActionResponse::decode(resp_buf.as_ref()).unwrap();
assert_eq!(pending_resp.decision, Decision::Pending as i32);
assert!(!pending_resp.approval_id.is_empty(), "approval_id must be set");
let approval_id = uuid::Uuid::parse_str(&pending_resp.approval_id).expect("invalid UUID in approval_id");
queue_ref
.decide(
approval_id,
RuntimeApprovalDecision::Approved {
by: "cli-operator".to_string(),
reason: Some("approved via IPC test".to_string()),
},
)
.expect("decide should succeed");
let tag2 = tokio::time::timeout(Duration::from_millis(200), read_half.read_u8())
.await
.expect("ApprovalDecision response timed out")
.expect("read error");
assert_eq!(tag2, TAG_APPROVAL_DECISION, "expected ApprovalDecision tag");
let mut dec_len: u64 = 0;
shift = 0;
loop {
let byte = read_half.read_u8().await.unwrap();
dec_len |= ((byte & 0x7F) as u64) << shift;
if byte & 0x80 == 0 {
break;
}
shift += 7;
}
let mut dec_buf = vec![0u8; dec_len as usize];
read_half.read_exact(&mut dec_buf).await.unwrap();
let decision = ProtoApprovalDecision::decode(dec_buf.as_ref()).unwrap();
assert!(decision.approved, "decision should be approved");
assert_eq!(decision.decided_by, "cli-operator");
assert_eq!(decision.approval_id, approval_id.to_string());
token.cancel();
}
#[tokio::test]
async fn event_without_handshake_is_not_dispatched() {
let socket_path = temp_socket_path("no-handshake-event");
let token = CancellationToken::new();
let counter = Arc::new(AtomicI64::new(0));
let (mut rx, _router, _verified) = start_server(socket_path.clone(), token.clone(), Arc::clone(&counter)).await;
let stream = {
let mut s = None;
for _ in 0..20 {
if let Ok(st) = UnixStream::connect(&socket_path).await {
s = Some(st);
break;
}
tokio::time::sleep(Duration::from_millis(10)).await;
}
s.expect("connect failed")
};
let (_read_half, mut write_half) = stream.into_split();
let event = AuditEvent {
event_id: "forged-no-handshake".to_string(),
..Default::default()
};
let payload = event.encode_to_vec();
write_raw_frame(&mut write_half, TAG_EVENT_REPORT, &payload).await;
let result = tokio::time::timeout(Duration::from_millis(300), rx.recv()).await;
assert!(
result.is_err(),
"an un-handshaked peer's event must never be dispatched"
);
token.cancel();
}
#[tokio::test]
async fn forged_handshake_signature_is_rejected() {
use crate::ipc::codec::TAG_HANDSHAKE_PROOF;
use aa_proto::assembly::ipc::v1::HandshakeProof;
use ed25519_dalek::Signer;
use sha2::{Digest, Sha256};
let socket_path = temp_socket_path("forged-handshake");
let token = CancellationToken::new();
let counter = Arc::new(AtomicI64::new(0));
let (mut rx, _router, _verified) = start_server(socket_path.clone(), token.clone(), Arc::clone(&counter)).await;
let mut stream = {
let mut s = None;
for _ in 0..20 {
if let Ok(st) = UnixStream::connect(&socket_path).await {
s = Some(st);
break;
}
tokio::time::sleep(Duration::from_millis(10)).await;
}
s.expect("connect failed")
};
use tokio::io::{AsyncReadExt, AsyncWriteExt};
let tag = stream.read_u8().await.unwrap();
assert_eq!(tag, crate::ipc::codec::TAG_HANDSHAKE_CHALLENGE);
let clen = read_varint_stream(&mut stream).await;
let mut cbuf = vec![0u8; clen];
stream.read_exact(&mut cbuf).await.unwrap();
let challenge = aa_proto::assembly::ipc::v1::HandshakeChallenge::decode(cbuf.as_ref()).unwrap();
let seed: [u8; 32] = Sha256::digest(TEST_AGENT_ID.as_bytes()).into();
let sk = ed25519_dalek::SigningKey::from_bytes(&seed);
let mut sig = sk.sign(&challenge.nonce).to_bytes().to_vec();
sig[0] ^= 0xFF;
let proof = HandshakeProof {
agent_did: format!("did:key:{TEST_AGENT_ID}"),
public_key: hex::encode(sk.verifying_key().to_bytes()),
signature: sig,
sdk_version: String::new(),
};
let payload = proof.encode_to_vec();
stream.write_u8(TAG_HANDSHAKE_PROOF).await.unwrap();
write_varint_stream(&mut stream, payload.len() as u64).await;
stream.write_all(&payload).await.unwrap();
stream.flush().await.unwrap();
let (_r, mut w) = stream.into_split();
let event = AuditEvent {
event_id: "forged-sig".to_string(),
..Default::default()
};
let payload = event.encode_to_vec();
let _ = w.write_u8(TAG_EVENT_REPORT).await;
let mut len = payload.len() as u64;
loop {
let byte = (len & 0x7F) as u8;
len >>= 7;
if len == 0 {
let _ = w.write_u8(byte).await;
break;
} else {
let _ = w.write_u8(byte | 0x80).await;
}
}
let _ = w.write_all(&payload).await;
let _ = w.flush().await;
let result = tokio::time::timeout(Duration::from_millis(300), rx.recv()).await;
assert!(result.is_err(), "a forged-signature peer must never be dispatched");
token.cancel();
}
#[tokio::test]
async fn valid_handshake_then_event_is_dispatched() {
let socket_path = temp_socket_path("valid-handshake-event");
let token = CancellationToken::new();
let counter = Arc::new(AtomicI64::new(0));
let (mut rx, _router, _verified) = start_server(socket_path.clone(), token.clone(), Arc::clone(&counter)).await;
let client = connect_client(&socket_path).await;
let (_read_half, mut write_half) = client.into_split();
let event = AuditEvent {
event_id: "authenticated-event".to_string(),
..Default::default()
};
let payload = event.encode_to_vec();
write_raw_frame(&mut write_half, TAG_EVENT_REPORT, &payload).await;
let (_conn_id, frame) = tokio::time::timeout(Duration::from_secs(2), rx.recv())
.await
.expect("authenticated event timed out")
.expect("channel closed");
match frame {
IpcFrame::EventReport(decoded) => assert_eq!(decoded.event_id, "authenticated-event"),
other => panic!("expected EventReport, got {other:?}"),
}
token.cancel();
}
}