1use std::{
2 fs::{self, File},
3 io::Read,
4 sync::Arc,
5};
6
7use axum::{
8 extract::{ws::WebSocket, Query, WebSocketUpgrade},
9 http::{HeaderMap, StatusCode, Uri},
10 response::{Html, IntoResponse, Response},
11 routing::{get, post},
12 Extension, Router,
13};
14use reqwest::header;
15use serde::Deserialize;
16use tokio::sync::Mutex;
17
18const WS_FILE: &str = include_str!("../res/ws.js");
19
20#[derive(Deserialize)]
21struct LoadParameter {
22 path: String,
23}
24
25struct SocketState {
26 subscriber: Mutex<Vec<WebSocket>>,
27}
28
29impl SocketState {
30 fn new() -> Self {
31 Self {
32 subscriber: Mutex::new(Vec::default()),
33 }
34 }
35}
36
37#[derive(Clone)]
38struct Args {
39 directory: String,
40 address: String,
41 port: u16,
42 err_callback: fn(err: String) -> ()
43}
44
45pub struct Handler {
46 client: reqwest::Client,
47 address: String,
48 port: u16,
49}
50
51impl Handler {
52 fn url(&self, path: &str) -> String {
53 format!("http://{}:{}{}", self.address, self.port, path)
54 }
55
56 fn new(address: String, port: u16) -> Self {
57 return Self {
58 client: reqwest::Client::new(),
59 address,
60 port,
61 };
62 }
63
64 pub async fn refresh(&self) -> Result<reqwest::Response, reqwest::Error> {
65 return self.client.post(self.url("/ws/refresh")).send().await;
66 }
67
68 pub async fn load(&self, path: &str) -> Result<reqwest::Response, reqwest::Error> {
69 let url = format!("{}{}", self.url("/ws/load?path="), r#path);
70 return self.client.post(url).send().await;
71 }
72}
73
74pub fn listen(directory: &str, address: &str, port: u16, err_callback: Option<fn(err: String) -> ()>) -> Handler {
75 let args = Args {
76 directory: directory.to_string(),
77 port, address: address.to_string(),
78 err_callback: err_callback.unwrap_or(|_| {})
79 };
80 let state = Arc::new(SocketState::new());
81 let app: Router = Router::new()
82 .route("/ws/connect", get(ws_handler))
83 .route("/ws/refresh", post(refresh_handler))
84 .route("/ws/load", post(load_handler))
85 .layer(Extension(state))
86 .fallback(get(|uri: Uri| async move {
87 fallback_handler(uri, args).await
88 }));
89
90 let addr = format!("{}:{}", address.replace("localhost", "127.0.0.1"), port);
91
92 tokio::spawn(async move {
93 axum::Server::bind(&addr.parse().unwrap())
94 .serve(app.into_make_service())
95 .await
96 .unwrap();
97 });
98
99 return Handler::new(address.to_string(), port);
100}
101
102async fn ws_handler(
103 event: WebSocketUpgrade,
104 Extension(state): Extension<Arc<SocketState>>,
105) -> axum::response::Response {
106 event.on_upgrade(|socket: WebSocket| async move {
107 let mut subscriber = state.subscriber.lock().await;
108 subscriber.push(socket);
109 })
110}
111
112async fn refresh_handler(Extension(state): Extension<Arc<SocketState>>) -> String {
113 for subscriber in state.subscriber.lock().await.iter_mut() {
114 subscriber.send("".into()).await.unwrap();
115 }
116
117 "refreshing".to_string()
118}
119
120async fn load_handler(
121 Extension(state): Extension<Arc<SocketState>>,
122 paramerter: Query<LoadParameter>,
123) -> String {
124 for subscriber in state.subscriber.lock().await.iter_mut() {
125 subscriber
126 .send(paramerter.path.clone().into())
127 .await
128 .unwrap();
129 }
130
131 "loading".to_string()
132}
133
134async fn fallback_handler(uri: Uri, args: Args) -> Response {
135 let path = uri.path();
136 let transformed_path = transform_path(args.directory.as_str(), path);
137 let extension = transformed_path
138 .split(".")
139 .collect::<Vec<&str>>()
140 .last()
141 .expect("no extension")
142 .clone();
143
144 if is_image(extension) {
145 let content_type = format!("image/{}", extension);
146 let bytes = read_bytes(transformed_path.as_str());
147 (header_with_content_type(content_type.as_str()), bytes).into_response()
148 } else if extension == "css" {
149 let Ok(content) = read_file(&transformed_path) else {
150 (args.err_callback)("not found".to_string());
151 return StatusCode::NOT_FOUND.into_response()
152 };
153
154 (header_with_content_type("text/css"), content).into_response()
155 } else if extension == "js" {
156 let Ok(content) = read_file(&transformed_path) else {
157 return StatusCode::NOT_FOUND.into_response()
158 };
159
160 (header_with_content_type("text/javascript"), content).into_response()
161 } else {
162 let Ok(content) = read_file(&transformed_path) else {
163 (args.err_callback)("not found".to_string());
164 return StatusCode::NOT_FOUND.into_response()
165 };
166
167 let injected_content = inject_websocket_into(content, args);
168 Html(injected_content).into_response()
169 }
170}
171
172fn header_with_content_type(content_type: &str) -> HeaderMap {
173 let mut header_map = HeaderMap::new();
174 header_map.insert(header::CONTENT_TYPE, content_type.parse().unwrap());
175 header_map
176}
177
178fn transform_path(directory: &str, path: &str) -> String {
179 let completed_path = complete_path(path);
180 add_directory_path(&completed_path, directory)
181}
182
183fn complete_path(url: &str) -> String {
184 let with_index_default = if url.ends_with("/") {
185 format!("{}index", url)
186 } else {
187 url.to_string()
188 };
189
190 let with_default_extension = if !with_index_default.contains(".") {
191 return format!("{}.html", with_index_default);
192 } else {
193 with_index_default
194 };
195
196 return with_default_extension.to_string();
197}
198
199fn add_directory_path(path: &str, directory: &str) -> String {
200 format!("{}{}", directory, path)
201}
202
203fn is_image(extension: &str) -> bool {
204 return extension == "png"
205 || extension == "jpg"
206 || extension == "jpeg"
207 || extension == "avif"
208 || extension == "gif"
209 || extension == "svg"
210 || extension == "webp";
211}
212
213fn read_bytes(path: &str) -> Vec<u8> {
214 let mut f = File::open(path).unwrap();
215 let metadata = fs::metadata(path).unwrap();
216 let mut buffer = vec![0; metadata.len() as usize];
217 f.read(&mut buffer).expect("Buffer overload");
218 return buffer;
219}
220
221fn read_file(path: &String) -> Result<String, std::io::Error> {
222 return std::fs::read_to_string(path);
223}
224
225fn inject_websocket_into(mut content: String, args: Args) -> String {
226 let insert_index = content.find("<head>").unwrap() + 6;
227 let injected_ws_file = WS_FILE.replace("<address>", args.address.as_str())
228 .replace("<port>", args.port.to_string().as_str());
229
230 content.insert_str(
231 insert_index,
232 format!("<script>{}</script>", injected_ws_file).as_str(),
233 );
234
235 return content;
236}