1use axum::extract::ws::{Message, WebSocket};
2use futures_util::SinkExt;
3use std::collections::HashMap;
4use std::path::PathBuf;
5use std::sync::{Arc, RwLock};
6use tokio::sync::broadcast;
7
8pub struct ConfigState {
9 pub tx: broadcast::Sender<String>,
10 pub write_lock: tokio::sync::Mutex<()>,
11}
12
13impl Default for ConfigState {
14 fn default() -> Self {
15 Self::new()
16 }
17}
18
19impl ConfigState {
20 pub fn new() -> Self {
21 let (tx, _) = broadcast::channel::<String>(64);
22 spawn_watcher(tx.clone());
23 Self {
24 tx,
25 write_lock: tokio::sync::Mutex::new(()),
26 }
27 }
28}
29
30fn blit_config_dir() -> PathBuf {
31 #[cfg(unix)]
32 let base = std::env::var("XDG_CONFIG_HOME")
33 .map(PathBuf::from)
34 .unwrap_or_else(|_| {
35 let home = std::env::var("HOME").unwrap_or_else(|_| "/root".into());
36 PathBuf::from(home).join(".config")
37 });
38 #[cfg(windows)]
39 let base = std::env::var("APPDATA")
40 .map(PathBuf::from)
41 .unwrap_or_else(|_| PathBuf::from(r"C:\ProgramData"));
42 base.join("blit")
43}
44
45pub fn config_path() -> PathBuf {
46 if let Ok(p) = std::env::var("BLIT_CONFIG") {
47 return PathBuf::from(p);
48 }
49 blit_config_dir().join("blit.conf")
50}
51
52pub fn remotes_path() -> PathBuf {
53 if let Ok(p) = std::env::var("BLIT_REMOTES") {
54 return PathBuf::from(p);
55 }
56 blit_config_dir().join("blit.remotes")
57}
58
59pub fn read_config() -> HashMap<String, String> {
60 let path = config_path();
61 let contents = match std::fs::read_to_string(&path) {
62 Ok(c) => c,
63 Err(_) => return HashMap::new(),
64 };
65 parse_config_str(&contents)
66}
67
68pub fn read_remotes() -> Vec<(String, String)> {
71 let path = remotes_path();
72 let contents = match std::fs::read_to_string(&path) {
73 Ok(c) => c,
74 Err(_) => {
75 let default = vec![("local".to_string(), "local".to_string())];
76 write_remotes(&default);
77 return default;
78 }
79 };
80 parse_remotes_str(&contents)
81}
82
83pub fn parse_remotes_str(contents: &str) -> Vec<(String, String)> {
87 let mut order: Vec<String> = Vec::new();
90 let mut map: HashMap<String, String> = HashMap::new();
91 for line in contents.lines() {
92 let line = line.trim();
93 if line.is_empty() || line.starts_with('#') {
94 continue;
95 }
96 if let Some((k, v)) = line.split_once('=') {
97 let k = k.trim().to_string();
98 let v = v.trim().to_string();
99 if !k.is_empty() && !v.is_empty() {
100 if !map.contains_key(&k) {
101 order.push(k.clone());
102 }
103 map.insert(k, v);
104 }
105 }
106 }
107 order
108 .into_iter()
109 .map(|k| {
110 let v = map.remove(&k).unwrap();
111 (k, v)
112 })
113 .collect()
114}
115
116fn serialize_remotes(entries: &[(String, String)]) -> String {
117 let mut out = String::new();
118 for (k, v) in entries {
119 out.push_str(k);
120 out.push_str(" = ");
121 out.push_str(v);
122 out.push('\n');
123 }
124 out
125}
126
127pub fn write_remotes(entries: &[(String, String)]) {
129 let path = remotes_path();
130 if let Some(parent) = path.parent() {
131 let _ = std::fs::create_dir_all(parent);
132 }
133 let contents = serialize_remotes(entries);
134 write_secret_file(&path, &contents);
135}
136
137fn write_secret_file(path: &PathBuf, contents: &str) {
141 #[cfg(unix)]
142 {
143 use std::os::unix::fs::OpenOptionsExt;
144 let tmp = path.with_extension("tmp");
146 let result = std::fs::OpenOptions::new()
147 .write(true)
148 .create(true)
149 .truncate(true)
150 .mode(0o600)
151 .open(&tmp)
152 .and_then(|mut f| {
153 use std::io::Write;
154 f.write_all(contents.as_bytes())
155 });
156 if result.is_ok() {
157 let _ = std::fs::rename(&tmp, path);
158 } else {
159 let _ = std::fs::remove_file(&tmp);
160 }
161 }
162 #[cfg(not(unix))]
163 {
164 let _ = std::fs::write(path, contents);
165 }
166}
167
168fn serialize_config_str(map: &HashMap<String, String>) -> String {
169 let mut lines: Vec<String> = map.iter().map(|(k, v)| format!("{k} = {v}")).collect();
170 lines.sort();
171 lines.push(String::new());
172 lines.join("\n")
173}
174
175pub fn write_config(map: &HashMap<String, String>) {
176 let path = config_path();
177 if let Some(parent) = path.parent() {
178 let _ = std::fs::create_dir_all(parent);
179 }
180 let _ = std::fs::write(&path, serialize_config_str(map));
181}
182
183fn spawn_file_watcher<F>(path: PathBuf, label: &'static str, on_change: F)
186where
187 F: Fn() + Send + 'static,
188{
189 use notify::{RecursiveMode, Watcher};
190
191 if let Some(parent) = path.parent() {
192 let _ = std::fs::create_dir_all(parent);
193 }
194
195 let watch_dir = path.parent().unwrap_or(&path).to_path_buf();
196 let file_name = path.file_name().map(|n| n.to_os_string());
197
198 std::thread::Builder::new()
199 .name(format!("{label}-watcher"))
200 .spawn(move || {
201 let (ntx, nrx) = std::sync::mpsc::channel();
202 let mut watcher = match notify::recommended_watcher(ntx) {
203 Ok(w) => w,
204 Err(e) => {
205 eprintln!("blit: {label} watcher failed: {e}");
206 return;
207 }
208 };
209 if let Err(e) = watcher.watch(&watch_dir, RecursiveMode::NonRecursive) {
210 eprintln!("blit: {label} watch failed: {e}");
211 return;
212 }
213 loop {
214 match nrx.recv() {
215 Ok(Ok(event)) => {
216 if matches!(event.kind, notify::EventKind::Access(_)) {
217 continue;
218 }
219 let matches = file_name.as_ref().is_none_or(|name| {
220 event.paths.iter().any(|p| p.file_name() == Some(name))
221 });
222 if matches {
223 on_change();
224 }
225 }
226 Ok(Err(_)) => continue,
227 Err(_) => break,
228 }
229 }
230 })
231 .expect("failed to spawn file-watcher thread");
232}
233
234fn spawn_watcher(tx: broadcast::Sender<String>) {
235 let path = config_path();
236 spawn_file_watcher(path, "config", move || {
237 let map = read_config();
238 for (k, v) in &map {
239 let _ = tx.send(format!("{k}={v}"));
240 }
241 let _ = tx.send("ready".into());
242 });
243}
244
245#[derive(Clone)]
256pub struct RemotesState {
257 inner: Arc<RemotesInner>,
258}
259
260struct RemotesInner {
261 contents: RwLock<String>,
263 tx: broadcast::Sender<String>,
264}
265
266impl RemotesState {
267 pub fn new() -> Self {
269 let (tx, _) = broadcast::channel(64);
270 let inner = Arc::new(RemotesInner {
271 contents: RwLock::new(serialize_remotes(&read_remotes())),
272 tx,
273 });
274 let watcher_inner = inner.clone();
275 spawn_file_watcher(remotes_path(), "remotes", move || {
276 let text = std::fs::read_to_string(remotes_path()).unwrap_or_default();
279 *watcher_inner.contents.write().unwrap() = text.clone();
280 let _ = watcher_inner.tx.send(text);
281 });
282 Self { inner }
283 }
284
285 pub fn ephemeral(initial: String) -> Self {
289 let (tx, _) = broadcast::channel(64);
290 Self {
291 inner: Arc::new(RemotesInner {
292 contents: RwLock::new(initial),
293 tx,
294 }),
295 }
296 }
297
298 pub fn get(&self) -> String {
300 self.inner.contents.read().unwrap().clone()
301 }
302
303 pub fn set(&self, entries: &[(String, String)]) {
305 write_remotes(entries);
306 let text = serialize_remotes(entries);
307 *self.inner.contents.write().unwrap() = text.clone();
308 let _ = self.inner.tx.send(text);
309 }
310
311 pub fn subscribe(&self) -> broadcast::Receiver<String> {
312 self.inner.tx.subscribe()
313 }
314}
315
316impl Default for RemotesState {
317 fn default() -> Self {
318 Self::new()
319 }
320}
321
322fn constant_time_eq(a: &[u8], b: &[u8]) -> bool {
323 if a.len() != b.len() {
324 return false;
325 }
326 let mut diff = 0u8;
327 for (x, y) in a.iter().zip(b.iter()) {
328 diff |= x ^ y;
329 }
330 std::hint::black_box(diff) == 0
331}
332
333fn parse_config_str(contents: &str) -> HashMap<String, String> {
334 let mut map = HashMap::new();
335 for line in contents.lines() {
336 let line = line.trim();
337 if line.is_empty() || line.starts_with('#') {
338 continue;
339 }
340 if let Some((k, v)) = line.split_once('=') {
341 map.insert(k.trim().to_string(), v.trim().to_string());
342 }
343 }
344 map
345}
346
347pub async fn handle_config_ws(
374 mut ws: WebSocket,
375 token: &str,
376 config: &ConfigState,
377 remotes: Option<&RemotesState>,
378 remotes_transform: Option<fn(&str) -> String>,
379) {
380 let authed = loop {
381 match ws.recv().await {
382 Some(Ok(Message::Text(pass))) => {
383 if constant_time_eq(pass.trim().as_bytes(), token.as_bytes()) {
384 let _ = ws.send(Message::Text("ok".into())).await;
385 break true;
386 } else {
387 let _ = ws.close().await;
388 break false;
389 }
390 }
391 Some(Ok(Message::Ping(d))) => {
392 let _ = ws.send(Message::Pong(d)).await;
393 }
394 _ => break false,
395 }
396 };
397 if !authed {
398 return;
399 }
400
401 let mut remotes_rx = remotes.map(|r| r.subscribe());
403
404 let remotes_text = remotes.map(|r| r.get()).unwrap_or_default();
407 let remotes_text = remotes_transform
408 .map(|f| f(&remotes_text))
409 .unwrap_or(remotes_text);
410 if ws
411 .send(Message::Text(format!("remotes:{remotes_text}").into()))
412 .await
413 .is_err()
414 {
415 return;
416 }
417
418 let map = read_config();
419 for (k, v) in &map {
420 if ws
421 .send(Message::Text(format!("{k}={v}").into()))
422 .await
423 .is_err()
424 {
425 return;
426 }
427 }
428 if ws.send(Message::Text("ready".into())).await.is_err() {
429 return;
430 }
431
432 let mut config_rx = config.tx.subscribe();
433
434 loop {
435 tokio::select! {
439 msg = ws.recv() => {
440 match msg {
441 Some(Ok(Message::Text(text))) => {
442 let text = text.trim();
443 if let Some(rest) = text.strip_prefix("set ")
444 && let Some((k, v)) = rest.split_once(' ') {
445 let _guard = config.write_lock.lock().await;
446 let mut map = read_config();
447 let k = k.trim().replace(['\n', '\r'], "");
448 let v = v.trim().replace(['\n', '\r'], "");
449 if k.is_empty() { continue; }
450 if v.is_empty() {
451 map.remove(&k);
452 } else {
453 map.insert(k, v);
454 }
455 write_config(&map);
456 } else if let Some(rest) = text.strip_prefix("remotes-add ") {
457 if let Some((raw_name, raw_uri)) = rest.split_once(' ') {
460 let name = raw_name.trim().replace(['\n', '\r'], "");
461 let uri = raw_uri.trim().replace(['\n', '\r'], "");
462 if !name.is_empty()
463 && !name.contains('=')
464 && !uri.is_empty()
465 && let Some(r) = remotes
466 {
467 let mut entries = parse_remotes_str(&r.get());
468 if let Some(pos) = entries.iter().position(|(n, _)| n == &name) {
469 entries[pos].1 = uri;
470 } else {
471 entries.push((name, uri));
472 }
473 r.set(&entries);
474 }
475 }
476 } else if let Some(name) = text.strip_prefix("remotes-remove ") {
477 let name = name.trim().replace(['\n', '\r'], "");
478 if !name.is_empty()
479 && let Some(r) = remotes
480 {
481 let mut entries = parse_remotes_str(&r.get());
482 entries.retain(|(n, _)| n != &name);
483 r.set(&entries);
484 }
485 } else if let Some(name) = text.strip_prefix("remotes-set-default ") {
486 let name = name.trim().replace(['\n', '\r'], "");
488 let _guard = config.write_lock.lock().await;
489 let mut map = read_config();
490 if name.is_empty() || name == "local" {
491 map.remove("blit.target");
492 } else {
493 map.insert("blit.target".into(), name);
494 }
495 write_config(&map);
496 } else if let Some(rest) = text.strip_prefix("remotes-reorder ") {
497 if let Some(r) = remotes {
500 let desired: Vec<String> = rest
501 .split_whitespace()
502 .map(|s| s.replace(['\n', '\r'], ""))
503 .filter(|s| !s.is_empty())
504 .collect();
505 if !desired.is_empty() {
506 let entries = parse_remotes_str(&r.get());
507 let map: std::collections::HashMap<&str, &str> = entries
509 .iter()
510 .map(|(n, u)| (n.as_str(), u.as_str()))
511 .collect();
512 let mut reordered: Vec<(String, String)> = desired
514 .iter()
515 .filter_map(|n| {
516 map.get(n.as_str())
517 .map(|u| (n.clone(), u.to_string()))
518 })
519 .collect();
520 let desired_set: std::collections::HashSet<&str> =
523 desired.iter().map(|s| s.as_str()).collect();
524 for (n, u) in &entries {
525 if !desired_set.contains(n.as_str()) {
526 reordered.push((n.clone(), u.clone()));
527 }
528 }
529 r.set(&reordered);
530 }
531 }
532 }
533 }
534 Some(Ok(Message::Close(_))) | None => break,
535 Some(Err(_)) => break,
536 _ => continue,
537 }
538 }
539 broadcast = config_rx.recv() => {
540 match broadcast {
541 Ok(line) => {
542 if ws.send(Message::Text(line.into())).await.is_err() {
543 break;
544 }
545 }
546 Err(tokio::sync::broadcast::error::RecvError::Lagged(_)) => continue,
547 Err(_) => break,
548 }
549 }
550 remotes_update = async {
551 match remotes_rx.as_mut() {
552 Some(rx) => rx.recv().await,
553 None => std::future::pending().await,
554 }
555 } => {
556 match remotes_update {
557 Ok(text) => {
558 let text = remotes_transform
559 .map(|f| f(&text))
560 .unwrap_or(text);
561 if ws
562 .send(Message::Text(format!("remotes:{text}").into()))
563 .await
564 .is_err()
565 {
566 break;
567 }
568 }
569 Err(tokio::sync::broadcast::error::RecvError::Lagged(_)) => {
570 if let Some(r) = remotes {
572 let text = r.get();
573 let text = remotes_transform
574 .map(|f| f(&text))
575 .unwrap_or(text);
576 if ws
577 .send(Message::Text(format!("remotes:{text}").into()))
578 .await
579 .is_err()
580 {
581 break;
582 }
583 }
584 }
585 Err(_) => break,
586 }
587 }
588 }
589 }
590}
591
592#[cfg(test)]
593mod tests {
594 use super::*;
595
596 #[test]
599 fn ct_eq_equal_slices() {
600 assert!(constant_time_eq(b"hello", b"hello"));
601 }
602
603 #[test]
604 fn ct_eq_different_slices() {
605 assert!(!constant_time_eq(b"hello", b"world"));
606 }
607
608 #[test]
609 fn ct_eq_different_lengths() {
610 assert!(!constant_time_eq(b"short", b"longer"));
611 }
612
613 #[test]
614 fn ct_eq_empty_slices() {
615 assert!(constant_time_eq(b"", b""));
616 }
617
618 #[test]
619 fn ct_eq_single_bit_diff() {
620 assert!(!constant_time_eq(b"\x00", b"\x01"));
621 }
622
623 #[test]
624 fn ct_eq_one_empty_one_not() {
625 assert!(!constant_time_eq(b"", b"x"));
626 }
627
628 #[test]
631 fn parse_empty_string() {
632 let map = parse_config_str("");
633 assert!(map.is_empty());
634 }
635
636 #[test]
637 fn parse_comments_and_blanks() {
638 let map = parse_config_str("# comment\n\n # another\n");
639 assert!(map.is_empty());
640 }
641
642 #[test]
643 fn parse_key_value() {
644 let map = parse_config_str("font = Menlo\ntheme = dark\n");
645 assert_eq!(map.get("font").unwrap(), "Menlo");
646 assert_eq!(map.get("theme").unwrap(), "dark");
647 }
648
649 #[test]
650 fn parse_trims_whitespace() {
651 let map = parse_config_str(" key = value ");
652 assert_eq!(map.get("key").unwrap(), "value");
653 }
654
655 #[test]
656 fn parse_line_without_equals() {
657 let map = parse_config_str("no-equals-here\nkey=val");
658 assert_eq!(map.len(), 1);
659 assert_eq!(map.get("key").unwrap(), "val");
660 }
661
662 #[test]
663 fn parse_equals_in_value() {
664 let map = parse_config_str("cmd = a=b=c");
665 assert_eq!(map.get("cmd").unwrap(), "a=b=c");
666 }
667
668 #[test]
669 fn parse_duplicate_keys_last_wins() {
670 let map = parse_config_str("key = first\nkey = second");
671 assert_eq!(map.get("key").unwrap(), "second");
672 }
673
674 #[test]
675 fn parse_mixed_content() {
676 let input = "# header\nfont = FiraCode\n\n# size\nsize = 14\ntheme=light";
677 let map = parse_config_str(input);
678 assert_eq!(map.len(), 3);
679 assert_eq!(map.get("font").unwrap(), "FiraCode");
680 assert_eq!(map.get("size").unwrap(), "14");
681 assert_eq!(map.get("theme").unwrap(), "light");
682 }
683
684 #[test]
687 fn serialize_config_produces_sorted_output() {
688 let mut map: HashMap<String, String> = HashMap::new();
689 map.insert("z".into(), "last".into());
690 map.insert("a".into(), "first".into());
691 let output = serialize_config_str(&map);
692 assert!(output.starts_with("a = first"));
693 assert!(output.contains("z = last"));
694 }
695
696 #[test]
697 fn round_trip_parse_serialize() {
698 let input = "alpha = 1\nbeta = 2\ngamma = 3";
699 let map = parse_config_str(input);
700 let serialized = serialize_config_str(&map);
701 let reparsed = parse_config_str(&serialized);
702 assert_eq!(map, reparsed);
703 }
704
705 #[test]
708 fn remotes_add_new_entry() {
709 let state = RemotesState::ephemeral(String::new());
710 let mut entries = parse_remotes_str(&state.get());
711 entries.push(("rabbit".to_string(), "ssh:rabbit".to_string()));
712 state.set(&entries);
713 let got = parse_remotes_str(&state.get());
714 assert_eq!(got.len(), 1);
715 assert_eq!(got[0], ("rabbit".to_string(), "ssh:rabbit".to_string()));
716 }
717
718 #[test]
719 fn remotes_add_updates_existing() {
720 let initial = "rabbit = ssh:rabbit\n";
721 let state = RemotesState::ephemeral(initial.to_string());
722 let mut entries = parse_remotes_str(&state.get());
723 if let Some(pos) = entries.iter().position(|(n, _)| n == "rabbit") {
724 entries[pos].1 = "tcp:rabbit:3264".to_string();
725 }
726 state.set(&entries);
727 let got = parse_remotes_str(&state.get());
728 assert_eq!(got.len(), 1);
729 assert_eq!(got[0].1, "tcp:rabbit:3264");
730 }
731
732 #[test]
733 fn remotes_remove_existing() {
734 let initial = "rabbit = ssh:rabbit\nhound = ssh:hound\n";
735 let state = RemotesState::ephemeral(initial.to_string());
736 let mut entries = parse_remotes_str(&state.get());
737 entries.retain(|(n, _)| n != "rabbit");
738 state.set(&entries);
739 let got = parse_remotes_str(&state.get());
740 assert_eq!(got.len(), 1);
741 assert_eq!(got[0].0, "hound");
742 }
743
744 #[test]
745 fn remotes_remove_nonexistent_is_noop() {
746 let initial = "rabbit = ssh:rabbit\n";
747 let state = RemotesState::ephemeral(initial.to_string());
748 let mut entries = parse_remotes_str(&state.get());
749 let before = entries.len();
750 entries.retain(|(n, _)| n != "does-not-exist");
751 assert_eq!(entries.len(), before);
752 }
753
754 #[test]
755 fn remotes_add_rejects_empty_name() {
756 let name = "";
758 assert!(name.is_empty() || name.contains('='));
759 }
760
761 #[test]
762 fn remotes_add_rejects_name_with_equals() {
763 let name = "foo=bar";
764 assert!(name.contains('='));
765 }
766
767 #[test]
770 fn set_default_inserts_target_key() {
771 let mut map = parse_config_str("font = Mono\n");
772 map.insert("blit.target".into(), "rabbit".into());
773 let serialized = serialize_config_str(&map);
774 let reparsed = parse_config_str(&serialized);
775 assert_eq!(
776 reparsed.get("blit.target").map(|s| s.as_str()),
777 Some("rabbit")
778 );
779 assert_eq!(reparsed.get("font").map(|s| s.as_str()), Some("Mono"));
780 }
781
782 #[test]
783 fn set_default_local_removes_target_key() {
784 let mut map = parse_config_str("blit.target = rabbit\nfont = Mono\n");
785 map.remove("blit.target");
787 let serialized = serialize_config_str(&map);
788 let reparsed = parse_config_str(&serialized);
789 assert!(!reparsed.contains_key("blit.target"));
790 assert_eq!(reparsed.get("font").map(|s| s.as_str()), Some("Mono"));
791 }
792}