rs-mock-server 0.6.16

A simple, file-based mock API server that maps your directory structure to HTTP and GraphQL routes. Ideal for local development and testing.
use std::{
    cell::RefCell,
    ffi::OsString,
    io::Write,
    sync::{Arc, Mutex, RwLock},
};

use axum::{
    Router, middleware,
    response::IntoResponse,
    routing::{MethodRouter, Route, get},
};
use fosk::Db;
use http::{HeaderMap, HeaderValue, StatusCode, header::CONTENT_TYPE};
use terminal_link::Link;
use tokio::net::TcpListener;
use tower::{
    Layer, ServiceBuilder,
    layer::util::{Identity, Stack},
};
use tower_http::{
    cors::CorsLayer, normalize_path::NormalizePathLayer, services::ServeDir, trace::TraceLayer,
};

use crate::{
    DEFAULT_FOLDER, DEFAULT_PORT,
    handlers::{create_collections_routes, make_auth_middleware},
    pages::Pages,
    route_builder::{
        RouteGenerator, RouteRegistrator,
        config::{Config, ServerConfig},
        route_manager::RouteManager,
    },
    upload_configuration::UploadConfiguration,
};

#[derive(Default)]
pub struct GlobalSharedInfo {
    pub jwt_secret: String,
    pub token_collection: String,
    pub auth_cookie_name: String,
}

pub const MOCK_SERVER_ROUTE: &str = "/mock-server";
pub static GLOBAL_SHARED_INFO: RwLock<GlobalSharedInfo> = RwLock::new(GlobalSharedInfo {
    jwt_secret: String::new(),
    token_collection: String::new(),
    auth_cookie_name: String::new(),
});

pub struct App {
    pub router: RefCell<Router>,
    pub pages: Arc<Mutex<Pages>>,
    uploads_configurations: Vec<UploadConfiguration>,
    pub db: Arc<Db>,
    pub server_config: Config,
}

impl Default for App {
    fn default() -> Self {
        let router = RefCell::new(Router::new());
        let pages = Arc::new(Mutex::new(Pages::new()));
        let uploads_configurations = vec![];
        let db = Db::new_arc();
        let server_config = Config {
            server: Some(ServerConfig {
                folder: Some(DEFAULT_FOLDER.into()),
                port: Some(DEFAULT_PORT),
                ..Default::default()
            }),
            ..Default::default()
        };
        App {
            router,
            pages,
            uploads_configurations,
            db,
            server_config,
        }
    }
}

impl App {
    pub fn new(server_config: Config) -> Self {
        let router = RefCell::new(Router::new());
        let pages = Arc::new(Mutex::new(Pages::new()));
        let uploads_configurations = vec![];
        let db = Db::new_arc();
        App {
            router,
            pages,
            uploads_configurations,
            db,
            server_config,
        }
    }

    pub fn get_folder(&self) -> String {
        self.server_config
            .server
            .as_ref()
            .unwrap_or(&ServerConfig::default())
            .folder
            .clone()
            .unwrap_or(DEFAULT_FOLDER.to_string())
    }

    pub fn get_port(&self) -> u16 {
        self.server_config
            .server
            .as_ref()
            .unwrap_or(&ServerConfig::default())
            .port
            .unwrap_or(DEFAULT_PORT)
    }

    pub fn push_uploads_config(&mut self, uploads_path: String, clean_uploads: bool) {
        self.uploads_configurations
            .push(UploadConfiguration::new(uploads_path, clean_uploads));
    }

    fn get_router(&self) -> Router {
        self.router.take()
    }

    fn replace_router(&mut self, new_router: Router) {
        // _old_route object will be dropped (Axum uses builder pattern)
        let _old_route = self.router.replace(new_router);
    }

    pub fn route(
        &mut self,
        path: &str,
        router: MethodRouter<()>,
        method: Option<&str>,
        options: Option<&[String]>,
    ) {
        let new_router = self.get_router().route(path, router);

        self.replace_router(new_router);

        if let Some(method) = method {
            self.pages.lock().unwrap().push_link(
                method.to_string(),
                path.to_string(),
                options.unwrap_or(&Vec::<String>::new()),
            );
        }
    }

    pub fn try_add_auth_middleware_layer(
        &mut self,
        router: MethodRouter,
        is_protected: bool,
    ) -> MethodRouter {
        if !is_protected {
            return router;
        }

        let shared_info = GLOBAL_SHARED_INFO.read().unwrap();
        if let Some(token_collection) = &self.db.get(&shared_info.token_collection) {
            return router.layer(middleware::from_fn(make_auth_middleware(
                token_collection,
                &shared_info.jwt_secret,
                &shared_info.auth_cookie_name,
            )));
        }
        router
    }

    fn build_dyn_routes(&mut self) {
        let dir = self.get_folder();
        RouteManager::from_dir(&dir, Some(self.server_config.clone())).make_routes(self);
    }

    fn build_home_route(&mut self) {
        let pages = Arc::clone(&self.pages);

        self.route(
            "/",
            get(|| async move {
                let body = pages.lock().unwrap().render_index();
                let mut headers = HeaderMap::new();
                headers.insert(CONTENT_TYPE, HeaderValue::from_str("text/html").unwrap());
                headers.insert(
                    "Cache-Control",
                    HeaderValue::from_str("no-cache, no-store, must-revalidate").unwrap(),
                );
                headers.insert("Pragma", HeaderValue::from_str("no-cache").unwrap());
                headers.insert("Expires", HeaderValue::from_str("0").unwrap());

                (headers, body).into_response()
            }),
            None,
            None,
        );
    }

    fn build_cors_layer<L>(
        &self,
        service_builder: ServiceBuilder<L>,
    ) -> ServiceBuilder<Stack<tower::util::Either<CorsLayer, Identity>, L>>
    where
        L: Layer<Route> + Clone + Send + Sync + 'static,
        Stack<CorsLayer, L>: Layer<Route> + Clone + Send + Sync + 'static,
    {
        let server_config = self.server_config.server.clone().unwrap_or_default();
        let enable = server_config.enable_cors.unwrap_or(true);
        let allowed_origin = server_config.allowed_origin;

        service_builder.option_layer(enable.then(|| {
            if let Some(allowed_origin) = allowed_origin {
                CorsLayer::very_permissive()
                    .allow_origin(allowed_origin.parse::<HeaderValue>().unwrap())
            } else {
                CorsLayer::very_permissive()
            }
        }))
    }

    fn build_middlewares(&mut self) {
        let service_builder = ServiceBuilder::new().layer(TraceLayer::new_for_http());

        let service_builder = self.build_cors_layer(service_builder);

        let service_builder = service_builder.layer(NormalizePathLayer::trim_trailing_slash());

        let new_router = self.get_router().layer(service_builder);

        self.replace_router(new_router);
    }

    fn build_fallback(&mut self) {
        let new_router = self.get_router().fallback(Self::handler_404);
        self.replace_router(new_router);
    }

    async fn handler_404() -> impl IntoResponse {
        (StatusCode::NOT_FOUND, "nothing to see here")
    }

    pub fn build_public_router(&mut self, file_name: String, path: String) {
        let public_end_point = if let Some((_, to)) = file_name.split_once('-') {
            to
        } else {
            "public"
        };

        let static_files = ServeDir::new(path);
        let new_router = self
            .router
            .take()
            .nest_service(&format!("/{}", public_end_point), static_files);
        self.replace_router(new_router);
    }

    pub fn build_public_router_v2(&mut self, path: &OsString, route: &str) {
        let static_files = ServeDir::new(path);
        let new_router = self.router.take().nest_service(route, static_files);
        self.replace_router(new_router);
    }

    pub fn build_collections_route(&mut self) {
        create_collections_routes(self);
    }

    pub fn build_collections_references(&mut self) {
        let collections = self.db.list_collections();

        if collections.len() > 1 {
            for i in 0..collections.len() - 1 {
                for j in i + 1..collections.len() {
                    self.db.infer_reference(&collections[i], &collections[j]);
                    self.db.infer_reference(&collections[j], &collections[i]);
                }
            }
        }
    }

    pub fn show_greetings() {
        let banner = r"
                                  ___     ___
                                 (o o)   (o o)
 _____                          (  V  ) (  V  )                         _____
( ___ )------------------------ /--m-m- /--m-m-------------------------( ___ )
 |   |                                                                  |   |
 |   |                                                                  |   |
 |   |     ░█▀▄░█▀▀░░░░░█▄█░█▀█░█▀▀░█░█░░░░░█▀▀░█▀▀░█▀▄░█░█░█▀▀░█▀▄     |   |
 |   |     ░█▀▄░▀▀█░▄▄▄░█░█░█░█░█░░░█▀▄░▄▄▄░▀▀█░█▀▀░█▀▄░▀▄▀░█▀▀░█▀▄     |   |
 |   |     ░▀░▀░▀▀▀░░░░░▀░▀░▀▀▀░▀▀▀░▀░▀░░░░░▀▀▀░▀▀▀░▀░▀░░▀░░▀▀▀░▀░▀     |   |
 |   |                                                                  |   |
 |   |                             {{{{}}}}                             |   |
 |___|                                                                  |___|
(_____)----------------------------------------------------------------(_____)

";

        let version = format!("v{}", env!("CARGO_PKG_VERSION"));
        let version = format!("{:^8}", version);
        let _ = std::io::stdout().write_all(banner.replace("{{{{}}}}", &version).as_bytes());
    }

    async fn start_server(&self) {
        let address = format!("0.0.0.0:{}", self.get_port());

        let listener = TcpListener::bind(address.clone()).await.unwrap();
        App::show_greetings();

        let link = format!("http://localhost:{}", self.get_port());
        let link = Link::new(&link, &link);
        println!("🚀 Listening on {}", link);

        axum::serve(listener, self.get_router()).await.unwrap();
    }

    pub async fn initialize(&mut self) {
        self.build_dyn_routes();
        self.build_home_route();
        self.build_collections_route();
        self.build_fallback();
        self.build_middlewares();
        self.build_collections_references();
        self.start_server().await;
    }

    pub fn finish(&mut self) {
        println!("\n");

        for upload_config in self.uploads_configurations.iter() {
            upload_config.clean_upload_folder();
        }

        self.router = RefCell::new(Router::new());
        self.pages = Arc::new(Mutex::new(Pages::new()));
        self.uploads_configurations = vec![];
        self.db.clear();

        println!("\n👋👋👋👋👋 Goodbye! 👋👋👋👋👋👋");
    }
}

impl RouteRegistrator for App {
    fn push_route(
        &mut self,
        path: &str,
        router: MethodRouter,
        method: Option<&str>,
        is_protected: bool,
        options: Option<&[String]>,
    ) {
        let router = self.try_add_auth_middleware_layer(router, is_protected);

        self.route(path, router, method, options);
    }
}