1#![feature(path_file_prefix)]
2#![feature(result_flattening)]
3#![feature(return_position_impl_trait_in_trait)]
6
7use std::{
8 error::Error,
9 fs::{read_to_string, write, File},
10 io::BufReader,
11 net::SocketAddr,
12 path::Path,
13 process::Stdio,
14 time::SystemTime,
15};
16
17use axum::Router;
18use bearer::BearerAuth;
19use clap::{Parser, Subcommand};
20use console::{listen_for_commands, send_args_to_remote, ExecutableArgs};
21use hyper::server::{accept::Accept, Builder, conn::AddrIncoming};
22use hyper_rustls::TlsAcceptor;
23#[cfg(feature = "collect-certs")]
24use lers::solver::Http01Solver;
25use log::{info, warn};
26#[cfg(feature = "python")]
27use py::load_py_into_router;
28#[cfg(feature = "python")]
29use pyo3_asyncio::TaskLocals;
30use regex::RegexSet;
31use rustls::{PrivateKey, Certificate};
32use serde::Deserialize;
33use tokio::io::{AsyncRead, AsyncWrite};
34use tower::ServiceBuilder;
35use tower_http::{
36 auth::AsyncRequireAuthorizationLayer, compression::CompressionLayer, cors::CorsLayer,
37 trace::TraceLayer,
38};
39
40use crate::console::does_remote_exist;
41
42mod bearer;
43pub mod console;
44#[cfg(feature = "python")]
45mod py;
46#[cfg(all(feature = "hot-reload", feature = "python"))]
49const SYNC_CHANGES_DELAY: std::time::Duration = std::time::Duration::from_millis(1000);
50
51#[cfg(feature = "python")]
52static PY_TASK_LOCALS: std::sync::OnceLock<TaskLocals> = std::sync::OnceLock::new();
53
54pub fn load_scripts_into_router(router: Router, path: &Path) -> Router {
55 #[cfg(feature = "python")]
56 {
57 let mut router = router;
58 #[cfg(feature = "hot-reload")]
59 {
60 use notify::Watcher;
61 let async_runtime = tokio::runtime::Handle::current();
62 let working_dir = path.canonicalize().unwrap().parent().unwrap().to_owned();
63 let mut watcher =
64 notify::recommended_watcher(move |res: Result<notify::Event, _>| match res {
65 Ok(event) => {
66 let _guard = async_runtime.enter();
67 let event = std::sync::Arc::new(event);
68 py::py_handle_notify_event(event.clone(), working_dir.clone());
69 }
70 Err(event) => log::error!("File Watcher Error: {event:?}"),
71 })
72 .expect("Filesystem notification should be available");
73
74 watcher
75 .watch(path, notify::RecursiveMode::Recursive)
76 .expect("Scripts folder should be watchable");
77
78 Box::leak(Box::new(watcher));
79 }
80
81 for result in path
82 .read_dir()
83 .expect("Scripts directory should be readable")
84 {
85 let entry = result.expect("Script or sub-directory should be readable");
86 let path = entry.path();
87 let file_type = entry
88 .file_type()
89 .expect("File type of script or sub-directory should be accessible");
90
91 if file_type.is_dir() {
92 router = load_scripts_into_router(router, &path);
93 } else if file_type.is_file() {
94 match path.extension().map(std::ffi::OsStr::to_str).flatten() {
95 #[cfg(feature = "python")]
96 Some("py") => router = load_py_into_router(router, &path),
97 _ => {}
98 }
99 } else {
100 panic!("Failed to get the file type of {entry:?}");
101 }
102 }
103
104 router
105 }
106
107 #[cfg(not(feature = "python"))]
108 {
109 let _path = path;
110 router
111 }
112}
113
114pub fn setup_logger(log_file_path: &str, log_level: &str) {
115 let log_level = if log_level.is_empty() {
116 log::LevelFilter::Info
117 } else {
118 log_level.parse().expect("Log Level should be valid")
119 };
120
121 let mut dispatch = fern::Dispatch::new()
122 .format(|out, message, record| {
123 out.finish(format_args!(
124 "[{} {} {}] {}",
125 humantime::format_rfc3339_seconds(SystemTime::now()),
126 record.level(),
127 record.target(),
128 message
129 ))
130 })
131 .level(log_level)
132 .chain(std::io::stdout());
133
134 if !log_file_path.is_empty() {
135 dispatch =
136 dispatch.chain(fern::log_file(log_file_path).expect("Log File should be writable"))
137 }
138
139 dispatch
140 .apply()
141 .expect("Logger should have initialized successfully");
142}
143
144#[cfg(feature = "python")]
145#[inline]
146fn u16_to_status(code: u16, f: impl Fn() -> String) -> axum::http::StatusCode {
147 axum::http::StatusCode::from_u16(code).expect(&f())
148}
149
150#[derive(Deserialize)]
151pub struct HyperDomeConfig {
152 #[serde(default)]
153 cors_methods: Vec<String>,
154 #[serde(default)]
155 cors_origins: Vec<String>,
156 #[serde(default)]
157 api_token: String,
158 bind_address: SocketAddr,
159 #[serde(default)]
160 public_paths: Vec<String>,
161 #[serde(default)]
162 cert_path: String,
163 #[serde(default)]
164 key_path: String,
165 #[serde(default)]
166 email: String,
167 #[serde(default)]
168 domain_name: String,
169 #[serde(default)]
170 log_file_path: String,
171 #[serde(default)]
172 log_level: String,
173}
174
175impl HyperDomeConfig {
176 pub fn from_toml_file(path: &Path) -> Self {
177 let txt = read_to_string(path).expect(&format!("{path:?} should be readable"));
178 toml::from_str(&txt).expect(&format!("{path:?} should be valid toml"))
179 }
180}
181
182#[inline]
183pub async fn async_run_router<P, I>(server: Builder<I>, mut router: Router, config: HyperDomeConfig)
184where
185 P: ExecutableArgs,
186 I: Accept,
187 I::Error: Into<Box<dyn Error + Send + Sync>>,
188 I::Conn: AsyncRead + AsyncWrite + Unpin + Send + 'static,
189{
190 router = load_scripts_into_router(router, "scripts".as_ref());
191
192 router = router.layer(
193 ServiceBuilder::new()
194 .layer(CompressionLayer::new())
195 .layer(TraceLayer::new_for_http())
196 .layer(
197 CorsLayer::new()
198 .allow_methods(
199 config
200 .cors_methods
201 .into_iter()
202 .map(|x| {
203 x.parse()
204 .expect("CORS Method should be a valid HTTP Method")
205 })
206 .collect::<Vec<_>>(),
207 )
208 .allow_origin(
209 config
210 .cors_origins
211 .into_iter()
212 .map(|x| x.parse().expect("CORS Origin should be a valid origin"))
213 .collect::<Vec<_>>(),
214 ),
215 ),
216 );
217
218 if !config.api_token.is_empty() {
219 router = router.layer(AsyncRequireAuthorizationLayer::new(BearerAuth::new(
220 config.api_token.parse().expect("msg"),
221 RegexSet::new(config.public_paths).expect("msg"),
222 )));
223 }
224
225 server
226 .serve(router.into_make_service())
227 .with_graceful_shutdown(listen_for_commands::<P>())
228 .await
229 .unwrap();
230}
231
232#[derive(Parser)]
233#[command(author, version, about, long_about = None)]
234struct Args {
235 #[command(subcommand)]
236 command: Commands,
237}
238
239#[derive(Subcommand)]
240enum Commands {
241 Run {
242 #[arg(short, long)]
243 detached: bool,
244 },
245}
246
247pub fn auto_main<P: ExecutableArgs>(router: impl Fn() -> Router) {
248 let Ok(args) = Args::try_parse_from(std::env::args_os()) else {
249 send_args_to_remote();
250 return;
251 };
252
253 match args.command {
254 Commands::Run { detached } => {
255 if let Some(id) = does_remote_exist() {
256 println!("Remote already exists with process id: {id}");
257 return;
258 }
259 if detached {
260 let id = std::process::Command::new(
261 std::env::current_exe().expect("Current EXE name should be accessible"),
262 )
263 .arg("run")
264 .stdin(Stdio::null())
265 .stdout(Stdio::null())
266 .stderr(Stdio::null())
267 .spawn()
268 .expect("Child process should have spawned successfully")
269 .id();
270 println!("Process has spawned successfully with id: {id}");
271 return;
272 }
273 }
274 }
275
276 auto_main_inner::<P>(router());
277}
278
279#[tokio::main]
280async fn auto_main_inner<P: ExecutableArgs>(router: Router) {
281 console_subscriber::init();
282 let config = HyperDomeConfig::from_toml_file("hypermangle.toml".as_ref());
283 setup_logger(&config.log_file_path, &config.log_level);
284
285 #[cfg(feature = "python")]
286 std::thread::spawn(|| {
287 pyo3::Python::with_gil(|py| {
288 let signal_module = py.import("signal").unwrap();
290 signal_module
291 .call_method1(
292 "signal",
293 (
294 signal_module.getattr("SIGINT").unwrap(),
295 signal_module.getattr("SIG_DFL").unwrap(),
296 ),
297 )
298 .unwrap();
299
300 let event_loop = py
301 .import("asyncio")
302 .unwrap()
303 .call_method0("new_event_loop")
304 .unwrap();
305 PY_TASK_LOCALS
306 .set(pyo3_asyncio::TaskLocals::new(event_loop))
307 .unwrap();
308 event_loop.call_method0("run_forever").unwrap();
309 })
310 });
311
312 if !config.cert_path.is_empty() && !config.key_path.is_empty() {
313 let cert_path: &Path = config.cert_path.as_ref();
314 let key_path: &Path = config.key_path.as_ref();
315
316 if cert_path.exists() && key_path.exists() {
317 info!("Loading HTTP Certificates");
318 let file = File::open(cert_path).expect("Cert path should be readable");
319 let mut reader = BufReader::new(file);
320 let certs = rustls_pemfile::certs(&mut reader).expect("Cert file should be valid");
321 let certs: Vec<_> = certs.into_iter().map(Certificate).collect();
322
323 let file = File::open(&key_path).expect("Key path should be readable");
324 let mut reader = BufReader::new(file);
325 let mut keys =
326 rustls_pemfile::pkcs8_private_keys(&mut reader).expect("Key file should be valid");
327
328 let key = match keys.len() {
329 0 => panic!("No PKCS8-encoded private key found in key file"),
330 1 => PrivateKey(keys.remove(0)),
331 _ => panic!("More than one PKCS8-encoded private key found in key file"),
332 };
333
334 info!("HTTP Certificates successfully loaded");
335 let incoming = AddrIncoming::bind(&config.bind_address).unwrap();
336 async_run_router::<P, _>(
337 axum::Server::builder(
338 TlsAcceptor::builder()
339 .with_single_cert(certs, key)
340 .unwrap()
341 .with_all_versions_alpn()
342 .with_incoming(incoming)
343 ),
344 router,
345 config,
346 )
347 .await;
348 return;
349 } else {
350 #[cfg(feature = "collect-certs")]
351 if !cert_path.exists() && !key_path.exists() {
352 warn!("Acquiring HTTP Certificates");
353 macro_rules! unwrap {
354 ($result: expr) => {
355 match $result {
356 Ok(x) => x,
357 Err(e) => {
358 panic!("Error running LERS: {e}");
359 }
360 }
361 };
362 }
363
364 #[cfg(not(debug_assertions))]
365 const URL: &str = lers::LETS_ENCRYPT_PRODUCTION_URL;
366 #[cfg(debug_assertions)]
367 const URL: &str = lers::LETS_ENCRYPT_STAGING_URL;
368
369 if config.email.is_empty() {
370 panic!("Email not provided!");
371 }
372
373 let mut bind_address = config.bind_address;
374 bind_address.set_port(80);
375 let solver = Http01Solver::new();
376 let handle = unwrap!(solver.start(&bind_address));
377
378 let directory = unwrap!(
379 lers::Directory::builder(URL)
380 .http01_solver(Box::new(solver))
381 .build()
382 .await
383 );
384
385 let account = unwrap!(
386 directory
387 .account()
388 .terms_of_service_agreed(true)
389 .contacts(vec![format!("mailto:{}", config.email)])
390 .create_if_not_exists()
391 .await
392 );
393
394 let certificate = unwrap!(
395 account
396 .certificate()
397 .add_domain(&config.domain_name)
398 .obtain()
399 .await
400 );
401
402 tokio::spawn(handle.stop());
403
404 let certs: Vec<_> = certificate
405 .x509_chain()
406 .iter()
407 .map(|x| Certificate(x.to_der().unwrap()))
408 .collect();
409 let key = PrivateKey(certificate.private_key_to_der().unwrap());
410
411 write(cert_path, certificate.fullchain_to_pem().unwrap())
412 .expect("Cert file should be writable");
413 write(key_path, certificate.private_key_to_pem().unwrap())
414 .expect("Key file should be writable");
415
416 info!("Certificates successfully downloaded");
417
418 let incoming = AddrIncoming::bind(&config.bind_address).unwrap();
419 async_run_router::<P, _>(
420 axum::Server::builder(
421 TlsAcceptor::builder()
422 .with_single_cert(certs, key)
423 .unwrap()
424 .with_all_versions_alpn()
425 .with_incoming(incoming)
426 ),
427 router,
428 config,
429 )
430 .await;
431 return;
432 }
433
434 if !cert_path.exists() {
435 panic!("Certificate does not exist at the given path");
436 } else {
437 panic!("Private Key does not exist at the given path");
438 }
439 }
440 }
441
442 async_run_router::<P, _>(axum::Server::bind(&config.bind_address), router, config).await;
443}