use bssh::executor::{ExecutionResult, ExitCodeStrategy, RankDetector};
use bssh::node::Node;
use bssh::ssh::client::CommandResult;
use serial_test::serial;
fn success_result(host: &str) -> ExecutionResult {
ExecutionResult {
node: Node::new(host.to_string(), 22, "user".to_string()),
result: Ok(CommandResult {
host: host.to_string(),
output: Vec::new(),
stderr: Vec::new(),
exit_status: 0,
}),
is_main_rank: false,
}
}
fn failure_result(host: &str, exit_code: u32) -> ExecutionResult {
ExecutionResult {
node: Node::new(host.to_string(), 22, "user".to_string()),
result: Ok(CommandResult {
host: host.to_string(),
output: Vec::new(),
stderr: Vec::new(),
exit_status: exit_code,
}),
is_main_rank: false,
}
}
#[test]
fn test_main_rank_strategy_preserves_exit_code() {
let nodes = vec![
Node::new("host1".to_string(), 22, "user".to_string()),
Node::new("host2".to_string(), 22, "user".to_string()),
Node::new("host3".to_string(), 22, "user".to_string()),
];
let results = vec![
failure_result("host1", 139), success_result("host2"),
success_result("host3"),
];
let main_idx = RankDetector::identify_main_rank(&nodes);
let exit_code = ExitCodeStrategy::MainRank.calculate(&results, main_idx);
assert_eq!(exit_code, 139, "Should preserve SIGSEGV exit code");
}
#[test]
fn test_require_all_success_any_failure() {
let nodes = vec![
Node::new("host1".to_string(), 22, "user".to_string()),
Node::new("host2".to_string(), 22, "user".to_string()),
];
let results = vec![success_result("host1"), failure_result("host2", 137)];
let main_idx = RankDetector::identify_main_rank(&nodes);
let exit_code = ExitCodeStrategy::RequireAllSuccess.calculate(&results, main_idx);
assert_eq!(exit_code, 1, "Should return 1 when any node fails");
}
#[test]
fn test_hybrid_strategy_main_ok_others_fail() {
let nodes = vec![
Node::new("host1".to_string(), 22, "user".to_string()),
Node::new("host2".to_string(), 22, "user".to_string()),
Node::new("host3".to_string(), 22, "user".to_string()),
];
let results = vec![
success_result("host1"), failure_result("host2", 1),
failure_result("host3", 1),
];
let main_idx = RankDetector::identify_main_rank(&nodes);
let exit_code = ExitCodeStrategy::MainRankWithFailureCheck.calculate(&results, main_idx);
assert_eq!(
exit_code, 1,
"Should return 1 when main succeeds but others fail"
);
}
#[test]
fn test_hybrid_strategy_main_fails() {
let nodes = vec![
Node::new("host1".to_string(), 22, "user".to_string()),
Node::new("host2".to_string(), 22, "user".to_string()),
];
let results = vec![
failure_result("host1", 124), success_result("host2"),
];
let main_idx = RankDetector::identify_main_rank(&nodes);
let exit_code = ExitCodeStrategy::MainRankWithFailureCheck.calculate(&results, main_idx);
assert_eq!(exit_code, 124, "Should preserve main rank exit code");
}
#[test]
#[serial]
fn test_backendai_main_rank_detection() {
std::env::set_var("BACKENDAI_CLUSTER_ROLE", "main");
std::env::set_var("BACKENDAI_CLUSTER_HOST", "host2");
let nodes = vec![
Node::new("host1".to_string(), 22, "user".to_string()),
Node::new("host2".to_string(), 22, "user".to_string()), Node::new("host3".to_string(), 22, "user".to_string()),
];
let main_idx = RankDetector::identify_main_rank(&nodes);
assert_eq!(
main_idx,
Some(1),
"Should detect host2 as main rank via Backend.AI env"
);
std::env::remove_var("BACKENDAI_CLUSTER_ROLE");
std::env::remove_var("BACKENDAI_CLUSTER_HOST");
}
#[test]
fn test_exit_code_strategy_comprehensive() {
let test_cases = vec![
(ExitCodeStrategy::MainRank, true, true, 0, 0),
(ExitCodeStrategy::MainRank, false, true, 139, 139),
(ExitCodeStrategy::MainRank, true, false, 0, 0), (ExitCodeStrategy::MainRank, false, false, 137, 137),
(ExitCodeStrategy::RequireAllSuccess, true, true, 0, 0),
(ExitCodeStrategy::RequireAllSuccess, false, true, 139, 1),
(ExitCodeStrategy::RequireAllSuccess, true, false, 0, 1),
(ExitCodeStrategy::RequireAllSuccess, false, false, 137, 1),
(ExitCodeStrategy::MainRankWithFailureCheck, true, true, 0, 0),
(
ExitCodeStrategy::MainRankWithFailureCheck,
false,
true,
139,
139,
),
(
ExitCodeStrategy::MainRankWithFailureCheck,
true,
false,
0,
1,
), (
ExitCodeStrategy::MainRankWithFailureCheck,
false,
false,
137,
137,
),
];
for (strategy, main_ok, others_ok, main_exit, expected) in test_cases {
let nodes = vec![
Node::new("host1".to_string(), 22, "user".to_string()),
Node::new("host2".to_string(), 22, "user".to_string()),
];
let results = vec![
if main_ok {
success_result("host1")
} else {
failure_result("host1", main_exit)
},
if others_ok {
success_result("host2")
} else {
failure_result("host2", 1)
},
];
let main_idx = RankDetector::identify_main_rank(&nodes);
let exit_code = strategy.calculate(&results, main_idx);
assert_eq!(
exit_code, expected,
"Strategy {strategy:?}: main_ok={main_ok}, others_ok={others_ok}, main_exit={main_exit} → expected {expected}"
);
}
}
#[test]
fn test_main_rank_marking_in_results() {
let nodes = vec![
Node::new("host1".to_string(), 22, "user".to_string()),
Node::new("host2".to_string(), 22, "user".to_string()),
Node::new("host3".to_string(), 22, "user".to_string()),
];
let mut results = vec![
success_result("host1"),
success_result("host2"),
success_result("host3"),
];
let main_idx = RankDetector::identify_main_rank(&nodes);
assert_eq!(
main_idx,
Some(0),
"First node should be main rank by default"
);
if let Some(idx) = main_idx {
results[idx].is_main_rank = true;
}
assert!(
results[0].is_main_rank,
"First result should be marked as main rank"
);
assert!(
!results[1].is_main_rank,
"Second result should not be main rank"
);
assert!(
!results[2].is_main_rank,
"Third result should not be main rank"
);
}
#[test]
#[serial]
fn test_main_rank_marking_with_backendai_env() {
std::env::set_var("BACKENDAI_CLUSTER_ROLE", "main");
std::env::set_var("BACKENDAI_CLUSTER_HOST", "host3");
let nodes = vec![
Node::new("host1".to_string(), 22, "user".to_string()),
Node::new("host2".to_string(), 22, "user".to_string()),
Node::new("host3".to_string(), 22, "user".to_string()),
];
let mut results = vec![
success_result("host1"),
success_result("host2"),
success_result("host3"),
];
let main_idx = RankDetector::identify_main_rank(&nodes);
assert_eq!(
main_idx,
Some(2),
"host3 should be identified as main rank via Backend.AI env"
);
if let Some(idx) = main_idx {
results[idx].is_main_rank = true;
}
assert!(
!results[0].is_main_rank,
"host1 should not be marked as main rank"
);
assert!(
!results[1].is_main_rank,
"host2 should not be marked as main rank"
);
assert!(
results[2].is_main_rank,
"host3 should be marked as main rank"
);
std::env::remove_var("BACKENDAI_CLUSTER_ROLE");
std::env::remove_var("BACKENDAI_CLUSTER_HOST");
}
#[test]
fn test_strategy_with_all_connection_errors() {
use anyhow::anyhow;
let nodes = [
Node::new("host1".to_string(), 22, "user".to_string()),
Node::new("host2".to_string(), 22, "user".to_string()),
];
let results = vec![
ExecutionResult {
node: nodes[0].clone(),
result: Err(anyhow!("Connection timeout")),
is_main_rank: true, },
ExecutionResult {
node: nodes[1].clone(),
result: Err(anyhow!("Connection refused")),
is_main_rank: false,
},
];
let main_idx = Some(0);
let exit_code = ExitCodeStrategy::MainRank.calculate(&results, main_idx);
assert_eq!(exit_code, 1, "Connection error should return exit code 1");
let exit_code = ExitCodeStrategy::RequireAllSuccess.calculate(&results, main_idx);
assert_eq!(
exit_code, 1,
"Any failure should return 1 in RequireAllSuccess"
);
let exit_code = ExitCodeStrategy::MainRankWithFailureCheck.calculate(&results, main_idx);
assert_eq!(
exit_code, 1,
"Main rank connection error should return 1 in hybrid mode"
);
}
#[test]
fn test_strategy_with_mixed_errors() {
use anyhow::anyhow;
let nodes = [
Node::new("host1".to_string(), 22, "user".to_string()),
Node::new("host2".to_string(), 22, "user".to_string()),
Node::new("host3".to_string(), 22, "user".to_string()),
];
let results = vec![
success_result("host1"), ExecutionResult {
node: nodes[1].clone(),
result: Err(anyhow!("Connection timeout")), is_main_rank: false,
},
failure_result("host3", 137), ];
let main_idx = Some(0);
let exit_code = ExitCodeStrategy::MainRank.calculate(&results, main_idx);
assert_eq!(exit_code, 0, "Main rank succeeded, should return 0");
let exit_code = ExitCodeStrategy::RequireAllSuccess.calculate(&results, main_idx);
assert_eq!(
exit_code, 1,
"Some nodes failed, should return 1 in RequireAllSuccess"
);
let exit_code = ExitCodeStrategy::MainRankWithFailureCheck.calculate(&results, main_idx);
assert_eq!(
exit_code, 1,
"Main OK but others failed, should return 1 in hybrid mode"
);
}
#[test]
fn test_main_rank_index_boundary() {
let results = vec![
success_result("host1"),
failure_result("host2", 1),
failure_result("host3", 139), ];
let main_idx = Some(2);
let exit_code = ExitCodeStrategy::MainRank.calculate(&results, main_idx);
assert_eq!(
exit_code, 139,
"Should return exit code from last node (main rank)"
);
}
#[test]
fn test_strategy_selection_logic() {
let require_all_success = false;
let check_all_nodes = false;
let strategy = if require_all_success {
ExitCodeStrategy::RequireAllSuccess
} else if check_all_nodes {
ExitCodeStrategy::MainRankWithFailureCheck
} else {
ExitCodeStrategy::MainRank
};
assert_eq!(
strategy,
ExitCodeStrategy::MainRank,
"Default should be MainRank"
);
let require_all_success = true;
let check_all_nodes = false;
let strategy = if require_all_success {
ExitCodeStrategy::RequireAllSuccess
} else if check_all_nodes {
ExitCodeStrategy::MainRankWithFailureCheck
} else {
ExitCodeStrategy::MainRank
};
assert_eq!(
strategy,
ExitCodeStrategy::RequireAllSuccess,
"--require-all-success should select RequireAllSuccess"
);
let require_all_success = false;
let check_all_nodes = true;
let strategy = if require_all_success {
ExitCodeStrategy::RequireAllSuccess
} else if check_all_nodes {
ExitCodeStrategy::MainRankWithFailureCheck
} else {
ExitCodeStrategy::MainRank
};
assert_eq!(
strategy,
ExitCodeStrategy::MainRankWithFailureCheck,
"--check-all-nodes should select MainRankWithFailureCheck"
);
let require_all_success = true;
let check_all_nodes = true;
let strategy = if require_all_success {
ExitCodeStrategy::RequireAllSuccess
} else if check_all_nodes {
ExitCodeStrategy::MainRankWithFailureCheck
} else {
ExitCodeStrategy::MainRank
};
assert_eq!(
strategy,
ExitCodeStrategy::RequireAllSuccess,
"When both flags set, --require-all-success should take precedence"
);
}