Skip to main content

blit_webserver/
config.rs

1use axum::extract::ws::{Message, WebSocket};
2use futures_util::SinkExt;
3use std::collections::HashMap;
4use std::path::PathBuf;
5use tokio::sync::broadcast;
6
7pub struct ConfigState {
8    pub tx: broadcast::Sender<String>,
9}
10
11impl ConfigState {
12    pub fn new() -> Self {
13        let (tx, _) = broadcast::channel::<String>(64);
14        spawn_watcher(tx.clone());
15        Self { tx }
16    }
17}
18
19pub fn config_path() -> PathBuf {
20    if let Ok(p) = std::env::var("BLIT_CONFIG") {
21        return PathBuf::from(p);
22    }
23    let base = std::env::var("XDG_CONFIG_HOME")
24        .map(PathBuf::from)
25        .unwrap_or_else(|_| {
26            let home = std::env::var("HOME").unwrap_or_else(|_| "/root".into());
27            PathBuf::from(home).join(".config")
28        });
29    base.join("blit").join("blit.conf")
30}
31
32pub fn read_config() -> HashMap<String, String> {
33    let mut map = HashMap::new();
34    let path = config_path();
35    let contents = match std::fs::read_to_string(&path) {
36        Ok(c) => c,
37        Err(_) => return map,
38    };
39    for line in contents.lines() {
40        let line = line.trim();
41        if line.is_empty() || line.starts_with('#') {
42            continue;
43        }
44        if let Some((k, v)) = line.split_once('=') {
45            map.insert(k.trim().to_string(), v.trim().to_string());
46        }
47    }
48    map
49}
50
51pub fn write_config(map: &HashMap<String, String>) {
52    let path = config_path();
53    if let Some(parent) = path.parent() {
54        let _ = std::fs::create_dir_all(parent);
55    }
56    let mut lines: Vec<String> = map.iter().map(|(k, v)| format!("{k} = {v}")).collect();
57    lines.sort();
58    lines.push(String::new());
59    let _ = std::fs::write(&path, lines.join("\n"));
60}
61
62fn spawn_watcher(tx: broadcast::Sender<String>) {
63    use notify::{RecursiveMode, Watcher};
64
65    let path = config_path();
66    if let Some(parent) = path.parent() {
67        let _ = std::fs::create_dir_all(parent);
68    }
69
70    let watch_dir = path.parent().unwrap_or(&path).to_path_buf();
71    let file_name = path.file_name().map(|n| n.to_os_string());
72
73    std::thread::spawn(move || {
74        let (ntx, nrx) = std::sync::mpsc::channel();
75        let mut watcher = match notify::recommended_watcher(ntx) {
76            Ok(w) => w,
77            Err(e) => {
78                eprintln!("blit: config watcher failed: {e}");
79                return;
80            }
81        };
82        if let Err(e) = watcher.watch(&watch_dir, RecursiveMode::NonRecursive) {
83            eprintln!("blit: config watch failed: {e}");
84            return;
85        }
86        loop {
87            match nrx.recv() {
88                Ok(Ok(event)) => {
89                    let dominated = file_name.as_ref().map_or(true, |name| {
90                        event.paths.iter().any(|p| p.file_name() == Some(name))
91                    });
92                    if !dominated {
93                        continue;
94                    }
95                    let map = read_config();
96                    for (k, v) in &map {
97                        let _ = tx.send(format!("{k}={v}"));
98                    }
99                    let _ = tx.send("ready".into());
100                }
101                Ok(Err(_)) => continue,
102                Err(_) => break,
103            }
104        }
105    });
106}
107
108pub async fn handle_config_ws(mut ws: WebSocket, token: &str, config: &ConfigState) {
109    let authed = loop {
110        match ws.recv().await {
111            Some(Ok(Message::Text(pass))) => {
112                if pass.trim() == token {
113                    let _ = ws.send(Message::Text("ok".into())).await;
114                    break true;
115                } else {
116                    let _ = ws.close().await;
117                    break false;
118                }
119            }
120            Some(Ok(Message::Ping(d))) => {
121                let _ = ws.send(Message::Pong(d)).await;
122            }
123            _ => break false,
124        }
125    };
126    if !authed {
127        return;
128    }
129
130    let map = read_config();
131    for (k, v) in &map {
132        if ws
133            .send(Message::Text(format!("{k}={v}").into()))
134            .await
135            .is_err()
136        {
137            return;
138        }
139    }
140    if ws.send(Message::Text("ready".into())).await.is_err() {
141        return;
142    }
143
144    let mut config_rx = config.tx.subscribe();
145
146    loop {
147        tokio::select! {
148            msg = ws.recv() => {
149                match msg {
150                    Some(Ok(Message::Text(text))) => {
151                        let text = text.trim();
152                        if let Some(rest) = text.strip_prefix("set ") {
153                            if let Some((k, v)) = rest.split_once(' ') {
154                                let mut map = read_config();
155                                let k = k.trim().replace(['\n', '\r'], "");
156                                let v = v.trim().replace(['\n', '\r'], "");
157                                if k.is_empty() { continue; }
158                                if v.is_empty() {
159                                    map.remove(&k);
160                                } else {
161                                    map.insert(k, v);
162                                }
163                                write_config(&map);
164                            }
165                        }
166                    }
167                    Some(Ok(Message::Close(_))) | None => break,
168                    Some(Err(_)) => break,
169                    _ => continue,
170                }
171            }
172            broadcast = config_rx.recv() => {
173                match broadcast {
174                    Ok(line) => {
175                        if ws.send(Message::Text(line.into())).await.is_err() {
176                            break;
177                        }
178                    }
179                    Err(tokio::sync::broadcast::error::RecvError::Lagged(_)) => continue,
180                    Err(_) => break,
181                }
182            }
183        }
184    }
185}