use std::{
fs::{self, File},
io::Read,
sync::Arc,
};
use axum::{
extract::{ws::WebSocket, Query, WebSocketUpgrade},
http::{HeaderMap, StatusCode, Uri},
response::{Html, IntoResponse, Response},
routing::{get, post},
Extension, Router,
};
use reqwest::header;
use serde::Deserialize;
use tokio::sync::Mutex;
const WS_FILE: &str = include_str!("../res/ws.js");
#[derive(Deserialize)]
struct LoadParameter {
path: String,
}
struct SocketState {
subscriber: Mutex<Vec<WebSocket>>,
}
impl SocketState {
fn new() -> Self {
Self {
subscriber: Mutex::new(Vec::default()),
}
}
}
#[derive(Clone)]
struct Args {
directory: String,
address: String,
port: u16,
err_callback: fn(err: String) -> ()
}
pub struct Handler {
client: reqwest::Client,
address: String,
port: u16,
}
impl Handler {
fn url(&self, path: &str) -> String {
format!("http://{}:{}{}", self.address, self.port, path)
}
fn new(address: String, port: u16) -> Self {
return Self {
client: reqwest::Client::new(),
address,
port,
};
}
pub async fn refresh(&self) -> Result<reqwest::Response, reqwest::Error> {
return self.client.post(self.url("/ws/refresh")).send().await;
}
pub async fn load(&self, path: &str) -> Result<reqwest::Response, reqwest::Error> {
let url = format!("{}{}", self.url("/ws/load?path="), r#path);
return self.client.post(url).send().await;
}
}
pub fn listen(directory: &str, address: &str, port: u16, err_callback: Option<fn(err: String) -> ()>) -> Handler {
let args = Args {
directory: directory.to_string(),
port, address: address.to_string(),
err_callback: err_callback.unwrap_or(|_| {})
};
let state = Arc::new(SocketState::new());
let app: Router = Router::new()
.route("/ws/connect", get(ws_handler))
.route("/ws/refresh", post(refresh_handler))
.route("/ws/load", post(load_handler))
.layer(Extension(state))
.fallback(get(|uri: Uri| async move {
fallback_handler(uri, args).await
}));
let addr = format!("{}:{}", address.replace("localhost", "127.0.0.1"), port);
tokio::spawn(async move {
axum::Server::bind(&addr.parse().unwrap())
.serve(app.into_make_service())
.await
.unwrap();
});
return Handler::new(address.to_string(), port);
}
async fn ws_handler(
event: WebSocketUpgrade,
Extension(state): Extension<Arc<SocketState>>,
) -> axum::response::Response {
event.on_upgrade(|socket: WebSocket| async move {
let mut subscriber = state.subscriber.lock().await;
subscriber.push(socket);
})
}
async fn refresh_handler(Extension(state): Extension<Arc<SocketState>>) -> String {
for subscriber in state.subscriber.lock().await.iter_mut() {
subscriber.send("".into()).await.unwrap();
}
"refreshing".to_string()
}
async fn load_handler(
Extension(state): Extension<Arc<SocketState>>,
paramerter: Query<LoadParameter>,
) -> String {
for subscriber in state.subscriber.lock().await.iter_mut() {
subscriber
.send(paramerter.path.clone().into())
.await
.unwrap();
}
"loading".to_string()
}
async fn fallback_handler(uri: Uri, args: Args) -> Response {
let path = uri.path();
let transformed_path = transform_path(args.directory.as_str(), path);
let extension = transformed_path
.split(".")
.collect::<Vec<&str>>()
.last()
.expect("no extension")
.clone();
if is_image(extension) {
let content_type = format!("image/{}", extension);
let bytes = read_bytes(transformed_path.as_str());
(header_with_content_type(content_type.as_str()), bytes).into_response()
} else if extension == "css" {
let Ok(content) = read_file(&transformed_path) else {
(args.err_callback)("not found".to_string());
return StatusCode::NOT_FOUND.into_response()
};
(header_with_content_type("text/css"), content).into_response()
} else if extension == "js" {
let Ok(content) = read_file(&transformed_path) else {
return StatusCode::NOT_FOUND.into_response()
};
(header_with_content_type("text/javascript"), content).into_response()
} else {
let Ok(content) = read_file(&transformed_path) else {
(args.err_callback)("not found".to_string());
return StatusCode::NOT_FOUND.into_response()
};
let injected_content = inject_websocket_into(content, args);
Html(injected_content).into_response()
}
}
fn header_with_content_type(content_type: &str) -> HeaderMap {
let mut header_map = HeaderMap::new();
header_map.insert(header::CONTENT_TYPE, content_type.parse().unwrap());
header_map
}
fn transform_path(directory: &str, path: &str) -> String {
let completed_path = complete_path(path);
add_directory_path(&completed_path, directory)
}
fn complete_path(url: &str) -> String {
let with_index_default = if url.ends_with("/") {
format!("{}index", url)
} else {
url.to_string()
};
let with_default_extension = if !with_index_default.contains(".") {
return format!("{}.html", with_index_default);
} else {
with_index_default
};
return with_default_extension.to_string();
}
fn add_directory_path(path: &str, directory: &str) -> String {
format!("{}{}", directory, path)
}
fn is_image(extension: &str) -> bool {
return extension == "png"
|| extension == "jpg"
|| extension == "jpeg"
|| extension == "avif"
|| extension == "gif"
|| extension == "svg"
|| extension == "webp";
}
fn read_bytes(path: &str) -> Vec<u8> {
let mut f = File::open(path).unwrap();
let metadata = fs::metadata(path).unwrap();
let mut buffer = vec![0; metadata.len() as usize];
f.read(&mut buffer).expect("Buffer overload");
return buffer;
}
fn read_file(path: &String) -> Result<String, std::io::Error> {
return std::fs::read_to_string(path);
}
fn inject_websocket_into(mut content: String, args: Args) -> String {
let insert_index = content.find("<head>").unwrap() + 6;
let injected_ws_file = WS_FILE.replace("<address>", args.address.as_str())
.replace("<port>", args.port.to_string().as_str());
content.insert_str(
insert_index,
format!("<script>{}</script>", injected_ws_file).as_str(),
);
return content;
}