bssh/
config.rs

1// Copyright 2025 Lablup Inc. and Jeongkyu Shin
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use anyhow::{Context, Result};
16use directories::ProjectDirs;
17use serde::{Deserialize, Serialize};
18use std::collections::HashMap;
19use std::env;
20use std::path::{Path, PathBuf};
21use tokio::fs;
22
23use crate::node::Node;
24
25#[derive(Debug, Serialize, Deserialize, Default, Clone)]
26pub struct Config {
27    #[serde(default)]
28    pub defaults: Defaults,
29
30    #[serde(default)]
31    pub clusters: HashMap<String, Cluster>,
32
33    #[serde(default)]
34    pub interactive: InteractiveConfig,
35}
36
37#[derive(Debug, Serialize, Deserialize, Default, Clone)]
38pub struct Defaults {
39    pub user: Option<String>,
40    pub port: Option<u16>,
41    pub ssh_key: Option<String>,
42    pub parallel: Option<usize>,
43    pub timeout: Option<u64>,
44}
45
46#[derive(Debug, Serialize, Deserialize, Default, Clone)]
47pub struct InteractiveConfig {
48    #[serde(default = "default_interactive_mode")]
49    pub default_mode: InteractiveMode,
50
51    #[serde(default = "default_prompt_format")]
52    pub prompt_format: String,
53
54    #[serde(default)]
55    pub history_file: Option<String>,
56
57    #[serde(default)]
58    pub colors: HashMap<String, String>,
59
60    #[serde(default)]
61    pub keybindings: KeyBindings,
62
63    #[serde(default)]
64    pub broadcast_prefix: Option<String>,
65
66    #[serde(default)]
67    pub node_switch_prefix: Option<String>,
68
69    #[serde(default)]
70    pub show_timestamps: bool,
71
72    #[serde(default)]
73    pub work_dir: Option<String>,
74}
75
76#[derive(Debug, Serialize, Deserialize, Clone)]
77#[serde(rename_all = "snake_case")]
78#[derive(Default)]
79pub enum InteractiveMode {
80    #[default]
81    SingleNode,
82    Multiplex,
83}
84
85fn default_interactive_mode() -> InteractiveMode {
86    InteractiveMode::SingleNode
87}
88
89fn default_prompt_format() -> String {
90    "[{node}:{user}@{host}:{pwd}]$ ".to_string()
91}
92
93#[derive(Debug, Serialize, Deserialize, Default, Clone)]
94pub struct KeyBindings {
95    #[serde(default = "default_switch_node")]
96    pub switch_node: String,
97
98    #[serde(default = "default_broadcast_toggle")]
99    pub broadcast_toggle: String,
100
101    #[serde(default = "default_quit")]
102    pub quit: String,
103
104    #[serde(default)]
105    pub clear_screen: Option<String>,
106}
107
108fn default_switch_node() -> String {
109    "Ctrl+N".to_string()
110}
111
112fn default_broadcast_toggle() -> String {
113    "Ctrl+B".to_string()
114}
115
116fn default_quit() -> String {
117    "Ctrl+Q".to_string()
118}
119
120#[derive(Debug, Serialize, Deserialize, Clone)]
121pub struct Cluster {
122    pub nodes: Vec<NodeConfig>,
123
124    #[serde(flatten)]
125    pub defaults: ClusterDefaults,
126
127    #[serde(default)]
128    pub interactive: Option<InteractiveConfig>,
129}
130
131#[derive(Debug, Serialize, Deserialize, Default, Clone)]
132pub struct ClusterDefaults {
133    pub user: Option<String>,
134    pub port: Option<u16>,
135    pub ssh_key: Option<String>,
136    pub timeout: Option<u64>,
137}
138
139#[derive(Debug, Serialize, Deserialize, Clone)]
140#[serde(untagged)]
141pub enum NodeConfig {
142    Simple(String),
143    Detailed {
144        host: String,
145        #[serde(default)]
146        port: Option<u16>,
147        #[serde(default)]
148        user: Option<String>,
149    },
150}
151
152impl Config {
153    pub async fn load(path: &Path) -> Result<Self> {
154        // Expand tilde in path
155        let expanded_path = expand_tilde(path);
156
157        if !expanded_path.exists() {
158            tracing::debug!(
159                "Config file not found at {:?}, using defaults",
160                expanded_path
161            );
162            return Ok(Self::default());
163        }
164
165        let content = fs::read_to_string(&expanded_path)
166            .await
167            .with_context(|| format!("Failed to read configuration file at {}. Please check file permissions and ensure the file is accessible.", expanded_path.display()))?;
168
169        let config: Config =
170            serde_yaml::from_str(&content).with_context(|| format!("Failed to parse YAML configuration file at {}. Please check the YAML syntax is valid.\nCommon issues:\n  - Incorrect indentation (use spaces, not tabs)\n  - Missing colons after keys\n  - Unquoted special characters", expanded_path.display()))?;
171
172        Ok(config)
173    }
174
175    /// Create a cluster configuration from Backend.AI environment variables
176    pub fn from_backendai_env() -> Option<Cluster> {
177        let cluster_hosts = env::var("BACKENDAI_CLUSTER_HOSTS").ok()?;
178        let _current_host = env::var("BACKENDAI_CLUSTER_HOST").ok()?;
179        let cluster_role = env::var("BACKENDAI_CLUSTER_ROLE").ok();
180
181        // Parse the hosts into nodes
182        let mut nodes = Vec::new();
183        for host in cluster_hosts.split(',') {
184            let host = host.trim();
185            if !host.is_empty() {
186                // Get current user as default
187                let default_user = env::var("USER")
188                    .or_else(|_| env::var("USERNAME"))
189                    .or_else(|_| env::var("LOGNAME"))
190                    .unwrap_or_else(|_| {
191                        // Try to get current user from system
192                        #[cfg(unix)]
193                        {
194                            whoami::username()
195                        }
196                        #[cfg(not(unix))]
197                        {
198                            "user".to_string()
199                        }
200                    });
201
202                // Backend.AI multi-node clusters use port 2200 by default
203                nodes.push(NodeConfig::Simple(format!("{default_user}@{host}:2200")));
204            }
205        }
206
207        if nodes.is_empty() {
208            return None;
209        }
210
211        // Check if we should filter nodes based on role
212        let filtered_nodes = if let Some(role) = &cluster_role {
213            if role == "main" {
214                // If current node is main, execute on all nodes
215                nodes
216            } else {
217                // If current node is sub, only execute on sub nodes
218                // We need to identify which nodes are sub nodes
219                // For now, we'll execute on all nodes except the main (first) node
220                nodes.into_iter().skip(1).collect()
221            }
222        } else {
223            nodes
224        };
225
226        Some(Cluster {
227            nodes: filtered_nodes,
228            defaults: ClusterDefaults::default(),
229            interactive: None,
230        })
231    }
232
233    /// Load configuration with priority order:
234    /// 1. Backend.AI environment variables
235    /// 2. Current directory config.yaml
236    /// 3. XDG config directory ($XDG_CONFIG_HOME/bssh/config.yaml or ~/.config/bssh/config.yaml)
237    /// 4. Default path (from CLI argument)
238    pub async fn load_with_priority(default_path: &Path) -> Result<Self> {
239        // Try Backend.AI environment first
240        if let Some(backendai_cluster) = Self::from_backendai_env() {
241            let mut config = Self::default();
242            config
243                .clusters
244                .insert("backendai".to_string(), backendai_cluster);
245            return Ok(config);
246        }
247
248        // Try current directory config.yaml
249        let current_dir_config = PathBuf::from("config.yaml");
250        if current_dir_config.exists() {
251            if let Ok(config) = Self::load(&current_dir_config).await {
252                return Ok(config);
253            }
254        }
255
256        // Try XDG config directory
257        if let Ok(xdg_config_home) = env::var("XDG_CONFIG_HOME") {
258            // Use XDG_CONFIG_HOME if set
259            let xdg_config = PathBuf::from(xdg_config_home)
260                .join("bssh")
261                .join("config.yaml");
262            if xdg_config.exists() {
263                if let Ok(config) = Self::load(&xdg_config).await {
264                    return Ok(config);
265                }
266            }
267        } else if let Some(proj_dirs) = ProjectDirs::from("", "", "bssh") {
268            // Use directories crate for standard XDG path
269            let xdg_config = proj_dirs.config_dir().join("config.yaml");
270            if xdg_config.exists() {
271                if let Ok(config) = Self::load(&xdg_config).await {
272                    return Ok(config);
273                }
274            }
275        }
276
277        // Finally, try the default path from CLI (will create if needed)
278        Self::load(default_path).await
279    }
280
281    pub fn get_cluster(&self, name: &str) -> Option<&Cluster> {
282        self.clusters.get(name)
283    }
284
285    pub fn resolve_nodes(&self, cluster_name: &str) -> Result<Vec<Node>> {
286        let cluster = self
287            .get_cluster(cluster_name)
288            .ok_or_else(|| anyhow::anyhow!("Cluster '{}' not found in configuration.\nAvailable clusters: {}\nPlease check your configuration file or use 'bssh list' to see available clusters.", cluster_name, self.clusters.keys().cloned().collect::<Vec<_>>().join(", ")))?;
289
290        let mut nodes = Vec::new();
291
292        for node_config in &cluster.nodes {
293            let node = match node_config {
294                NodeConfig::Simple(host) => {
295                    // Expand environment variables in host
296                    let expanded_host = expand_env_vars(host);
297
298                    let default_user = cluster
299                        .defaults
300                        .user
301                        .as_ref()
302                        .or(self.defaults.user.as_ref())
303                        .map(|u| expand_env_vars(u));
304
305                    let default_port = cluster.defaults.port.or(self.defaults.port).unwrap_or(22);
306
307                    Node::parse(&expanded_host, default_user.as_deref()).map(|mut n| {
308                        if !expanded_host.contains(':') {
309                            n.port = default_port;
310                        }
311                        n
312                    })?
313                }
314                NodeConfig::Detailed { host, port, user } => {
315                    // Expand environment variables
316                    let expanded_host = expand_env_vars(host);
317
318                    let username = user
319                        .as_ref()
320                        .map(|u| expand_env_vars(u))
321                        .or_else(|| cluster.defaults.user.as_ref().map(|u| expand_env_vars(u)))
322                        .or_else(|| self.defaults.user.as_ref().map(|u| expand_env_vars(u)))
323                        .unwrap_or_else(|| {
324                            std::env::var("USER")
325                                .or_else(|_| std::env::var("USERNAME"))
326                                .or_else(|_| std::env::var("LOGNAME"))
327                                .unwrap_or_else(|_| {
328                                    // Try to get current user from system
329                                    #[cfg(unix)]
330                                    {
331                                        whoami::username()
332                                    }
333                                    #[cfg(not(unix))]
334                                    {
335                                        "user".to_string()
336                                    }
337                                })
338                        });
339
340                    let port = port
341                        .or(cluster.defaults.port)
342                        .or(self.defaults.port)
343                        .unwrap_or(22);
344
345                    Node::new(expanded_host, port, username)
346                }
347            };
348
349            nodes.push(node);
350        }
351
352        Ok(nodes)
353    }
354
355    pub fn get_ssh_key(&self, cluster_name: Option<&str>) -> Option<String> {
356        if let Some(cluster_name) = cluster_name {
357            if let Some(cluster) = self.get_cluster(cluster_name) {
358                if let Some(key) = &cluster.defaults.ssh_key {
359                    return Some(key.clone());
360                }
361            }
362        }
363
364        self.defaults.ssh_key.clone()
365    }
366
367    pub fn get_timeout(&self, cluster_name: Option<&str>) -> Option<u64> {
368        if let Some(cluster_name) = cluster_name {
369            if let Some(cluster) = self.get_cluster(cluster_name) {
370                if let Some(timeout) = cluster.defaults.timeout {
371                    return Some(timeout);
372                }
373            }
374        }
375
376        self.defaults.timeout
377    }
378
379    /// Get interactive configuration for a cluster (with fallback to global)
380    pub fn get_interactive_config(&self, cluster_name: Option<&str>) -> InteractiveConfig {
381        let mut config = self.interactive.clone();
382
383        if let Some(cluster_name) = cluster_name {
384            if let Some(cluster) = self.get_cluster(cluster_name) {
385                if let Some(ref cluster_interactive) = cluster.interactive {
386                    // Merge cluster-specific overrides with global config
387                    // Cluster settings take precedence where specified
388                    config.default_mode = cluster_interactive.default_mode.clone();
389
390                    if !cluster_interactive.prompt_format.is_empty() {
391                        config.prompt_format = cluster_interactive.prompt_format.clone();
392                    }
393
394                    if cluster_interactive.history_file.is_some() {
395                        config.history_file = cluster_interactive.history_file.clone();
396                    }
397
398                    if cluster_interactive.work_dir.is_some() {
399                        config.work_dir = cluster_interactive.work_dir.clone();
400                    }
401
402                    if cluster_interactive.broadcast_prefix.is_some() {
403                        config.broadcast_prefix = cluster_interactive.broadcast_prefix.clone();
404                    }
405
406                    if cluster_interactive.node_switch_prefix.is_some() {
407                        config.node_switch_prefix = cluster_interactive.node_switch_prefix.clone();
408                    }
409
410                    // Note: For booleans, we always use the cluster value since there's no "unset" state
411                    config.show_timestamps = cluster_interactive.show_timestamps;
412
413                    // Merge colors (cluster colors override global ones)
414                    for (k, v) in &cluster_interactive.colors {
415                        config.colors.insert(k.clone(), v.clone());
416                    }
417
418                    // Merge keybindings
419                    if !cluster_interactive.keybindings.switch_node.is_empty() {
420                        config.keybindings.switch_node =
421                            cluster_interactive.keybindings.switch_node.clone();
422                    }
423                    if !cluster_interactive.keybindings.broadcast_toggle.is_empty() {
424                        config.keybindings.broadcast_toggle =
425                            cluster_interactive.keybindings.broadcast_toggle.clone();
426                    }
427                    if !cluster_interactive.keybindings.quit.is_empty() {
428                        config.keybindings.quit = cluster_interactive.keybindings.quit.clone();
429                    }
430                    if cluster_interactive.keybindings.clear_screen.is_some() {
431                        config.keybindings.clear_screen =
432                            cluster_interactive.keybindings.clear_screen.clone();
433                    }
434                }
435            }
436        }
437
438        config
439    }
440
441    /// Save the configuration to a file
442    pub async fn save(&self, path: &Path) -> Result<()> {
443        let expanded_path = expand_tilde(path);
444
445        // Ensure parent directory exists
446        if let Some(parent) = expanded_path.parent() {
447            fs::create_dir_all(parent)
448                .await
449                .with_context(|| format!("Failed to create directory {parent:?}"))?;
450        }
451
452        let yaml =
453            serde_yaml::to_string(self).context("Failed to serialize configuration to YAML")?;
454
455        fs::write(&expanded_path, yaml)
456            .await
457            .with_context(|| format!("Failed to write configuration to {expanded_path:?}"))?;
458
459        Ok(())
460    }
461
462    /// Update interactive preferences and save to the default config file
463    pub async fn update_interactive_preferences(
464        &mut self,
465        cluster_name: Option<&str>,
466        updates: InteractiveConfigUpdate,
467    ) -> Result<()> {
468        let target_config = if let Some(cluster_name) = cluster_name {
469            if let Some(cluster) = self.clusters.get_mut(cluster_name) {
470                // Update cluster-specific config
471                if cluster.interactive.is_none() {
472                    cluster.interactive = Some(InteractiveConfig::default());
473                }
474                cluster.interactive.as_mut().unwrap()
475            } else {
476                // Update global config
477                &mut self.interactive
478            }
479        } else {
480            // Update global config
481            &mut self.interactive
482        };
483
484        // Apply updates
485        if let Some(mode) = updates.default_mode {
486            target_config.default_mode = mode;
487        }
488        if let Some(prompt) = updates.prompt_format {
489            target_config.prompt_format = prompt;
490        }
491        if let Some(history) = updates.history_file {
492            target_config.history_file = Some(history);
493        }
494        if let Some(work_dir) = updates.work_dir {
495            target_config.work_dir = Some(work_dir);
496        }
497        if let Some(timestamps) = updates.show_timestamps {
498            target_config.show_timestamps = timestamps;
499        }
500        if let Some(colors) = updates.colors {
501            target_config.colors.extend(colors);
502        }
503
504        // Save to the appropriate config file
505        let config_path = self.get_config_path()?;
506        self.save(&config_path).await?;
507
508        Ok(())
509    }
510
511    /// Get the path to the configuration file (for saving)
512    fn get_config_path(&self) -> Result<PathBuf> {
513        // Priority order for determining config file path:
514        // 1. Current directory config.yaml (if it exists)
515        // 2. XDG config directory
516        // 3. Default ~/.bssh/config.yaml
517
518        let current_dir_config = PathBuf::from("config.yaml");
519        if current_dir_config.exists() {
520            return Ok(current_dir_config);
521        }
522
523        // Try XDG config directory
524        if let Ok(xdg_config_home) = env::var("XDG_CONFIG_HOME") {
525            let xdg_config = PathBuf::from(xdg_config_home)
526                .join("bssh")
527                .join("config.yaml");
528            return Ok(xdg_config);
529        } else if let Some(proj_dirs) = ProjectDirs::from("", "", "bssh") {
530            let xdg_config = proj_dirs.config_dir().join("config.yaml");
531            return Ok(xdg_config);
532        }
533
534        // Default to ~/.bssh/config.yaml
535        let home = env::var("HOME")
536            .or_else(|_| env::var("USERPROFILE"))
537            .context("Unable to determine home directory")?;
538        Ok(PathBuf::from(home).join(".bssh").join("config.yaml"))
539    }
540}
541
542/// Structure for updating interactive configuration preferences
543#[derive(Debug, Default)]
544pub struct InteractiveConfigUpdate {
545    pub default_mode: Option<InteractiveMode>,
546    pub prompt_format: Option<String>,
547    pub history_file: Option<String>,
548    pub work_dir: Option<String>,
549    pub show_timestamps: Option<bool>,
550    pub colors: Option<HashMap<String, String>>,
551}
552
553fn expand_tilde(path: &Path) -> PathBuf {
554    if let Some(path_str) = path.to_str() {
555        if path_str.starts_with("~/") {
556            if let Ok(home) = std::env::var("HOME") {
557                return PathBuf::from(path_str.replacen("~", &home, 1));
558            }
559        }
560    }
561    path.to_path_buf()
562}
563
564/// Expand environment variables in a string
565/// Supports ${VAR} and $VAR syntax
566fn expand_env_vars(input: &str) -> String {
567    let mut result = input.to_string();
568    let mut processed = 0;
569
570    // Handle ${VAR} syntax
571    while processed < result.len() {
572        if let Some(start) = result[processed..].find("${") {
573            let abs_start = processed + start;
574            if let Some(end) = result[abs_start..].find('}') {
575                let var_name = &result[abs_start + 2..abs_start + end];
576                if !var_name.is_empty() && var_name.chars().all(|c| c.is_alphanumeric() || c == '_')
577                {
578                    let replacement = std::env::var(var_name).unwrap_or_else(|_| {
579                        tracing::debug!("Environment variable {} not found", var_name);
580                        format!("${{{var_name}}}")
581                    });
582                    result.replace_range(abs_start..abs_start + end + 1, &replacement);
583                    processed = abs_start + replacement.len();
584                } else {
585                    processed = abs_start + end + 1;
586                }
587            } else {
588                break;
589            }
590        } else {
591            break;
592        }
593    }
594
595    // Handle $VAR syntax (but be careful not to expand ${} again)
596    let mut i = 0;
597    let bytes = result.as_bytes();
598    let mut new_result = String::new();
599
600    while i < bytes.len() {
601        if bytes[i] == b'$' && i + 1 < bytes.len() && bytes[i + 1] != b'{' {
602            let start = i;
603            i += 1;
604
605            // Find the end of the variable name
606            while i < bytes.len() && (bytes[i].is_ascii_alphanumeric() || bytes[i] == b'_') {
607                i += 1;
608            }
609
610            if i > start + 1 {
611                let var_name = std::str::from_utf8(&bytes[start + 1..i]).unwrap();
612                let replacement = std::env::var(var_name).unwrap_or_else(|_| {
613                    tracing::debug!("Environment variable {} not found", var_name);
614                    String::from_utf8(bytes[start..i].to_vec()).unwrap()
615                });
616                new_result.push_str(&replacement);
617            } else {
618                new_result.push('$');
619            }
620        } else {
621            new_result.push(bytes[i] as char);
622            i += 1;
623        }
624    }
625
626    new_result
627}
628
629#[cfg(test)]
630mod tests {
631    use super::*;
632
633    #[test]
634    fn test_expand_env_vars() {
635        unsafe {
636            std::env::set_var("TEST_VAR", "test_value");
637            std::env::set_var("TEST_USER", "testuser");
638        }
639
640        // Test ${VAR} syntax
641        assert_eq!(expand_env_vars("Hello ${TEST_VAR}!"), "Hello test_value!");
642        assert_eq!(expand_env_vars("${TEST_USER}@host"), "testuser@host");
643
644        // Test $VAR syntax
645        assert_eq!(expand_env_vars("Hello $TEST_VAR!"), "Hello test_value!");
646        assert_eq!(expand_env_vars("$TEST_USER@host"), "testuser@host");
647
648        // Test mixed
649        assert_eq!(
650            expand_env_vars("${TEST_USER}:$TEST_VAR"),
651            "testuser:test_value"
652        );
653
654        // Test non-existent variable (should leave as-is)
655        assert_eq!(expand_env_vars("${NONEXISTENT}"), "${NONEXISTENT}");
656        assert_eq!(expand_env_vars("$NONEXISTENT"), "$NONEXISTENT");
657
658        // Test no variables
659        assert_eq!(expand_env_vars("no variables here"), "no variables here");
660    }
661
662    #[test]
663    fn test_expand_tilde() {
664        unsafe {
665            std::env::set_var("HOME", "/home/user");
666        }
667        let path = Path::new("~/.ssh/config");
668        let expanded = expand_tilde(path);
669        assert_eq!(expanded, PathBuf::from("/home/user/.ssh/config"));
670    }
671
672    #[test]
673    fn test_config_parsing() {
674        let yaml = r#"
675defaults:
676  user: admin
677  port: 22
678  ssh_key: ~/.ssh/id_rsa
679
680interactive:
681  default_mode: multiplex
682  prompt_format: "[{node}] $ "
683  history_file: ~/.bssh_history
684  show_timestamps: true
685  colors:
686    node1: red
687    node2: blue
688  keybindings:
689    switch_node: "Ctrl+T"
690    broadcast_toggle: "Ctrl+A"
691
692clusters:
693  production:
694    nodes:
695      - web1.example.com
696      - web2.example.com:2222
697      - user@web3.example.com
698    ssh_key: ~/.ssh/prod_key
699    interactive:
700      default_mode: single_node
701      prompt_format: "prod> "
702  
703  staging:
704    nodes:
705      - host: staging1.example.com
706        port: 2200
707        user: deploy
708      - staging2.example.com
709    user: staging_user
710"#;
711
712        let config: Config = serde_yaml::from_str(yaml).unwrap();
713        assert_eq!(config.defaults.user, Some("admin".to_string()));
714        assert_eq!(config.clusters.len(), 2);
715
716        // Test global interactive config
717        assert!(matches!(
718            config.interactive.default_mode,
719            InteractiveMode::Multiplex
720        ));
721        assert_eq!(config.interactive.prompt_format, "[{node}] $ ");
722        assert_eq!(
723            config.interactive.history_file,
724            Some("~/.bssh_history".to_string())
725        );
726        assert!(config.interactive.show_timestamps);
727        assert_eq!(
728            config.interactive.colors.get("node1"),
729            Some(&"red".to_string())
730        );
731        assert_eq!(config.interactive.keybindings.switch_node, "Ctrl+T");
732
733        let prod_cluster = config.get_cluster("production").unwrap();
734        assert_eq!(prod_cluster.nodes.len(), 3);
735        assert_eq!(
736            prod_cluster.defaults.ssh_key,
737            Some("~/.ssh/prod_key".to_string())
738        );
739
740        // Test cluster-specific interactive config
741        let prod_interactive = prod_cluster.interactive.as_ref().unwrap();
742        assert!(matches!(
743            prod_interactive.default_mode,
744            InteractiveMode::SingleNode
745        ));
746        assert_eq!(prod_interactive.prompt_format, "prod> ");
747    }
748
749    #[test]
750    fn test_interactive_config_fallback() {
751        let yaml = r#"
752interactive:
753  default_mode: multiplex
754  prompt_format: "global> "
755  show_timestamps: true
756
757clusters:
758  with_override:
759    nodes:
760      - host1
761    interactive:
762      default_mode: multiplex
763      prompt_format: "override> "
764  
765  without_override:
766    nodes:
767      - host2
768"#;
769
770        let config: Config = serde_yaml::from_str(yaml).unwrap();
771
772        // Test cluster with override - merged config
773        let with_override = config.get_interactive_config(Some("with_override"));
774        assert_eq!(with_override.prompt_format, "override> ");
775        assert!(matches!(
776            with_override.default_mode,
777            InteractiveMode::Multiplex
778        ));
779        // Note: show_timestamps uses cluster value (default false) since we can't tell if it was explicitly set
780
781        // Test cluster without override (falls back to global)
782        let without_override = config.get_interactive_config(Some("without_override"));
783        assert_eq!(without_override.prompt_format, "global> ");
784        assert!(matches!(
785            without_override.default_mode,
786            InteractiveMode::Multiplex
787        ));
788        assert!(without_override.show_timestamps);
789
790        // Test global config when no cluster specified
791        let global = config.get_interactive_config(None);
792        assert_eq!(global.prompt_format, "global> ");
793        assert!(matches!(global.default_mode, InteractiveMode::Multiplex));
794    }
795
796    #[test]
797    fn test_backendai_env_parsing() {
798        // Set up Backend.AI environment variables
799        unsafe {
800            std::env::set_var("BACKENDAI_CLUSTER_HOSTS", "sub1,main1");
801            std::env::set_var("BACKENDAI_CLUSTER_HOST", "main1");
802            std::env::set_var("BACKENDAI_CLUSTER_ROLE", "main");
803            std::env::set_var("USER", "testuser");
804        }
805
806        let cluster = Config::from_backendai_env().unwrap();
807
808        // Should have 2 nodes when role is "main"
809        assert_eq!(cluster.nodes.len(), 2);
810
811        // Check first node (should include port 2200)
812        match &cluster.nodes[0] {
813            NodeConfig::Simple(host) => {
814                assert_eq!(host, "testuser@sub1:2200");
815            }
816            _ => panic!("Expected Simple node config"),
817        }
818
819        // Test with sub role - should skip the first (main) node
820        unsafe {
821            std::env::set_var("BACKENDAI_CLUSTER_ROLE", "sub");
822        }
823        let cluster = Config::from_backendai_env().unwrap();
824        assert_eq!(cluster.nodes.len(), 1);
825
826        match &cluster.nodes[0] {
827            NodeConfig::Simple(host) => {
828                assert_eq!(host, "testuser@main1:2200");
829            }
830            _ => panic!("Expected Simple node config"),
831        }
832
833        // Clean up
834        unsafe {
835            std::env::remove_var("BACKENDAI_CLUSTER_HOSTS");
836            std::env::remove_var("BACKENDAI_CLUSTER_HOST");
837            std::env::remove_var("BACKENDAI_CLUSTER_ROLE");
838        }
839    }
840}