1use axum::extract::ws::{Message, WebSocket};
2use futures_util::SinkExt;
3use std::collections::HashMap;
4use std::path::PathBuf;
5use std::sync::{Arc, RwLock};
6use tokio::sync::broadcast;
7
8pub struct ConfigState {
9 pub tx: broadcast::Sender<String>,
10 pub write_lock: tokio::sync::Mutex<()>,
11}
12
13impl Default for ConfigState {
14 fn default() -> Self {
15 Self::new()
16 }
17}
18
19impl ConfigState {
20 pub fn new() -> Self {
21 let (tx, _) = broadcast::channel::<String>(64);
22 spawn_watcher(tx.clone());
23 Self {
24 tx,
25 write_lock: tokio::sync::Mutex::new(()),
26 }
27 }
28}
29
30fn blit_config_dir() -> PathBuf {
31 #[cfg(unix)]
32 let base = std::env::var("XDG_CONFIG_HOME")
33 .map(PathBuf::from)
34 .unwrap_or_else(|_| {
35 let home = std::env::var("HOME").unwrap_or_else(|_| "/root".into());
36 PathBuf::from(home).join(".config")
37 });
38 #[cfg(windows)]
39 let base = std::env::var("APPDATA")
40 .map(PathBuf::from)
41 .unwrap_or_else(|_| PathBuf::from(r"C:\ProgramData"));
42 base.join("blit")
43}
44
45pub fn config_path() -> PathBuf {
46 if let Ok(p) = std::env::var("BLIT_CONFIG") {
47 return PathBuf::from(p);
48 }
49 blit_config_dir().join("blit.conf")
50}
51
52pub fn remotes_path() -> PathBuf {
53 if let Ok(p) = std::env::var("BLIT_REMOTES") {
54 return PathBuf::from(p);
55 }
56 blit_config_dir().join("blit.remotes")
57}
58
59pub fn read_config() -> HashMap<String, String> {
60 let path = config_path();
61 let contents = match std::fs::read_to_string(&path) {
62 Ok(c) => c,
63 Err(_) => return HashMap::new(),
64 };
65 parse_config_str(&contents)
66}
67
68pub fn read_remotes() -> Vec<(String, String)> {
71 let path = remotes_path();
72 let contents = match std::fs::read_to_string(&path) {
73 Ok(c) => c,
74 Err(_) => {
75 let default = vec![("local".to_string(), "local".to_string())];
76 write_remotes(&default);
77 return default;
78 }
79 };
80 parse_remotes_str(&contents)
81}
82
83pub fn parse_remotes_str(contents: &str) -> Vec<(String, String)> {
87 let mut order: Vec<String> = Vec::new();
90 let mut map: HashMap<String, String> = HashMap::new();
91 for line in contents.lines() {
92 let line = line.trim();
93 if line.is_empty() || line.starts_with('#') {
94 continue;
95 }
96 if let Some((k, v)) = line.split_once('=') {
97 let k = k.trim().to_string();
98 let v = v.trim().to_string();
99 if !k.is_empty() && !v.is_empty() {
100 if !map.contains_key(&k) {
101 order.push(k.clone());
102 }
103 map.insert(k, v);
104 }
105 }
106 }
107 order
108 .into_iter()
109 .map(|k| {
110 let v = map.remove(&k).unwrap();
111 (k, v)
112 })
113 .collect()
114}
115
116fn serialize_remotes(entries: &[(String, String)]) -> String {
117 let mut out = String::new();
118 for (k, v) in entries {
119 out.push_str(k);
120 out.push_str(" = ");
121 out.push_str(v);
122 out.push('\n');
123 }
124 out
125}
126
127pub fn write_remotes(entries: &[(String, String)]) {
129 let path = remotes_path();
130 if let Some(parent) = path.parent() {
131 let _ = std::fs::create_dir_all(parent);
132 }
133 let contents = serialize_remotes(entries);
134 write_secret_file(&path, &contents);
135}
136
137fn write_secret_file(path: &PathBuf, contents: &str) {
141 #[cfg(unix)]
142 {
143 use std::os::unix::fs::OpenOptionsExt;
144 let tmp = path.with_extension("tmp");
146 let result = std::fs::OpenOptions::new()
147 .write(true)
148 .create(true)
149 .truncate(true)
150 .mode(0o600)
151 .open(&tmp)
152 .and_then(|mut f| {
153 use std::io::Write;
154 f.write_all(contents.as_bytes())
155 });
156 if result.is_ok() {
157 let _ = std::fs::rename(&tmp, path);
158 } else {
159 let _ = std::fs::remove_file(&tmp);
160 }
161 }
162 #[cfg(not(unix))]
163 {
164 let _ = std::fs::write(path, contents);
165 }
166}
167
168fn serialize_config_str(map: &HashMap<String, String>) -> String {
169 let mut lines: Vec<String> = map.iter().map(|(k, v)| format!("{k} = {v}")).collect();
170 lines.sort();
171 lines.push(String::new());
172 lines.join("\n")
173}
174
175pub fn write_config(map: &HashMap<String, String>) {
176 let path = config_path();
177 if let Some(parent) = path.parent() {
178 let _ = std::fs::create_dir_all(parent);
179 }
180 let _ = std::fs::write(&path, serialize_config_str(map));
181}
182
183fn spawn_file_watcher<F>(path: PathBuf, label: &'static str, on_change: F)
186where
187 F: Fn() + Send + 'static,
188{
189 use notify::{RecursiveMode, Watcher};
190
191 if let Some(parent) = path.parent() {
192 let _ = std::fs::create_dir_all(parent);
193 }
194
195 let watch_dir = path.parent().unwrap_or(&path).to_path_buf();
196 let file_name = path.file_name().map(|n| n.to_os_string());
197
198 std::thread::spawn(move || {
199 let (ntx, nrx) = std::sync::mpsc::channel();
200 let mut watcher = match notify::recommended_watcher(ntx) {
201 Ok(w) => w,
202 Err(e) => {
203 eprintln!("blit: {label} watcher failed: {e}");
204 return;
205 }
206 };
207 if let Err(e) = watcher.watch(&watch_dir, RecursiveMode::NonRecursive) {
208 eprintln!("blit: {label} watch failed: {e}");
209 return;
210 }
211 loop {
212 match nrx.recv() {
213 Ok(Ok(event)) => {
214 if matches!(event.kind, notify::EventKind::Access(_)) {
215 continue;
216 }
217 let matches = file_name
218 .as_ref()
219 .is_none_or(|name| event.paths.iter().any(|p| p.file_name() == Some(name)));
220 if matches {
221 on_change();
222 }
223 }
224 Ok(Err(_)) => continue,
225 Err(_) => break,
226 }
227 }
228 });
229}
230
231fn spawn_watcher(tx: broadcast::Sender<String>) {
232 let path = config_path();
233 spawn_file_watcher(path, "config", move || {
234 let map = read_config();
235 for (k, v) in &map {
236 let _ = tx.send(format!("{k}={v}"));
237 }
238 let _ = tx.send("ready".into());
239 });
240}
241
242#[derive(Clone)]
253pub struct RemotesState {
254 inner: Arc<RemotesInner>,
255}
256
257struct RemotesInner {
258 contents: RwLock<String>,
260 tx: broadcast::Sender<String>,
261}
262
263impl RemotesState {
264 pub fn new() -> Self {
266 let (tx, _) = broadcast::channel(64);
267 let inner = Arc::new(RemotesInner {
268 contents: RwLock::new(serialize_remotes(&read_remotes())),
269 tx,
270 });
271 let watcher_inner = inner.clone();
272 spawn_file_watcher(remotes_path(), "remotes", move || {
273 let text = std::fs::read_to_string(remotes_path()).unwrap_or_default();
276 *watcher_inner.contents.write().unwrap() = text.clone();
277 let _ = watcher_inner.tx.send(text);
278 });
279 Self { inner }
280 }
281
282 pub fn ephemeral(initial: String) -> Self {
286 let (tx, _) = broadcast::channel(64);
287 Self {
288 inner: Arc::new(RemotesInner {
289 contents: RwLock::new(initial),
290 tx,
291 }),
292 }
293 }
294
295 pub fn get(&self) -> String {
297 self.inner.contents.read().unwrap().clone()
298 }
299
300 pub fn set(&self, entries: &[(String, String)]) {
302 write_remotes(entries);
303 let text = serialize_remotes(entries);
304 *self.inner.contents.write().unwrap() = text.clone();
305 let _ = self.inner.tx.send(text);
306 }
307
308 pub fn subscribe(&self) -> broadcast::Receiver<String> {
309 self.inner.tx.subscribe()
310 }
311}
312
313impl Default for RemotesState {
314 fn default() -> Self {
315 Self::new()
316 }
317}
318
319fn constant_time_eq(a: &[u8], b: &[u8]) -> bool {
320 if a.len() != b.len() {
321 return false;
322 }
323 let mut diff = 0u8;
324 for (x, y) in a.iter().zip(b.iter()) {
325 diff |= x ^ y;
326 }
327 std::hint::black_box(diff) == 0
328}
329
330fn parse_config_str(contents: &str) -> HashMap<String, String> {
331 let mut map = HashMap::new();
332 for line in contents.lines() {
333 let line = line.trim();
334 if line.is_empty() || line.starts_with('#') {
335 continue;
336 }
337 if let Some((k, v)) = line.split_once('=') {
338 map.insert(k.trim().to_string(), v.trim().to_string());
339 }
340 }
341 map
342}
343
344pub async fn handle_config_ws(
371 mut ws: WebSocket,
372 token: &str,
373 config: &ConfigState,
374 remotes: Option<&RemotesState>,
375 remotes_transform: Option<fn(&str) -> String>,
376) {
377 let authed = loop {
378 match ws.recv().await {
379 Some(Ok(Message::Text(pass))) => {
380 if constant_time_eq(pass.trim().as_bytes(), token.as_bytes()) {
381 let _ = ws.send(Message::Text("ok".into())).await;
382 break true;
383 } else {
384 let _ = ws.close().await;
385 break false;
386 }
387 }
388 Some(Ok(Message::Ping(d))) => {
389 let _ = ws.send(Message::Pong(d)).await;
390 }
391 _ => break false,
392 }
393 };
394 if !authed {
395 return;
396 }
397
398 let mut remotes_rx = remotes.map(|r| r.subscribe());
400
401 let remotes_text = remotes.map(|r| r.get()).unwrap_or_default();
404 let remotes_text = remotes_transform
405 .map(|f| f(&remotes_text))
406 .unwrap_or(remotes_text);
407 if ws
408 .send(Message::Text(format!("remotes:{remotes_text}").into()))
409 .await
410 .is_err()
411 {
412 return;
413 }
414
415 let map = read_config();
416 for (k, v) in &map {
417 if ws
418 .send(Message::Text(format!("{k}={v}").into()))
419 .await
420 .is_err()
421 {
422 return;
423 }
424 }
425 if ws.send(Message::Text("ready".into())).await.is_err() {
426 return;
427 }
428
429 let mut config_rx = config.tx.subscribe();
430
431 loop {
432 tokio::select! {
436 msg = ws.recv() => {
437 match msg {
438 Some(Ok(Message::Text(text))) => {
439 let text = text.trim();
440 if let Some(rest) = text.strip_prefix("set ")
441 && let Some((k, v)) = rest.split_once(' ') {
442 let _guard = config.write_lock.lock().await;
443 let mut map = read_config();
444 let k = k.trim().replace(['\n', '\r'], "");
445 let v = v.trim().replace(['\n', '\r'], "");
446 if k.is_empty() { continue; }
447 if v.is_empty() {
448 map.remove(&k);
449 } else {
450 map.insert(k, v);
451 }
452 write_config(&map);
453 } else if let Some(rest) = text.strip_prefix("remotes-add ") {
454 if let Some((raw_name, raw_uri)) = rest.split_once(' ') {
457 let name = raw_name.trim().replace(['\n', '\r'], "");
458 let uri = raw_uri.trim().replace(['\n', '\r'], "");
459 if !name.is_empty()
460 && !name.contains('=')
461 && !uri.is_empty()
462 && let Some(r) = remotes
463 {
464 let mut entries = parse_remotes_str(&r.get());
465 if let Some(pos) = entries.iter().position(|(n, _)| n == &name) {
466 entries[pos].1 = uri;
467 } else {
468 entries.push((name, uri));
469 }
470 r.set(&entries);
471 }
472 }
473 } else if let Some(name) = text.strip_prefix("remotes-remove ") {
474 let name = name.trim().replace(['\n', '\r'], "");
475 if !name.is_empty()
476 && let Some(r) = remotes
477 {
478 let mut entries = parse_remotes_str(&r.get());
479 entries.retain(|(n, _)| n != &name);
480 r.set(&entries);
481 }
482 } else if let Some(name) = text.strip_prefix("remotes-set-default ") {
483 let name = name.trim().replace(['\n', '\r'], "");
485 let _guard = config.write_lock.lock().await;
486 let mut map = read_config();
487 if name.is_empty() || name == "local" {
488 map.remove("blit.target");
489 } else {
490 map.insert("blit.target".into(), name);
491 }
492 write_config(&map);
493 } else if let Some(rest) = text.strip_prefix("remotes-reorder ") {
494 if let Some(r) = remotes {
497 let desired: Vec<String> = rest
498 .split_whitespace()
499 .map(|s| s.replace(['\n', '\r'], ""))
500 .filter(|s| !s.is_empty())
501 .collect();
502 if !desired.is_empty() {
503 let entries = parse_remotes_str(&r.get());
504 let map: std::collections::HashMap<&str, &str> = entries
506 .iter()
507 .map(|(n, u)| (n.as_str(), u.as_str()))
508 .collect();
509 let mut reordered: Vec<(String, String)> = desired
511 .iter()
512 .filter_map(|n| {
513 map.get(n.as_str())
514 .map(|u| (n.clone(), u.to_string()))
515 })
516 .collect();
517 let desired_set: std::collections::HashSet<&str> =
520 desired.iter().map(|s| s.as_str()).collect();
521 for (n, u) in &entries {
522 if !desired_set.contains(n.as_str()) {
523 reordered.push((n.clone(), u.clone()));
524 }
525 }
526 r.set(&reordered);
527 }
528 }
529 }
530 }
531 Some(Ok(Message::Close(_))) | None => break,
532 Some(Err(_)) => break,
533 _ => continue,
534 }
535 }
536 broadcast = config_rx.recv() => {
537 match broadcast {
538 Ok(line) => {
539 if ws.send(Message::Text(line.into())).await.is_err() {
540 break;
541 }
542 }
543 Err(tokio::sync::broadcast::error::RecvError::Lagged(_)) => continue,
544 Err(_) => break,
545 }
546 }
547 remotes_update = async {
548 match remotes_rx.as_mut() {
549 Some(rx) => rx.recv().await,
550 None => std::future::pending().await,
551 }
552 } => {
553 match remotes_update {
554 Ok(text) => {
555 let text = remotes_transform
556 .map(|f| f(&text))
557 .unwrap_or(text);
558 if ws
559 .send(Message::Text(format!("remotes:{text}").into()))
560 .await
561 .is_err()
562 {
563 break;
564 }
565 }
566 Err(tokio::sync::broadcast::error::RecvError::Lagged(_)) => {
567 if let Some(r) = remotes {
569 let text = r.get();
570 let text = remotes_transform
571 .map(|f| f(&text))
572 .unwrap_or(text);
573 if ws
574 .send(Message::Text(format!("remotes:{text}").into()))
575 .await
576 .is_err()
577 {
578 break;
579 }
580 }
581 }
582 Err(_) => break,
583 }
584 }
585 }
586 }
587}
588
589#[cfg(test)]
590mod tests {
591 use super::*;
592
593 #[test]
596 fn ct_eq_equal_slices() {
597 assert!(constant_time_eq(b"hello", b"hello"));
598 }
599
600 #[test]
601 fn ct_eq_different_slices() {
602 assert!(!constant_time_eq(b"hello", b"world"));
603 }
604
605 #[test]
606 fn ct_eq_different_lengths() {
607 assert!(!constant_time_eq(b"short", b"longer"));
608 }
609
610 #[test]
611 fn ct_eq_empty_slices() {
612 assert!(constant_time_eq(b"", b""));
613 }
614
615 #[test]
616 fn ct_eq_single_bit_diff() {
617 assert!(!constant_time_eq(b"\x00", b"\x01"));
618 }
619
620 #[test]
621 fn ct_eq_one_empty_one_not() {
622 assert!(!constant_time_eq(b"", b"x"));
623 }
624
625 #[test]
628 fn parse_empty_string() {
629 let map = parse_config_str("");
630 assert!(map.is_empty());
631 }
632
633 #[test]
634 fn parse_comments_and_blanks() {
635 let map = parse_config_str("# comment\n\n # another\n");
636 assert!(map.is_empty());
637 }
638
639 #[test]
640 fn parse_key_value() {
641 let map = parse_config_str("font = Menlo\ntheme = dark\n");
642 assert_eq!(map.get("font").unwrap(), "Menlo");
643 assert_eq!(map.get("theme").unwrap(), "dark");
644 }
645
646 #[test]
647 fn parse_trims_whitespace() {
648 let map = parse_config_str(" key = value ");
649 assert_eq!(map.get("key").unwrap(), "value");
650 }
651
652 #[test]
653 fn parse_line_without_equals() {
654 let map = parse_config_str("no-equals-here\nkey=val");
655 assert_eq!(map.len(), 1);
656 assert_eq!(map.get("key").unwrap(), "val");
657 }
658
659 #[test]
660 fn parse_equals_in_value() {
661 let map = parse_config_str("cmd = a=b=c");
662 assert_eq!(map.get("cmd").unwrap(), "a=b=c");
663 }
664
665 #[test]
666 fn parse_duplicate_keys_last_wins() {
667 let map = parse_config_str("key = first\nkey = second");
668 assert_eq!(map.get("key").unwrap(), "second");
669 }
670
671 #[test]
672 fn parse_mixed_content() {
673 let input = "# header\nfont = FiraCode\n\n# size\nsize = 14\ntheme=light";
674 let map = parse_config_str(input);
675 assert_eq!(map.len(), 3);
676 assert_eq!(map.get("font").unwrap(), "FiraCode");
677 assert_eq!(map.get("size").unwrap(), "14");
678 assert_eq!(map.get("theme").unwrap(), "light");
679 }
680
681 #[test]
684 fn serialize_config_produces_sorted_output() {
685 let mut map: HashMap<String, String> = HashMap::new();
686 map.insert("z".into(), "last".into());
687 map.insert("a".into(), "first".into());
688 let output = serialize_config_str(&map);
689 assert!(output.starts_with("a = first"));
690 assert!(output.contains("z = last"));
691 }
692
693 #[test]
694 fn round_trip_parse_serialize() {
695 let input = "alpha = 1\nbeta = 2\ngamma = 3";
696 let map = parse_config_str(input);
697 let serialized = serialize_config_str(&map);
698 let reparsed = parse_config_str(&serialized);
699 assert_eq!(map, reparsed);
700 }
701
702 #[test]
705 fn remotes_add_new_entry() {
706 let state = RemotesState::ephemeral(String::new());
707 let mut entries = parse_remotes_str(&state.get());
708 entries.push(("rabbit".to_string(), "ssh:rabbit".to_string()));
709 state.set(&entries);
710 let got = parse_remotes_str(&state.get());
711 assert_eq!(got.len(), 1);
712 assert_eq!(got[0], ("rabbit".to_string(), "ssh:rabbit".to_string()));
713 }
714
715 #[test]
716 fn remotes_add_updates_existing() {
717 let initial = "rabbit = ssh:rabbit\n";
718 let state = RemotesState::ephemeral(initial.to_string());
719 let mut entries = parse_remotes_str(&state.get());
720 if let Some(pos) = entries.iter().position(|(n, _)| n == "rabbit") {
721 entries[pos].1 = "tcp:rabbit:3264".to_string();
722 }
723 state.set(&entries);
724 let got = parse_remotes_str(&state.get());
725 assert_eq!(got.len(), 1);
726 assert_eq!(got[0].1, "tcp:rabbit:3264");
727 }
728
729 #[test]
730 fn remotes_remove_existing() {
731 let initial = "rabbit = ssh:rabbit\nhound = ssh:hound\n";
732 let state = RemotesState::ephemeral(initial.to_string());
733 let mut entries = parse_remotes_str(&state.get());
734 entries.retain(|(n, _)| n != "rabbit");
735 state.set(&entries);
736 let got = parse_remotes_str(&state.get());
737 assert_eq!(got.len(), 1);
738 assert_eq!(got[0].0, "hound");
739 }
740
741 #[test]
742 fn remotes_remove_nonexistent_is_noop() {
743 let initial = "rabbit = ssh:rabbit\n";
744 let state = RemotesState::ephemeral(initial.to_string());
745 let mut entries = parse_remotes_str(&state.get());
746 let before = entries.len();
747 entries.retain(|(n, _)| n != "does-not-exist");
748 assert_eq!(entries.len(), before);
749 }
750
751 #[test]
752 fn remotes_add_rejects_empty_name() {
753 let name = "";
755 assert!(name.is_empty() || name.contains('='));
756 }
757
758 #[test]
759 fn remotes_add_rejects_name_with_equals() {
760 let name = "foo=bar";
761 assert!(name.contains('='));
762 }
763
764 #[test]
767 fn set_default_inserts_target_key() {
768 let mut map = parse_config_str("font = Mono\n");
769 map.insert("blit.target".into(), "rabbit".into());
770 let serialized = serialize_config_str(&map);
771 let reparsed = parse_config_str(&serialized);
772 assert_eq!(
773 reparsed.get("blit.target").map(|s| s.as_str()),
774 Some("rabbit")
775 );
776 assert_eq!(reparsed.get("font").map(|s| s.as_str()), Some("Mono"));
777 }
778
779 #[test]
780 fn set_default_local_removes_target_key() {
781 let mut map = parse_config_str("blit.target = rabbit\nfont = Mono\n");
782 map.remove("blit.target");
784 let serialized = serialize_config_str(&map);
785 let reparsed = parse_config_str(&serialized);
786 assert!(!reparsed.contains_key("blit.target"));
787 assert_eq!(reparsed.get("font").map(|s| s.as_str()), Some("Mono"));
788 }
789}