faucet_server/server/router/
mod.rs

1use std::{
2    collections::HashSet, ffi::OsStr, net::SocketAddr, num::NonZeroUsize, path::PathBuf, pin::pin,
3    sync::Arc,
4};
5
6use hyper::{body::Incoming, server::conn::http1, service::service_fn, Request, Uri};
7use hyper_util::rt::TokioIo;
8use tokio::net::TcpListener;
9use tokio_tungstenite::tungstenite::{http::uri::PathAndQuery, protocol::WebSocketConfig};
10
11use super::{onion::Service, FaucetServerBuilder, FaucetServerService};
12use crate::{
13    client::{
14        load_balancing::{IpExtractor, Strategy},
15        worker::{WorkerConfigs, WorkerType},
16        ExclusiveBody,
17    },
18    error::{FaucetError, FaucetResult},
19    shutdown::ShutdownSignal,
20};
21
22fn default_workdir() -> PathBuf {
23    PathBuf::from(".")
24}
25
26#[derive(serde::Deserialize)]
27struct ReducedServerConfig {
28    pub strategy: Option<Strategy>,
29    #[serde(default = "default_workdir")]
30    pub workdir: PathBuf,
31    pub app_dir: Option<String>,
32    pub workers: NonZeroUsize,
33    pub server_type: WorkerType,
34    pub qmd: Option<PathBuf>,
35    pub max_rps: Option<f64>,
36}
37
38#[derive(serde::Deserialize)]
39struct RouteConfig {
40    route: String,
41    #[serde(flatten)]
42    config: ReducedServerConfig,
43}
44
45#[derive(serde::Deserialize)]
46pub struct RouterConfig {
47    route: Vec<RouteConfig>,
48}
49
50#[derive(Clone)]
51struct RouterService {
52    routes: &'static [String],
53    clients: Arc<[FaucetServerService]>,
54}
55
56fn strip_prefix_exact(path_and_query: &PathAndQuery, prefix: &str) -> Option<PathAndQuery> {
57    if path_and_query.path() == prefix {
58        return Some(match path_and_query.query() {
59            Some(query) => format!("/?{query}").parse().unwrap(),
60            None => "/".parse().unwrap(),
61        });
62    }
63    None
64}
65
66fn strip_prefix_relative(path_and_query: &PathAndQuery, prefix: &str) -> Option<PathAndQuery> {
67    // Try to strip the prefix. It is fails we short-circuit.
68    let after_prefix = path_and_query.path().strip_prefix(prefix)?;
69
70    let start_slash = after_prefix.starts_with('/');
71
72    Some(match (start_slash, path_and_query.query()) {
73        (true, None) => after_prefix.parse().unwrap(),
74        (true, Some(query)) => format!("{after_prefix}?{query}").parse().unwrap(),
75        (false, None) => format!("/{after_prefix}").parse().unwrap(),
76        (false, Some(query)) => format!("/{after_prefix}?{query}").parse().unwrap(),
77    })
78}
79
80fn strip_prefix(uri: &Uri, prefix: &str) -> Option<Uri> {
81    let path_and_query = uri.path_and_query()?;
82
83    let new_path_and_query = match prefix.ends_with('/') {
84        true => strip_prefix_relative(path_and_query, prefix)?,
85        false => strip_prefix_exact(path_and_query, prefix)?,
86    };
87
88    let mut parts = uri.clone().into_parts();
89    parts.path_and_query = Some(new_path_and_query);
90
91    Some(Uri::from_parts(parts).unwrap())
92}
93
94impl Service<hyper::Request<Incoming>> for RouterService {
95    type Error = FaucetError;
96    type Response = hyper::Response<ExclusiveBody>;
97    async fn call(
98        &self,
99        mut req: hyper::Request<Incoming>,
100        ip_addr: Option<std::net::IpAddr>,
101    ) -> Result<Self::Response, Self::Error> {
102        let mut client = None;
103        for i in 0..self.routes.len() {
104            let route = &self.routes[i];
105            if let Some(new_uri) = strip_prefix(req.uri(), route) {
106                client = Some(&self.clients[i]);
107                *req.uri_mut() = new_uri;
108                break;
109            }
110        }
111        match client {
112            None => Ok(hyper::Response::builder()
113                .status(404)
114                .body(ExclusiveBody::plain_text("404 not found"))
115                .expect("Response should build")),
116            Some(client) => client.call(req, ip_addr).await,
117        }
118    }
119}
120
121impl RouterConfig {
122    async fn into_service(
123        self,
124        rscript: impl AsRef<OsStr>,
125        quarto: impl AsRef<OsStr>,
126        ip_from: IpExtractor,
127        shutdown: &'static ShutdownSignal,
128        websocket_config: &'static WebSocketConfig,
129    ) -> FaucetResult<(RouterService, Vec<WorkerConfigs>)> {
130        let mut all_workers = Vec::with_capacity(self.route.len());
131        let mut routes = Vec::with_capacity(self.route.len());
132        let mut clients = Vec::with_capacity(self.route.len());
133        let mut routes_set = HashSet::with_capacity(self.route.len());
134        for route_conf in self.route.into_iter() {
135            let route = route_conf.route;
136            if !routes_set.insert(route.clone()) {
137                return Err(FaucetError::DuplicateRoute(route));
138            }
139            let (client, workers) = FaucetServerBuilder::new()
140                .workdir(route_conf.config.workdir)
141                .server_type(route_conf.config.server_type)
142                .strategy(route_conf.config.strategy)
143                .rscript(&rscript)
144                .quarto(&quarto)
145                .qmd(route_conf.config.qmd)
146                .workers(route_conf.config.workers.get())
147                .extractor(ip_from)
148                .app_dir(route_conf.config.app_dir)
149                .route(route.clone())
150                .max_rps(route_conf.config.max_rps)
151                .build()?
152                .extract_service(shutdown, websocket_config)
153                .await?;
154            routes.push(route);
155            all_workers.push(workers);
156            clients.push(client);
157        }
158        let routes = routes.leak();
159        let clients = clients.into();
160        let service = RouterService { clients, routes };
161        Ok((service, all_workers))
162    }
163}
164
165impl RouterConfig {
166    pub async fn run(
167        self,
168        rscript: impl AsRef<OsStr>,
169        quarto: impl AsRef<OsStr>,
170        ip_from: IpExtractor,
171        addr: SocketAddr,
172        shutdown: &'static ShutdownSignal,
173        websocket_config: &'static WebSocketConfig,
174    ) -> FaucetResult<()> {
175        let (service, all_workers) = self
176            .into_service(rscript, quarto, ip_from, shutdown, websocket_config)
177            .await?;
178        // Bind to the port and listen for incoming TCP connections
179        let listener = TcpListener::bind(addr).await?;
180        log::info!(target: "faucet", "Listening on http://{addr}");
181        let main_loop = || async {
182            loop {
183                match listener.accept().await {
184                    Err(e) => {
185                        log::error!(target: "faucet", "Unable to accept TCP connection: {e}");
186                        return;
187                    }
188                    Ok((tcp, client_addr)) => {
189                        let tcp = TokioIo::new(tcp);
190                        log::debug!(target: "faucet", "Accepted TCP connection from {client_addr}");
191
192                        let service = service.clone();
193
194                        tokio::task::spawn(async move {
195                            let mut conn = http1::Builder::new()
196                                .serve_connection(
197                                    tcp,
198                                    service_fn(|req: Request<Incoming>| {
199                                        service.call(req, Some(client_addr.ip()))
200                                    }),
201                                )
202                                .with_upgrades();
203
204                            let conn = pin!(&mut conn);
205
206                            tokio::select! {
207                                result = conn => {
208                                    if let Err(e) = result {
209                                        log::error!(target: "faucet", "Connection error: {e:?}");
210                                    }
211                                }
212                                _ = shutdown.wait() => ()
213                            }
214                        });
215                    }
216                }
217            }
218        };
219
220        // Race the shutdown vs the main loop
221        tokio::select! {
222            _ = shutdown.wait() => (),
223            _ = main_loop() => (),
224        }
225
226        // Kill child process
227        for w in all_workers.iter().flat_map(|ws| &ws.workers) {
228            w.wait_until_done().await;
229        }
230
231        FaucetResult::Ok(())
232    }
233}