use std::net::SocketAddr;
use std::sync::atomic::AtomicBool;
use std::sync::Arc;
use cellos_core::CloudEventV1;
#[cfg(target_os = "linux")]
use super::DnsProxyConfig;
use super::{DnsProxyStats, DnsQueryEmitter};
#[cfg(target_os = "linux")]
const LISTENER_READ_TIMEOUT: std::time::Duration = std::time::Duration::from_millis(200);
pub struct EventSinkEmitter {
pub runtime_handle: tokio::runtime::Handle,
pub sink: Arc<dyn cellos_core::ports::EventSink>,
pub jsonl_sink: Option<Arc<dyn cellos_core::ports::EventSink>>,
}
impl EventSinkEmitter {
pub fn capture_current(
sink: Arc<dyn cellos_core::ports::EventSink>,
jsonl_sink: Option<Arc<dyn cellos_core::ports::EventSink>>,
) -> Self {
let handle = tokio::runtime::Handle::try_current()
.expect("EventSinkEmitter::capture_current called outside a tokio runtime context");
Self {
runtime_handle: handle,
sink,
jsonl_sink,
}
}
}
impl DnsQueryEmitter for EventSinkEmitter {
fn emit(&self, event: CloudEventV1) {
let sink = self.sink.clone();
let jsonl = self.jsonl_sink.clone();
let event_for_jsonl = event.clone();
self.runtime_handle.spawn(async move {
if let Err(e) = sink.emit(&event).await {
tracing::warn!(
target: "cellos.supervisor.dns_proxy",
error = %e,
"primary sink emit failed for dns_query event"
);
}
});
if let Some(j) = jsonl {
self.runtime_handle.spawn(async move {
if let Err(e) = j.emit(&event_for_jsonl).await {
tracing::warn!(
target: "cellos.supervisor.dns_proxy",
error = %e,
"jsonl sink emit failed for dns_query event"
);
}
});
}
}
}
pub struct DnsProxyHandle {
pub shutdown: Arc<AtomicBool>,
pub listen_addr: SocketAddr,
#[cfg(target_os = "linux")]
pub thread: Option<std::thread::JoinHandle<DnsProxyStats>>,
pub upstream_resolver_id: String,
}
impl DnsProxyHandle {
#[cfg(target_os = "linux")]
pub fn join(&mut self) -> Option<DnsProxyStats> {
let handle = self.thread.take()?;
match handle.join() {
Ok(stats) => Some(stats),
Err(_panic) => {
tracing::warn!(
target: "cellos.supervisor.dns_proxy",
"DNS proxy thread panicked on join"
);
None
}
}
}
#[cfg(not(target_os = "linux"))]
pub fn join(&mut self) -> Option<DnsProxyStats> {
None
}
}
pub fn signal_proxy_shutdown(listen_addr: SocketAddr) {
let bind_str = if listen_addr.is_ipv6() {
"[::1]:0"
} else {
"127.0.0.1:0"
};
let waker = match std::net::UdpSocket::bind(bind_str) {
Ok(s) => s,
Err(e) => {
tracing::debug!(
target: "cellos.supervisor.dns_proxy",
error = %e,
"wake-up socket bind failed; falling back to timeout-based shutdown"
);
return;
}
};
if let Err(e) = waker.send_to(&[], listen_addr) {
tracing::debug!(
target: "cellos.supervisor.dns_proxy",
error = %e,
addr = %listen_addr,
"wake-up packet send failed; falling back to timeout-based shutdown"
);
}
}
#[cfg(target_os = "linux")]
pub fn spawn_proxy_in_netns(
child_pid: u32,
cfg: DnsProxyConfig,
listen_addr: SocketAddr,
upstream_addr: SocketAddr,
emitter: Arc<dyn DnsQueryEmitter>,
shutdown: Arc<AtomicBool>,
) -> std::io::Result<DnsProxyHandle> {
use std::fs::File;
use std::os::unix::io::AsRawFd;
let netns_path = format!("/proc/{child_pid}/ns/net");
let netns_file = File::open(&netns_path)
.map_err(|e| std::io::Error::new(e.kind(), format!("open netns at {netns_path}: {e}")))?;
let upstream_resolver_id = cfg.upstream_resolver_id.clone();
let shutdown_for_thread = shutdown.clone();
let (ready_tx, ready_rx) = std::sync::mpsc::channel::<std::io::Result<SocketAddr>>();
let thread = std::thread::Builder::new()
.name(format!("cellos-dns-proxy-{child_pid}"))
.spawn(move || {
let setns_rc = unsafe { libc::setns(netns_file.as_raw_fd(), libc::CLONE_NEWNET) };
if setns_rc != 0 {
let err = std::io::Error::last_os_error();
let _ = ready_tx.send(Err(std::io::Error::new(
err.kind(),
format!("setns(CLONE_NEWNET) for pid={child_pid}: {err}"),
)));
return DnsProxyStats::default();
}
let listener = match std::net::UdpSocket::bind(listen_addr) {
Ok(s) => s,
Err(e) => {
let _ = ready_tx.send(Err(std::io::Error::new(
e.kind(),
format!("bind listener at {listen_addr} in cell netns: {e}"),
)));
return DnsProxyStats::default();
}
};
if let Err(e) = listener.set_read_timeout(Some(LISTENER_READ_TIMEOUT)) {
let _ = ready_tx.send(Err(e));
return DnsProxyStats::default();
}
let upstream_sock = match std::net::UdpSocket::bind(if upstream_addr.is_ipv6() {
"[::]:0"
} else {
"0.0.0.0:0"
}) {
Ok(s) => s,
Err(e) => {
let _ = ready_tx.send(Err(std::io::Error::new(
e.kind(),
format!("bind upstream socket in cell netns: {e}"),
)));
return DnsProxyStats::default();
}
};
let actual_listen = match listener.local_addr() {
Ok(a) => a,
Err(e) => {
let _ = ready_tx.send(Err(e));
return DnsProxyStats::default();
}
};
let tcp_listener = match std::net::TcpListener::bind(listen_addr) {
Ok(l) => Some(l),
Err(e) => {
tracing::warn!(
target: "cellos.supervisor.dns_proxy",
error = %e,
addr = %listen_addr,
"TCP listener bind FAILED in cell netns — continuing UDP-only"
);
None
}
};
if ready_tx.send(Ok(actual_listen)).is_err() {
return DnsProxyStats::default();
}
let tcp_worker: Option<std::thread::JoinHandle<()>> = match tcp_listener {
None => None,
Some(listener_tcp) => {
match upstream_sock.try_clone() {
Ok(upstream_clone) => {
let cfg_tcp = cfg.clone();
let emitter_tcp: Arc<dyn DnsQueryEmitter> = emitter.clone();
let shutdown_tcp = shutdown_for_thread.clone();
let upstream_tcp = Arc::new(upstream_clone);
std::thread::Builder::new()
.name(format!("cellos-dns-proxy-tcp-{child_pid}"))
.spawn(move || {
if let Err(e) = super::run_tcp_one_shot(
&cfg_tcp,
&listener_tcp,
upstream_tcp,
emitter_tcp,
&shutdown_tcp,
) {
tracing::warn!(
target: "cellos.supervisor.dns_proxy",
error = %e,
"TCP proxy accept loop returned with I/O error"
);
}
})
.ok()
}
Err(e) => {
tracing::warn!(
target: "cellos.supervisor.dns_proxy",
error = %e,
"TCP upstream socket clone FAILED — TCP path will not start"
);
None
}
}
}
};
let udp_stats = match super::run_one_shot(
&cfg,
&listener,
&upstream_sock,
emitter.as_ref(),
&shutdown_for_thread,
) {
Ok(stats) => stats,
Err(e) => {
tracing::warn!(
target: "cellos.supervisor.dns_proxy",
error = %e,
"proxy recv loop returned with I/O error"
);
DnsProxyStats::default()
}
};
if let Some(h) = tcp_worker {
let _ = h.join();
}
udp_stats
})?;
let bound_addr = ready_rx
.recv_timeout(std::time::Duration::from_secs(2))
.map_err(|e| std::io::Error::other(format!("proxy thread ready timeout: {e}")))??;
Ok(DnsProxyHandle {
shutdown,
listen_addr: bound_addr,
thread: Some(thread),
upstream_resolver_id,
})
}
#[cfg(test)]
mod tests {
use super::*;
use std::sync::atomic::Ordering;
use std::sync::Mutex;
struct CountingSink {
count: Mutex<u64>,
}
#[async_trait::async_trait]
impl cellos_core::ports::EventSink for CountingSink {
async fn emit(&self, _event: &CloudEventV1) -> Result<(), cellos_core::error::CellosError> {
*self.count.lock().unwrap() += 1;
Ok(())
}
}
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
async fn event_sink_emitter_dispatches_to_runtime() {
let sink = Arc::new(CountingSink {
count: Mutex::new(0),
});
let emitter = EventSinkEmitter::capture_current(sink.clone(), None);
let event = CloudEventV1 {
specversion: "1.0".into(),
id: "evt-1".into(),
source: "test".into(),
ty: "test.event".into(),
datacontenttype: Some("application/json".into()),
data: Some(serde_json::json!({"k": "v"})),
time: Some(chrono::Utc::now().to_rfc3339()),
traceparent: None,
};
emitter.emit(event);
for _ in 0..50 {
if *sink.count.lock().unwrap() >= 1 {
break;
}
tokio::time::sleep(std::time::Duration::from_millis(10)).await;
}
assert_eq!(
*sink.count.lock().unwrap(),
1,
"sink should have received one event via the runtime handle"
);
}
#[test]
fn signal_proxy_shutdown_does_not_panic_on_unbound_addr() {
let addr: SocketAddr = "127.0.0.1:1".parse().unwrap();
signal_proxy_shutdown(addr);
}
#[test]
fn handle_join_returns_none_when_no_thread() {
let mut h = DnsProxyHandle {
shutdown: Arc::new(AtomicBool::new(false)),
listen_addr: "127.0.0.1:0".parse().unwrap(),
#[cfg(target_os = "linux")]
thread: None,
upstream_resolver_id: "test".into(),
};
assert!(h.join().is_none());
h.shutdown.store(true, Ordering::SeqCst);
assert!(h.shutdown.load(Ordering::SeqCst));
}
}