1use axum::extract::ws::{Message, WebSocket};
2use futures_util::SinkExt;
3use std::collections::HashMap;
4use std::path::PathBuf;
5use std::sync::{Arc, RwLock};
6use tokio::sync::broadcast;
7
8pub struct ConfigState {
9 pub tx: broadcast::Sender<String>,
10 pub write_lock: tokio::sync::Mutex<()>,
11}
12
13impl Default for ConfigState {
14 fn default() -> Self {
15 Self::new()
16 }
17}
18
19impl ConfigState {
20 pub fn new() -> Self {
21 let (tx, _) = broadcast::channel::<String>(64);
22 spawn_watcher(tx.clone());
23 Self {
24 tx,
25 write_lock: tokio::sync::Mutex::new(()),
26 }
27 }
28}
29
30fn blit_config_dir() -> PathBuf {
31 #[cfg(unix)]
32 let base = std::env::var("XDG_CONFIG_HOME")
33 .map(PathBuf::from)
34 .unwrap_or_else(|_| {
35 let home = std::env::var("HOME").unwrap_or_else(|_| "/root".into());
36 PathBuf::from(home).join(".config")
37 });
38 #[cfg(windows)]
39 let base = std::env::var("APPDATA")
40 .map(PathBuf::from)
41 .unwrap_or_else(|_| PathBuf::from(r"C:\ProgramData"));
42 base.join("blit")
43}
44
45pub fn config_path() -> PathBuf {
46 if let Ok(p) = std::env::var("BLIT_CONFIG") {
47 return PathBuf::from(p);
48 }
49 blit_config_dir().join("blit.conf")
50}
51
52pub fn remotes_path() -> PathBuf {
53 if let Ok(p) = std::env::var("BLIT_REMOTES") {
54 return PathBuf::from(p);
55 }
56 blit_config_dir().join("blit.remotes")
57}
58
59pub fn read_config() -> HashMap<String, String> {
60 let path = config_path();
61 let contents = match std::fs::read_to_string(&path) {
62 Ok(c) => c,
63 Err(_) => return HashMap::new(),
64 };
65 parse_config_str(&contents)
66}
67
68pub fn read_remotes() -> Vec<(String, String)> {
71 let path = remotes_path();
72 let contents = match std::fs::read_to_string(&path) {
73 Ok(c) => c,
74 Err(_) => {
75 let default = vec![("local".to_string(), "local".to_string())];
76 write_remotes(&default);
77 return default;
78 }
79 };
80 parse_remotes_str(&contents)
81}
82
83pub fn parse_remotes_str(contents: &str) -> Vec<(String, String)> {
87 let mut order: Vec<String> = Vec::new();
90 let mut map: HashMap<String, String> = HashMap::new();
91 for line in contents.lines() {
92 let line = line.trim();
93 if line.is_empty() || line.starts_with('#') {
94 continue;
95 }
96 if let Some((k, v)) = line.split_once('=') {
97 let k = k.trim().to_string();
98 let v = v.trim().to_string();
99 if !k.is_empty() && !v.is_empty() {
100 if !map.contains_key(&k) {
101 order.push(k.clone());
102 }
103 map.insert(k, v);
104 }
105 }
106 }
107 order
108 .into_iter()
109 .map(|k| {
110 let v = map.remove(&k).unwrap();
111 (k, v)
112 })
113 .collect()
114}
115
116fn serialize_remotes(entries: &[(String, String)]) -> String {
117 let mut out = String::new();
118 for (k, v) in entries {
119 out.push_str(k);
120 out.push_str(" = ");
121 out.push_str(v);
122 out.push('\n');
123 }
124 out
125}
126
127pub fn write_remotes(entries: &[(String, String)]) {
129 let path = remotes_path();
130 if let Some(parent) = path.parent() {
131 let _ = std::fs::create_dir_all(parent);
132 }
133 let contents = serialize_remotes(entries);
134 write_secret_file(&path, &contents);
135}
136
137fn write_secret_file(path: &PathBuf, contents: &str) {
141 #[cfg(unix)]
142 {
143 use std::os::unix::fs::OpenOptionsExt;
144 let tmp = path.with_extension("tmp");
146 let result = std::fs::OpenOptions::new()
147 .write(true)
148 .create(true)
149 .truncate(true)
150 .mode(0o600)
151 .open(&tmp)
152 .and_then(|mut f| {
153 use std::io::Write;
154 f.write_all(contents.as_bytes())
155 });
156 if result.is_ok() {
157 let _ = std::fs::rename(&tmp, path);
158 } else {
159 let _ = std::fs::remove_file(&tmp);
160 }
161 }
162 #[cfg(not(unix))]
163 {
164 let _ = std::fs::write(path, contents);
165 }
166}
167
168fn serialize_config_str(map: &HashMap<String, String>) -> String {
169 let mut lines: Vec<String> = map.iter().map(|(k, v)| format!("{k} = {v}")).collect();
170 lines.sort();
171 lines.push(String::new());
172 lines.join("\n")
173}
174
175pub fn write_config(map: &HashMap<String, String>) {
176 let path = config_path();
177 if let Some(parent) = path.parent() {
178 let _ = std::fs::create_dir_all(parent);
179 }
180 let _ = std::fs::write(&path, serialize_config_str(map));
181}
182
183fn spawn_file_watcher<F>(path: PathBuf, label: &'static str, on_change: F)
186where
187 F: Fn() + Send + 'static,
188{
189 use notify::{RecursiveMode, Watcher};
190
191 if let Some(parent) = path.parent() {
192 let _ = std::fs::create_dir_all(parent);
193 }
194
195 let watch_dir = path.parent().unwrap_or(&path).to_path_buf();
196 let file_name = path.file_name().map(|n| n.to_os_string());
197
198 std::thread::spawn(move || {
199 let (ntx, nrx) = std::sync::mpsc::channel();
200 let mut watcher = match notify::recommended_watcher(ntx) {
201 Ok(w) => w,
202 Err(e) => {
203 eprintln!("blit: {label} watcher failed: {e}");
204 return;
205 }
206 };
207 if let Err(e) = watcher.watch(&watch_dir, RecursiveMode::NonRecursive) {
208 eprintln!("blit: {label} watch failed: {e}");
209 return;
210 }
211 loop {
212 match nrx.recv() {
213 Ok(Ok(event)) => {
214 if matches!(event.kind, notify::EventKind::Access(_)) {
215 continue;
216 }
217 let matches = file_name
218 .as_ref()
219 .is_none_or(|name| event.paths.iter().any(|p| p.file_name() == Some(name)));
220 if matches {
221 on_change();
222 }
223 }
224 Ok(Err(_)) => continue,
225 Err(_) => break,
226 }
227 }
228 });
229}
230
231fn spawn_watcher(tx: broadcast::Sender<String>) {
232 let path = config_path();
233 spawn_file_watcher(path, "config", move || {
234 let map = read_config();
235 for (k, v) in &map {
236 let _ = tx.send(format!("{k}={v}"));
237 }
238 let _ = tx.send("ready".into());
239 });
240}
241
242#[derive(Clone)]
253pub struct RemotesState {
254 inner: Arc<RemotesInner>,
255}
256
257struct RemotesInner {
258 contents: RwLock<String>,
260 tx: broadcast::Sender<String>,
261}
262
263impl RemotesState {
264 pub fn new() -> Self {
266 let (tx, _) = broadcast::channel(64);
267 let inner = Arc::new(RemotesInner {
268 contents: RwLock::new(serialize_remotes(&read_remotes())),
269 tx,
270 });
271 let watcher_inner = inner.clone();
272 spawn_file_watcher(remotes_path(), "remotes", move || {
273 let text = std::fs::read_to_string(remotes_path()).unwrap_or_default();
276 *watcher_inner.contents.write().unwrap() = text.clone();
277 let _ = watcher_inner.tx.send(text);
278 });
279 Self { inner }
280 }
281
282 pub fn ephemeral(initial: String) -> Self {
286 let (tx, _) = broadcast::channel(64);
287 Self {
288 inner: Arc::new(RemotesInner {
289 contents: RwLock::new(initial),
290 tx,
291 }),
292 }
293 }
294
295 pub fn get(&self) -> String {
297 self.inner.contents.read().unwrap().clone()
298 }
299
300 pub fn set(&self, entries: &[(String, String)]) {
302 write_remotes(entries);
303 let text = serialize_remotes(entries);
304 *self.inner.contents.write().unwrap() = text.clone();
305 let _ = self.inner.tx.send(text);
306 }
307
308 pub fn subscribe(&self) -> broadcast::Receiver<String> {
309 self.inner.tx.subscribe()
310 }
311}
312
313impl Default for RemotesState {
314 fn default() -> Self {
315 Self::new()
316 }
317}
318
319fn constant_time_eq(a: &[u8], b: &[u8]) -> bool {
320 if a.len() != b.len() {
321 return false;
322 }
323 let mut diff = 0u8;
324 for (x, y) in a.iter().zip(b.iter()) {
325 diff |= x ^ y;
326 }
327 std::hint::black_box(diff) == 0
328}
329
330fn parse_config_str(contents: &str) -> HashMap<String, String> {
331 let mut map = HashMap::new();
332 for line in contents.lines() {
333 let line = line.trim();
334 if line.is_empty() || line.starts_with('#') {
335 continue;
336 }
337 if let Some((k, v)) = line.split_once('=') {
338 map.insert(k.trim().to_string(), v.trim().to_string());
339 }
340 }
341 map
342}
343
344pub async fn handle_config_ws(
371 mut ws: WebSocket,
372 token: &str,
373 config: &ConfigState,
374 remotes: Option<&RemotesState>,
375) {
376 let authed = loop {
377 match ws.recv().await {
378 Some(Ok(Message::Text(pass))) => {
379 if constant_time_eq(pass.trim().as_bytes(), token.as_bytes()) {
380 let _ = ws.send(Message::Text("ok".into())).await;
381 break true;
382 } else {
383 let _ = ws.close().await;
384 break false;
385 }
386 }
387 Some(Ok(Message::Ping(d))) => {
388 let _ = ws.send(Message::Pong(d)).await;
389 }
390 _ => break false,
391 }
392 };
393 if !authed {
394 return;
395 }
396
397 let mut remotes_rx = remotes.map(|r| r.subscribe());
399
400 let remotes_text = remotes.map(|r| r.get()).unwrap_or_default();
403 if ws
404 .send(Message::Text(format!("remotes:{remotes_text}").into()))
405 .await
406 .is_err()
407 {
408 return;
409 }
410
411 let map = read_config();
412 for (k, v) in &map {
413 if ws
414 .send(Message::Text(format!("{k}={v}").into()))
415 .await
416 .is_err()
417 {
418 return;
419 }
420 }
421 if ws.send(Message::Text("ready".into())).await.is_err() {
422 return;
423 }
424
425 let mut config_rx = config.tx.subscribe();
426
427 loop {
428 tokio::select! {
432 msg = ws.recv() => {
433 match msg {
434 Some(Ok(Message::Text(text))) => {
435 let text = text.trim();
436 if let Some(rest) = text.strip_prefix("set ")
437 && let Some((k, v)) = rest.split_once(' ') {
438 let _guard = config.write_lock.lock().await;
439 let mut map = read_config();
440 let k = k.trim().replace(['\n', '\r'], "");
441 let v = v.trim().replace(['\n', '\r'], "");
442 if k.is_empty() { continue; }
443 if v.is_empty() {
444 map.remove(&k);
445 } else {
446 map.insert(k, v);
447 }
448 write_config(&map);
449 } else if let Some(rest) = text.strip_prefix("remotes-add ") {
450 if let Some((raw_name, raw_uri)) = rest.split_once(' ') {
453 let name = raw_name.trim().replace(['\n', '\r'], "");
454 let uri = raw_uri.trim().replace(['\n', '\r'], "");
455 if !name.is_empty()
456 && !name.contains('=')
457 && !uri.is_empty()
458 && let Some(r) = remotes
459 {
460 let mut entries = parse_remotes_str(&r.get());
461 if let Some(pos) = entries.iter().position(|(n, _)| n == &name) {
462 entries[pos].1 = uri;
463 } else {
464 entries.push((name, uri));
465 }
466 r.set(&entries);
467 }
468 }
469 } else if let Some(name) = text.strip_prefix("remotes-remove ") {
470 let name = name.trim().replace(['\n', '\r'], "");
471 if !name.is_empty()
472 && let Some(r) = remotes
473 {
474 let mut entries = parse_remotes_str(&r.get());
475 entries.retain(|(n, _)| n != &name);
476 r.set(&entries);
477 }
478 } else if let Some(name) = text.strip_prefix("remotes-set-default ") {
479 let name = name.trim().replace(['\n', '\r'], "");
481 let _guard = config.write_lock.lock().await;
482 let mut map = read_config();
483 if name.is_empty() || name == "local" {
484 map.remove("target");
485 } else {
486 map.insert("target".into(), name);
487 }
488 write_config(&map);
489 } else if let Some(rest) = text.strip_prefix("remotes-reorder ") {
490 if let Some(r) = remotes {
493 let desired: Vec<String> = rest
494 .split_whitespace()
495 .map(|s| s.replace(['\n', '\r'], ""))
496 .filter(|s| !s.is_empty())
497 .collect();
498 if !desired.is_empty() {
499 let entries = parse_remotes_str(&r.get());
500 let map: std::collections::HashMap<&str, &str> = entries
502 .iter()
503 .map(|(n, u)| (n.as_str(), u.as_str()))
504 .collect();
505 let mut reordered: Vec<(String, String)> = desired
507 .iter()
508 .filter_map(|n| {
509 map.get(n.as_str())
510 .map(|u| (n.clone(), u.to_string()))
511 })
512 .collect();
513 let desired_set: std::collections::HashSet<&str> =
516 desired.iter().map(|s| s.as_str()).collect();
517 for (n, u) in &entries {
518 if !desired_set.contains(n.as_str()) {
519 reordered.push((n.clone(), u.clone()));
520 }
521 }
522 r.set(&reordered);
523 }
524 }
525 }
526 }
527 Some(Ok(Message::Close(_))) | None => break,
528 Some(Err(_)) => break,
529 _ => continue,
530 }
531 }
532 broadcast = config_rx.recv() => {
533 match broadcast {
534 Ok(line) => {
535 if ws.send(Message::Text(line.into())).await.is_err() {
536 break;
537 }
538 }
539 Err(tokio::sync::broadcast::error::RecvError::Lagged(_)) => continue,
540 Err(_) => break,
541 }
542 }
543 remotes_update = async {
544 match remotes_rx.as_mut() {
545 Some(rx) => rx.recv().await,
546 None => std::future::pending().await,
547 }
548 } => {
549 match remotes_update {
550 Ok(text) => {
551 if ws
552 .send(Message::Text(format!("remotes:{text}").into()))
553 .await
554 .is_err()
555 {
556 break;
557 }
558 }
559 Err(tokio::sync::broadcast::error::RecvError::Lagged(_)) => {
560 if let Some(r) = remotes {
562 let text = r.get();
563 if ws
564 .send(Message::Text(format!("remotes:{text}").into()))
565 .await
566 .is_err()
567 {
568 break;
569 }
570 }
571 }
572 Err(_) => break,
573 }
574 }
575 }
576 }
577}
578
579#[cfg(test)]
580mod tests {
581 use super::*;
582
583 #[test]
586 fn ct_eq_equal_slices() {
587 assert!(constant_time_eq(b"hello", b"hello"));
588 }
589
590 #[test]
591 fn ct_eq_different_slices() {
592 assert!(!constant_time_eq(b"hello", b"world"));
593 }
594
595 #[test]
596 fn ct_eq_different_lengths() {
597 assert!(!constant_time_eq(b"short", b"longer"));
598 }
599
600 #[test]
601 fn ct_eq_empty_slices() {
602 assert!(constant_time_eq(b"", b""));
603 }
604
605 #[test]
606 fn ct_eq_single_bit_diff() {
607 assert!(!constant_time_eq(b"\x00", b"\x01"));
608 }
609
610 #[test]
611 fn ct_eq_one_empty_one_not() {
612 assert!(!constant_time_eq(b"", b"x"));
613 }
614
615 #[test]
618 fn parse_empty_string() {
619 let map = parse_config_str("");
620 assert!(map.is_empty());
621 }
622
623 #[test]
624 fn parse_comments_and_blanks() {
625 let map = parse_config_str("# comment\n\n # another\n");
626 assert!(map.is_empty());
627 }
628
629 #[test]
630 fn parse_key_value() {
631 let map = parse_config_str("font = Menlo\ntheme = dark\n");
632 assert_eq!(map.get("font").unwrap(), "Menlo");
633 assert_eq!(map.get("theme").unwrap(), "dark");
634 }
635
636 #[test]
637 fn parse_trims_whitespace() {
638 let map = parse_config_str(" key = value ");
639 assert_eq!(map.get("key").unwrap(), "value");
640 }
641
642 #[test]
643 fn parse_line_without_equals() {
644 let map = parse_config_str("no-equals-here\nkey=val");
645 assert_eq!(map.len(), 1);
646 assert_eq!(map.get("key").unwrap(), "val");
647 }
648
649 #[test]
650 fn parse_equals_in_value() {
651 let map = parse_config_str("cmd = a=b=c");
652 assert_eq!(map.get("cmd").unwrap(), "a=b=c");
653 }
654
655 #[test]
656 fn parse_duplicate_keys_last_wins() {
657 let map = parse_config_str("key = first\nkey = second");
658 assert_eq!(map.get("key").unwrap(), "second");
659 }
660
661 #[test]
662 fn parse_mixed_content() {
663 let input = "# header\nfont = FiraCode\n\n# size\nsize = 14\ntheme=light";
664 let map = parse_config_str(input);
665 assert_eq!(map.len(), 3);
666 assert_eq!(map.get("font").unwrap(), "FiraCode");
667 assert_eq!(map.get("size").unwrap(), "14");
668 assert_eq!(map.get("theme").unwrap(), "light");
669 }
670
671 #[test]
674 fn serialize_config_produces_sorted_output() {
675 let mut map: HashMap<String, String> = HashMap::new();
676 map.insert("z".into(), "last".into());
677 map.insert("a".into(), "first".into());
678 let output = serialize_config_str(&map);
679 assert!(output.starts_with("a = first"));
680 assert!(output.contains("z = last"));
681 }
682
683 #[test]
684 fn round_trip_parse_serialize() {
685 let input = "alpha = 1\nbeta = 2\ngamma = 3";
686 let map = parse_config_str(input);
687 let serialized = serialize_config_str(&map);
688 let reparsed = parse_config_str(&serialized);
689 assert_eq!(map, reparsed);
690 }
691
692 #[test]
695 fn remotes_add_new_entry() {
696 let state = RemotesState::ephemeral(String::new());
697 let mut entries = parse_remotes_str(&state.get());
698 entries.push(("rabbit".to_string(), "ssh:rabbit".to_string()));
699 state.set(&entries);
700 let got = parse_remotes_str(&state.get());
701 assert_eq!(got.len(), 1);
702 assert_eq!(got[0], ("rabbit".to_string(), "ssh:rabbit".to_string()));
703 }
704
705 #[test]
706 fn remotes_add_updates_existing() {
707 let initial = "rabbit = ssh:rabbit\n";
708 let state = RemotesState::ephemeral(initial.to_string());
709 let mut entries = parse_remotes_str(&state.get());
710 if let Some(pos) = entries.iter().position(|(n, _)| n == "rabbit") {
711 entries[pos].1 = "tcp:rabbit:3264".to_string();
712 }
713 state.set(&entries);
714 let got = parse_remotes_str(&state.get());
715 assert_eq!(got.len(), 1);
716 assert_eq!(got[0].1, "tcp:rabbit:3264");
717 }
718
719 #[test]
720 fn remotes_remove_existing() {
721 let initial = "rabbit = ssh:rabbit\nhound = ssh:hound\n";
722 let state = RemotesState::ephemeral(initial.to_string());
723 let mut entries = parse_remotes_str(&state.get());
724 entries.retain(|(n, _)| n != "rabbit");
725 state.set(&entries);
726 let got = parse_remotes_str(&state.get());
727 assert_eq!(got.len(), 1);
728 assert_eq!(got[0].0, "hound");
729 }
730
731 #[test]
732 fn remotes_remove_nonexistent_is_noop() {
733 let initial = "rabbit = ssh:rabbit\n";
734 let state = RemotesState::ephemeral(initial.to_string());
735 let mut entries = parse_remotes_str(&state.get());
736 let before = entries.len();
737 entries.retain(|(n, _)| n != "does-not-exist");
738 assert_eq!(entries.len(), before);
739 }
740
741 #[test]
742 fn remotes_add_rejects_empty_name() {
743 let name = "";
745 assert!(name.is_empty() || name.contains('='));
746 }
747
748 #[test]
749 fn remotes_add_rejects_name_with_equals() {
750 let name = "foo=bar";
751 assert!(name.contains('='));
752 }
753
754 #[test]
757 fn set_default_inserts_target_key() {
758 let mut map = parse_config_str("font = Mono\n");
759 map.insert("target".into(), "rabbit".into());
760 let serialized = serialize_config_str(&map);
761 let reparsed = parse_config_str(&serialized);
762 assert_eq!(reparsed.get("target").map(|s| s.as_str()), Some("rabbit"));
763 assert_eq!(reparsed.get("font").map(|s| s.as_str()), Some("Mono"));
764 }
765
766 #[test]
767 fn set_default_local_removes_target_key() {
768 let mut map = parse_config_str("target = rabbit\nfont = Mono\n");
769 map.remove("target");
771 let serialized = serialize_config_str(&map);
772 let reparsed = parse_config_str(&serialized);
773 assert!(!reparsed.contains_key("target"));
774 assert_eq!(reparsed.get("font").map(|s| s.as_str()), Some("Mono"));
775 }
776}