Skip to main content

blit_webserver/
config.rs

1use axum::extract::ws::{Message, WebSocket};
2use futures_util::SinkExt;
3use std::collections::HashMap;
4use std::path::PathBuf;
5use tokio::sync::broadcast;
6
7pub struct ConfigState {
8    pub tx: broadcast::Sender<String>,
9    pub write_lock: tokio::sync::Mutex<()>,
10}
11
12impl ConfigState {
13    pub fn new() -> Self {
14        let (tx, _) = broadcast::channel::<String>(64);
15        spawn_watcher(tx.clone());
16        Self {
17            tx,
18            write_lock: tokio::sync::Mutex::new(()),
19        }
20    }
21}
22
23pub fn config_path() -> PathBuf {
24    if let Ok(p) = std::env::var("BLIT_CONFIG") {
25        return PathBuf::from(p);
26    }
27    let base = std::env::var("XDG_CONFIG_HOME")
28        .map(PathBuf::from)
29        .unwrap_or_else(|_| {
30            let home = std::env::var("HOME").unwrap_or_else(|_| "/root".into());
31            PathBuf::from(home).join(".config")
32        });
33    base.join("blit").join("blit.conf")
34}
35
36pub fn read_config() -> HashMap<String, String> {
37    let mut map = HashMap::new();
38    let path = config_path();
39    let contents = match std::fs::read_to_string(&path) {
40        Ok(c) => c,
41        Err(_) => return map,
42    };
43    for line in contents.lines() {
44        let line = line.trim();
45        if line.is_empty() || line.starts_with('#') {
46            continue;
47        }
48        if let Some((k, v)) = line.split_once('=') {
49            map.insert(k.trim().to_string(), v.trim().to_string());
50        }
51    }
52    map
53}
54
55pub fn write_config(map: &HashMap<String, String>) {
56    let path = config_path();
57    if let Some(parent) = path.parent() {
58        let _ = std::fs::create_dir_all(parent);
59    }
60    let mut lines: Vec<String> = map.iter().map(|(k, v)| format!("{k} = {v}")).collect();
61    lines.sort();
62    lines.push(String::new());
63    let _ = std::fs::write(&path, lines.join("\n"));
64}
65
66fn spawn_watcher(tx: broadcast::Sender<String>) {
67    use notify::{RecursiveMode, Watcher};
68
69    let path = config_path();
70    if let Some(parent) = path.parent() {
71        let _ = std::fs::create_dir_all(parent);
72    }
73
74    let watch_dir = path.parent().unwrap_or(&path).to_path_buf();
75    let file_name = path.file_name().map(|n| n.to_os_string());
76
77    std::thread::spawn(move || {
78        let (ntx, nrx) = std::sync::mpsc::channel();
79        let mut watcher = match notify::recommended_watcher(ntx) {
80            Ok(w) => w,
81            Err(e) => {
82                eprintln!("blit: config watcher failed: {e}");
83                return;
84            }
85        };
86        if let Err(e) = watcher.watch(&watch_dir, RecursiveMode::NonRecursive) {
87            eprintln!("blit: config watch failed: {e}");
88            return;
89        }
90        loop {
91            match nrx.recv() {
92                Ok(Ok(event)) => {
93                    let dominated = file_name.as_ref().map_or(true, |name| {
94                        event.paths.iter().any(|p| p.file_name() == Some(name))
95                    });
96                    if !dominated {
97                        continue;
98                    }
99                    let map = read_config();
100                    for (k, v) in &map {
101                        let _ = tx.send(format!("{k}={v}"));
102                    }
103                    let _ = tx.send("ready".into());
104                }
105                Ok(Err(_)) => continue,
106                Err(_) => break,
107            }
108        }
109    });
110}
111
112fn constant_time_eq(a: &[u8], b: &[u8]) -> bool {
113    if a.len() != b.len() {
114        return false;
115    }
116    let mut diff = 0u8;
117    for (x, y) in a.iter().zip(b.iter()) {
118        diff |= x ^ y;
119    }
120    diff == 0
121}
122
123pub async fn handle_config_ws(mut ws: WebSocket, token: &str, config: &ConfigState) {
124    let authed = loop {
125        match ws.recv().await {
126            Some(Ok(Message::Text(pass))) => {
127                if constant_time_eq(pass.trim().as_bytes(), token.as_bytes()) {
128                    let _ = ws.send(Message::Text("ok".into())).await;
129                    break true;
130                } else {
131                    let _ = ws.close().await;
132                    break false;
133                }
134            }
135            Some(Ok(Message::Ping(d))) => {
136                let _ = ws.send(Message::Pong(d)).await;
137            }
138            _ => break false,
139        }
140    };
141    if !authed {
142        return;
143    }
144
145    let map = read_config();
146    for (k, v) in &map {
147        if ws
148            .send(Message::Text(format!("{k}={v}").into()))
149            .await
150            .is_err()
151        {
152            return;
153        }
154    }
155    if ws.send(Message::Text("ready".into())).await.is_err() {
156        return;
157    }
158
159    let mut config_rx = config.tx.subscribe();
160
161    loop {
162        tokio::select! {
163            msg = ws.recv() => {
164                match msg {
165                    Some(Ok(Message::Text(text))) => {
166                        let text = text.trim();
167                        if let Some(rest) = text.strip_prefix("set ") {
168                            if let Some((k, v)) = rest.split_once(' ') {
169                                let _guard = config.write_lock.lock().await;
170                                let mut map = read_config();
171                                let k = k.trim().replace(['\n', '\r'], "");
172                                let v = v.trim().replace(['\n', '\r'], "");
173                                if k.is_empty() { continue; }
174                                if v.is_empty() {
175                                    map.remove(&k);
176                                } else {
177                                    map.insert(k, v);
178                                }
179                                write_config(&map);
180                            }
181                        }
182                    }
183                    Some(Ok(Message::Close(_))) | None => break,
184                    Some(Err(_)) => break,
185                    _ => continue,
186                }
187            }
188            broadcast = config_rx.recv() => {
189                match broadcast {
190                    Ok(line) => {
191                        if ws.send(Message::Text(line.into())).await.is_err() {
192                            break;
193                        }
194                    }
195                    Err(tokio::sync::broadcast::error::RecvError::Lagged(_)) => continue,
196                    Err(_) => break,
197                }
198            }
199        }
200    }
201}