use std::sync::atomic::{AtomicU64, Ordering};
use std::sync::Arc;
use tokio::sync::oneshot;
use tracing::{debug, info, warn};
use crate::pool::Pool;
use crate::proxy::StreamCounters;
#[derive(Debug, Default, Clone)]
pub struct SupervisorMetrics {
pub streams_total: Arc<AtomicU64>,
pub bytes_in: Arc<AtomicU64>,
pub bytes_out: Arc<AtomicU64>,
}
impl SupervisorMetrics {
fn stream_counters(&self) -> StreamCounters {
StreamCounters {
bytes_in: self.bytes_in.clone(),
bytes_out: self.bytes_out.clone(),
}
}
pub fn snapshot(&self) -> (u64, u64, u64) {
(
self.streams_total.load(Ordering::Relaxed),
self.bytes_in.load(Ordering::Relaxed),
self.bytes_out.load(Ordering::Relaxed),
)
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum SupervisorExit {
Shutdown,
ConnectionLost,
}
pub async fn run(
conn: quinn::Connection,
local_port: u16,
metrics: SupervisorMetrics,
pool: Arc<Pool>,
mut shutdown_rx: oneshot::Receiver<()>,
) -> SupervisorExit {
info!(local_port, "tunnel supervisor running");
let exit = loop {
tokio::select! {
biased;
_ = &mut shutdown_rx => {
debug!("supervisor: shutdown signal");
conn.close(0u32.into(), b"client shutdown");
break SupervisorExit::Shutdown;
}
accepted = conn.accept_bi() => {
match accepted {
Ok((send, recv)) => {
metrics.streams_total.fetch_add(1, Ordering::Relaxed);
let counters = metrics.stream_counters();
let pool = pool.clone();
tokio::spawn(async move {
if let Err(e) = crate::proxy::handle_inbound_stream(
local_port, send, recv, counters, pool,
)
.await
{
warn!(error = %e, "stream proxy failed");
}
});
}
Err(quinn::ConnectionError::ApplicationClosed(_))
| Err(quinn::ConnectionError::LocallyClosed) => {
debug!("connection closed cleanly");
break SupervisorExit::ConnectionLost;
}
Err(e) => {
warn!(error = %e, "accept_bi failed; supervisor cycling");
break SupervisorExit::ConnectionLost;
}
}
}
}
};
info!(?exit, "tunnel supervisor exited");
exit
}