use std::sync::Arc;
use axum::{
extract::{ws::WebSocket, WebSocketUpgrade, Query},
http::Uri,
response::Html,
routing::{get, post},
Router, Extension,
};
use html_editor::{
operation::{Editable, Htmlifiable, Selector},
Node,
};
use tokio::sync::Mutex;
use serde::Deserialize;
#[derive(Deserialize)]
struct LoadParameter {
path: String
}
struct SocketState {
subscriber: Mutex<Vec<WebSocket>>
}
impl SocketState {
fn new() -> Self {
Self {
subscriber: Mutex::new(Vec::default()),
}
}
}
pub struct Handler {
client: reqwest::Client
}
impl Handler {
fn new() -> Self {
return Self { client: reqwest::Client::new() }
}
pub async fn refresh(&self) -> Result<reqwest::Response, reqwest::Error> {
return self.client.post("http://localhost:3000/ws/refresh").send().await;
}
pub async fn load(&self, path: &str) -> Result<reqwest::Response, reqwest::Error> {
let url = format!("http://localhost:3000/ws/load?path={}", path);
dbg!(&url);
return self.client.post(url).send().await;
}
}
pub fn listen(directory: String, address: String, port: u32) -> Handler {
let state = Arc::new(SocketState::new());
let app: Router = Router::new()
.route("/ws/connect", get(ws_handler))
.route("/ws/refresh", post(refresh_handler))
.route("/ws/load", post(load_handler))
.layer(Extension(state))
.fallback(get(|uri: Uri| async move {
Html(
fetch_file(uri.path().to_string(), &directory).unwrap()
)
}));
let addr = format!("{}:{}", address.replace("localhost", "127.0.0.1"), port);
tokio::spawn(async move {
axum::Server::bind(&addr.parse().unwrap())
.serve(app.into_make_service())
.await
.unwrap();
});
return Handler::new()
}
async fn ws_handler(
event: WebSocketUpgrade,
Extension(state): Extension<Arc<SocketState>>,
) -> axum::response::Response {
event.on_upgrade(|socket: WebSocket| async move {
let mut subscriber = state.subscriber.lock().await;
subscriber.push(socket);
})
}
async fn refresh_handler(Extension(state): Extension<Arc<SocketState>>) -> String {
for subscriber in state.subscriber.lock().await.iter_mut() {
subscriber.send("".into()).await.unwrap();
}
"refreshing".to_string()
}
async fn load_handler(Extension(state): Extension<Arc<SocketState>>, paramerter: Query<LoadParameter>) -> String {
for subscriber in state.subscriber.lock().await.iter_mut() {
subscriber.send(paramerter.path.clone().into()).await.unwrap();
}
"loading".to_string()
}
fn fetch_file(url: String, directory: &String) -> Result<String, &str> {
let path = format!("{}{}", directory, url_to_path(&url));
dbg!(&path);
let file_optional = read_file(path);
let Ok(file) = file_optional else {
return Err("not found");
};
let injected_file = inject_websocket_into(file.clone());
return Ok(injected_file);
}
fn url_to_path(url: &str) -> String {
let with_index_default = if url.ends_with("/") {
format!("{}index", url)
} else {
url.to_string()
};
let with_default_extension = if !with_index_default.contains(".") {
return format!("{}.html", with_index_default);
} else {
with_index_default
};
return with_default_extension.to_string();
}
fn read_file(path: String) -> Result<String, std::io::Error> {
return std::fs::read_to_string(path);
}
fn inject_websocket_into(content: String) -> String {
let ws_file = include_str!("../res/ws.js");
let mut dom = html_editor::parse(&content).unwrap();
let text_script_node = Node::Text(ws_file.to_string());
let script_node = Node::new_element("script", vec![], vec![text_script_node]);
dom.insert_to(&Selector::from("body"), script_node);
return dom.html();
}