1use axum::extract::ws::{Message, WebSocket};
2use futures_util::SinkExt;
3use std::collections::HashMap;
4use std::path::PathBuf;
5use std::sync::{Arc, RwLock};
6use tokio::sync::broadcast;
7
8pub struct ConfigState {
9 pub tx: broadcast::Sender<String>,
10}
11
12impl Default for ConfigState {
13 fn default() -> Self {
14 Self::new()
15 }
16}
17
18impl ConfigState {
19 pub fn new() -> Self {
20 let (tx, _) = broadcast::channel::<String>(64);
21 spawn_watcher(tx.clone());
22 Self { tx }
23 }
24}
25
26fn blit_config_dir() -> PathBuf {
27 #[cfg(unix)]
28 let base = std::env::var("XDG_CONFIG_HOME")
29 .map(PathBuf::from)
30 .unwrap_or_else(|_| {
31 let home = std::env::var("HOME").unwrap_or_else(|_| "/root".into());
32 PathBuf::from(home).join(".config")
33 });
34 #[cfg(windows)]
35 let base = std::env::var("APPDATA")
36 .map(PathBuf::from)
37 .unwrap_or_else(|_| PathBuf::from(r"C:\ProgramData"));
38 base.join("blit")
39}
40
41pub fn config_path() -> PathBuf {
42 if let Ok(p) = std::env::var("BLIT_CONFIG") {
43 return PathBuf::from(p);
44 }
45 blit_config_dir().join("blit.conf")
46}
47
48pub fn remotes_path() -> PathBuf {
49 if let Ok(p) = std::env::var("BLIT_REMOTES") {
50 return PathBuf::from(p);
51 }
52 blit_config_dir().join("blit.remotes")
53}
54
55fn lock_config_dir() -> Option<std::fs::File> {
59 #[cfg(unix)]
60 {
61 use std::os::unix::fs::OpenOptionsExt;
62 let dir = blit_config_dir();
63 let _ = std::fs::create_dir_all(&dir);
64 let lock_path = dir.join("blit.lock");
65 if let Ok(f) = std::fs::OpenOptions::new()
66 .write(true)
67 .create(true)
68 .truncate(false)
69 .mode(0o600)
70 .open(&lock_path)
71 {
72 use std::os::unix::io::AsRawFd;
74 if unsafe { libc::flock(f.as_raw_fd(), libc::LOCK_EX) } == 0 {
75 return Some(f);
76 }
77 }
78 None
79 }
80 #[cfg(not(unix))]
81 {
82 None
83 }
84}
85
86pub fn read_config() -> HashMap<String, String> {
87 let path = config_path();
88 let contents = match std::fs::read_to_string(&path) {
89 Ok(c) => c,
90 Err(e) if e.kind() == std::io::ErrorKind::NotFound => return HashMap::new(),
91 Err(e) => {
92 eprintln!("blit: could not read {}: {e}", path.display());
93 return HashMap::new();
94 }
95 };
96 parse_config_str(&contents)
97}
98
99pub fn read_remotes() -> Vec<(String, String)> {
102 let path = remotes_path();
103 let contents = match std::fs::read_to_string(&path) {
104 Ok(c) => c,
105 Err(e) if e.kind() == std::io::ErrorKind::NotFound => {
106 let default = vec![("local".to_string(), "local".to_string())];
107 write_remotes(&default);
108 return default;
109 }
110 Err(e) => {
111 eprintln!("blit: could not read {}: {e}", path.display());
112 return vec![];
113 }
114 };
115 parse_remotes_str(&contents)
116}
117
118pub fn modify_config(f: impl FnOnce(&mut HashMap<String, String>)) {
120 let _lock = lock_config_dir();
121 let mut map = read_config();
122 f(&mut map);
123 write_config(&map);
124}
125
126pub fn modify_remotes(f: impl FnOnce(&mut Vec<(String, String)>)) {
128 let _lock = lock_config_dir();
129 let mut entries = read_remotes();
130 f(&mut entries);
131 write_remotes(&entries);
132}
133
134pub fn parse_remotes_str(contents: &str) -> Vec<(String, String)> {
138 let mut order: Vec<String> = Vec::new();
141 let mut map: HashMap<String, String> = HashMap::new();
142 for line in contents.lines() {
143 let line = line.trim();
144 if line.is_empty() || line.starts_with('#') {
145 continue;
146 }
147 if let Some((k, v)) = line.split_once('=') {
148 let k = k.trim().to_string();
149 let v = v.trim().to_string();
150 if !k.is_empty() && !v.is_empty() {
151 if !map.contains_key(&k) {
152 order.push(k.clone());
153 }
154 map.insert(k, v);
155 }
156 }
157 }
158 order
159 .into_iter()
160 .map(|k| {
161 let v = map.remove(&k).unwrap();
162 (k, v)
163 })
164 .collect()
165}
166
167fn serialize_remotes(entries: &[(String, String)]) -> String {
168 let mut out = String::new();
169 for (k, v) in entries {
170 out.push_str(k);
171 out.push_str(" = ");
172 out.push_str(v);
173 out.push('\n');
174 }
175 out
176}
177
178pub fn write_remotes(entries: &[(String, String)]) {
180 let path = remotes_path();
181 if let Some(parent) = path.parent() {
182 let _ = std::fs::create_dir_all(parent);
183 }
184 let contents = serialize_remotes(entries);
185 write_secret_file(&path, &contents);
186}
187
188fn write_secret_file(path: &PathBuf, contents: &str) {
192 #[cfg(unix)]
193 {
194 use std::os::unix::fs::OpenOptionsExt;
195 use std::sync::atomic::{AtomicU32, Ordering};
198 static COUNTER: AtomicU32 = AtomicU32::new(0);
199 let seq = COUNTER.fetch_add(1, Ordering::Relaxed);
200 let pid = std::process::id();
201 let tmp = path.with_extension(format!("tmp.{pid}.{seq}"));
202 let result = std::fs::OpenOptions::new()
203 .write(true)
204 .create(true)
205 .truncate(true)
206 .mode(0o600)
207 .open(&tmp)
208 .and_then(|mut f| {
209 use std::io::Write;
210 f.write_all(contents.as_bytes())
211 });
212 if result.is_ok() {
213 let _ = std::fs::rename(&tmp, path);
214 } else {
215 let _ = std::fs::remove_file(&tmp);
216 }
217 }
218 #[cfg(not(unix))]
219 {
220 let _ = std::fs::write(path, contents);
221 }
222}
223
224fn serialize_config_str(map: &HashMap<String, String>) -> String {
225 let mut lines: Vec<String> = map.iter().map(|(k, v)| format!("{k} = {v}")).collect();
226 lines.sort();
227 lines.push(String::new());
228 lines.join("\n")
229}
230
231pub fn write_config(map: &HashMap<String, String>) {
232 let path = config_path();
233 if let Some(parent) = path.parent() {
234 let _ = std::fs::create_dir_all(parent);
235 }
236 write_secret_file(&path, &serialize_config_str(map));
237}
238
239fn spawn_file_watcher<F>(path: PathBuf, label: &'static str, on_change: F)
242where
243 F: Fn() + Send + 'static,
244{
245 use notify::{RecursiveMode, Watcher};
246
247 if let Some(parent) = path.parent() {
248 let _ = std::fs::create_dir_all(parent);
249 }
250
251 let watch_dir = path.parent().unwrap_or(&path).to_path_buf();
252 let file_name = path.file_name().map(|n| n.to_os_string());
253
254 std::thread::Builder::new()
255 .name(format!("{label}-watcher"))
256 .spawn(move || {
257 let (ntx, nrx) = std::sync::mpsc::channel();
258 let mut watcher = match notify::recommended_watcher(ntx) {
259 Ok(w) => w,
260 Err(e) => {
261 eprintln!("blit: {label} watcher failed: {e}");
262 return;
263 }
264 };
265 if let Err(e) = watcher.watch(&watch_dir, RecursiveMode::NonRecursive) {
266 eprintln!("blit: {label} watch failed: {e}");
267 return;
268 }
269 loop {
270 match nrx.recv() {
271 Ok(Ok(event)) => {
272 if matches!(event.kind, notify::EventKind::Access(_)) {
273 continue;
274 }
275 let matches = file_name.as_ref().is_none_or(|name| {
276 event.paths.iter().any(|p| p.file_name() == Some(name))
277 });
278 if matches {
279 on_change();
280 }
281 }
282 Ok(Err(_)) => continue,
283 Err(_) => break,
284 }
285 }
286 })
287 .expect("failed to spawn file-watcher thread");
288}
289
290fn spawn_watcher(tx: broadcast::Sender<String>) {
291 let path = config_path();
292 spawn_file_watcher(path, "config", move || {
293 let map = read_config();
294 for (k, v) in &map {
295 let _ = tx.send(format!("{k}={v}"));
296 }
297 let _ = tx.send("ready".into());
298 });
299}
300
301#[derive(Clone)]
312pub struct RemotesState {
313 inner: Arc<RemotesInner>,
314}
315
316struct RemotesInner {
317 contents: RwLock<String>,
319 tx: broadcast::Sender<String>,
320}
321
322impl RemotesState {
323 pub fn new() -> Self {
325 let (tx, _) = broadcast::channel(64);
326 let inner = Arc::new(RemotesInner {
327 contents: RwLock::new(serialize_remotes(&read_remotes())),
328 tx,
329 });
330 let watcher_inner = inner.clone();
331 spawn_file_watcher(remotes_path(), "remotes", move || {
332 let text = std::fs::read_to_string(remotes_path()).unwrap_or_default();
335 *watcher_inner.contents.write().unwrap() = text.clone();
336 let _ = watcher_inner.tx.send(text);
337 });
338 Self { inner }
339 }
340
341 pub fn ephemeral(initial: String) -> Self {
345 let (tx, _) = broadcast::channel(64);
346 Self {
347 inner: Arc::new(RemotesInner {
348 contents: RwLock::new(initial),
349 tx,
350 }),
351 }
352 }
353
354 pub fn get(&self) -> String {
356 self.inner.contents.read().unwrap().clone()
357 }
358
359 pub fn set(&self, entries: &[(String, String)]) {
361 write_remotes(entries);
362 let text = serialize_remotes(entries);
363 *self.inner.contents.write().unwrap() = text.clone();
364 let _ = self.inner.tx.send(text);
365 }
366
367 pub fn modify(&self, f: impl FnOnce(&mut Vec<(String, String)>)) {
370 let _lock = lock_config_dir();
371 let mut entries = parse_remotes_str(&self.get());
372 f(&mut entries);
373 self.set(&entries);
374 }
375
376 pub fn subscribe(&self) -> broadcast::Receiver<String> {
377 self.inner.tx.subscribe()
378 }
379}
380
381impl Default for RemotesState {
382 fn default() -> Self {
383 Self::new()
384 }
385}
386
387fn constant_time_eq(a: &[u8], b: &[u8]) -> bool {
388 if a.len() != b.len() {
389 return false;
390 }
391 let mut diff = 0u8;
392 for (x, y) in a.iter().zip(b.iter()) {
393 diff |= x ^ y;
394 }
395 std::hint::black_box(diff) == 0
396}
397
398fn parse_config_str(contents: &str) -> HashMap<String, String> {
399 let mut map = HashMap::new();
400 for line in contents.lines() {
401 let line = line.trim();
402 if line.is_empty() || line.starts_with('#') {
403 continue;
404 }
405 if let Some((k, v)) = line.split_once('=') {
406 map.insert(k.trim().to_string(), v.trim().to_string());
407 }
408 }
409 map
410}
411
412pub async fn handle_config_ws(
439 mut ws: WebSocket,
440 token: &str,
441 config: &ConfigState,
442 remotes: Option<&RemotesState>,
443 remotes_transform: Option<fn(&str) -> String>,
444 extra_init: &[String],
445) {
446 let authed = loop {
447 match ws.recv().await {
448 Some(Ok(Message::Text(pass))) => {
449 if constant_time_eq(pass.trim().as_bytes(), token.as_bytes()) {
450 let _ = ws.send(Message::Text("ok".into())).await;
451 break true;
452 } else {
453 let _ = ws.close().await;
454 break false;
455 }
456 }
457 Some(Ok(Message::Ping(d))) => {
458 let _ = ws.send(Message::Pong(d)).await;
459 }
460 _ => break false,
461 }
462 };
463 if !authed {
464 return;
465 }
466
467 let mut remotes_rx = remotes.map(|r| r.subscribe());
469
470 let remotes_text = remotes.map(|r| r.get()).unwrap_or_default();
473 let remotes_text = remotes_transform
474 .map(|f| f(&remotes_text))
475 .unwrap_or(remotes_text);
476 if ws
477 .send(Message::Text(format!("remotes:{remotes_text}").into()))
478 .await
479 .is_err()
480 {
481 return;
482 }
483
484 let map = read_config();
485 for (k, v) in &map {
486 if ws
487 .send(Message::Text(format!("{k}={v}").into()))
488 .await
489 .is_err()
490 {
491 return;
492 }
493 }
494 for msg in extra_init {
495 if ws.send(Message::Text(msg.clone().into())).await.is_err() {
496 return;
497 }
498 }
499 if ws.send(Message::Text("ready".into())).await.is_err() {
500 return;
501 }
502
503 let mut config_rx = config.tx.subscribe();
504
505 loop {
506 tokio::select! {
510 msg = ws.recv() => {
511 match msg {
512 Some(Ok(Message::Text(text))) => {
513 let text = text.trim();
514 if let Some(rest) = text.strip_prefix("set ")
515 && let Some((k, v)) = rest.split_once(' ') {
516 let k = k.trim().replace(['\n', '\r'], "");
517 let v = v.trim().replace(['\n', '\r'], "");
518 if k.is_empty() { continue; }
519 modify_config(|map| {
520 if v.is_empty() {
521 map.remove(&k);
522 } else {
523 map.insert(k, v);
524 }
525 });
526 } else if let Some(rest) = text.strip_prefix("remotes-add ") {
527 if let Some((raw_name, raw_uri)) = rest.split_once(' ') {
530 let name = raw_name.trim().replace(['\n', '\r'], "");
531 let uri = raw_uri.trim().replace(['\n', '\r'], "");
532 if !name.is_empty()
533 && !name.contains('=')
534 && !uri.is_empty()
535 && let Some(r) = remotes
536 {
537 r.modify(|entries| {
538 if let Some(pos) = entries.iter().position(|(n, _)| n == &name) {
539 entries[pos].1 = uri;
540 } else {
541 entries.push((name, uri));
542 }
543 });
544 }
545 }
546 } else if let Some(name) = text.strip_prefix("remotes-remove ") {
547 let name = name.trim().replace(['\n', '\r'], "");
548 if !name.is_empty()
549 && let Some(r) = remotes
550 {
551 r.modify(|entries| {
552 entries.retain(|(n, _)| n != &name);
553 });
554 }
555 } else if let Some(name) = text.strip_prefix("remotes-set-default ") {
556 let name = name.trim().replace(['\n', '\r'], "");
558 modify_config(|map| {
559 if name.is_empty() || name == "local" {
560 map.remove("blit.target");
561 } else {
562 map.insert("blit.target".into(), name);
563 }
564 });
565 } else if let Some(rest) = text.strip_prefix("remotes-reorder ") {
566 if let Some(r) = remotes {
569 let desired: Vec<String> = rest
570 .split_whitespace()
571 .map(|s| s.replace(['\n', '\r'], ""))
572 .filter(|s| !s.is_empty())
573 .collect();
574 if !desired.is_empty() {
575 r.modify(|entries| {
576 let map: std::collections::HashMap<&str, &str> = entries
577 .iter()
578 .map(|(n, u)| (n.as_str(), u.as_str()))
579 .collect();
580 let mut reordered: Vec<(String, String)> = desired
581 .iter()
582 .filter_map(|n| {
583 map.get(n.as_str())
584 .map(|u| (n.clone(), u.to_string()))
585 })
586 .collect();
587 let desired_set: std::collections::HashSet<&str> =
588 desired.iter().map(|s| s.as_str()).collect();
589 for (n, u) in entries.iter() {
590 if !desired_set.contains(n.as_str()) {
591 reordered.push((n.clone(), u.clone()));
592 }
593 }
594 *entries = reordered;
595 });
596 }
597 }
598 }
599 }
600 Some(Ok(Message::Close(_))) | None => break,
601 Some(Err(_)) => break,
602 _ => continue,
603 }
604 }
605 broadcast = config_rx.recv() => {
606 match broadcast {
607 Ok(line) => {
608 if ws.send(Message::Text(line.into())).await.is_err() {
609 break;
610 }
611 }
612 Err(tokio::sync::broadcast::error::RecvError::Lagged(_)) => continue,
613 Err(_) => break,
614 }
615 }
616 remotes_update = async {
617 match remotes_rx.as_mut() {
618 Some(rx) => rx.recv().await,
619 None => std::future::pending().await,
620 }
621 } => {
622 match remotes_update {
623 Ok(text) => {
624 let text = remotes_transform
625 .map(|f| f(&text))
626 .unwrap_or(text);
627 if ws
628 .send(Message::Text(format!("remotes:{text}").into()))
629 .await
630 .is_err()
631 {
632 break;
633 }
634 }
635 Err(tokio::sync::broadcast::error::RecvError::Lagged(_)) => {
636 if let Some(r) = remotes {
638 let text = r.get();
639 let text = remotes_transform
640 .map(|f| f(&text))
641 .unwrap_or(text);
642 if ws
643 .send(Message::Text(format!("remotes:{text}").into()))
644 .await
645 .is_err()
646 {
647 break;
648 }
649 }
650 }
651 Err(_) => break,
652 }
653 }
654 }
655 }
656}
657
658#[cfg(test)]
659mod tests {
660 use super::*;
661
662 #[test]
665 fn ct_eq_equal_slices() {
666 assert!(constant_time_eq(b"hello", b"hello"));
667 }
668
669 #[test]
670 fn ct_eq_different_slices() {
671 assert!(!constant_time_eq(b"hello", b"world"));
672 }
673
674 #[test]
675 fn ct_eq_different_lengths() {
676 assert!(!constant_time_eq(b"short", b"longer"));
677 }
678
679 #[test]
680 fn ct_eq_empty_slices() {
681 assert!(constant_time_eq(b"", b""));
682 }
683
684 #[test]
685 fn ct_eq_single_bit_diff() {
686 assert!(!constant_time_eq(b"\x00", b"\x01"));
687 }
688
689 #[test]
690 fn ct_eq_one_empty_one_not() {
691 assert!(!constant_time_eq(b"", b"x"));
692 }
693
694 #[test]
697 fn parse_empty_string() {
698 let map = parse_config_str("");
699 assert!(map.is_empty());
700 }
701
702 #[test]
703 fn parse_comments_and_blanks() {
704 let map = parse_config_str("# comment\n\n # another\n");
705 assert!(map.is_empty());
706 }
707
708 #[test]
709 fn parse_key_value() {
710 let map = parse_config_str("font = Menlo\ntheme = dark\n");
711 assert_eq!(map.get("font").unwrap(), "Menlo");
712 assert_eq!(map.get("theme").unwrap(), "dark");
713 }
714
715 #[test]
716 fn parse_trims_whitespace() {
717 let map = parse_config_str(" key = value ");
718 assert_eq!(map.get("key").unwrap(), "value");
719 }
720
721 #[test]
722 fn parse_line_without_equals() {
723 let map = parse_config_str("no-equals-here\nkey=val");
724 assert_eq!(map.len(), 1);
725 assert_eq!(map.get("key").unwrap(), "val");
726 }
727
728 #[test]
729 fn parse_equals_in_value() {
730 let map = parse_config_str("cmd = a=b=c");
731 assert_eq!(map.get("cmd").unwrap(), "a=b=c");
732 }
733
734 #[test]
735 fn parse_duplicate_keys_last_wins() {
736 let map = parse_config_str("key = first\nkey = second");
737 assert_eq!(map.get("key").unwrap(), "second");
738 }
739
740 #[test]
741 fn parse_mixed_content() {
742 let input = "# header\nfont = FiraCode\n\n# size\nsize = 14\ntheme=light";
743 let map = parse_config_str(input);
744 assert_eq!(map.len(), 3);
745 assert_eq!(map.get("font").unwrap(), "FiraCode");
746 assert_eq!(map.get("size").unwrap(), "14");
747 assert_eq!(map.get("theme").unwrap(), "light");
748 }
749
750 #[test]
753 fn serialize_config_produces_sorted_output() {
754 let mut map: HashMap<String, String> = HashMap::new();
755 map.insert("z".into(), "last".into());
756 map.insert("a".into(), "first".into());
757 let output = serialize_config_str(&map);
758 assert!(output.starts_with("a = first"));
759 assert!(output.contains("z = last"));
760 }
761
762 #[test]
763 fn round_trip_parse_serialize() {
764 let input = "alpha = 1\nbeta = 2\ngamma = 3";
765 let map = parse_config_str(input);
766 let serialized = serialize_config_str(&map);
767 let reparsed = parse_config_str(&serialized);
768 assert_eq!(map, reparsed);
769 }
770
771 #[test]
774 fn remotes_add_new_entry() {
775 let state = RemotesState::ephemeral(String::new());
776 let mut entries = parse_remotes_str(&state.get());
777 entries.push(("rabbit".to_string(), "ssh:rabbit".to_string()));
778 state.set(&entries);
779 let got = parse_remotes_str(&state.get());
780 assert_eq!(got.len(), 1);
781 assert_eq!(got[0], ("rabbit".to_string(), "ssh:rabbit".to_string()));
782 }
783
784 #[test]
785 fn remotes_add_updates_existing() {
786 let initial = "rabbit = ssh:rabbit\n";
787 let state = RemotesState::ephemeral(initial.to_string());
788 let mut entries = parse_remotes_str(&state.get());
789 if let Some(pos) = entries.iter().position(|(n, _)| n == "rabbit") {
790 entries[pos].1 = "tcp:rabbit:3264".to_string();
791 }
792 state.set(&entries);
793 let got = parse_remotes_str(&state.get());
794 assert_eq!(got.len(), 1);
795 assert_eq!(got[0].1, "tcp:rabbit:3264");
796 }
797
798 #[test]
799 fn remotes_remove_existing() {
800 let initial = "rabbit = ssh:rabbit\nhound = ssh:hound\n";
801 let state = RemotesState::ephemeral(initial.to_string());
802 let mut entries = parse_remotes_str(&state.get());
803 entries.retain(|(n, _)| n != "rabbit");
804 state.set(&entries);
805 let got = parse_remotes_str(&state.get());
806 assert_eq!(got.len(), 1);
807 assert_eq!(got[0].0, "hound");
808 }
809
810 #[test]
811 fn remotes_remove_nonexistent_is_noop() {
812 let initial = "rabbit = ssh:rabbit\n";
813 let state = RemotesState::ephemeral(initial.to_string());
814 let mut entries = parse_remotes_str(&state.get());
815 let before = entries.len();
816 entries.retain(|(n, _)| n != "does-not-exist");
817 assert_eq!(entries.len(), before);
818 }
819
820 #[test]
821 fn remotes_add_rejects_empty_name() {
822 let name = "";
824 assert!(name.is_empty() || name.contains('='));
825 }
826
827 #[test]
828 fn remotes_add_rejects_name_with_equals() {
829 let name = "foo=bar";
830 assert!(name.contains('='));
831 }
832
833 #[test]
836 fn set_default_inserts_target_key() {
837 let mut map = parse_config_str("font = Mono\n");
838 map.insert("blit.target".into(), "rabbit".into());
839 let serialized = serialize_config_str(&map);
840 let reparsed = parse_config_str(&serialized);
841 assert_eq!(
842 reparsed.get("blit.target").map(|s| s.as_str()),
843 Some("rabbit")
844 );
845 assert_eq!(reparsed.get("font").map(|s| s.as_str()), Some("Mono"));
846 }
847
848 #[test]
849 fn set_default_local_removes_target_key() {
850 let mut map = parse_config_str("blit.target = rabbit\nfont = Mono\n");
851 map.remove("blit.target");
853 let serialized = serialize_config_str(&map);
854 let reparsed = parse_config_str(&serialized);
855 assert!(!reparsed.contains_key("blit.target"));
856 assert_eq!(reparsed.get("font").map(|s| s.as_str()), Some("Mono"));
857 }
858}