refresh_server/
lib.rs

1use std::{
2    fs::{self, File},
3    io::Read,
4    sync::Arc,
5};
6
7use axum::{
8    extract::{ws::WebSocket, Query, WebSocketUpgrade},
9    http::{HeaderMap, StatusCode, Uri},
10    response::{Html, IntoResponse, Response},
11    routing::{get, post},
12    Extension, Router,
13};
14use reqwest::header;
15use serde::Deserialize;
16use tokio::sync::Mutex;
17
18const WS_FILE: &str = include_str!("../res/ws.js");
19
20#[derive(Deserialize)]
21struct LoadParameter {
22    path: String,
23}
24
25struct SocketState {
26    subscriber: Mutex<Vec<WebSocket>>,
27}
28
29impl SocketState {
30    fn new() -> Self {
31        Self {
32            subscriber: Mutex::new(Vec::default()),
33        }
34    }
35}
36
37#[derive(Clone)]
38struct Args {
39    directory: String,
40    address: String,
41    port: u16,
42    err_callback: fn(err: String) -> ()
43}
44
45pub struct Handler {
46    client: reqwest::Client,
47    address: String,
48    port: u16,
49}
50
51impl Handler {
52    fn url(&self, path: &str) -> String {
53        format!("http://{}:{}{}", self.address, self.port, path)
54    }
55
56    fn new(address: String, port: u16) -> Self {
57        return Self {
58            client: reqwest::Client::new(),
59            address,
60            port,
61        };
62    }
63
64    pub async fn refresh(&self) -> Result<reqwest::Response, reqwest::Error> {
65        return self.client.post(self.url("/ws/refresh")).send().await;
66    }
67
68    pub async fn load(&self, path: &str) -> Result<reqwest::Response, reqwest::Error> {
69        let url = format!("{}{}", self.url("/ws/load?path="), r#path);
70        return self.client.post(url).send().await;
71    }
72}
73
74pub fn listen(directory: &str, address: &str, port: u16, err_callback: Option<fn(err: String) -> ()>) -> Handler {
75    let args = Args {
76        directory: directory.to_string(),
77        port, address: address.to_string(),
78        err_callback: err_callback.unwrap_or(|_| {})  
79    };
80    let state = Arc::new(SocketState::new());
81    let app: Router = Router::new()
82        .route("/ws/connect", get(ws_handler))
83        .route("/ws/refresh", post(refresh_handler))
84        .route("/ws/load", post(load_handler))
85        .layer(Extension(state))
86        .fallback(get(|uri: Uri| async move {
87            fallback_handler(uri, args).await
88        }));
89
90    let addr = format!("{}:{}", address.replace("localhost", "127.0.0.1"), port);
91
92    tokio::spawn(async move {
93        axum::Server::bind(&addr.parse().unwrap())
94            .serve(app.into_make_service())
95            .await
96            .unwrap();
97    });
98
99    return Handler::new(address.to_string(), port);
100}
101
102async fn ws_handler(
103    event: WebSocketUpgrade,
104    Extension(state): Extension<Arc<SocketState>>,
105) -> axum::response::Response {
106    event.on_upgrade(|socket: WebSocket| async move {
107        let mut subscriber = state.subscriber.lock().await;
108        subscriber.push(socket);
109    })
110}
111
112async fn refresh_handler(Extension(state): Extension<Arc<SocketState>>) -> String {
113    for subscriber in state.subscriber.lock().await.iter_mut() {
114        subscriber.send("".into()).await.unwrap();
115    }
116
117    "refreshing".to_string()
118}
119
120async fn load_handler(
121    Extension(state): Extension<Arc<SocketState>>,
122    paramerter: Query<LoadParameter>,
123) -> String {
124    for subscriber in state.subscriber.lock().await.iter_mut() {
125        subscriber
126            .send(paramerter.path.clone().into())
127            .await
128            .unwrap();
129    }
130
131    "loading".to_string()
132}
133
134async fn fallback_handler(uri: Uri, args: Args) -> Response {
135    let path = uri.path();
136    let transformed_path = transform_path(args.directory.as_str(), path);
137    let extension = transformed_path
138        .split(".")
139        .collect::<Vec<&str>>()
140        .last()
141        .expect("no extension")
142        .clone();
143
144    if is_image(extension) {
145        let content_type = format!("image/{}", extension);
146        let bytes = read_bytes(transformed_path.as_str());
147        (header_with_content_type(content_type.as_str()), bytes).into_response()
148    } else if extension == "css" {
149        let Ok(content) = read_file(&transformed_path) else {
150            (args.err_callback)("not found".to_string());
151            return StatusCode::NOT_FOUND.into_response()
152        };
153
154        (header_with_content_type("text/css"), content).into_response()
155    } else if extension == "js" {
156        let Ok(content) = read_file(&transformed_path) else {
157            return StatusCode::NOT_FOUND.into_response()
158        };
159
160        (header_with_content_type("text/javascript"), content).into_response()
161    } else {
162        let Ok(content) = read_file(&transformed_path) else {
163            (args.err_callback)("not found".to_string());
164            return StatusCode::NOT_FOUND.into_response()
165        };
166
167        let injected_content = inject_websocket_into(content, args);
168        Html(injected_content).into_response()
169    }
170}
171
172fn header_with_content_type(content_type: &str) -> HeaderMap {
173    let mut header_map = HeaderMap::new();
174    header_map.insert(header::CONTENT_TYPE, content_type.parse().unwrap());
175    header_map
176}
177
178fn transform_path(directory: &str, path: &str) -> String {
179    let completed_path = complete_path(path);
180    add_directory_path(&completed_path, directory)
181}
182
183fn complete_path(url: &str) -> String {
184    let with_index_default = if url.ends_with("/") {
185        format!("{}index", url)
186    } else {
187        url.to_string()
188    };
189
190    let with_default_extension = if !with_index_default.contains(".") {
191        return format!("{}.html", with_index_default);
192    } else {
193        with_index_default
194    };
195
196    return with_default_extension.to_string();
197}
198
199fn add_directory_path(path: &str, directory: &str) -> String {
200    format!("{}{}", directory, path)
201}
202
203fn is_image(extension: &str) -> bool {
204    return extension == "png"
205        || extension == "jpg"
206        || extension == "jpeg"
207        || extension == "avif"
208        || extension == "gif"
209        || extension == "svg"
210        || extension == "webp";
211}
212
213fn read_bytes(path: &str) -> Vec<u8> {
214    let mut f = File::open(path).unwrap();
215    let metadata = fs::metadata(path).unwrap();
216    let mut buffer = vec![0; metadata.len() as usize];
217    f.read(&mut buffer).expect("Buffer overload");
218    return buffer;
219}
220
221fn read_file(path: &String) -> Result<String, std::io::Error> {
222    return std::fs::read_to_string(path);
223}
224
225fn inject_websocket_into(mut content: String, args: Args) -> String {
226    let insert_index = content.find("<head>").unwrap() + 6;
227    let injected_ws_file = WS_FILE.replace("<address>", args.address.as_str())
228        .replace("<port>", args.port.to_string().as_str());
229
230    content.insert_str(
231        insert_index,
232        format!("<script>{}</script>", injected_ws_file).as_str(),
233    );
234
235    return content;
236}