1pub mod acme;
8mod forward;
9mod handler;
10pub mod rate_limit;
11mod routing;
12pub mod tls;
13mod websocket;
14
15use std::collections::HashMap;
16use std::sync::Arc;
17use std::sync::atomic::AtomicUsize;
18
19use hyper::Request;
20use hyper::body::Incoming;
21use hyper::server::conn::http1;
22use hyper::service::service_fn;
23use hyper_util::rt::TokioIo;
24use tokio::net::TcpListener;
25use tokio::sync::RwLock;
26use tracing::{debug, error, info, warn};
27
28use acme::AcmeManager;
29use handler::{handle_acme_challenge, handle_request};
30use rate_limit::RateLimiter;
31
32#[derive(Debug, Clone)]
34pub struct RouteTarget {
35 pub address: String,
37 pub service_name: String,
39 pub path_pattern: Option<String>,
43 pub weight: u32,
46}
47
48#[derive(Debug, Clone)]
50pub struct WasmTrigger {
51 pub pattern: String,
53 pub runtime_id: String,
55 pub service_name: String,
57}
58
59pub type WasmInvoker =
62 Arc<dyn Fn(String, String, String, String) -> WasmInvokeFuture + Send + Sync>;
63
64pub type WasmInvokeFuture =
66 std::pin::Pin<Box<dyn std::future::Future<Output = Result<String, String>> + Send>>;
67
68pub type SharedWasmTriggers = Arc<RwLock<Vec<WasmTrigger>>>;
70
71pub async fn run_proxy(
73 route_table: Arc<RwLock<HashMap<String, Vec<RouteTarget>>>>,
74 wasm_triggers: SharedWasmTriggers,
75 wasm_invoker: Option<WasmInvoker>,
76 port: u16,
77 tls_acceptor: Option<tokio_rustls::TlsAcceptor>,
78 acme_manager: Option<AcmeManager>,
79) -> anyhow::Result<()> {
80 let addr = format!("0.0.0.0:{port}");
81 let listener = TcpListener::bind(&addr).await?;
82 let proto = if tls_acceptor.is_some() {
83 "HTTPS"
84 } else {
85 "HTTP"
86 };
87 info!("Reverse proxy listening on {addr} ({proto})");
88
89 serve_loop(
90 listener,
91 route_table,
92 wasm_triggers,
93 wasm_invoker,
94 tls_acceptor,
95 acme_manager,
96 )
97 .await
98}
99
100pub type SharedCertResolver = Arc<acme::DynCertResolver>;
102
103pub async fn run_proxy_with_acme(
109 route_table: Arc<RwLock<HashMap<String, Vec<RouteTarget>>>>,
110 wasm_triggers: SharedWasmTriggers,
111 wasm_invoker: Option<WasmInvoker>,
112 acme_manager: AcmeManager,
113 domains: Vec<String>,
114) -> anyhow::Result<SharedCertResolver> {
115 let resolver = Arc::new(acme::DynCertResolver::new());
116
117 let acme_mgr = acme_manager.clone();
118 let routes_clone = route_table.clone();
119 let triggers_clone = wasm_triggers.clone();
120 let invoker_clone = wasm_invoker.clone();
121
122 let http_handle = tokio::spawn({
124 let acme = acme_mgr.clone();
125 let routes = routes_clone.clone();
126 let triggers = triggers_clone.clone();
127 let invoker = invoker_clone.clone();
128 async move {
129 if let Err(e) = run_proxy(routes, triggers, invoker, 80, None, Some(acme)).await {
130 error!("HTTP listener failed: {e}");
131 }
132 }
133 });
134
135 let resolver_clone = resolver.clone();
137 let https_handle = tokio::spawn(async move {
138 tokio::time::sleep(std::time::Duration::from_millis(100)).await;
139
140 for domain in &domains {
142 if let Err(e) = acme_mgr
143 .ensure_cert_for_resolver(domain, &resolver_clone)
144 .await
145 {
146 error!(domain = %domain, error = %e, "Failed to provision cert");
147 }
148 }
149
150 let config = rustls::ServerConfig::builder()
152 .with_no_client_auth()
153 .with_cert_resolver(resolver_clone);
154
155 let acceptor = tokio_rustls::TlsAcceptor::from(Arc::new(config));
156 info!(
157 "Starting HTTPS with SNI resolver ({} domains)",
158 domains.len()
159 );
160
161 let routes = routes_clone;
162 let triggers = triggers_clone;
163 let invoker = invoker_clone;
164 if let Err(e) = run_proxy(
165 routes,
166 triggers,
167 invoker,
168 443,
169 Some(acceptor),
170 Some(acme_mgr),
171 )
172 .await
173 {
174 error!("HTTPS listener failed: {e}");
175 }
176 });
177
178 tokio::spawn(async move {
181 tokio::select! {
182 _ = http_handle => warn!("HTTP listener exited"),
183 _ = https_handle => warn!("HTTPS listener exited"),
184 }
185 });
186
187 Ok(resolver)
188}
189
190async fn serve_loop(
192 listener: TcpListener,
193 route_table: Arc<RwLock<HashMap<String, Vec<RouteTarget>>>>,
194 wasm_triggers: SharedWasmTriggers,
195 wasm_invoker: Option<WasmInvoker>,
196 tls_acceptor: Option<tokio_rustls::TlsAcceptor>,
197 acme_manager: Option<AcmeManager>,
198) -> anyhow::Result<()> {
199 let counter = Arc::new(AtomicUsize::new(0));
200 let client = Arc::new(
201 reqwest::Client::builder()
202 .no_proxy()
203 .build()
204 .expect("failed to build HTTP client"),
205 );
206 let acme = acme_manager.map(Arc::new);
207 let is_tls = tls_acceptor.is_some();
208 let rate_limiter = RateLimiter::new();
209
210 loop {
211 let (stream, peer) = match listener.accept().await {
212 Ok(conn) => conn,
213 Err(e) => {
214 warn!("Proxy accept error: {e}");
215 continue;
216 }
217 };
218
219 let routes = route_table.clone();
220 let triggers = wasm_triggers.clone();
221 let invoker = wasm_invoker.clone();
222 let counter = counter.clone();
223 let client = client.clone();
224 let acme = acme.clone();
225 let tls = tls_acceptor.clone();
226 let rl = rate_limiter.clone();
227
228 tokio::spawn(async move {
229 let service = service_fn(move |req: Request<Incoming>| {
230 let routes = routes.clone();
231 let triggers = triggers.clone();
232 let invoker = invoker.clone();
233 let counter = counter.clone();
234 let client = client.clone();
235 let acme = acme.clone();
236 let rl = rl.clone();
237 async move {
238 if let Some(resp) = handle_acme_challenge(&req, acme.as_deref()).await {
239 return Ok(resp);
240 }
241 handle_request(
242 req,
243 &routes,
244 &triggers,
245 invoker.as_ref(),
246 &counter,
247 &client,
248 is_tls,
249 &rl,
250 peer,
251 )
252 .await
253 }
254 });
255 if let Some(acceptor) = tls {
256 match acceptor.accept(stream).await {
257 Ok(tls_stream) => {
258 let io = TokioIo::new(tls_stream);
259 if let Err(e) = http1::Builder::new().serve_connection(io, service).await {
260 debug!("TLS proxy error from {peer}: {e}");
261 }
262 }
263 Err(e) => debug!("TLS handshake failed from {peer}: {e}"),
264 }
265 } else {
266 let io = TokioIo::new(stream);
267 if let Err(e) = http1::Builder::new().serve_connection(io, service).await {
268 debug!("Proxy connection error from {peer}: {e}");
269 }
270 }
271 });
272 }
273}