Skip to main content

blit_webserver/
config.rs

1use axum::extract::ws::{Message, WebSocket};
2use futures_util::SinkExt;
3use std::collections::HashMap;
4use std::path::PathBuf;
5use tokio::sync::broadcast;
6
7pub struct ConfigState {
8    pub tx: broadcast::Sender<String>,
9    pub write_lock: tokio::sync::Mutex<()>,
10}
11
12impl Default for ConfigState {
13    fn default() -> Self {
14        Self::new()
15    }
16}
17
18impl ConfigState {
19    pub fn new() -> Self {
20        let (tx, _) = broadcast::channel::<String>(64);
21        spawn_watcher(tx.clone());
22        Self {
23            tx,
24            write_lock: tokio::sync::Mutex::new(()),
25        }
26    }
27}
28
29pub fn config_path() -> PathBuf {
30    if let Ok(p) = std::env::var("BLIT_CONFIG") {
31        return PathBuf::from(p);
32    }
33    #[cfg(unix)]
34    let base = std::env::var("XDG_CONFIG_HOME")
35        .map(PathBuf::from)
36        .unwrap_or_else(|_| {
37            let home = std::env::var("HOME").unwrap_or_else(|_| "/root".into());
38            PathBuf::from(home).join(".config")
39        });
40    #[cfg(windows)]
41    let base = std::env::var("APPDATA")
42        .map(PathBuf::from)
43        .unwrap_or_else(|_| PathBuf::from(r"C:\ProgramData"));
44    base.join("blit").join("blit.conf")
45}
46
47pub fn read_config() -> HashMap<String, String> {
48    let mut map = HashMap::new();
49    let path = config_path();
50    let contents = match std::fs::read_to_string(&path) {
51        Ok(c) => c,
52        Err(_) => return map,
53    };
54    for line in contents.lines() {
55        let line = line.trim();
56        if line.is_empty() || line.starts_with('#') {
57            continue;
58        }
59        if let Some((k, v)) = line.split_once('=') {
60            map.insert(k.trim().to_string(), v.trim().to_string());
61        }
62    }
63    map
64}
65
66pub fn write_config(map: &HashMap<String, String>) {
67    let path = config_path();
68    if let Some(parent) = path.parent() {
69        let _ = std::fs::create_dir_all(parent);
70    }
71    let mut lines: Vec<String> = map.iter().map(|(k, v)| format!("{k} = {v}")).collect();
72    lines.sort();
73    lines.push(String::new());
74    let _ = std::fs::write(&path, lines.join("\n"));
75}
76
77fn spawn_watcher(tx: broadcast::Sender<String>) {
78    use notify::{RecursiveMode, Watcher};
79
80    let path = config_path();
81    if let Some(parent) = path.parent() {
82        let _ = std::fs::create_dir_all(parent);
83    }
84
85    let watch_dir = path.parent().unwrap_or(&path).to_path_buf();
86    let file_name = path.file_name().map(|n| n.to_os_string());
87
88    std::thread::spawn(move || {
89        let (ntx, nrx) = std::sync::mpsc::channel();
90        let mut watcher = match notify::recommended_watcher(ntx) {
91            Ok(w) => w,
92            Err(e) => {
93                eprintln!("blit: config watcher failed: {e}");
94                return;
95            }
96        };
97        if let Err(e) = watcher.watch(&watch_dir, RecursiveMode::NonRecursive) {
98            eprintln!("blit: config watch failed: {e}");
99            return;
100        }
101        loop {
102            match nrx.recv() {
103                Ok(Ok(event)) => {
104                    let dominated = file_name
105                        .as_ref()
106                        .is_none_or(|name| event.paths.iter().any(|p| p.file_name() == Some(name)));
107                    if !dominated {
108                        continue;
109                    }
110                    let map = read_config();
111                    for (k, v) in &map {
112                        let _ = tx.send(format!("{k}={v}"));
113                    }
114                    let _ = tx.send("ready".into());
115                }
116                Ok(Err(_)) => continue,
117                Err(_) => break,
118            }
119        }
120    });
121}
122
123fn constant_time_eq(a: &[u8], b: &[u8]) -> bool {
124    if a.len() != b.len() {
125        return false;
126    }
127    let mut diff = 0u8;
128    for (x, y) in a.iter().zip(b.iter()) {
129        diff |= x ^ y;
130    }
131    diff == 0
132}
133
134pub async fn handle_config_ws(mut ws: WebSocket, token: &str, config: &ConfigState) {
135    let authed = loop {
136        match ws.recv().await {
137            Some(Ok(Message::Text(pass))) => {
138                if constant_time_eq(pass.trim().as_bytes(), token.as_bytes()) {
139                    let _ = ws.send(Message::Text("ok".into())).await;
140                    break true;
141                } else {
142                    let _ = ws.close().await;
143                    break false;
144                }
145            }
146            Some(Ok(Message::Ping(d))) => {
147                let _ = ws.send(Message::Pong(d)).await;
148            }
149            _ => break false,
150        }
151    };
152    if !authed {
153        return;
154    }
155
156    let map = read_config();
157    for (k, v) in &map {
158        if ws
159            .send(Message::Text(format!("{k}={v}").into()))
160            .await
161            .is_err()
162        {
163            return;
164        }
165    }
166    if ws.send(Message::Text("ready".into())).await.is_err() {
167        return;
168    }
169
170    let mut config_rx = config.tx.subscribe();
171
172    loop {
173        tokio::select! {
174            msg = ws.recv() => {
175                match msg {
176                    Some(Ok(Message::Text(text))) => {
177                        let text = text.trim();
178                        if let Some(rest) = text.strip_prefix("set ")
179                            && let Some((k, v)) = rest.split_once(' ') {
180                                let _guard = config.write_lock.lock().await;
181                                let mut map = read_config();
182                                let k = k.trim().replace(['\n', '\r'], "");
183                                let v = v.trim().replace(['\n', '\r'], "");
184                                if k.is_empty() { continue; }
185                                if v.is_empty() {
186                                    map.remove(&k);
187                                } else {
188                                    map.insert(k, v);
189                                }
190                                write_config(&map);
191                            }
192                    }
193                    Some(Ok(Message::Close(_))) | None => break,
194                    Some(Err(_)) => break,
195                    _ => continue,
196                }
197            }
198            broadcast = config_rx.recv() => {
199                match broadcast {
200                    Ok(line) => {
201                        if ws.send(Message::Text(line.into())).await.is_err() {
202                            break;
203                        }
204                    }
205                    Err(tokio::sync::broadcast::error::RecvError::Lagged(_)) => continue,
206                    Err(_) => break,
207                }
208            }
209        }
210    }
211}