Skip to main content

blit_webserver/
config.rs

1use axum::extract::ws::{Message, WebSocket};
2use futures_util::SinkExt;
3use std::collections::HashMap;
4use std::path::PathBuf;
5use tokio::sync::broadcast;
6
7pub struct ConfigState {
8    pub tx: broadcast::Sender<String>,
9    pub write_lock: tokio::sync::Mutex<()>,
10}
11
12impl Default for ConfigState {
13    fn default() -> Self {
14        Self::new()
15    }
16}
17
18impl ConfigState {
19    pub fn new() -> Self {
20        let (tx, _) = broadcast::channel::<String>(64);
21        spawn_watcher(tx.clone());
22        Self {
23            tx,
24            write_lock: tokio::sync::Mutex::new(()),
25        }
26    }
27}
28
29pub fn config_path() -> PathBuf {
30    if let Ok(p) = std::env::var("BLIT_CONFIG") {
31        return PathBuf::from(p);
32    }
33    let base = std::env::var("XDG_CONFIG_HOME")
34        .map(PathBuf::from)
35        .unwrap_or_else(|_| {
36            let home = std::env::var("HOME").unwrap_or_else(|_| "/root".into());
37            PathBuf::from(home).join(".config")
38        });
39    base.join("blit").join("blit.conf")
40}
41
42pub fn read_config() -> HashMap<String, String> {
43    let mut map = HashMap::new();
44    let path = config_path();
45    let contents = match std::fs::read_to_string(&path) {
46        Ok(c) => c,
47        Err(_) => return map,
48    };
49    for line in contents.lines() {
50        let line = line.trim();
51        if line.is_empty() || line.starts_with('#') {
52            continue;
53        }
54        if let Some((k, v)) = line.split_once('=') {
55            map.insert(k.trim().to_string(), v.trim().to_string());
56        }
57    }
58    map
59}
60
61pub fn write_config(map: &HashMap<String, String>) {
62    let path = config_path();
63    if let Some(parent) = path.parent() {
64        let _ = std::fs::create_dir_all(parent);
65    }
66    let mut lines: Vec<String> = map.iter().map(|(k, v)| format!("{k} = {v}")).collect();
67    lines.sort();
68    lines.push(String::new());
69    let _ = std::fs::write(&path, lines.join("\n"));
70}
71
72fn spawn_watcher(tx: broadcast::Sender<String>) {
73    use notify::{RecursiveMode, Watcher};
74
75    let path = config_path();
76    if let Some(parent) = path.parent() {
77        let _ = std::fs::create_dir_all(parent);
78    }
79
80    let watch_dir = path.parent().unwrap_or(&path).to_path_buf();
81    let file_name = path.file_name().map(|n| n.to_os_string());
82
83    std::thread::spawn(move || {
84        let (ntx, nrx) = std::sync::mpsc::channel();
85        let mut watcher = match notify::recommended_watcher(ntx) {
86            Ok(w) => w,
87            Err(e) => {
88                eprintln!("blit: config watcher failed: {e}");
89                return;
90            }
91        };
92        if let Err(e) = watcher.watch(&watch_dir, RecursiveMode::NonRecursive) {
93            eprintln!("blit: config watch failed: {e}");
94            return;
95        }
96        loop {
97            match nrx.recv() {
98                Ok(Ok(event)) => {
99                    let dominated = file_name.as_ref().is_none_or(|name| {
100                        event.paths.iter().any(|p| p.file_name() == Some(name))
101                    });
102                    if !dominated {
103                        continue;
104                    }
105                    let map = read_config();
106                    for (k, v) in &map {
107                        let _ = tx.send(format!("{k}={v}"));
108                    }
109                    let _ = tx.send("ready".into());
110                }
111                Ok(Err(_)) => continue,
112                Err(_) => break,
113            }
114        }
115    });
116}
117
118fn constant_time_eq(a: &[u8], b: &[u8]) -> bool {
119    if a.len() != b.len() {
120        return false;
121    }
122    let mut diff = 0u8;
123    for (x, y) in a.iter().zip(b.iter()) {
124        diff |= x ^ y;
125    }
126    diff == 0
127}
128
129pub async fn handle_config_ws(mut ws: WebSocket, token: &str, config: &ConfigState) {
130    let authed = loop {
131        match ws.recv().await {
132            Some(Ok(Message::Text(pass))) => {
133                if constant_time_eq(pass.trim().as_bytes(), token.as_bytes()) {
134                    let _ = ws.send(Message::Text("ok".into())).await;
135                    break true;
136                } else {
137                    let _ = ws.close().await;
138                    break false;
139                }
140            }
141            Some(Ok(Message::Ping(d))) => {
142                let _ = ws.send(Message::Pong(d)).await;
143            }
144            _ => break false,
145        }
146    };
147    if !authed {
148        return;
149    }
150
151    let map = read_config();
152    for (k, v) in &map {
153        if ws
154            .send(Message::Text(format!("{k}={v}").into()))
155            .await
156            .is_err()
157        {
158            return;
159        }
160    }
161    if ws.send(Message::Text("ready".into())).await.is_err() {
162        return;
163    }
164
165    let mut config_rx = config.tx.subscribe();
166
167    loop {
168        tokio::select! {
169            msg = ws.recv() => {
170                match msg {
171                    Some(Ok(Message::Text(text))) => {
172                        let text = text.trim();
173                        if let Some(rest) = text.strip_prefix("set ") {
174                            if let Some((k, v)) = rest.split_once(' ') {
175                                let _guard = config.write_lock.lock().await;
176                                let mut map = read_config();
177                                let k = k.trim().replace(['\n', '\r'], "");
178                                let v = v.trim().replace(['\n', '\r'], "");
179                                if k.is_empty() { continue; }
180                                if v.is_empty() {
181                                    map.remove(&k);
182                                } else {
183                                    map.insert(k, v);
184                                }
185                                write_config(&map);
186                            }
187                        }
188                    }
189                    Some(Ok(Message::Close(_))) | None => break,
190                    Some(Err(_)) => break,
191                    _ => continue,
192                }
193            }
194            broadcast = config_rx.recv() => {
195                match broadcast {
196                    Ok(line) => {
197                        if ws.send(Message::Text(line.into())).await.is_err() {
198                            break;
199                        }
200                    }
201                    Err(tokio::sync::broadcast::error::RecvError::Lagged(_)) => continue,
202                    Err(_) => break,
203                }
204            }
205        }
206    }
207}