rs-mock-server 0.6.15

A simple, file-based mock API server that maps your directory structure to HTTP and GraphQL routes. Ideal for local development and testing.
use std::{cell::RefCell, ffi::OsString, io::Write, sync::{Arc, Mutex, RwLock}};

use axum::{
    middleware, response::IntoResponse, routing::{get, MethodRouter, Route}, Router
};
use fosk::Db;
use http::{header::CONTENT_TYPE, HeaderMap, HeaderValue, StatusCode};
use terminal_link::Link;
use tokio::net::TcpListener;
use tower::{layer::util::{Identity, Stack}, Layer, ServiceBuilder};
use tower_http::{cors::CorsLayer, normalize_path::NormalizePathLayer, services::ServeDir, trace::TraceLayer};

use crate::{
    handlers::{create_collections_routes, make_auth_middleware},
    pages::Pages,
    route_builder::{config::{Config, ServerConfig},
    route_manager::RouteManager, RouteGenerator, RouteRegistrator},
    upload_configuration::UploadConfiguration,
    DEFAULT_FOLDER, DEFAULT_PORT
};

#[derive(Default)]
pub struct GlobalSharedInfo {
    pub jwt_secret: String,
    pub token_collection: String,
    pub auth_cookie_name: String,
}

pub const MOCK_SERVER_ROUTE: &str = "/mock-server";
pub static GLOBAL_SHARED_INFO: RwLock<GlobalSharedInfo> = RwLock::new(GlobalSharedInfo {
    jwt_secret: String::new(),
    token_collection: String::new(),
    auth_cookie_name: String::new(),
});

pub struct App {
    pub router: RefCell<Router>,
    pub pages: Arc<Mutex<Pages>>,
    uploads_configurations: Vec<UploadConfiguration>,
    pub db: Arc<Db>,
    pub server_config: Config,
}

impl Default for App {
    fn default() -> Self {
        let router = RefCell::new(Router::new());
        let pages = Arc::new(Mutex::new(Pages::new()));
        let uploads_configurations = vec![];
        let db = Db::new_arc();
        let server_config = Config {
            server: Some(ServerConfig {
                folder: Some(DEFAULT_FOLDER.into()),
                port: Some(DEFAULT_PORT),
                ..Default::default()
            }),
            ..Default::default()
        };
        App { router, pages, uploads_configurations, db, server_config }
    }
}

impl App {

    pub fn new(server_config: Config) -> Self {
        let router = RefCell::new(Router::new());
        let pages = Arc::new(Mutex::new(Pages::new()));
        let uploads_configurations = vec![];
        let db = Db::new_arc();
        App { router, pages, uploads_configurations, db, server_config }
    }

    pub fn get_folder(&self) -> String {
        self.server_config.server
            .as_ref()
            .unwrap_or(&ServerConfig::default())
            .folder
            .clone()
            .unwrap_or(DEFAULT_FOLDER.to_string())
    }

    pub fn get_port(&self) -> u16 {
        self.server_config.server
            .as_ref()
            .unwrap_or(&ServerConfig::default())
            .port
            .unwrap_or(DEFAULT_PORT)
    }

    pub fn push_uploads_config(&mut self, uploads_path: String, clean_uploads: bool) {
        self.uploads_configurations.push(
            UploadConfiguration::new(uploads_path, clean_uploads)
        );
    }

    fn get_router(&self) -> Router {
        self.router.take()
    }

    fn replace_router(&mut self, new_router: Router) {
        // _old_route object will be dropped (Axum uses builder pattern)
        let _old_route = self.router.replace(new_router);
    }

    pub fn route(&mut self, path: &str, router: MethodRouter<()>, method: Option<&str>, options: Option<&[String]>) {
        let new_router = self.get_router().route(path, router);

        self.replace_router(new_router);

        if let Some(method) = method {
            self.pages.lock().unwrap().push_link(method.to_string(), path.to_string(), options.unwrap_or(&Vec::<String>::new()));
        }
    }

    pub fn try_add_auth_middleware_layer(&mut self, router: MethodRouter, is_protected: bool) -> MethodRouter {
        if !is_protected {
            return router;
        }

        let shared_info = GLOBAL_SHARED_INFO.read().unwrap();
        if let Some(token_collection) = &self.db.get(&shared_info.token_collection) {
            return router.layer(
                middleware::from_fn(make_auth_middleware(token_collection, &shared_info.jwt_secret, &shared_info.auth_cookie_name))
            )
        }
        router
    }

    fn build_dyn_routes(&mut self) {
        let dir = self.get_folder();
        RouteManager::from_dir(&dir, Some(self.server_config.clone()))
            .make_routes(self);
    }

    fn build_home_route(&mut self) {
        let pages = Arc::clone(&self.pages);

        self
            .route("/", get(|| async move {
                let body = pages.lock().unwrap().render_index();
                let mut headers = HeaderMap::new();
                headers.insert(CONTENT_TYPE, HeaderValue::from_str("text/html").unwrap());
                headers.insert("Cache-Control", HeaderValue::from_str("no-cache, no-store, must-revalidate").unwrap());
                headers.insert("Pragma", HeaderValue::from_str("no-cache").unwrap());
                headers.insert("Expires", HeaderValue::from_str("0").unwrap());

                (headers, body).into_response()
            }), None, None);
    }

    fn build_cors_layer<L>(
        &self,
        service_builder: ServiceBuilder<L>,
    ) -> ServiceBuilder<Stack<tower::util::Either<CorsLayer, Identity>, L>>
    where
        L: Layer<Route> + Clone + Send + Sync + 'static,
        Stack<CorsLayer, L>: Layer<Route> + Clone + Send + Sync + 'static,
    {
        let server_config = self.server_config.server.clone().unwrap_or_default();
        let enable = server_config.enable_cors.unwrap_or(true);
        let allowed_origin = server_config.allowed_origin;

        service_builder.option_layer(
            enable.then(|| {
                if let Some(allowed_origin) = allowed_origin {
                    CorsLayer::very_permissive()
                        .allow_origin(allowed_origin.parse::<HeaderValue>().unwrap())
                } else {
                    CorsLayer::very_permissive()
                }
            })
        )
    }

    fn build_middlewares(&mut self) {
        let service_builder = ServiceBuilder::new()
            .layer(TraceLayer::new_for_http());

        let service_builder = self.build_cors_layer(service_builder);

        let service_builder = service_builder
            .layer(NormalizePathLayer::trim_trailing_slash());

        let new_router = self.get_router().layer(service_builder);

        self.replace_router(new_router);
    }

    fn build_fallback(&mut self) {
        let new_router = self.get_router().fallback(Self::handler_404);
        self.replace_router(new_router);
    }

    async fn handler_404() -> impl IntoResponse {
        (StatusCode::NOT_FOUND, "nothing to see here")
    }

    pub fn build_public_router(&mut self, file_name: String, path: String) {
        let  public_end_point = if let Some((_, to)) = file_name.split_once('-') {
            to
        } else {
            "public"
        };

        let static_files = ServeDir::new(path);
        let new_router = self.router.take().nest_service(
            &format!("/{}", public_end_point),
            static_files
        );
        self.replace_router(new_router);
    }

    pub fn build_public_router_v2(&mut self, path: &OsString, route: &str) {
        let static_files = ServeDir::new(path);
        let new_router = self.router.take().nest_service(
            route,
            static_files
        );
        self.replace_router(new_router);
    }

    pub fn build_collections_route(&mut self) {
        create_collections_routes(self);
    }

    pub fn build_collections_references(&mut self) {
        let collections = self.db.list_collections();

        for i in 0..collections.len() - 1 {
            for j in i+1..collections.len() {
                self.db.infer_reference(&collections[i], &collections[j]);
                self.db.infer_reference(&collections[j], &collections[i]);
            }
        }
    }

    pub fn show_greetings() {
        let banner = r"
                                  ___     ___
                                 (o o)   (o o)
 _____                          (  V  ) (  V  )                         _____
( ___ )------------------------ /--m-m- /--m-m-------------------------( ___ )
 |   |                                                                  |   |
 |   |                                                                  |   |
 |   |     ░█▀▄░█▀▀░░░░░█▄█░█▀█░█▀▀░█░█░░░░░█▀▀░█▀▀░█▀▄░█░█░█▀▀░█▀▄     |   |
 |   |     ░█▀▄░▀▀█░▄▄▄░█░█░█░█░█░░░█▀▄░▄▄▄░▀▀█░█▀▀░█▀▄░▀▄▀░█▀▀░█▀▄     |   |
 |   |     ░▀░▀░▀▀▀░░░░░▀░▀░▀▀▀░▀▀▀░▀░▀░░░░░▀▀▀░▀▀▀░▀░▀░░▀░░▀▀▀░▀░▀     |   |
 |   |                                                                  |   |
 |   |                             {{{{}}}}                             |   |
 |___|                                                                  |___|
(_____)----------------------------------------------------------------(_____)

";

        let version = format!("v{}", env!("CARGO_PKG_VERSION"));
        let version = format!("{:^8}", version);
        let _ = std::io::stdout().write_all(banner.replace("{{{{}}}}", &version).as_bytes());
    }

    async fn start_server(&self) {
        let address = format!("0.0.0.0:{}", self.get_port());

        let listener = TcpListener::bind(address.clone()).await.unwrap();
        App::show_greetings();

        let link = format!("http://localhost:{}", self.get_port());
        let link = Link::new(&link, &link);
        println!("🚀 Listening on {}", link);

        axum::serve(listener, self.get_router()).await.unwrap();
    }

    pub async fn initialize(&mut self) {
        self.build_dyn_routes();
        self.build_home_route();
        self.build_collections_route();
        self.build_fallback();
        self.build_middlewares();
        self.build_collections_references();
        self.start_server().await;
    }

    pub fn finish(&mut self) {
        println!("\n");

        for upload_config in self.uploads_configurations.iter() {
            upload_config.clean_upload_folder();
        }

        self.router = RefCell::new(Router::new());
        self.pages = Arc::new(Mutex::new(Pages::new()));
        self.uploads_configurations = vec![];
        self.db.clear();

        println!("\n👋👋👋👋👋 Goodbye! 👋👋👋👋👋👋");
    }
}

impl RouteRegistrator for App {
    fn push_route(&mut self, path: &str, router: MethodRouter, method: Option<&str>, is_protected: bool, options: Option<&[String]>) {
        let router = self.try_add_auth_middleware_layer(router, is_protected);

        self.route(path, router, method, options);
    }
}