1pub mod config;
4mod forwarder;
5pub mod http;
6mod listener;
7mod safety;
8
9use std::collections::HashMap;
10use std::sync::Arc;
11
12use tokio::sync::{broadcast, Mutex};
13use tokio_util::sync::CancellationToken;
14
15use koi_common::capability::{Capability, CapabilityStatus};
16
17pub use config::ProxyEntry;
18pub use safety::ensure_backend_allowed;
19
20const BROADCAST_CHANNEL_CAPACITY: usize = 256;
22
23#[derive(Debug, Clone)]
25pub enum ProxyEvent {
26 EntryUpdated { entry: ProxyEntry },
28 EntryRemoved { name: String },
30}
31
32#[derive(Debug, thiserror::Error)]
33pub enum ProxyError {
34 #[error("proxy config error: {0}")]
35 Config(String),
36
37 #[error("proxy io error: {0}")]
38 Io(String),
39
40 #[error("proxy invalid config: {0}")]
41 InvalidConfig(String),
42
43 #[error("proxy forward error: {0}")]
44 Forward(String),
45
46 #[error("proxy entry not found: {0}")]
47 NotFound(String),
48}
49
50#[derive(Debug, Clone, serde::Serialize)]
51pub struct ProxyStatus {
52 pub name: String,
53 pub listen_port: u16,
54 pub backend: String,
55 pub allow_remote: bool,
56 pub running: bool,
57}
58
59pub struct ProxyCore {
60 entries: Arc<Mutex<Vec<ProxyEntry>>>,
61 event_tx: broadcast::Sender<ProxyEvent>,
62 data_dir: Option<std::path::PathBuf>,
63}
64
65impl ProxyCore {
66 pub fn new() -> Result<Self, ProxyError> {
67 let entries = config::load_entries()?;
68 Ok(Self {
69 entries: Arc::new(Mutex::new(entries)),
70 event_tx: broadcast::channel(BROADCAST_CHANNEL_CAPACITY).0,
71 data_dir: None,
72 })
73 }
74
75 pub fn with_data_dir(data_dir: &std::path::Path) -> Result<Self, ProxyError> {
77 let entries = config::load_entries_with_data_dir(Some(data_dir))?;
78 Ok(Self {
79 entries: Arc::new(Mutex::new(entries)),
80 event_tx: broadcast::channel(BROADCAST_CHANNEL_CAPACITY).0,
81 data_dir: Some(data_dir.to_path_buf()),
82 })
83 }
84
85 pub async fn entries(&self) -> Vec<ProxyEntry> {
86 self.entries.lock().await.clone()
87 }
88
89 pub async fn reload(&self) -> Result<Vec<ProxyEntry>, ProxyError> {
90 let data_dir = self.data_dir.clone();
91 let entries = tokio::task::spawn_blocking(move || {
92 config::load_entries_with_data_dir(data_dir.as_deref())
93 })
94 .await
95 .map_err(|e| ProxyError::Io(format!("config task: {e}")))??;
96 let mut guard = self.entries.lock().await;
97 *guard = entries.clone();
98 Ok(entries)
99 }
100
101 pub async fn upsert(&self, entry: ProxyEntry) -> Result<Vec<ProxyEntry>, ProxyError> {
102 let data_dir = self.data_dir.clone();
103 let entry_for_io = entry.clone();
104 let entries = tokio::task::spawn_blocking(move || {
105 config::upsert_entry_with_data_dir(entry_for_io, data_dir.as_deref())
106 })
107 .await
108 .map_err(|e| ProxyError::Io(format!("config task: {e}")))??;
109 let mut guard = self.entries.lock().await;
110 *guard = entries.clone();
111 let _ = self.event_tx.send(ProxyEvent::EntryUpdated { entry });
112 Ok(entries)
113 }
114
115 pub async fn remove(&self, name: &str) -> Result<Vec<ProxyEntry>, ProxyError> {
116 let data_dir = self.data_dir.clone();
117 let name_owned = name.to_string();
118 let entries = tokio::task::spawn_blocking(move || {
119 config::remove_entry_with_data_dir(&name_owned, data_dir.as_deref())
120 })
121 .await
122 .map_err(|e| ProxyError::Io(format!("config task: {e}")))??;
123 let mut guard = self.entries.lock().await;
124 *guard = entries.clone();
125 let _ = self.event_tx.send(ProxyEvent::EntryRemoved {
126 name: name.to_string(),
127 });
128 Ok(entries)
129 }
130
131 pub fn subscribe(&self) -> broadcast::Receiver<ProxyEvent> {
133 self.event_tx.subscribe()
134 }
135}
136
137impl Capability for ProxyCore {
138 fn name(&self) -> &str {
139 "proxy"
140 }
141
142 fn status(&self) -> CapabilityStatus {
143 CapabilityStatus {
144 name: "proxy".to_string(),
145 summary: "configured".to_string(),
146 healthy: true,
147 }
148 }
149}
150
151struct ProxyInstance {
152 entry: ProxyEntry,
153 cancel: CancellationToken,
154}
155
156pub struct ProxyRuntime {
158 core: Arc<ProxyCore>,
159 instances: Arc<Mutex<HashMap<String, ProxyInstance>>>,
160}
161
162impl ProxyRuntime {
163 pub fn new(core: Arc<ProxyCore>) -> Self {
164 Self {
165 core,
166 instances: Arc::new(Mutex::new(HashMap::new())),
167 }
168 }
169
170 pub fn core(&self) -> Arc<ProxyCore> {
171 Arc::clone(&self.core)
172 }
173
174 pub async fn start_all(&self) -> Result<(), ProxyError> {
175 let entries = self.core.entries().await;
176 self.apply_entries(entries).await
177 }
178
179 pub async fn reload(&self) -> Result<(), ProxyError> {
180 let entries = self.core.reload().await?;
181 self.apply_entries(entries).await
182 }
183
184 async fn apply_entries(&self, entries: Vec<ProxyEntry>) -> Result<(), ProxyError> {
185 let mut guard = self.instances.lock().await;
186 let mut seen = HashMap::new();
187
188 for entry in entries {
189 seen.insert(entry.name.clone(), entry.clone());
190 let entry_name = entry.name.clone();
191 let entry_name_for_task = entry_name.clone();
192 let needs_restart = match guard.get(&entry.name) {
193 Some(existing) => existing.entry != entry,
194 None => true,
195 };
196 if needs_restart {
197 if let Some(existing) = guard.remove(&entry.name) {
198 existing.cancel.cancel();
199 }
200 let cancel = CancellationToken::new();
201 let mut listener =
202 listener::ProxyListener::new(entry.clone(), cancel.clone()).await?;
203 let watch = listener.watch_certs().await;
204 if let Err(e) = watch {
205 tracing::warn!(error = %e, name = %entry.name, "Failed to watch certs");
206 }
207 tokio::spawn(async move {
208 if let Err(e) = listener.run().await {
209 tracing::error!(error = %e, name = %entry_name_for_task, "Proxy listener failed");
210 }
211 });
212 guard.insert(entry_name.clone(), ProxyInstance { entry, cancel });
213 }
214 }
215
216 let remove_names: Vec<String> = guard
217 .keys()
218 .filter(|name| !seen.contains_key(*name))
219 .cloned()
220 .collect();
221 for name in remove_names {
222 if let Some(instance) = guard.remove(&name) {
223 instance.cancel.cancel();
224 }
225 }
226
227 Ok(())
228 }
229
230 pub async fn stop_all(&self) {
231 let mut guard = self.instances.lock().await;
232 for instance in guard.values() {
233 instance.cancel.cancel();
234 }
235 guard.clear();
236 }
237
238 pub async fn status(&self) -> Vec<ProxyStatus> {
239 let guard = self.instances.lock().await;
240 guard
241 .values()
242 .map(|instance| ProxyStatus {
243 name: instance.entry.name.clone(),
244 listen_port: instance.entry.listen_port,
245 backend: instance.entry.backend.clone(),
246 allow_remote: instance.entry.allow_remote,
247 running: true,
248 })
249 .collect()
250 }
251}
252
253impl Clone for ProxyRuntime {
254 fn clone(&self) -> Self {
255 Self {
256 core: Arc::clone(&self.core),
257 instances: Arc::clone(&self.instances),
258 }
259 }
260}
261
262#[cfg(test)]
263mod tests {
264 use super::*;
265
266 #[test]
267 fn subscribe_receives_emitted_entry_updated() {
268 let (tx, _) = broadcast::channel::<ProxyEvent>(16);
269 let mut rx = tx.subscribe();
270
271 let entry = ProxyEntry {
272 name: "test-svc".to_string(),
273 listen_port: 9090,
274 backend: "http://127.0.0.1:8080".to_string(),
275 allow_remote: false,
276 };
277 let _ = tx.send(ProxyEvent::EntryUpdated {
278 entry: entry.clone(),
279 });
280
281 let event = rx.try_recv().expect("should receive event");
282 match event {
283 ProxyEvent::EntryUpdated { entry: received } => {
284 assert_eq!(received.name, "test-svc");
285 assert_eq!(received.listen_port, 9090);
286 assert_eq!(received.backend, "http://127.0.0.1:8080");
287 }
288 other => panic!("expected EntryUpdated, got {other:?}"),
289 }
290 }
291
292 #[test]
293 fn subscribe_receives_emitted_entry_removed() {
294 let (tx, _) = broadcast::channel::<ProxyEvent>(16);
295 let mut rx = tx.subscribe();
296
297 let _ = tx.send(ProxyEvent::EntryRemoved {
298 name: "rm-svc".to_string(),
299 });
300
301 let event = rx.try_recv().expect("should receive event");
302 match event {
303 ProxyEvent::EntryRemoved { name } => {
304 assert_eq!(name, "rm-svc");
305 }
306 other => panic!("expected EntryRemoved, got {other:?}"),
307 }
308 }
309
310 #[test]
311 fn multiple_subscribers_each_receive_event() {
312 let (tx, _) = broadcast::channel::<ProxyEvent>(16);
313 let mut rx1 = tx.subscribe();
314 let mut rx2 = tx.subscribe();
315
316 let _ = tx.send(ProxyEvent::EntryRemoved {
317 name: "multi".to_string(),
318 });
319
320 assert!(rx1.try_recv().is_ok());
321 assert!(rx2.try_recv().is_ok());
322 }
323}