Skip to main content

koi_udp/
lib.rs

1//! UDP datagram bridging over HTTP/SSE.
2//!
3//! Containers cannot bind host UDP sockets directly. This crate exposes a
4//! lease-based HTTP API that lets a containerised process:
5//!
6//! 1. **Bind** a host UDP port (creating a `UdpBinding`).
7//! 2. **Receive** datagrams via an SSE stream (`GET /v1/udp/recv/{id}`).
8//! 3. **Send** datagrams through the bound socket (`POST /v1/udp/send/{id}`).
9//! 4. **Heartbeat** to extend the lease (`POST /v1/udp/heartbeat/{id}`).
10//!
11//! Bindings expire after `lease_secs` without a heartbeat, at which point the
12//! reaper closes the socket. This prevents resource leaks if a container dies.
13//!
14//! Follows the same Core/Runtime pattern as `koi-health` and `koi-dns`.
15
16mod binding;
17pub mod http;
18
19use std::collections::HashMap;
20use std::net::SocketAddr;
21use std::sync::Arc;
22
23use chrono::{DateTime, Utc};
24use tokio::sync::{broadcast, RwLock};
25use tokio_util::sync::CancellationToken;
26use uuid::Uuid;
27
28pub use binding::ActiveBinding;
29
30// ── Public types ────────────────────────────────────────────────────
31
32/// A datagram received on a bound socket, ready to be relayed over SSE.
33#[derive(Debug, Clone, serde::Serialize, serde::Deserialize, utoipa::ToSchema)]
34pub struct UdpDatagram {
35    pub binding_id: String,
36    pub src: String,
37    /// Base64-encoded payload.
38    pub payload: String,
39    pub received_at: DateTime<Utc>,
40}
41
42/// Request to send a datagram through a bound socket.
43#[derive(Debug, Clone, serde::Deserialize, utoipa::ToSchema)]
44pub struct UdpSendRequest {
45    /// Destination address in `host:port` form.
46    pub dest: String,
47    /// Base64-encoded payload.
48    pub payload: String,
49}
50
51/// Request body for creating a new binding.
52#[derive(Debug, Clone, serde::Deserialize, utoipa::ToSchema)]
53pub struct UdpBindRequest {
54    /// Port to bind on the host (0 = OS-assigned).
55    #[serde(default)]
56    pub port: u16,
57    /// Bind address. Default `0.0.0.0`.
58    #[serde(default = "default_bind_addr")]
59    pub addr: String,
60    /// Lease duration in seconds. Default 300.
61    #[serde(default = "default_lease")]
62    pub lease_secs: u64,
63}
64
65fn default_bind_addr() -> String {
66    "0.0.0.0".to_string()
67}
68
69fn default_lease() -> u64 {
70    300
71}
72
73/// Maximum lease duration (24 hours) to prevent unbounded resource retention.
74const MAX_LEASE_SECS: u64 = 86400;
75
76/// Metadata for a live binding (returned by status endpoint).
77#[derive(Debug, Clone, serde::Serialize, utoipa::ToSchema)]
78pub struct BindingInfo {
79    pub id: String,
80    pub local_addr: String,
81    pub created_at: DateTime<Utc>,
82    pub last_heartbeat: DateTime<Utc>,
83    pub lease_secs: u64,
84}
85
86// ── Error type ──────────────────────────────────────────────────────
87
88#[derive(Debug, thiserror::Error)]
89pub enum UdpError {
90    #[error("binding not found: {0}")]
91    NotFound(String),
92    #[error("io error: {0}")]
93    Io(#[from] std::io::Error),
94    #[error("invalid address: {0}")]
95    InvalidAddr(String),
96    #[error("base64 decode error: {0}")]
97    Base64(#[from] base64::DecodeError),
98}
99
100// ── UdpRuntime ──────────────────────────────────────────────────────
101
102/// Manages UDP socket bindings, datagram relay, and lease reaping.
103pub struct UdpRuntime {
104    bindings: Arc<RwLock<HashMap<String, ActiveBinding>>>,
105    cancel: CancellationToken,
106    _reaper_handle: tokio::task::JoinHandle<()>,
107}
108
109impl UdpRuntime {
110    /// Create a new runtime. Spawns a lease reaper task.
111    pub fn new(cancel: CancellationToken) -> Self {
112        let bindings: Arc<RwLock<HashMap<String, ActiveBinding>>> =
113            Arc::new(RwLock::new(HashMap::new()));
114
115        let reaper_bindings = bindings.clone();
116        let reaper_cancel = cancel.clone();
117        let reaper_handle = tokio::spawn(async move {
118            Self::reaper_loop(reaper_bindings, reaper_cancel).await;
119        });
120
121        Self {
122            bindings,
123            cancel,
124            _reaper_handle: reaper_handle,
125        }
126    }
127
128    /// Create a new UDP binding. Binds a socket and starts a relay task.
129    pub async fn bind(&self, req: UdpBindRequest) -> Result<BindingInfo, UdpError> {
130        let bind_addr: SocketAddr = format!("{}:{}", req.addr, req.port)
131            .parse()
132            .map_err(|e| UdpError::InvalidAddr(format!("{}", e)))?;
133
134        let socket = tokio::net::UdpSocket::bind(bind_addr).await?;
135        let local_addr = socket.local_addr()?;
136        let id = Uuid::now_v7().to_string();
137        let now = Utc::now();
138
139        let lease_secs = req.lease_secs.min(MAX_LEASE_SECS);
140
141        let active = ActiveBinding::new(
142            id.clone(),
143            socket,
144            local_addr,
145            now,
146            lease_secs,
147            self.cancel.clone(),
148        );
149
150        let info = BindingInfo {
151            id: id.clone(),
152            local_addr: local_addr.to_string(),
153            created_at: now,
154            last_heartbeat: now,
155            lease_secs,
156        };
157
158        self.bindings.write().await.insert(id, active);
159
160        tracing::info!(binding = %info.id, addr = %info.local_addr, "UDP binding created");
161        Ok(info)
162    }
163
164    /// Remove a binding and close its socket.
165    pub async fn unbind(&self, id: &str) -> Result<(), UdpError> {
166        let binding = self
167            .bindings
168            .write()
169            .await
170            .remove(id)
171            .ok_or_else(|| UdpError::NotFound(id.to_string()))?;
172
173        binding.shutdown();
174        tracing::info!(binding = %id, "UDP binding removed");
175        Ok(())
176    }
177
178    /// Subscribe to incoming datagrams for a binding.
179    pub async fn subscribe(&self, id: &str) -> Result<broadcast::Receiver<UdpDatagram>, UdpError> {
180        let bindings = self.bindings.read().await;
181        let binding = bindings
182            .get(id)
183            .ok_or_else(|| UdpError::NotFound(id.to_string()))?;
184        Ok(binding.subscribe())
185    }
186
187    /// Send a datagram through a binding's socket.
188    pub async fn send(&self, id: &str, req: UdpSendRequest) -> Result<usize, UdpError> {
189        use base64::Engine;
190
191        let dest: SocketAddr = req
192            .dest
193            .parse()
194            .map_err(|e| UdpError::InvalidAddr(format!("{}", e)))?;
195
196        let payload = base64::engine::general_purpose::STANDARD.decode(&req.payload)?;
197
198        let bindings = self.bindings.read().await;
199        let binding = bindings
200            .get(id)
201            .ok_or_else(|| UdpError::NotFound(id.to_string()))?;
202
203        let sent = binding.send_to(&payload, dest).await?;
204        Ok(sent)
205    }
206
207    /// Extend a binding's lease.
208    pub async fn heartbeat(&self, id: &str) -> Result<(), UdpError> {
209        let bindings = self.bindings.read().await;
210        let binding = bindings
211            .get(id)
212            .ok_or_else(|| UdpError::NotFound(id.to_string()))?;
213        binding.touch();
214        Ok(())
215    }
216
217    /// List all active bindings.
218    pub async fn status(&self) -> Vec<BindingInfo> {
219        let bindings = self.bindings.read().await;
220        bindings
221            .values()
222            .map(|b| BindingInfo {
223                id: b.id().to_string(),
224                local_addr: b.local_addr().to_string(),
225                created_at: b.created_at(),
226                last_heartbeat: b.last_heartbeat(),
227                lease_secs: b.lease_secs(),
228            })
229            .collect()
230    }
231
232    /// Background task that reaps expired leases every 30 seconds.
233    async fn reaper_loop(
234        bindings: Arc<RwLock<HashMap<String, ActiveBinding>>>,
235        cancel: CancellationToken,
236    ) {
237        let mut interval = tokio::time::interval(std::time::Duration::from_secs(30));
238
239        loop {
240            tokio::select! {
241                _ = cancel.cancelled() => break,
242                _ = interval.tick() => {
243                    let now = Utc::now();
244                    let mut map = bindings.write().await;
245                    let expired: Vec<String> = map
246                        .iter()
247                        .filter(|(_, b)| {
248                            let elapsed = now
249                                .signed_duration_since(b.last_heartbeat())
250                                .num_seconds();
251                            elapsed > b.lease_secs() as i64
252                        })
253                        .map(|(id, _)| id.clone())
254                        .collect();
255
256                    for id in expired {
257                        if let Some(binding) = map.remove(&id) {
258                            binding.shutdown();
259                            tracing::info!(binding = %id, "Reaped expired UDP binding");
260                        }
261                    }
262                }
263            }
264        }
265    }
266
267    /// Shut down the runtime - cancel reaper + close all bindings.
268    pub async fn shutdown(&self) {
269        self.cancel.cancel();
270        let mut map = self.bindings.write().await;
271        for (_, binding) in map.drain() {
272            binding.shutdown();
273        }
274        tracing::debug!("UDP runtime shut down");
275    }
276}
277
278// ── Capability trait ────────────────────────────────────────────────
279
280impl koi_common::capability::Capability for UdpRuntime {
281    fn name(&self) -> &str {
282        "udp"
283    }
284
285    fn status(&self) -> koi_common::capability::CapabilityStatus {
286        // status() is async but trait is sync - use try_read for non-blocking check.
287        let count = self.bindings.try_read().map(|b| b.len()).unwrap_or(0);
288
289        let summary = if count == 0 {
290            "no bindings".to_string()
291        } else {
292            format!("{count} binding{}", if count == 1 { "" } else { "s" })
293        };
294
295        koi_common::capability::CapabilityStatus {
296            name: "udp".to_string(),
297            summary,
298            healthy: true,
299        }
300    }
301}