use std::collections::HashMap;
use std::path::PathBuf;
use std::sync::Arc;
use std::sync::atomic::{AtomicU32, Ordering};
use std::time::Duration;
use breaker_machines::CircuitBreaker;
use chrono_machines::{BackoffStrategy, ExponentialBackoff};
use tokio::io::{AsyncReadExt, AsyncWriteExt};
use tokio::net::UnixStream;
use tokio::sync::{RwLock, mpsc};
use tracing::{debug, error, info, warn};
use super::{
Boarding, Cargo, Disembark, MessageType, Moored, VERSION, decode_cargo, decode_header,
encode_boarding, encode_cargo, encode_disembark, encode_moored,
};
static NEXT_CONN_ID: AtomicU32 = AtomicU32::new(1);
pub fn next_conn_id() -> u32 {
NEXT_CONN_ID.fetch_add(1, Ordering::Relaxed)
}
const CONNECT_TIMEOUT: Duration = Duration::from_secs(5);
const HANDSHAKE_TIMEOUT: Duration = Duration::from_secs(10);
#[derive(Debug, thiserror::Error)]
pub enum DockingError {
#[error("Connection timeout after {0:?}")]
ConnectTimeout(Duration),
#[error("Handshake timeout after {0:?}")]
HandshakeTimeout(Duration),
#[error("Circuit breaker open for ship {0}")]
CircuitOpen(String),
#[error("IO error: {0}")]
Io(#[from] std::io::Error),
#[error("Protocol error: {0}")]
Protocol(String),
}
#[derive(Debug)]
pub enum ToShip {
Boarding(Boarding),
Cargo(Cargo),
Disembark(Disembark),
}
pub struct DockingConnector {
ship_name: String,
socket_path: PathBuf,
tx: mpsc::Sender<ToShip>,
cargo_receivers: Arc<RwLock<HashMap<u32, mpsc::Sender<Vec<u8>>>>>,
#[allow(dead_code)]
circuit: Arc<RwLock<CircuitBreaker>>,
}
impl DockingConnector {
fn create_circuit(ship_name: &str) -> CircuitBreaker {
CircuitBreaker::builder(ship_name)
.failure_threshold(5) .failure_window_secs(60.0) .half_open_timeout_secs(30.0) .success_threshold(2) .on_open({
let ship = ship_name.to_string();
move |_| {
warn!(ship = %ship, "Circuit breaker opened for docking");
}
})
.on_close({
let ship = ship_name.to_string();
move |_| {
info!(ship = %ship, "Circuit breaker closed for docking");
}
})
.build()
}
async fn connect_socket(socket_path: &PathBuf) -> Result<UnixStream, DockingError> {
match tokio::time::timeout(CONNECT_TIMEOUT, UnixStream::connect(socket_path)).await {
Ok(Ok(stream)) => Ok(stream),
Ok(Err(e)) => Err(DockingError::Io(e)),
Err(_) => Err(DockingError::ConnectTimeout(CONNECT_TIMEOUT)),
}
}
async fn perform_handshake(
stream: &mut UnixStream,
ship_name: &str,
config: HashMap<String, String>,
) -> Result<Vec<u8>, DockingError> {
let handshake_future = async {
let mut buf = vec![0u8; 4096];
let mut pending = Vec::new();
loop {
let n = stream.read(&mut buf).await.map_err(DockingError::Io)?;
if n == 0 {
return Err(DockingError::Protocol(
"Ship closed connection during docking".to_string(),
));
}
pending.extend_from_slice(&buf[..n]);
if pending.len() >= 5 {
let (msg_type, payload_len) = decode_header(&pending).map_err(|e| {
DockingError::Protocol(format!("Header decode error: {}", e))
})?;
let total_len = 5 + payload_len;
if pending.len() >= total_len {
if msg_type != MessageType::Dock {
return Err(DockingError::Protocol(format!(
"Expected DOCK, got {:?}",
msg_type
)));
}
let payload = &pending[5..total_len];
let dock: super::Dock = serde_json::from_slice(payload).map_err(|e| {
DockingError::Protocol(format!("DOCK parse error: {}", e))
})?;
info!(
ship = %ship_name,
version = dock.version,
ship_reported = %dock.ship,
"Received DOCK"
);
let moored = Moored {
version: VERSION,
config,
};
let encoded = encode_moored(&moored);
stream.write_all(&encoded).await.map_err(DockingError::Io)?;
info!(
ship = %ship_name,
config_keys = ?moored.config.keys().collect::<Vec<_>>(),
"Sent MOORED - docking complete"
);
pending.drain(..total_len);
return Ok(pending);
}
}
}
};
match tokio::time::timeout(HANDSHAKE_TIMEOUT, handshake_future).await {
Ok(Ok(pending)) => Ok(pending),
Ok(Err(e)) => Err(e),
Err(_) => Err(DockingError::HandshakeTimeout(HANDSHAKE_TIMEOUT)),
}
}
pub async fn connect(
ship_name: &str,
socket_path: PathBuf,
config: HashMap<String, String>,
) -> anyhow::Result<Self> {
info!(ship = %ship_name, socket = %socket_path.display(), "Connecting to docking ship");
let circuit = Arc::new(RwLock::new(Self::create_circuit(ship_name)));
let backoff = ExponentialBackoff::default()
.max_attempts(3)
.base_delay_ms(200)
.max_delay_ms(2000);
let mut last_error = None;
let mut stream_result = None;
let mut pending_result = None;
for attempt in 1..=backoff.max_attempts {
{
let circuit_guard = circuit.read().await;
if circuit_guard.is_open() {
error!(ship = %ship_name, attempt, "Circuit breaker is open, refusing connection");
return Err(anyhow::anyhow!(DockingError::CircuitOpen(
ship_name.to_string()
)));
}
}
let result = async {
let mut stream = Self::connect_socket(&socket_path).await?;
info!(ship = %ship_name, attempt, "Connected to ship socket");
let pending =
Self::perform_handshake(&mut stream, ship_name, config.clone()).await?;
Ok::<(UnixStream, Vec<u8>), DockingError>((stream, pending))
}
.await;
let mut circuit_guard = circuit.write().await;
match &result {
Ok(_) => {
circuit_guard.call(|| Ok::<_, ()>(())).ok();
}
Err(_) => {
circuit_guard.call(|| Err::<(), _>(())).ok();
}
}
drop(circuit_guard);
match result {
Ok((stream, pending)) => {
stream_result = Some(stream);
pending_result = Some(pending);
break;
}
Err(e) => {
last_error = Some(e);
let should_retry = matches!(
last_error.as_ref().unwrap(),
DockingError::ConnectTimeout(_)
| DockingError::HandshakeTimeout(_)
| DockingError::Io(_)
);
if !should_retry || attempt >= backoff.max_attempts {
break;
}
let delay_ms = {
let mut rng = rand::rng();
backoff.delay(attempt, &mut rng)
};
if let Some(delay_ms) = delay_ms {
warn!(
ship = %ship_name,
attempt,
next_delay_ms = delay_ms,
error = %last_error.as_ref().unwrap(),
"Docking attempt failed, retrying"
);
tokio::time::sleep(Duration::from_millis(delay_ms)).await;
}
}
}
}
let (stream, pending) = match (stream_result, pending_result) {
(Some(s), Some(p)) => (s, p),
_ => {
let err = last_error.unwrap_or_else(|| {
DockingError::Protocol("Unknown error during connection".to_string())
});
error!(ship = %ship_name, error = %err, "All docking attempts failed");
return Err(anyhow::anyhow!(err));
}
};
let (reader, writer) = stream.into_split();
let (tx, rx) = mpsc::channel::<ToShip>(1024);
let cargo_receivers: Arc<RwLock<HashMap<u32, mpsc::Sender<Vec<u8>>>>> =
Arc::new(RwLock::new(HashMap::new()));
let ship_name_clone = ship_name.to_string();
tokio::spawn(Self::writer_task(ship_name_clone, writer, rx));
let ship_name_clone = ship_name.to_string();
let cargo_receivers_clone = cargo_receivers.clone();
tokio::spawn(Self::reader_task(
ship_name_clone,
reader,
pending,
cargo_receivers_clone,
));
Ok(Self {
ship_name: ship_name.to_string(),
socket_path,
tx,
cargo_receivers,
circuit,
})
}
async fn writer_task(
ship_name: String,
mut writer: tokio::net::unix::OwnedWriteHalf,
mut rx: mpsc::Receiver<ToShip>,
) {
while let Some(msg) = rx.recv().await {
let encoded = match msg {
ToShip::Boarding(boarding) => encode_boarding(&boarding),
ToShip::Cargo(cargo) => encode_cargo(&cargo),
ToShip::Disembark(disembark) => encode_disembark(&disembark),
};
if let Err(e) = writer.write_all(&encoded).await {
error!(ship = %ship_name, error = %e, "Failed to write to ship");
break;
}
}
debug!(ship = %ship_name, "Writer task ended");
}
async fn reader_task(
ship_name: String,
mut reader: tokio::net::unix::OwnedReadHalf,
mut pending: Vec<u8>,
cargo_receivers: Arc<RwLock<HashMap<u32, mpsc::Sender<Vec<u8>>>>>,
) {
let mut buf = vec![0u8; 64 * 1024];
loop {
let n = match reader.read(&mut buf).await {
Ok(0) => {
info!(ship = %ship_name, "Ship disconnected");
break;
}
Ok(n) => n,
Err(e) => {
error!(ship = %ship_name, error = %e, "Read error from ship");
break;
}
};
pending.extend_from_slice(&buf[..n]);
while pending.len() >= 5 {
let (msg_type, payload_len) = match decode_header(&pending) {
Ok(h) => h,
Err(e) => {
error!(ship = %ship_name, error = %e, "Failed to decode header");
break;
}
};
let total_len = 5 + payload_len;
if pending.len() < total_len {
break; }
let payload = &pending[5..total_len];
match msg_type {
MessageType::Cargo => match decode_cargo(payload) {
Ok(cargo) => {
let receivers = cargo_receivers.read().await;
if let Some(tx) = receivers.get(&cargo.conn_id) {
if let Err(e) = tx.send(cargo.data).await {
debug!(
ship = %ship_name,
conn_id = cargo.conn_id,
error = %e,
"Failed to forward cargo to connection"
);
}
} else {
warn!(
ship = %ship_name,
conn_id = cargo.conn_id,
"Received cargo for unknown connection"
);
}
}
Err(e) => {
error!(ship = %ship_name, error = %e, "Failed to decode cargo");
}
},
_ => {
warn!(ship = %ship_name, msg_type = ?msg_type, "Unexpected message from ship");
}
}
pending.drain(..total_len);
}
}
debug!(ship = %ship_name, "Reader task ended");
}
pub async fn register_connection(&self, conn_id: u32) -> mpsc::Receiver<Vec<u8>> {
let (tx, rx) = mpsc::channel(256);
let mut receivers = self.cargo_receivers.write().await;
receivers.insert(conn_id, tx);
rx
}
pub async fn unregister_connection(&self, conn_id: u32) {
let mut receivers = self.cargo_receivers.write().await;
receivers.remove(&conn_id);
}
pub async fn send_boarding(&self, boarding: Boarding) -> anyhow::Result<()> {
self.tx.send(ToShip::Boarding(boarding)).await?;
Ok(())
}
pub async fn send_cargo(&self, cargo: Cargo) -> anyhow::Result<()> {
self.tx.send(ToShip::Cargo(cargo)).await?;
Ok(())
}
pub async fn send_disembark(&self, disembark: Disembark) -> anyhow::Result<()> {
self.tx.send(ToShip::Disembark(disembark)).await?;
Ok(())
}
pub fn ship_name(&self) -> &str {
&self.ship_name
}
pub fn socket_path(&self) -> &PathBuf {
&self.socket_path
}
}
impl std::fmt::Debug for DockingConnector {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("DockingConnector")
.field("ship_name", &self.ship_name)
.field("socket_path", &self.socket_path)
.finish()
}
}