1pub mod config;
4mod forwarder;
5pub mod http;
6mod listener;
7mod safety;
8
9use std::collections::HashMap;
10use std::sync::Arc;
11
12use tokio::sync::Mutex;
13use tokio_util::sync::CancellationToken;
14
15use koi_common::capability::{Capability, CapabilityStatus};
16
17pub use config::ProxyEntry;
18pub use safety::ensure_backend_allowed;
19
20#[derive(Debug, thiserror::Error)]
21pub enum ProxyError {
22 #[error("proxy config error: {0}")]
23 Config(String),
24
25 #[error("proxy io error: {0}")]
26 Io(String),
27
28 #[error("proxy invalid config: {0}")]
29 InvalidConfig(String),
30
31 #[error("proxy forward error: {0}")]
32 Forward(String),
33
34 #[error("proxy entry not found: {0}")]
35 NotFound(String),
36}
37
38#[derive(Debug, Clone, serde::Serialize)]
39pub struct ProxyStatus {
40 pub name: String,
41 pub listen_port: u16,
42 pub backend: String,
43 pub allow_remote: bool,
44 pub running: bool,
45}
46
47pub struct ProxyCore {
48 entries: Arc<Mutex<Vec<ProxyEntry>>>,
49}
50
51impl ProxyCore {
52 pub fn new() -> Result<Self, ProxyError> {
53 let entries = config::load_entries()?;
54 Ok(Self {
55 entries: Arc::new(Mutex::new(entries)),
56 })
57 }
58
59 pub async fn entries(&self) -> Vec<ProxyEntry> {
60 self.entries.lock().await.clone()
61 }
62
63 pub async fn reload(&self) -> Result<Vec<ProxyEntry>, ProxyError> {
64 let entries = config::load_entries()?;
65 let mut guard = self.entries.lock().await;
66 *guard = entries.clone();
67 Ok(entries)
68 }
69
70 pub async fn upsert(&self, entry: ProxyEntry) -> Result<Vec<ProxyEntry>, ProxyError> {
71 let entries = config::upsert_entry(entry)?;
72 let mut guard = self.entries.lock().await;
73 *guard = entries.clone();
74 Ok(entries)
75 }
76
77 pub async fn remove(&self, name: &str) -> Result<Vec<ProxyEntry>, ProxyError> {
78 let entries = config::remove_entry(name)?;
79 let mut guard = self.entries.lock().await;
80 *guard = entries.clone();
81 Ok(entries)
82 }
83}
84
85impl Capability for ProxyCore {
86 fn name(&self) -> &str {
87 "proxy"
88 }
89
90 fn status(&self) -> CapabilityStatus {
91 CapabilityStatus {
92 name: "proxy".to_string(),
93 summary: "configured".to_string(),
94 healthy: true,
95 }
96 }
97}
98
99struct ProxyInstance {
100 entry: ProxyEntry,
101 cancel: CancellationToken,
102}
103
104pub struct ProxyRuntime {
106 core: Arc<ProxyCore>,
107 instances: Arc<Mutex<HashMap<String, ProxyInstance>>>,
108}
109
110impl ProxyRuntime {
111 pub fn new(core: Arc<ProxyCore>) -> Self {
112 Self {
113 core,
114 instances: Arc::new(Mutex::new(HashMap::new())),
115 }
116 }
117
118 pub fn core(&self) -> Arc<ProxyCore> {
119 Arc::clone(&self.core)
120 }
121
122 pub async fn start_all(&self) -> Result<(), ProxyError> {
123 let entries = self.core.entries().await;
124 self.apply_entries(entries).await
125 }
126
127 pub async fn reload(&self) -> Result<(), ProxyError> {
128 let entries = self.core.reload().await?;
129 self.apply_entries(entries).await
130 }
131
132 async fn apply_entries(&self, entries: Vec<ProxyEntry>) -> Result<(), ProxyError> {
133 let mut guard = self.instances.lock().await;
134 let mut seen = HashMap::new();
135
136 for entry in entries {
137 seen.insert(entry.name.clone(), entry.clone());
138 let entry_name = entry.name.clone();
139 let entry_name_for_task = entry_name.clone();
140 let needs_restart = match guard.get(&entry.name) {
141 Some(existing) => existing.entry != entry,
142 None => true,
143 };
144 if needs_restart {
145 if let Some(existing) = guard.remove(&entry.name) {
146 existing.cancel.cancel();
147 }
148 let cancel = CancellationToken::new();
149 let mut listener =
150 listener::ProxyListener::new(entry.clone(), cancel.clone()).await?;
151 let watch = listener.watch_certs().await;
152 if let Err(e) = watch {
153 tracing::warn!(error = %e, name = %entry.name, "Failed to watch certs");
154 }
155 tokio::spawn(async move {
156 if let Err(e) = listener.run().await {
157 tracing::error!(error = %e, name = %entry_name_for_task, "Proxy listener failed");
158 }
159 });
160 guard.insert(entry_name.clone(), ProxyInstance { entry, cancel });
161 }
162 }
163
164 let remove_names: Vec<String> = guard
165 .keys()
166 .filter(|name| !seen.contains_key(*name))
167 .cloned()
168 .collect();
169 for name in remove_names {
170 if let Some(instance) = guard.remove(&name) {
171 instance.cancel.cancel();
172 }
173 }
174
175 Ok(())
176 }
177
178 pub async fn stop_all(&self) {
179 let mut guard = self.instances.lock().await;
180 for instance in guard.values() {
181 instance.cancel.cancel();
182 }
183 guard.clear();
184 }
185
186 pub async fn status(&self) -> Vec<ProxyStatus> {
187 let guard = self.instances.lock().await;
188 guard
189 .values()
190 .map(|instance| ProxyStatus {
191 name: instance.entry.name.clone(),
192 listen_port: instance.entry.listen_port,
193 backend: instance.entry.backend.clone(),
194 allow_remote: instance.entry.allow_remote,
195 running: true,
196 })
197 .collect()
198 }
199}
200
201impl Clone for ProxyRuntime {
202 fn clone(&self) -> Self {
203 Self {
204 core: Arc::clone(&self.core),
205 instances: Arc::clone(&self.instances),
206 }
207 }
208}