use super::tokio_client::ServerCheckMethod;
use directories::BaseDirs;
use std::path::PathBuf;
use std::str::FromStr;
pub fn get_default_known_hosts_path() -> Option<PathBuf> {
BaseDirs::new().map(|dirs| dirs.home_dir().join(".ssh").join("known_hosts"))
}
pub fn get_check_method(strict_mode: StrictHostKeyChecking) -> ServerCheckMethod {
match strict_mode {
StrictHostKeyChecking::Yes => {
if let Some(known_hosts_path) = get_default_known_hosts_path() {
if known_hosts_path.exists() {
tracing::debug!(
"Using known_hosts file: {:?} (strict mode)",
known_hosts_path
);
ServerCheckMethod::DefaultKnownHostsFile
} else {
tracing::warn!(
"Known hosts file not found at {:?}, using NoCheck",
known_hosts_path
);
eprintln!(
"WARNING: Known hosts file not found. Host key verification disabled."
);
ServerCheckMethod::NoCheck
}
} else {
tracing::warn!("Could not determine known_hosts path, using NoCheck");
ServerCheckMethod::NoCheck
}
}
StrictHostKeyChecking::No => {
tracing::debug!("Host key checking disabled (strict mode = no)");
ServerCheckMethod::NoCheck
}
StrictHostKeyChecking::AcceptNew => {
if let Some(known_hosts_path) = get_default_known_hosts_path() {
if known_hosts_path.exists() {
tracing::debug!(
"Using known_hosts file: {:?} (accept-new mode)",
known_hosts_path
);
tracing::info!(
"Note: accept-new mode not fully supported, using relaxed checking"
);
ServerCheckMethod::NoCheck
} else {
if let Some(ssh_dir) = known_hosts_path.parent() {
let _ = std::fs::create_dir_all(ssh_dir);
}
let _ = std::fs::File::create(&known_hosts_path);
tracing::debug!("Created empty known_hosts file at {:?}", known_hosts_path);
ServerCheckMethod::NoCheck
}
} else {
tracing::warn!("Could not determine known_hosts path, using NoCheck");
ServerCheckMethod::NoCheck
}
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum StrictHostKeyChecking {
Yes,
No,
AcceptNew,
}
impl StrictHostKeyChecking {
pub fn to_bool(&self) -> bool {
matches!(self, Self::Yes)
}
}
impl Default for StrictHostKeyChecking {
fn default() -> Self {
Self::AcceptNew
}
}
impl FromStr for StrictHostKeyChecking {
type Err = ();
fn from_str(s: &str) -> Result<Self, Self::Err> {
Ok(match s.to_lowercase().as_str() {
"yes" | "true" => Self::Yes,
"no" | "false" => Self::No,
"accept-new" | "tofu" => Self::AcceptNew,
_ => Self::AcceptNew, })
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_strict_host_key_checking_from_str() {
assert_eq!(
StrictHostKeyChecking::from_str("yes").unwrap(),
StrictHostKeyChecking::Yes
);
assert_eq!(
StrictHostKeyChecking::from_str("true").unwrap(),
StrictHostKeyChecking::Yes
);
assert_eq!(
StrictHostKeyChecking::from_str("no").unwrap(),
StrictHostKeyChecking::No
);
assert_eq!(
StrictHostKeyChecking::from_str("false").unwrap(),
StrictHostKeyChecking::No
);
assert_eq!(
StrictHostKeyChecking::from_str("accept-new").unwrap(),
StrictHostKeyChecking::AcceptNew
);
assert_eq!(
StrictHostKeyChecking::from_str("tofu").unwrap(),
StrictHostKeyChecking::AcceptNew
);
assert_eq!(
StrictHostKeyChecking::from_str("invalid").unwrap(),
StrictHostKeyChecking::AcceptNew
);
}
#[test]
fn test_strict_host_key_checking_to_bool() {
assert!(StrictHostKeyChecking::Yes.to_bool());
assert!(!StrictHostKeyChecking::No.to_bool());
assert!(!StrictHostKeyChecking::AcceptNew.to_bool());
}
#[test]
fn test_strict_host_key_checking_default() {
assert_eq!(
StrictHostKeyChecking::default(),
StrictHostKeyChecking::AcceptNew
);
}
#[test]
fn test_get_default_known_hosts_path() {
let path = get_default_known_hosts_path();
assert!(path.is_some());
if let Some(p) = path {
assert!(p.to_str().unwrap().contains(".ssh/known_hosts"));
}
}
#[test]
fn test_get_check_method() {
let method = get_check_method(StrictHostKeyChecking::No);
assert!(matches!(method, ServerCheckMethod::NoCheck));
let method = get_check_method(StrictHostKeyChecking::AcceptNew);
assert!(matches!(method, ServerCheckMethod::NoCheck));
let method = get_check_method(StrictHostKeyChecking::Yes);
assert!(matches!(
method,
ServerCheckMethod::DefaultKnownHostsFile | ServerCheckMethod::NoCheck
));
}
}