hypermangle_core/
lib.rs

1#![feature(path_file_prefix)]
2#![feature(result_flattening)]
3// #![feature(never_type)]
4// #![feature(os_str_bytes)]
5#![feature(return_position_impl_trait_in_trait)]
6
7use std::{
8    error::Error,
9    fs::{read_to_string, write, File},
10    io::BufReader,
11    net::SocketAddr,
12    path::Path,
13    process::Stdio,
14    time::SystemTime,
15};
16
17use axum::Router;
18use bearer::BearerAuth;
19use clap::{Parser, Subcommand};
20use console::{listen_for_commands, send_args_to_remote, ExecutableArgs};
21use hyper::server::{accept::Accept, Builder, conn::AddrIncoming};
22use hyper_rustls::TlsAcceptor;
23#[cfg(feature = "collect-certs")]
24use lers::solver::Http01Solver;
25use log::{info, warn};
26#[cfg(feature = "python")]
27use py::load_py_into_router;
28#[cfg(feature = "python")]
29use pyo3_asyncio::TaskLocals;
30use regex::RegexSet;
31use rustls::{PrivateKey, Certificate};
32use serde::Deserialize;
33use tokio::io::{AsyncRead, AsyncWrite};
34use tower::ServiceBuilder;
35use tower_http::{
36    auth::AsyncRequireAuthorizationLayer, compression::CompressionLayer, cors::CorsLayer,
37    trace::TraceLayer,
38};
39
40use crate::console::does_remote_exist;
41
42mod bearer;
43pub mod console;
44#[cfg(feature = "python")]
45mod py;
46// mod tls;
47
48#[cfg(all(feature = "hot-reload", feature = "python"))]
49const SYNC_CHANGES_DELAY: std::time::Duration = std::time::Duration::from_millis(1000);
50
51#[cfg(feature = "python")]
52static PY_TASK_LOCALS: std::sync::OnceLock<TaskLocals> = std::sync::OnceLock::new();
53
54pub fn load_scripts_into_router(router: Router, path: &Path) -> Router {
55    #[cfg(feature = "python")]
56    {
57        let mut router = router;
58        #[cfg(feature = "hot-reload")]
59        {
60            use notify::Watcher;
61            let async_runtime = tokio::runtime::Handle::current();
62            let working_dir = path.canonicalize().unwrap().parent().unwrap().to_owned();
63            let mut watcher =
64                notify::recommended_watcher(move |res: Result<notify::Event, _>| match res {
65                    Ok(event) => {
66                        let _guard = async_runtime.enter();
67                        let event = std::sync::Arc::new(event);
68                        py::py_handle_notify_event(event.clone(), working_dir.clone());
69                    }
70                    Err(event) => log::error!("File Watcher Error: {event:?}"),
71                })
72                .expect("Filesystem notification should be available");
73
74            watcher
75                .watch(path, notify::RecursiveMode::Recursive)
76                .expect("Scripts folder should be watchable");
77
78            Box::leak(Box::new(watcher));
79        }
80
81        for result in path
82            .read_dir()
83            .expect("Scripts directory should be readable")
84        {
85            let entry = result.expect("Script or sub-directory should be readable");
86            let path = entry.path();
87            let file_type = entry
88                .file_type()
89                .expect("File type of script or sub-directory should be accessible");
90
91            if file_type.is_dir() {
92                router = load_scripts_into_router(router, &path);
93            } else if file_type.is_file() {
94                match path.extension().map(std::ffi::OsStr::to_str).flatten() {
95                    #[cfg(feature = "python")]
96                    Some("py") => router = load_py_into_router(router, &path),
97                    _ => {}
98                }
99            } else {
100                panic!("Failed to get the file type of {entry:?}");
101            }
102        }
103
104        router
105    }
106
107    #[cfg(not(feature = "python"))]
108    {
109        let _path = path;
110        router
111    }
112}
113
114pub fn setup_logger(log_file_path: &str, log_level: &str) {
115    let log_level = if log_level.is_empty() {
116        log::LevelFilter::Info
117    } else {
118        log_level.parse().expect("Log Level should be valid")
119    };
120
121    let mut dispatch = fern::Dispatch::new()
122        .format(|out, message, record| {
123            out.finish(format_args!(
124                "[{} {} {}] {}",
125                humantime::format_rfc3339_seconds(SystemTime::now()),
126                record.level(),
127                record.target(),
128                message
129            ))
130        })
131        .level(log_level)
132        .chain(std::io::stdout());
133
134    if !log_file_path.is_empty() {
135        dispatch =
136            dispatch.chain(fern::log_file(log_file_path).expect("Log File should be writable"))
137    }
138
139    dispatch
140        .apply()
141        .expect("Logger should have initialized successfully");
142}
143
144#[cfg(feature = "python")]
145#[inline]
146fn u16_to_status(code: u16, f: impl Fn() -> String) -> axum::http::StatusCode {
147    axum::http::StatusCode::from_u16(code).expect(&f())
148}
149
150#[derive(Deserialize)]
151pub struct HyperDomeConfig {
152    #[serde(default)]
153    cors_methods: Vec<String>,
154    #[serde(default)]
155    cors_origins: Vec<String>,
156    #[serde(default)]
157    api_token: String,
158    bind_address: SocketAddr,
159    #[serde(default)]
160    public_paths: Vec<String>,
161    #[serde(default)]
162    cert_path: String,
163    #[serde(default)]
164    key_path: String,
165    #[serde(default)]
166    email: String,
167    #[serde(default)]
168    domain_name: String,
169    #[serde(default)]
170    log_file_path: String,
171    #[serde(default)]
172    log_level: String,
173}
174
175impl HyperDomeConfig {
176    pub fn from_toml_file(path: &Path) -> Self {
177        let txt = read_to_string(path).expect(&format!("{path:?} should be readable"));
178        toml::from_str(&txt).expect(&format!("{path:?} should be valid toml"))
179    }
180}
181
182#[inline]
183pub async fn async_run_router<P, I>(server: Builder<I>, mut router: Router, config: HyperDomeConfig)
184where
185    P: ExecutableArgs,
186    I: Accept,
187    I::Error: Into<Box<dyn Error + Send + Sync>>,
188    I::Conn: AsyncRead + AsyncWrite + Unpin + Send + 'static,
189{
190    router = load_scripts_into_router(router, "scripts".as_ref());
191
192    router = router.layer(
193        ServiceBuilder::new()
194            .layer(CompressionLayer::new())
195            .layer(TraceLayer::new_for_http())
196            .layer(
197                CorsLayer::new()
198                    .allow_methods(
199                        config
200                            .cors_methods
201                            .into_iter()
202                            .map(|x| {
203                                x.parse()
204                                    .expect("CORS Method should be a valid HTTP Method")
205                            })
206                            .collect::<Vec<_>>(),
207                    )
208                    .allow_origin(
209                        config
210                            .cors_origins
211                            .into_iter()
212                            .map(|x| x.parse().expect("CORS Origin should be a valid origin"))
213                            .collect::<Vec<_>>(),
214                    ),
215            ),
216    );
217
218    if !config.api_token.is_empty() {
219        router = router.layer(AsyncRequireAuthorizationLayer::new(BearerAuth::new(
220            config.api_token.parse().expect("msg"),
221            RegexSet::new(config.public_paths).expect("msg"),
222        )));
223    }
224
225    server
226        .serve(router.into_make_service())
227        .with_graceful_shutdown(listen_for_commands::<P>())
228        .await
229        .unwrap();
230}
231
232#[derive(Parser)]
233#[command(author, version, about, long_about = None)]
234struct Args {
235    #[command(subcommand)]
236    command: Commands,
237}
238
239#[derive(Subcommand)]
240enum Commands {
241    Run {
242        #[arg(short, long)]
243        detached: bool,
244    },
245}
246
247pub fn auto_main<P: ExecutableArgs>(router: impl Fn() -> Router) {
248    let Ok(args) = Args::try_parse_from(std::env::args_os()) else {
249        send_args_to_remote();
250        return;
251    };
252
253    match args.command {
254        Commands::Run { detached } => {
255            if let Some(id) = does_remote_exist() {
256                println!("Remote already exists with process id: {id}");
257                return;
258            }
259            if detached {
260                let id = std::process::Command::new(
261                    std::env::current_exe().expect("Current EXE name should be accessible"),
262                )
263                .arg("run")
264                .stdin(Stdio::null())
265                .stdout(Stdio::null())
266                .stderr(Stdio::null())
267                .spawn()
268                .expect("Child process should have spawned successfully")
269                .id();
270                println!("Process has spawned successfully with id: {id}");
271                return;
272            }
273        }
274    }
275
276    auto_main_inner::<P>(router());
277}
278
279#[tokio::main]
280async fn auto_main_inner<P: ExecutableArgs>(router: Router) {
281    console_subscriber::init();
282    let config = HyperDomeConfig::from_toml_file("hypermangle.toml".as_ref());
283    setup_logger(&config.log_file_path, &config.log_level);
284
285    #[cfg(feature = "python")]
286    std::thread::spawn(|| {
287        pyo3::Python::with_gil(|py| {
288            // Disable Ctrl-C handling
289            let signal_module = py.import("signal").unwrap();
290            signal_module
291                .call_method1(
292                    "signal",
293                    (
294                        signal_module.getattr("SIGINT").unwrap(),
295                        signal_module.getattr("SIG_DFL").unwrap(),
296                    ),
297                )
298                .unwrap();
299
300            let event_loop = py
301                .import("asyncio")
302                .unwrap()
303                .call_method0("new_event_loop")
304                .unwrap();
305            PY_TASK_LOCALS
306                .set(pyo3_asyncio::TaskLocals::new(event_loop))
307                .unwrap();
308            event_loop.call_method0("run_forever").unwrap();
309        })
310    });
311
312    if !config.cert_path.is_empty() && !config.key_path.is_empty() {
313        let cert_path: &Path = config.cert_path.as_ref();
314        let key_path: &Path = config.key_path.as_ref();
315
316        if cert_path.exists() && key_path.exists() {
317            info!("Loading HTTP Certificates");
318            let file = File::open(cert_path).expect("Cert path should be readable");
319            let mut reader = BufReader::new(file);
320            let certs = rustls_pemfile::certs(&mut reader).expect("Cert file should be valid");
321            let certs: Vec<_> = certs.into_iter().map(Certificate).collect();
322
323            let file = File::open(&key_path).expect("Key path should be readable");
324            let mut reader = BufReader::new(file);
325            let mut keys =
326                rustls_pemfile::pkcs8_private_keys(&mut reader).expect("Key file should be valid");
327
328            let key = match keys.len() {
329                0 => panic!("No PKCS8-encoded private key found in key file"),
330                1 => PrivateKey(keys.remove(0)),
331                _ => panic!("More than one PKCS8-encoded private key found in key file"),
332            };
333
334            info!("HTTP Certificates successfully loaded");
335            let incoming = AddrIncoming::bind(&config.bind_address).unwrap();
336            async_run_router::<P, _>(
337                axum::Server::builder(
338                    TlsAcceptor::builder()
339                        .with_single_cert(certs, key)
340                        .unwrap()
341                        .with_all_versions_alpn()
342                        .with_incoming(incoming)
343                ),
344                router,
345                config,
346            )
347            .await;
348            return;
349        } else {
350            #[cfg(feature = "collect-certs")]
351            if !cert_path.exists() && !key_path.exists() {
352                warn!("Acquiring HTTP Certificates");
353                macro_rules! unwrap {
354                    ($result: expr) => {
355                        match $result {
356                            Ok(x) => x,
357                            Err(e) => {
358                                panic!("Error running LERS: {e}");
359                            }
360                        }
361                    };
362                }
363
364                #[cfg(not(debug_assertions))]
365                const URL: &str = lers::LETS_ENCRYPT_PRODUCTION_URL;
366                #[cfg(debug_assertions)]
367                const URL: &str = lers::LETS_ENCRYPT_STAGING_URL;
368
369                if config.email.is_empty() {
370                    panic!("Email not provided!");
371                }
372
373                let mut bind_address = config.bind_address;
374                bind_address.set_port(80);
375                let solver = Http01Solver::new();
376                let handle = unwrap!(solver.start(&bind_address));
377
378                let directory = unwrap!(
379                    lers::Directory::builder(URL)
380                        .http01_solver(Box::new(solver))
381                        .build()
382                        .await
383                );
384
385                let account = unwrap!(
386                    directory
387                        .account()
388                        .terms_of_service_agreed(true)
389                        .contacts(vec![format!("mailto:{}", config.email)])
390                        .create_if_not_exists()
391                        .await
392                );
393
394                let certificate = unwrap!(
395                    account
396                        .certificate()
397                        .add_domain(&config.domain_name)
398                        .obtain()
399                        .await
400                );
401
402                tokio::spawn(handle.stop());
403
404                let certs: Vec<_> = certificate
405                    .x509_chain()
406                    .iter()
407                    .map(|x| Certificate(x.to_der().unwrap()))
408                    .collect();
409                let key = PrivateKey(certificate.private_key_to_der().unwrap());
410
411                write(cert_path, certificate.fullchain_to_pem().unwrap())
412                    .expect("Cert file should be writable");
413                write(key_path, certificate.private_key_to_pem().unwrap())
414                    .expect("Key file should be writable");
415
416                info!("Certificates successfully downloaded");
417
418                let incoming = AddrIncoming::bind(&config.bind_address).unwrap();
419                async_run_router::<P, _>(
420                    axum::Server::builder(
421                        TlsAcceptor::builder()
422                            .with_single_cert(certs, key)
423                            .unwrap()
424                            .with_all_versions_alpn()
425                            .with_incoming(incoming)
426                    ),
427                    router,
428                    config,
429                )
430                .await;
431                return;
432            }
433
434            if !cert_path.exists() {
435                panic!("Certificate does not exist at the given path");
436            } else {
437                panic!("Private Key does not exist at the given path");
438            }
439        }
440    }
441
442    async_run_router::<P, _>(axum::Server::bind(&config.bind_address), router, config).await;
443}