refresh-server 0.7.0

Webserver with a refresh function whitch reloads the page in the browser
Documentation
use std::{
    fs::{self, File},
    io::Read,
    sync::Arc,
};

use axum::{
    extract::{ws::WebSocket, Query, WebSocketUpgrade},
    http::{HeaderMap, StatusCode, Uri},
    response::{Html, IntoResponse, Response},
    routing::{get, post},
    Extension, Router,
};
use reqwest::header;
use serde::Deserialize;
use tokio::sync::Mutex;

const WS_FILE: &str = include_str!("../res/ws.js");

#[derive(Deserialize)]
struct LoadParameter {
    path: String,
}

struct SocketState {
    subscriber: Mutex<Vec<WebSocket>>,
}

impl SocketState {
    fn new() -> Self {
        Self {
            subscriber: Mutex::new(Vec::default()),
        }
    }
}

#[derive(Clone)]
struct Args {
    directory: String,
    address: String,
    port: u16,
    err_callback: fn(err: String) -> ()
}

pub struct Handler {
    client: reqwest::Client,
    address: String,
    port: u16,
}

impl Handler {
    fn url(&self, path: &str) -> String {
        format!("http://{}:{}{}", self.address, self.port, path)
    }

    fn new(address: String, port: u16) -> Self {
        return Self {
            client: reqwest::Client::new(),
            address,
            port,
        };
    }

    pub async fn refresh(&self) -> Result<reqwest::Response, reqwest::Error> {
        return self.client.post(self.url("/ws/refresh")).send().await;
    }

    pub async fn load(&self, path: &str) -> Result<reqwest::Response, reqwest::Error> {
        let url = format!("{}{}", self.url("/ws/load?path="), r#path);
        return self.client.post(url).send().await;
    }
}

pub fn listen(directory: &str, address: &str, port: u16, err_callback: Option<fn(err: String) -> ()>) -> Handler {
    let args = Args {
        directory: directory.to_string(),
        port, address: address.to_string(),
        err_callback: err_callback.unwrap_or(|_| {})  
    };
    let state = Arc::new(SocketState::new());
    let app: Router = Router::new()
        .route("/ws/connect", get(ws_handler))
        .route("/ws/refresh", post(refresh_handler))
        .route("/ws/load", post(load_handler))
        .layer(Extension(state))
        .fallback(get(|uri: Uri| async move {
            fallback_handler(uri, args).await
        }));

    let addr = format!("{}:{}", address.replace("localhost", "127.0.0.1"), port);

    tokio::spawn(async move {
        axum::Server::bind(&addr.parse().unwrap())
            .serve(app.into_make_service())
            .await
            .unwrap();
    });

    return Handler::new(address.to_string(), port);
}

async fn ws_handler(
    event: WebSocketUpgrade,
    Extension(state): Extension<Arc<SocketState>>,
) -> axum::response::Response {
    event.on_upgrade(|socket: WebSocket| async move {
        let mut subscriber = state.subscriber.lock().await;
        subscriber.push(socket);
    })
}

async fn refresh_handler(Extension(state): Extension<Arc<SocketState>>) -> String {
    for subscriber in state.subscriber.lock().await.iter_mut() {
        subscriber.send("".into()).await.unwrap();
    }

    "refreshing".to_string()
}

async fn load_handler(
    Extension(state): Extension<Arc<SocketState>>,
    paramerter: Query<LoadParameter>,
) -> String {
    for subscriber in state.subscriber.lock().await.iter_mut() {
        subscriber
            .send(paramerter.path.clone().into())
            .await
            .unwrap();
    }

    "loading".to_string()
}

async fn fallback_handler(uri: Uri, args: Args) -> Response {
    let path = uri.path();
    let transformed_path = transform_path(args.directory.as_str(), path);
    let extension = transformed_path
        .split(".")
        .collect::<Vec<&str>>()
        .last()
        .expect("no extension")
        .clone();

    if is_image(extension) {
        let content_type = format!("image/{}", extension);
        let bytes = read_bytes(transformed_path.as_str());
        (header_with_content_type(content_type.as_str()), bytes).into_response()
    } else if extension == "css" {
        let Ok(content) = read_file(&transformed_path) else {
            (args.err_callback)("not found".to_string());
            return StatusCode::NOT_FOUND.into_response()
        };

        (header_with_content_type("text/css"), content).into_response()
    } else if extension == "js" {
        let Ok(content) = read_file(&transformed_path) else {
            return StatusCode::NOT_FOUND.into_response()
        };

        (header_with_content_type("text/javascript"), content).into_response()
    } else {
        let Ok(content) = read_file(&transformed_path) else {
            (args.err_callback)("not found".to_string());
            return StatusCode::NOT_FOUND.into_response()
        };

        let injected_content = inject_websocket_into(content, args);
        Html(injected_content).into_response()
    }
}

fn header_with_content_type(content_type: &str) -> HeaderMap {
    let mut header_map = HeaderMap::new();
    header_map.insert(header::CONTENT_TYPE, content_type.parse().unwrap());
    header_map
}

fn transform_path(directory: &str, path: &str) -> String {
    let completed_path = complete_path(path);
    add_directory_path(&completed_path, directory)
}

fn complete_path(url: &str) -> String {
    let with_index_default = if url.ends_with("/") {
        format!("{}index", url)
    } else {
        url.to_string()
    };

    let with_default_extension = if !with_index_default.contains(".") {
        return format!("{}.html", with_index_default);
    } else {
        with_index_default
    };

    return with_default_extension.to_string();
}

fn add_directory_path(path: &str, directory: &str) -> String {
    format!("{}{}", directory, path)
}

fn is_image(extension: &str) -> bool {
    return extension == "png"
        || extension == "jpg"
        || extension == "jpeg"
        || extension == "avif"
        || extension == "gif"
        || extension == "svg"
        || extension == "webp";
}

fn read_bytes(path: &str) -> Vec<u8> {
    let mut f = File::open(path).unwrap();
    let metadata = fs::metadata(path).unwrap();
    let mut buffer = vec![0; metadata.len() as usize];
    f.read(&mut buffer).expect("Buffer overload");
    return buffer;
}

fn read_file(path: &String) -> Result<String, std::io::Error> {
    return std::fs::read_to_string(path);
}

fn inject_websocket_into(mut content: String, args: Args) -> String {
    let insert_index = content.find("<head>").unwrap() + 6;
    let injected_ws_file = WS_FILE.replace("<address>", args.address.as_str())
        .replace("<port>", args.port.to_string().as_str());

    content.insert_str(
        insert_index,
        format!("<script>{}</script>", injected_ws_file).as_str(),
    );

    return content;
}