Skip to main content

koi_proxy/
lib.rs

1//! Koi Proxy - TLS-terminating reverse proxy (Phase 8).
2
3pub mod config;
4mod forwarder;
5pub mod http;
6mod listener;
7mod safety;
8
9use std::collections::HashMap;
10use std::sync::Arc;
11
12use tokio::sync::{broadcast, Mutex};
13use tokio_util::sync::CancellationToken;
14
15use koi_common::capability::{Capability, CapabilityStatus};
16
17pub use config::ProxyEntry;
18pub use safety::ensure_backend_allowed;
19
20/// Capacity for the proxy event broadcast channel.
21const BROADCAST_CHANNEL_CAPACITY: usize = 256;
22
23/// Events emitted by the proxy subsystem when entries change.
24#[derive(Debug, Clone)]
25pub enum ProxyEvent {
26    /// A proxy entry was added or updated.
27    EntryUpdated { entry: ProxyEntry },
28    /// A proxy entry was removed.
29    EntryRemoved { name: String },
30}
31
32#[derive(Debug, thiserror::Error)]
33pub enum ProxyError {
34    #[error("proxy config error: {0}")]
35    Config(String),
36
37    #[error("proxy io error: {0}")]
38    Io(String),
39
40    #[error("proxy invalid config: {0}")]
41    InvalidConfig(String),
42
43    #[error("proxy forward error: {0}")]
44    Forward(String),
45
46    #[error("proxy entry not found: {0}")]
47    NotFound(String),
48}
49
50#[derive(Debug, Clone, serde::Serialize)]
51pub struct ProxyStatus {
52    pub name: String,
53    pub listen_port: u16,
54    pub backend: String,
55    pub allow_remote: bool,
56    pub running: bool,
57}
58
59pub struct ProxyCore {
60    entries: Arc<Mutex<Vec<ProxyEntry>>>,
61    event_tx: broadcast::Sender<ProxyEvent>,
62    data_dir: Option<std::path::PathBuf>,
63}
64
65impl ProxyCore {
66    pub fn new() -> Result<Self, ProxyError> {
67        let entries = config::load_entries()?;
68        Ok(Self {
69            entries: Arc::new(Mutex::new(entries)),
70            event_tx: broadcast::channel(BROADCAST_CHANNEL_CAPACITY).0,
71            data_dir: None,
72        })
73    }
74
75    /// Create a ProxyCore that reads/writes config from a custom data directory.
76    pub fn with_data_dir(data_dir: &std::path::Path) -> Result<Self, ProxyError> {
77        let entries = config::load_entries_with_data_dir(Some(data_dir))?;
78        Ok(Self {
79            entries: Arc::new(Mutex::new(entries)),
80            event_tx: broadcast::channel(BROADCAST_CHANNEL_CAPACITY).0,
81            data_dir: Some(data_dir.to_path_buf()),
82        })
83    }
84
85    pub async fn entries(&self) -> Vec<ProxyEntry> {
86        self.entries.lock().await.clone()
87    }
88
89    pub async fn reload(&self) -> Result<Vec<ProxyEntry>, ProxyError> {
90        let data_dir = self.data_dir.clone();
91        let entries = tokio::task::spawn_blocking(move || {
92            config::load_entries_with_data_dir(data_dir.as_deref())
93        })
94        .await
95        .map_err(|e| ProxyError::Io(format!("config task: {e}")))??;
96        let mut guard = self.entries.lock().await;
97        *guard = entries.clone();
98        Ok(entries)
99    }
100
101    pub async fn upsert(&self, entry: ProxyEntry) -> Result<Vec<ProxyEntry>, ProxyError> {
102        let data_dir = self.data_dir.clone();
103        let entry_for_io = entry.clone();
104        let entries = tokio::task::spawn_blocking(move || {
105            config::upsert_entry_with_data_dir(entry_for_io, data_dir.as_deref())
106        })
107        .await
108        .map_err(|e| ProxyError::Io(format!("config task: {e}")))??;
109        let mut guard = self.entries.lock().await;
110        *guard = entries.clone();
111        let _ = self.event_tx.send(ProxyEvent::EntryUpdated { entry });
112        Ok(entries)
113    }
114
115    pub async fn remove(&self, name: &str) -> Result<Vec<ProxyEntry>, ProxyError> {
116        let data_dir = self.data_dir.clone();
117        let name_owned = name.to_string();
118        let entries = tokio::task::spawn_blocking(move || {
119            config::remove_entry_with_data_dir(&name_owned, data_dir.as_deref())
120        })
121        .await
122        .map_err(|e| ProxyError::Io(format!("config task: {e}")))??;
123        let mut guard = self.entries.lock().await;
124        *guard = entries.clone();
125        let _ = self.event_tx.send(ProxyEvent::EntryRemoved {
126            name: name.to_string(),
127        });
128        Ok(entries)
129    }
130
131    /// Subscribe to proxy events.
132    pub fn subscribe(&self) -> broadcast::Receiver<ProxyEvent> {
133        self.event_tx.subscribe()
134    }
135}
136
137impl Capability for ProxyCore {
138    fn name(&self) -> &str {
139        "proxy"
140    }
141
142    fn status(&self) -> CapabilityStatus {
143        CapabilityStatus {
144            name: "proxy".to_string(),
145            summary: "configured".to_string(),
146            healthy: true,
147        }
148    }
149}
150
151struct ProxyInstance {
152    entry: ProxyEntry,
153    cancel: CancellationToken,
154}
155
156/// Runtime controller for proxy listeners.
157pub struct ProxyRuntime {
158    core: Arc<ProxyCore>,
159    instances: Arc<Mutex<HashMap<String, ProxyInstance>>>,
160}
161
162impl ProxyRuntime {
163    pub fn new(core: Arc<ProxyCore>) -> Self {
164        Self {
165            core,
166            instances: Arc::new(Mutex::new(HashMap::new())),
167        }
168    }
169
170    pub fn core(&self) -> Arc<ProxyCore> {
171        Arc::clone(&self.core)
172    }
173
174    pub async fn start_all(&self) -> Result<(), ProxyError> {
175        let entries = self.core.entries().await;
176        self.apply_entries(entries).await
177    }
178
179    pub async fn reload(&self) -> Result<(), ProxyError> {
180        let entries = self.core.reload().await?;
181        self.apply_entries(entries).await
182    }
183
184    async fn apply_entries(&self, entries: Vec<ProxyEntry>) -> Result<(), ProxyError> {
185        let mut guard = self.instances.lock().await;
186        let mut seen = HashMap::new();
187
188        for entry in entries {
189            seen.insert(entry.name.clone(), entry.clone());
190            let entry_name = entry.name.clone();
191            let entry_name_for_task = entry_name.clone();
192            let needs_restart = match guard.get(&entry.name) {
193                Some(existing) => existing.entry != entry,
194                None => true,
195            };
196            if needs_restart {
197                if let Some(existing) = guard.remove(&entry.name) {
198                    existing.cancel.cancel();
199                }
200                let cancel = CancellationToken::new();
201                let mut listener =
202                    listener::ProxyListener::new(entry.clone(), cancel.clone()).await?;
203                let watch = listener.watch_certs().await;
204                if let Err(e) = watch {
205                    tracing::warn!(error = %e, name = %entry.name, "Failed to watch certs");
206                }
207                tokio::spawn(async move {
208                    if let Err(e) = listener.run().await {
209                        tracing::error!(error = %e, name = %entry_name_for_task, "Proxy listener failed");
210                    }
211                });
212                guard.insert(entry_name.clone(), ProxyInstance { entry, cancel });
213            }
214        }
215
216        let remove_names: Vec<String> = guard
217            .keys()
218            .filter(|name| !seen.contains_key(*name))
219            .cloned()
220            .collect();
221        for name in remove_names {
222            if let Some(instance) = guard.remove(&name) {
223                instance.cancel.cancel();
224            }
225        }
226
227        Ok(())
228    }
229
230    pub async fn stop_all(&self) {
231        let mut guard = self.instances.lock().await;
232        for instance in guard.values() {
233            instance.cancel.cancel();
234        }
235        guard.clear();
236    }
237
238    pub async fn status(&self) -> Vec<ProxyStatus> {
239        let guard = self.instances.lock().await;
240        guard
241            .values()
242            .map(|instance| ProxyStatus {
243                name: instance.entry.name.clone(),
244                listen_port: instance.entry.listen_port,
245                backend: instance.entry.backend.clone(),
246                allow_remote: instance.entry.allow_remote,
247                running: true,
248            })
249            .collect()
250    }
251}
252
253impl Clone for ProxyRuntime {
254    fn clone(&self) -> Self {
255        Self {
256            core: Arc::clone(&self.core),
257            instances: Arc::clone(&self.instances),
258        }
259    }
260}
261
262#[cfg(test)]
263mod tests {
264    use super::*;
265
266    #[test]
267    fn subscribe_receives_emitted_entry_updated() {
268        let (tx, _) = broadcast::channel::<ProxyEvent>(16);
269        let mut rx = tx.subscribe();
270
271        let entry = ProxyEntry {
272            name: "test-svc".to_string(),
273            listen_port: 9090,
274            backend: "http://127.0.0.1:8080".to_string(),
275            allow_remote: false,
276        };
277        let _ = tx.send(ProxyEvent::EntryUpdated {
278            entry: entry.clone(),
279        });
280
281        let event = rx.try_recv().expect("should receive event");
282        match event {
283            ProxyEvent::EntryUpdated { entry: received } => {
284                assert_eq!(received.name, "test-svc");
285                assert_eq!(received.listen_port, 9090);
286                assert_eq!(received.backend, "http://127.0.0.1:8080");
287            }
288            other => panic!("expected EntryUpdated, got {other:?}"),
289        }
290    }
291
292    #[test]
293    fn subscribe_receives_emitted_entry_removed() {
294        let (tx, _) = broadcast::channel::<ProxyEvent>(16);
295        let mut rx = tx.subscribe();
296
297        let _ = tx.send(ProxyEvent::EntryRemoved {
298            name: "rm-svc".to_string(),
299        });
300
301        let event = rx.try_recv().expect("should receive event");
302        match event {
303            ProxyEvent::EntryRemoved { name } => {
304                assert_eq!(name, "rm-svc");
305            }
306            other => panic!("expected EntryRemoved, got {other:?}"),
307        }
308    }
309
310    #[test]
311    fn multiple_subscribers_each_receive_event() {
312        let (tx, _) = broadcast::channel::<ProxyEvent>(16);
313        let mut rx1 = tx.subscribe();
314        let mut rx2 = tx.subscribe();
315
316        let _ = tx.send(ProxyEvent::EntryRemoved {
317            name: "multi".to_string(),
318        });
319
320        assert!(rx1.try_recv().is_ok());
321        assert!(rx2.try_recv().is_ok());
322    }
323}