use std::collections::HashSet;
use std::path::Path;
use std::sync::Arc;
use std::time::Duration;
use tokio::io::{AsyncReadExt, AsyncWriteExt, BufReader, BufWriter};
use tokio::net::{UnixListener, UnixStream};
use tokio::sync::{mpsc, oneshot, Mutex, RwLock};
use tracing::{debug, error, info, warn};
use crate::v2::client::FlowState;
use crate::v2::pool::CHANNEL_BUFFER_SIZE;
use crate::v2::uds::{read_message, write_message, MessageType, UdsCapabilities};
use crate::v2::{AgentCapabilities, AgentPool, PROTOCOL_VERSION_2};
use crate::{AgentProtocolError, AgentResponse};
#[derive(Debug, Clone)]
pub struct ReverseConnectionConfig {
pub backlog: u32,
pub handshake_timeout: Duration,
pub max_connections_per_agent: usize,
pub allowed_agents: HashSet<String>,
pub require_auth: bool,
pub request_timeout: Duration,
}
impl Default for ReverseConnectionConfig {
fn default() -> Self {
Self {
backlog: 128,
handshake_timeout: Duration::from_secs(10),
max_connections_per_agent: 4,
allowed_agents: HashSet::new(),
require_auth: false,
request_timeout: Duration::from_secs(30),
}
}
}
#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
pub struct RegistrationRequest {
pub protocol_version: u32,
pub agent_id: String,
pub capabilities: UdsCapabilities,
pub auth_token: Option<String>,
pub metadata: Option<serde_json::Value>,
}
#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
pub struct RegistrationResponse {
pub success: bool,
pub error: Option<String>,
pub proxy_id: String,
pub proxy_version: String,
pub connection_id: String,
}
pub struct ReverseConnectionListener {
listener: UnixListener,
config: ReverseConnectionConfig,
socket_path: String,
}
impl ReverseConnectionListener {
pub async fn bind_uds(
path: impl AsRef<Path>,
config: ReverseConnectionConfig,
) -> Result<Self, AgentProtocolError> {
let path = path.as_ref();
let socket_path = path.to_string_lossy().to_string();
if path.exists() {
std::fs::remove_file(path).map_err(|e| {
AgentProtocolError::ConnectionFailed(format!(
"Failed to remove existing socket {}: {}",
socket_path, e
))
})?;
}
let listener = UnixListener::bind(path).map_err(|e| {
AgentProtocolError::ConnectionFailed(format!(
"Failed to bind to {}: {}",
socket_path, e
))
})?;
info!(path = %socket_path, "Reverse connection listener bound");
Ok(Self {
listener,
config,
socket_path,
})
}
pub fn socket_path(&self) -> &str {
&self.socket_path
}
pub async fn accept_one(&self, pool: &AgentPool) -> Result<String, AgentProtocolError> {
let (stream, _addr) =
self.listener.accept().await.map_err(|e| {
AgentProtocolError::ConnectionFailed(format!("Accept failed: {}", e))
})?;
debug!("Accepted reverse connection");
self.handle_connection(stream, pool).await
}
pub async fn accept_loop(self: Arc<Self>, pool: Arc<AgentPool>) {
info!(path = %self.socket_path, "Starting reverse connection accept loop");
loop {
match self.listener.accept().await {
Ok((stream, _addr)) => {
let listener = Arc::clone(&self);
let pool = Arc::clone(&pool);
tokio::spawn(async move {
match listener.handle_connection(stream, &pool).await {
Ok(agent_id) => {
info!(agent_id = %agent_id, "Reverse connection registered");
}
Err(e) => {
warn!(error = %e, "Failed to handle reverse connection");
}
}
});
}
Err(e) => {
error!(error = %e, "Accept failed");
tokio::time::sleep(Duration::from_millis(100)).await;
}
}
}
}
async fn handle_connection(
&self,
stream: UnixStream,
pool: &AgentPool,
) -> Result<String, AgentProtocolError> {
let (read_half, write_half) = stream.into_split();
let mut reader = BufReader::new(read_half);
let mut writer = BufWriter::new(write_half);
let registration = tokio::time::timeout(
self.config.handshake_timeout,
self.read_registration(&mut reader),
)
.await
.map_err(|_| AgentProtocolError::Timeout(self.config.handshake_timeout))??;
let agent_id = registration.agent_id.clone();
if let Err(e) = self.validate_registration(®istration) {
let response = RegistrationResponse {
success: false,
error: Some(e.to_string()),
proxy_id: "zentinel-proxy".to_string(),
proxy_version: env!("CARGO_PKG_VERSION").to_string(),
connection_id: String::new(),
};
self.send_registration_response(&mut writer, &response)
.await?;
return Err(e);
}
let connection_id = format!(
"{}-{:x}",
agent_id,
std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.map(|d| d.as_millis())
.unwrap_or(0)
);
let response = RegistrationResponse {
success: true,
error: None,
proxy_id: "zentinel-proxy".to_string(),
proxy_version: env!("CARGO_PKG_VERSION").to_string(),
connection_id: connection_id.clone(),
};
self.send_registration_response(&mut writer, &response)
.await?;
info!(
agent_id = %agent_id,
connection_id = %connection_id,
"Agent registration successful"
);
let capabilities: AgentCapabilities = registration.capabilities.into();
let client = ReverseConnectionClient::new(
agent_id.clone(),
connection_id,
capabilities.clone(),
reader,
writer,
self.config.request_timeout,
)
.await;
pool.add_reverse_connection(&agent_id, client, capabilities)
.await?;
Ok(agent_id)
}
async fn read_registration<R: AsyncReadExt + Unpin>(
&self,
reader: &mut R,
) -> Result<RegistrationRequest, AgentProtocolError> {
let (msg_type, payload) = read_message(reader).await?;
if msg_type != MessageType::HandshakeRequest {
return Err(AgentProtocolError::InvalidMessage(format!(
"Expected registration request (HandshakeRequest), got {:?}",
msg_type
)));
}
serde_json::from_slice(&payload)
.map_err(|e| AgentProtocolError::InvalidMessage(e.to_string()))
}
async fn send_registration_response<W: AsyncWriteExt + Unpin>(
&self,
writer: &mut W,
response: &RegistrationResponse,
) -> Result<(), AgentProtocolError> {
let payload = serde_json::to_vec(response)
.map_err(|e| AgentProtocolError::Serialization(e.to_string()))?;
write_message(writer, MessageType::HandshakeResponse, &payload).await
}
fn validate_registration(
&self,
registration: &RegistrationRequest,
) -> Result<(), AgentProtocolError> {
if registration.protocol_version != PROTOCOL_VERSION_2 {
return Err(AgentProtocolError::VersionMismatch {
expected: PROTOCOL_VERSION_2,
actual: registration.protocol_version,
});
}
if registration.agent_id.is_empty() {
return Err(AgentProtocolError::InvalidMessage(
"Agent ID cannot be empty".to_string(),
));
}
if !self.config.allowed_agents.is_empty()
&& !self.config.allowed_agents.contains(®istration.agent_id)
{
return Err(AgentProtocolError::InvalidMessage(format!(
"Agent '{}' is not in the allowed list",
registration.agent_id
)));
}
if self.config.require_auth && registration.auth_token.is_none() {
return Err(AgentProtocolError::InvalidMessage(
"Authentication required but no token provided".to_string(),
));
}
Ok(())
}
}
impl Drop for ReverseConnectionListener {
fn drop(&mut self) {
if let Err(e) = std::fs::remove_file(&self.socket_path) {
debug!(path = %self.socket_path, error = %e, "Failed to remove socket file on drop");
}
}
}
pub struct ReverseConnectionClient {
agent_id: String,
connection_id: String,
capabilities: RwLock<Option<AgentCapabilities>>,
pending: Arc<Mutex<std::collections::HashMap<String, oneshot::Sender<AgentResponse>>>>,
#[allow(clippy::type_complexity)]
outbound_tx: Mutex<Option<mpsc::Sender<(MessageType, Vec<u8>)>>>,
connected: RwLock<bool>,
timeout: Duration,
in_flight: std::sync::atomic::AtomicU64,
flow_state: Arc<RwLock<FlowState>>,
}
impl ReverseConnectionClient {
async fn new<R, W>(
agent_id: String,
connection_id: String,
capabilities: AgentCapabilities,
mut reader: BufReader<R>,
mut writer: BufWriter<W>,
timeout: Duration,
) -> Self
where
R: AsyncReadExt + Unpin + Send + 'static,
W: AsyncWriteExt + Unpin + Send + 'static,
{
let pending: Arc<Mutex<std::collections::HashMap<String, oneshot::Sender<AgentResponse>>>> =
Arc::new(Mutex::new(std::collections::HashMap::new()));
let (tx, mut rx) = mpsc::channel::<(MessageType, Vec<u8>)>(CHANNEL_BUFFER_SIZE);
let agent_id_clone = agent_id.clone();
tokio::spawn(async move {
while let Some((msg_type, payload)) = rx.recv().await {
if let Err(e) = write_message(&mut writer, msg_type, &payload).await {
error!(
agent_id = %agent_id_clone,
error = %e,
"Failed to write to reverse connection"
);
break;
}
}
debug!(agent_id = %agent_id_clone, "Reverse connection writer ended");
});
let pending_clone = Arc::clone(&pending);
let agent_id_clone = agent_id.clone();
tokio::spawn(async move {
loop {
match read_message(&mut reader).await {
Ok((msg_type, payload)) => {
if msg_type == MessageType::AgentResponse {
if let Ok(response) = serde_json::from_slice::<AgentResponse>(&payload)
{
let correlation_id = response
.audit
.custom
.get("correlation_id")
.and_then(|v| v.as_str())
.unwrap_or("")
.to_string();
if let Some(sender) =
pending_clone.lock().await.remove(&correlation_id)
{
let _ = sender.send(response);
}
}
}
}
Err(e) => {
if !matches!(e, AgentProtocolError::ConnectionClosed) {
error!(
agent_id = %agent_id_clone,
error = %e,
"Error reading from reverse connection"
);
}
break;
}
}
}
debug!(agent_id = %agent_id_clone, "Reverse connection reader ended");
});
Self {
agent_id,
connection_id,
capabilities: RwLock::new(Some(capabilities)),
pending,
outbound_tx: Mutex::new(Some(tx)),
connected: RwLock::new(true),
timeout,
in_flight: std::sync::atomic::AtomicU64::new(0),
flow_state: Arc::new(RwLock::new(FlowState::Normal)),
}
}
pub fn agent_id(&self) -> &str {
&self.agent_id
}
pub fn connection_id(&self) -> &str {
&self.connection_id
}
pub async fn is_connected(&self) -> bool {
*self.connected.read().await
}
pub async fn capabilities(&self) -> Option<AgentCapabilities> {
self.capabilities.read().await.clone()
}
pub async fn is_paused(&self) -> bool {
matches!(*self.flow_state.read().await, FlowState::Paused)
}
pub async fn can_accept_requests(&self) -> bool {
!self.is_paused().await
}
pub async fn send_request_headers(
&self,
correlation_id: &str,
event: &crate::RequestHeadersEvent,
) -> Result<AgentResponse, AgentProtocolError> {
self.send_event(MessageType::RequestHeaders, correlation_id, event)
.await
}
pub async fn send_request_body_chunk(
&self,
correlation_id: &str,
event: &crate::RequestBodyChunkEvent,
) -> Result<AgentResponse, AgentProtocolError> {
self.send_event(MessageType::RequestBodyChunk, correlation_id, event)
.await
}
pub async fn send_response_headers(
&self,
correlation_id: &str,
event: &crate::ResponseHeadersEvent,
) -> Result<AgentResponse, AgentProtocolError> {
self.send_event(MessageType::ResponseHeaders, correlation_id, event)
.await
}
pub async fn send_response_body_chunk(
&self,
correlation_id: &str,
event: &crate::ResponseBodyChunkEvent,
) -> Result<AgentResponse, AgentProtocolError> {
self.send_event(MessageType::ResponseBodyChunk, correlation_id, event)
.await
}
async fn send_event<T: serde::Serialize>(
&self,
msg_type: MessageType,
correlation_id: &str,
event: &T,
) -> Result<AgentResponse, AgentProtocolError> {
let (tx, rx) = oneshot::channel();
self.pending
.lock()
.await
.insert(correlation_id.to_string(), tx);
let mut payload = serde_json::to_value(event)
.map_err(|e| AgentProtocolError::Serialization(e.to_string()))?;
if let Some(obj) = payload.as_object_mut() {
obj.insert(
"correlation_id".to_string(),
serde_json::Value::String(correlation_id.to_string()),
);
}
let payload_bytes = serde_json::to_vec(&payload)
.map_err(|e| AgentProtocolError::Serialization(e.to_string()))?;
{
let outbound = self.outbound_tx.lock().await;
if let Some(tx) = outbound.as_ref() {
tx.send((msg_type, payload_bytes))
.await
.map_err(|_| AgentProtocolError::ConnectionClosed)?;
} else {
return Err(AgentProtocolError::ConnectionClosed);
}
}
self.in_flight
.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
let response = tokio::time::timeout(self.timeout, rx)
.await
.map_err(|_| {
self.pending
.try_lock()
.ok()
.map(|mut p| p.remove(correlation_id));
AgentProtocolError::Timeout(self.timeout)
})?
.map_err(|_| AgentProtocolError::ConnectionClosed)?;
self.in_flight
.fetch_sub(1, std::sync::atomic::Ordering::Relaxed);
Ok(response)
}
pub async fn cancel_request(
&self,
correlation_id: &str,
reason: super::client::CancelReason,
) -> Result<(), AgentProtocolError> {
let cancel = serde_json::json!({
"correlation_id": correlation_id,
"reason": reason as i32,
"timestamp_ms": now_ms(),
});
let payload = serde_json::to_vec(&cancel)
.map_err(|e| AgentProtocolError::Serialization(e.to_string()))?;
let outbound = self.outbound_tx.lock().await;
if let Some(tx) = outbound.as_ref() {
tx.send((MessageType::Cancel, payload))
.await
.map_err(|_| AgentProtocolError::ConnectionClosed)?;
}
self.pending.lock().await.remove(correlation_id);
Ok(())
}
pub async fn cancel_all(
&self,
reason: super::client::CancelReason,
) -> Result<usize, AgentProtocolError> {
let pending_ids: Vec<String> = self.pending.lock().await.keys().cloned().collect();
let count = pending_ids.len();
for correlation_id in pending_ids {
let _ = self.cancel_request(&correlation_id, reason).await;
}
Ok(count)
}
pub async fn close(&self) -> Result<(), AgentProtocolError> {
*self.connected.write().await = false;
*self.outbound_tx.lock().await = None;
Ok(())
}
pub fn in_flight(&self) -> u64 {
self.in_flight.load(std::sync::atomic::Ordering::Relaxed)
}
}
fn now_ms() -> u64 {
std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.map(|d| d.as_millis() as u64)
.unwrap_or(0)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_config_default() {
let config = ReverseConnectionConfig::default();
assert_eq!(config.backlog, 128);
assert_eq!(config.max_connections_per_agent, 4);
assert!(!config.require_auth);
}
#[test]
fn test_registration_request_serialization() {
let request = RegistrationRequest {
protocol_version: 2,
agent_id: "test-agent".to_string(),
capabilities: UdsCapabilities {
agent_id: "test-agent".to_string(),
name: "Test Agent".to_string(),
version: "1.0.0".to_string(),
supported_events: vec![1, 2],
features: Default::default(),
limits: Default::default(),
},
auth_token: None,
metadata: None,
};
let json = serde_json::to_string(&request).unwrap();
let parsed: RegistrationRequest = serde_json::from_str(&json).unwrap();
assert_eq!(parsed.agent_id, "test-agent");
assert_eq!(parsed.protocol_version, 2);
}
#[test]
fn test_registration_response_serialization() {
let response = RegistrationResponse {
success: true,
error: None,
proxy_id: "zentinel".to_string(),
proxy_version: "1.0.0".to_string(),
connection_id: "conn-123".to_string(),
};
let json = serde_json::to_string(&response).unwrap();
let parsed: RegistrationResponse = serde_json::from_str(&json).unwrap();
assert!(parsed.success);
assert_eq!(parsed.connection_id, "conn-123");
}
}