1use anyhow::{Context, Result};
16use directories::ProjectDirs;
17use serde::{Deserialize, Serialize};
18use std::collections::HashMap;
19use std::env;
20use std::path::{Path, PathBuf};
21use tokio::fs;
22
23use crate::node::Node;
24
25#[derive(Debug, Serialize, Deserialize, Default, Clone)]
26pub struct Config {
27 #[serde(default)]
28 pub defaults: Defaults,
29
30 #[serde(default)]
31 pub clusters: HashMap<String, Cluster>,
32
33 #[serde(default)]
34 pub interactive: InteractiveConfig,
35}
36
37#[derive(Debug, Serialize, Deserialize, Default, Clone)]
38pub struct Defaults {
39 pub user: Option<String>,
40 pub port: Option<u16>,
41 pub ssh_key: Option<String>,
42 pub parallel: Option<usize>,
43 pub timeout: Option<u64>,
44}
45
46#[derive(Debug, Serialize, Deserialize, Default, Clone)]
47pub struct InteractiveConfig {
48 #[serde(default = "default_interactive_mode")]
49 pub default_mode: InteractiveMode,
50
51 #[serde(default = "default_prompt_format")]
52 pub prompt_format: String,
53
54 #[serde(default)]
55 pub history_file: Option<String>,
56
57 #[serde(default)]
58 pub colors: HashMap<String, String>,
59
60 #[serde(default)]
61 pub keybindings: KeyBindings,
62
63 #[serde(default)]
64 pub broadcast_prefix: Option<String>,
65
66 #[serde(default)]
67 pub node_switch_prefix: Option<String>,
68
69 #[serde(default)]
70 pub show_timestamps: bool,
71
72 #[serde(default)]
73 pub work_dir: Option<String>,
74}
75
76#[derive(Debug, Serialize, Deserialize, Clone)]
77#[serde(rename_all = "snake_case")]
78#[derive(Default)]
79pub enum InteractiveMode {
80 #[default]
81 SingleNode,
82 Multiplex,
83}
84
85fn default_interactive_mode() -> InteractiveMode {
86 InteractiveMode::SingleNode
87}
88
89fn default_prompt_format() -> String {
90 "[{node}:{user}@{host}:{pwd}]$ ".to_string()
91}
92
93#[derive(Debug, Serialize, Deserialize, Default, Clone)]
94pub struct KeyBindings {
95 #[serde(default = "default_switch_node")]
96 pub switch_node: String,
97
98 #[serde(default = "default_broadcast_toggle")]
99 pub broadcast_toggle: String,
100
101 #[serde(default = "default_quit")]
102 pub quit: String,
103
104 #[serde(default)]
105 pub clear_screen: Option<String>,
106}
107
108fn default_switch_node() -> String {
109 "Ctrl+N".to_string()
110}
111
112fn default_broadcast_toggle() -> String {
113 "Ctrl+B".to_string()
114}
115
116fn default_quit() -> String {
117 "Ctrl+Q".to_string()
118}
119
120#[derive(Debug, Serialize, Deserialize, Clone)]
121pub struct Cluster {
122 pub nodes: Vec<NodeConfig>,
123
124 #[serde(flatten)]
125 pub defaults: ClusterDefaults,
126
127 #[serde(default)]
128 pub interactive: Option<InteractiveConfig>,
129}
130
131#[derive(Debug, Serialize, Deserialize, Default, Clone)]
132pub struct ClusterDefaults {
133 pub user: Option<String>,
134 pub port: Option<u16>,
135 pub ssh_key: Option<String>,
136 pub timeout: Option<u64>,
137}
138
139#[derive(Debug, Serialize, Deserialize, Clone)]
140#[serde(untagged)]
141pub enum NodeConfig {
142 Simple(String),
143 Detailed {
144 host: String,
145 #[serde(default)]
146 port: Option<u16>,
147 #[serde(default)]
148 user: Option<String>,
149 },
150}
151
152impl Config {
153 pub async fn load(path: &Path) -> Result<Self> {
154 let expanded_path = expand_tilde(path);
156
157 if !expanded_path.exists() {
158 tracing::debug!(
159 "Config file not found at {:?}, using defaults",
160 expanded_path
161 );
162 return Ok(Self::default());
163 }
164
165 let content = fs::read_to_string(&expanded_path)
166 .await
167 .with_context(|| format!("Failed to read configuration file at {}. Please check file permissions and ensure the file is accessible.", expanded_path.display()))?;
168
169 let config: Config =
170 serde_yaml::from_str(&content).with_context(|| format!("Failed to parse YAML configuration file at {}. Please check the YAML syntax is valid.\nCommon issues:\n - Incorrect indentation (use spaces, not tabs)\n - Missing colons after keys\n - Unquoted special characters", expanded_path.display()))?;
171
172 Ok(config)
173 }
174
175 pub fn from_backendai_env() -> Option<Cluster> {
177 let cluster_hosts = env::var("BACKENDAI_CLUSTER_HOSTS").ok()?;
178 let _current_host = env::var("BACKENDAI_CLUSTER_HOST").ok()?;
179 let cluster_role = env::var("BACKENDAI_CLUSTER_ROLE").ok();
180
181 let mut nodes = Vec::new();
183 for host in cluster_hosts.split(',') {
184 let host = host.trim();
185 if !host.is_empty() {
186 let default_user = env::var("USER")
188 .or_else(|_| env::var("USERNAME"))
189 .or_else(|_| env::var("LOGNAME"))
190 .unwrap_or_else(|_| {
191 #[cfg(unix)]
193 {
194 whoami::username()
195 }
196 #[cfg(not(unix))]
197 {
198 "user".to_string()
199 }
200 });
201
202 nodes.push(NodeConfig::Simple(format!("{default_user}@{host}:2200")));
204 }
205 }
206
207 if nodes.is_empty() {
208 return None;
209 }
210
211 let filtered_nodes = if let Some(role) = &cluster_role {
213 if role == "main" {
214 nodes
216 } else {
217 nodes.into_iter().skip(1).collect()
221 }
222 } else {
223 nodes
224 };
225
226 Some(Cluster {
227 nodes: filtered_nodes,
228 defaults: ClusterDefaults::default(),
229 interactive: None,
230 })
231 }
232
233 pub async fn load_with_priority(default_path: &Path) -> Result<Self> {
239 if let Some(backendai_cluster) = Self::from_backendai_env() {
241 let mut config = Self::default();
242 config
243 .clusters
244 .insert("backendai".to_string(), backendai_cluster);
245 return Ok(config);
246 }
247
248 let current_dir_config = PathBuf::from("config.yaml");
250 if current_dir_config.exists() {
251 if let Ok(config) = Self::load(¤t_dir_config).await {
252 return Ok(config);
253 }
254 }
255
256 if let Ok(xdg_config_home) = env::var("XDG_CONFIG_HOME") {
258 let xdg_config = PathBuf::from(xdg_config_home)
260 .join("bssh")
261 .join("config.yaml");
262 if xdg_config.exists() {
263 if let Ok(config) = Self::load(&xdg_config).await {
264 return Ok(config);
265 }
266 }
267 } else if let Some(proj_dirs) = ProjectDirs::from("", "", "bssh") {
268 let xdg_config = proj_dirs.config_dir().join("config.yaml");
270 if xdg_config.exists() {
271 if let Ok(config) = Self::load(&xdg_config).await {
272 return Ok(config);
273 }
274 }
275 }
276
277 Self::load(default_path).await
279 }
280
281 pub fn get_cluster(&self, name: &str) -> Option<&Cluster> {
282 self.clusters.get(name)
283 }
284
285 pub fn resolve_nodes(&self, cluster_name: &str) -> Result<Vec<Node>> {
286 let cluster = self
287 .get_cluster(cluster_name)
288 .ok_or_else(|| anyhow::anyhow!("Cluster '{}' not found in configuration.\nAvailable clusters: {}\nPlease check your configuration file or use 'bssh list' to see available clusters.", cluster_name, self.clusters.keys().cloned().collect::<Vec<_>>().join(", ")))?;
289
290 let mut nodes = Vec::new();
291
292 for node_config in &cluster.nodes {
293 let node = match node_config {
294 NodeConfig::Simple(host) => {
295 let expanded_host = expand_env_vars(host);
297
298 let default_user = cluster
299 .defaults
300 .user
301 .as_ref()
302 .or(self.defaults.user.as_ref())
303 .map(|u| expand_env_vars(u));
304
305 let default_port = cluster.defaults.port.or(self.defaults.port).unwrap_or(22);
306
307 Node::parse(&expanded_host, default_user.as_deref()).map(|mut n| {
308 if !expanded_host.contains(':') {
309 n.port = default_port;
310 }
311 n
312 })?
313 }
314 NodeConfig::Detailed { host, port, user } => {
315 let expanded_host = expand_env_vars(host);
317
318 let username = user
319 .as_ref()
320 .map(|u| expand_env_vars(u))
321 .or_else(|| cluster.defaults.user.as_ref().map(|u| expand_env_vars(u)))
322 .or_else(|| self.defaults.user.as_ref().map(|u| expand_env_vars(u)))
323 .unwrap_or_else(|| {
324 std::env::var("USER")
325 .or_else(|_| std::env::var("USERNAME"))
326 .or_else(|_| std::env::var("LOGNAME"))
327 .unwrap_or_else(|_| {
328 #[cfg(unix)]
330 {
331 whoami::username()
332 }
333 #[cfg(not(unix))]
334 {
335 "user".to_string()
336 }
337 })
338 });
339
340 let port = port
341 .or(cluster.defaults.port)
342 .or(self.defaults.port)
343 .unwrap_or(22);
344
345 Node::new(expanded_host, port, username)
346 }
347 };
348
349 nodes.push(node);
350 }
351
352 Ok(nodes)
353 }
354
355 pub fn get_ssh_key(&self, cluster_name: Option<&str>) -> Option<String> {
356 if let Some(cluster_name) = cluster_name {
357 if let Some(cluster) = self.get_cluster(cluster_name) {
358 if let Some(key) = &cluster.defaults.ssh_key {
359 return Some(key.clone());
360 }
361 }
362 }
363
364 self.defaults.ssh_key.clone()
365 }
366
367 pub fn get_timeout(&self, cluster_name: Option<&str>) -> Option<u64> {
368 if let Some(cluster_name) = cluster_name {
369 if let Some(cluster) = self.get_cluster(cluster_name) {
370 if let Some(timeout) = cluster.defaults.timeout {
371 return Some(timeout);
372 }
373 }
374 }
375
376 self.defaults.timeout
377 }
378
379 pub fn get_interactive_config(&self, cluster_name: Option<&str>) -> InteractiveConfig {
381 let mut config = self.interactive.clone();
382
383 if let Some(cluster_name) = cluster_name {
384 if let Some(cluster) = self.get_cluster(cluster_name) {
385 if let Some(ref cluster_interactive) = cluster.interactive {
386 config.default_mode = cluster_interactive.default_mode.clone();
389
390 if !cluster_interactive.prompt_format.is_empty() {
391 config.prompt_format = cluster_interactive.prompt_format.clone();
392 }
393
394 if cluster_interactive.history_file.is_some() {
395 config.history_file = cluster_interactive.history_file.clone();
396 }
397
398 if cluster_interactive.work_dir.is_some() {
399 config.work_dir = cluster_interactive.work_dir.clone();
400 }
401
402 if cluster_interactive.broadcast_prefix.is_some() {
403 config.broadcast_prefix = cluster_interactive.broadcast_prefix.clone();
404 }
405
406 if cluster_interactive.node_switch_prefix.is_some() {
407 config.node_switch_prefix = cluster_interactive.node_switch_prefix.clone();
408 }
409
410 config.show_timestamps = cluster_interactive.show_timestamps;
412
413 for (k, v) in &cluster_interactive.colors {
415 config.colors.insert(k.clone(), v.clone());
416 }
417
418 if !cluster_interactive.keybindings.switch_node.is_empty() {
420 config.keybindings.switch_node =
421 cluster_interactive.keybindings.switch_node.clone();
422 }
423 if !cluster_interactive.keybindings.broadcast_toggle.is_empty() {
424 config.keybindings.broadcast_toggle =
425 cluster_interactive.keybindings.broadcast_toggle.clone();
426 }
427 if !cluster_interactive.keybindings.quit.is_empty() {
428 config.keybindings.quit = cluster_interactive.keybindings.quit.clone();
429 }
430 if cluster_interactive.keybindings.clear_screen.is_some() {
431 config.keybindings.clear_screen =
432 cluster_interactive.keybindings.clear_screen.clone();
433 }
434 }
435 }
436 }
437
438 config
439 }
440
441 pub async fn save(&self, path: &Path) -> Result<()> {
443 let expanded_path = expand_tilde(path);
444
445 if let Some(parent) = expanded_path.parent() {
447 fs::create_dir_all(parent)
448 .await
449 .with_context(|| format!("Failed to create directory {parent:?}"))?;
450 }
451
452 let yaml =
453 serde_yaml::to_string(self).context("Failed to serialize configuration to YAML")?;
454
455 fs::write(&expanded_path, yaml)
456 .await
457 .with_context(|| format!("Failed to write configuration to {expanded_path:?}"))?;
458
459 Ok(())
460 }
461
462 pub async fn update_interactive_preferences(
464 &mut self,
465 cluster_name: Option<&str>,
466 updates: InteractiveConfigUpdate,
467 ) -> Result<()> {
468 let target_config = if let Some(cluster_name) = cluster_name {
469 if let Some(cluster) = self.clusters.get_mut(cluster_name) {
470 if cluster.interactive.is_none() {
472 cluster.interactive = Some(InteractiveConfig::default());
473 }
474 cluster.interactive.as_mut().unwrap()
475 } else {
476 &mut self.interactive
478 }
479 } else {
480 &mut self.interactive
482 };
483
484 if let Some(mode) = updates.default_mode {
486 target_config.default_mode = mode;
487 }
488 if let Some(prompt) = updates.prompt_format {
489 target_config.prompt_format = prompt;
490 }
491 if let Some(history) = updates.history_file {
492 target_config.history_file = Some(history);
493 }
494 if let Some(work_dir) = updates.work_dir {
495 target_config.work_dir = Some(work_dir);
496 }
497 if let Some(timestamps) = updates.show_timestamps {
498 target_config.show_timestamps = timestamps;
499 }
500 if let Some(colors) = updates.colors {
501 target_config.colors.extend(colors);
502 }
503
504 let config_path = self.get_config_path()?;
506 self.save(&config_path).await?;
507
508 Ok(())
509 }
510
511 fn get_config_path(&self) -> Result<PathBuf> {
513 let current_dir_config = PathBuf::from("config.yaml");
519 if current_dir_config.exists() {
520 return Ok(current_dir_config);
521 }
522
523 if let Ok(xdg_config_home) = env::var("XDG_CONFIG_HOME") {
525 let xdg_config = PathBuf::from(xdg_config_home)
526 .join("bssh")
527 .join("config.yaml");
528 return Ok(xdg_config);
529 } else if let Some(proj_dirs) = ProjectDirs::from("", "", "bssh") {
530 let xdg_config = proj_dirs.config_dir().join("config.yaml");
531 return Ok(xdg_config);
532 }
533
534 let home = env::var("HOME")
536 .or_else(|_| env::var("USERPROFILE"))
537 .context("Unable to determine home directory")?;
538 Ok(PathBuf::from(home).join(".bssh").join("config.yaml"))
539 }
540}
541
542#[derive(Debug, Default)]
544pub struct InteractiveConfigUpdate {
545 pub default_mode: Option<InteractiveMode>,
546 pub prompt_format: Option<String>,
547 pub history_file: Option<String>,
548 pub work_dir: Option<String>,
549 pub show_timestamps: Option<bool>,
550 pub colors: Option<HashMap<String, String>>,
551}
552
553fn expand_tilde(path: &Path) -> PathBuf {
554 if let Some(path_str) = path.to_str() {
555 if path_str.starts_with("~/") {
556 if let Ok(home) = std::env::var("HOME") {
557 return PathBuf::from(path_str.replacen("~", &home, 1));
558 }
559 }
560 }
561 path.to_path_buf()
562}
563
564fn expand_env_vars(input: &str) -> String {
567 let mut result = input.to_string();
568 let mut processed = 0;
569
570 while processed < result.len() {
572 if let Some(start) = result[processed..].find("${") {
573 let abs_start = processed + start;
574 if let Some(end) = result[abs_start..].find('}') {
575 let var_name = &result[abs_start + 2..abs_start + end];
576 if !var_name.is_empty() && var_name.chars().all(|c| c.is_alphanumeric() || c == '_')
577 {
578 let replacement = std::env::var(var_name).unwrap_or_else(|_| {
579 tracing::debug!("Environment variable {} not found", var_name);
580 format!("${{{var_name}}}")
581 });
582 result.replace_range(abs_start..abs_start + end + 1, &replacement);
583 processed = abs_start + replacement.len();
584 } else {
585 processed = abs_start + end + 1;
586 }
587 } else {
588 break;
589 }
590 } else {
591 break;
592 }
593 }
594
595 let mut i = 0;
597 let bytes = result.as_bytes();
598 let mut new_result = String::new();
599
600 while i < bytes.len() {
601 if bytes[i] == b'$' && i + 1 < bytes.len() && bytes[i + 1] != b'{' {
602 let start = i;
603 i += 1;
604
605 while i < bytes.len() && (bytes[i].is_ascii_alphanumeric() || bytes[i] == b'_') {
607 i += 1;
608 }
609
610 if i > start + 1 {
611 let var_name = std::str::from_utf8(&bytes[start + 1..i]).unwrap();
612 let replacement = std::env::var(var_name).unwrap_or_else(|_| {
613 tracing::debug!("Environment variable {} not found", var_name);
614 String::from_utf8(bytes[start..i].to_vec()).unwrap()
615 });
616 new_result.push_str(&replacement);
617 } else {
618 new_result.push('$');
619 }
620 } else {
621 new_result.push(bytes[i] as char);
622 i += 1;
623 }
624 }
625
626 new_result
627}
628
629#[cfg(test)]
630mod tests {
631 use super::*;
632
633 #[test]
634 fn test_expand_env_vars() {
635 unsafe {
636 std::env::set_var("TEST_VAR", "test_value");
637 std::env::set_var("TEST_USER", "testuser");
638 }
639
640 assert_eq!(expand_env_vars("Hello ${TEST_VAR}!"), "Hello test_value!");
642 assert_eq!(expand_env_vars("${TEST_USER}@host"), "testuser@host");
643
644 assert_eq!(expand_env_vars("Hello $TEST_VAR!"), "Hello test_value!");
646 assert_eq!(expand_env_vars("$TEST_USER@host"), "testuser@host");
647
648 assert_eq!(
650 expand_env_vars("${TEST_USER}:$TEST_VAR"),
651 "testuser:test_value"
652 );
653
654 assert_eq!(expand_env_vars("${NONEXISTENT}"), "${NONEXISTENT}");
656 assert_eq!(expand_env_vars("$NONEXISTENT"), "$NONEXISTENT");
657
658 assert_eq!(expand_env_vars("no variables here"), "no variables here");
660 }
661
662 #[test]
663 fn test_expand_tilde() {
664 unsafe {
665 std::env::set_var("HOME", "/home/user");
666 }
667 let path = Path::new("~/.ssh/config");
668 let expanded = expand_tilde(path);
669 assert_eq!(expanded, PathBuf::from("/home/user/.ssh/config"));
670 }
671
672 #[test]
673 fn test_config_parsing() {
674 let yaml = r#"
675defaults:
676 user: admin
677 port: 22
678 ssh_key: ~/.ssh/id_rsa
679
680interactive:
681 default_mode: multiplex
682 prompt_format: "[{node}] $ "
683 history_file: ~/.bssh_history
684 show_timestamps: true
685 colors:
686 node1: red
687 node2: blue
688 keybindings:
689 switch_node: "Ctrl+T"
690 broadcast_toggle: "Ctrl+A"
691
692clusters:
693 production:
694 nodes:
695 - web1.example.com
696 - web2.example.com:2222
697 - user@web3.example.com
698 ssh_key: ~/.ssh/prod_key
699 interactive:
700 default_mode: single_node
701 prompt_format: "prod> "
702
703 staging:
704 nodes:
705 - host: staging1.example.com
706 port: 2200
707 user: deploy
708 - staging2.example.com
709 user: staging_user
710"#;
711
712 let config: Config = serde_yaml::from_str(yaml).unwrap();
713 assert_eq!(config.defaults.user, Some("admin".to_string()));
714 assert_eq!(config.clusters.len(), 2);
715
716 assert!(matches!(
718 config.interactive.default_mode,
719 InteractiveMode::Multiplex
720 ));
721 assert_eq!(config.interactive.prompt_format, "[{node}] $ ");
722 assert_eq!(
723 config.interactive.history_file,
724 Some("~/.bssh_history".to_string())
725 );
726 assert!(config.interactive.show_timestamps);
727 assert_eq!(
728 config.interactive.colors.get("node1"),
729 Some(&"red".to_string())
730 );
731 assert_eq!(config.interactive.keybindings.switch_node, "Ctrl+T");
732
733 let prod_cluster = config.get_cluster("production").unwrap();
734 assert_eq!(prod_cluster.nodes.len(), 3);
735 assert_eq!(
736 prod_cluster.defaults.ssh_key,
737 Some("~/.ssh/prod_key".to_string())
738 );
739
740 let prod_interactive = prod_cluster.interactive.as_ref().unwrap();
742 assert!(matches!(
743 prod_interactive.default_mode,
744 InteractiveMode::SingleNode
745 ));
746 assert_eq!(prod_interactive.prompt_format, "prod> ");
747 }
748
749 #[test]
750 fn test_interactive_config_fallback() {
751 let yaml = r#"
752interactive:
753 default_mode: multiplex
754 prompt_format: "global> "
755 show_timestamps: true
756
757clusters:
758 with_override:
759 nodes:
760 - host1
761 interactive:
762 default_mode: multiplex
763 prompt_format: "override> "
764
765 without_override:
766 nodes:
767 - host2
768"#;
769
770 let config: Config = serde_yaml::from_str(yaml).unwrap();
771
772 let with_override = config.get_interactive_config(Some("with_override"));
774 assert_eq!(with_override.prompt_format, "override> ");
775 assert!(matches!(
776 with_override.default_mode,
777 InteractiveMode::Multiplex
778 ));
779 let without_override = config.get_interactive_config(Some("without_override"));
783 assert_eq!(without_override.prompt_format, "global> ");
784 assert!(matches!(
785 without_override.default_mode,
786 InteractiveMode::Multiplex
787 ));
788 assert!(without_override.show_timestamps);
789
790 let global = config.get_interactive_config(None);
792 assert_eq!(global.prompt_format, "global> ");
793 assert!(matches!(global.default_mode, InteractiveMode::Multiplex));
794 }
795
796 #[test]
797 fn test_backendai_env_parsing() {
798 unsafe {
800 std::env::set_var("BACKENDAI_CLUSTER_HOSTS", "sub1,main1");
801 std::env::set_var("BACKENDAI_CLUSTER_HOST", "main1");
802 std::env::set_var("BACKENDAI_CLUSTER_ROLE", "main");
803 std::env::set_var("USER", "testuser");
804 }
805
806 let cluster = Config::from_backendai_env().unwrap();
807
808 assert_eq!(cluster.nodes.len(), 2);
810
811 match &cluster.nodes[0] {
813 NodeConfig::Simple(host) => {
814 assert_eq!(host, "testuser@sub1:2200");
815 }
816 _ => panic!("Expected Simple node config"),
817 }
818
819 unsafe {
821 std::env::set_var("BACKENDAI_CLUSTER_ROLE", "sub");
822 }
823 let cluster = Config::from_backendai_env().unwrap();
824 assert_eq!(cluster.nodes.len(), 1);
825
826 match &cluster.nodes[0] {
827 NodeConfig::Simple(host) => {
828 assert_eq!(host, "testuser@main1:2200");
829 }
830 _ => panic!("Expected Simple node config"),
831 }
832
833 unsafe {
835 std::env::remove_var("BACKENDAI_CLUSTER_HOSTS");
836 std::env::remove_var("BACKENDAI_CLUSTER_HOST");
837 std::env::remove_var("BACKENDAI_CLUSTER_ROLE");
838 }
839 }
840}