use crate::error::Result;
use serde::{Deserialize, Serialize};
use std::io::Read;
use std::net::{TcpStream, ToSocketAddrs};
use std::time::{Duration, Instant};
pub const DEFAULT_TIMEOUT: Duration = Duration::from_secs(5);
pub const NETCONF_PORT: u16 = 830;
pub const GNMI_PORT: u16 = 9339;
pub const RESTCONF_PORT: u16 = 443;
pub const SSH_PORT: u16 = 22;
#[derive(Debug, Clone)]
pub struct ProbeConfig {
pub timeout: Duration,
pub skip: Vec<String>,
}
impl Default for ProbeConfig {
fn default() -> Self {
Self {
timeout: DEFAULT_TIMEOUT,
skip: Vec::new(),
}
}
}
impl ProbeConfig {
fn skipped(&self, name: &str) -> bool {
self.skip.iter().any(|s| s.eq_ignore_ascii_case(name))
}
}
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
pub struct ProbeReport {
pub host: String,
pub vendor: String,
pub netconf_available: Option<bool>,
pub gnmi_available: Option<bool>,
pub restconf_available: Option<bool>,
pub ssh_banner: Option<String>,
pub firmware: Option<String>,
pub diagnostics: Vec<String>,
pub elapsed_ms: u128,
}
pub fn probe_device(host: &str, vendor: &str, cfg: &ProbeConfig) -> Result<ProbeReport> {
let start = Instant::now();
let mut report = ProbeReport {
host: host.to_owned(),
vendor: vendor.to_owned(),
..Default::default()
};
if !cfg.skipped("ssh") {
match probe_ssh_banner(host, SSH_PORT, cfg.timeout) {
Ok(Some(banner)) => report.ssh_banner = Some(banner),
Ok(None) => report
.diagnostics
.push(format!("ssh:{SSH_PORT} reachable but no banner read")),
Err(e) => report.diagnostics.push(format!("ssh:{SSH_PORT} {e}")),
}
}
if !cfg.skipped("netconf") {
report.netconf_available = Some(match probe_ssh_banner(host, NETCONF_PORT, cfg.timeout) {
Ok(Some(_)) => true,
Ok(None) => {
report.diagnostics.push(format!(
"netconf:{NETCONF_PORT} reachable but no SSH banner"
));
false
}
Err(e) => {
report
.diagnostics
.push(format!("netconf:{NETCONF_PORT} {e}"));
false
}
});
}
if !cfg.skipped("gnmi") {
report.gnmi_available = Some(match probe_tcp_open(host, GNMI_PORT, cfg.timeout) {
Ok(()) => true,
Err(e) => {
report.diagnostics.push(format!("gnmi:{GNMI_PORT} {e}"));
false
}
});
}
if !cfg.skipped("restconf") {
report.restconf_available = Some(match probe_tcp_open(host, RESTCONF_PORT, cfg.timeout) {
Ok(()) => true,
Err(e) => {
report
.diagnostics
.push(format!("restconf:{RESTCONF_PORT} {e}"));
false
}
});
}
report.elapsed_ms = start.elapsed().as_millis();
Ok(report)
}
fn connect(host: &str, port: u16, timeout: Duration) -> std::io::Result<TcpStream> {
let addrs: Vec<_> = (host, port)
.to_socket_addrs()
.map_err(|e| {
std::io::Error::new(std::io::ErrorKind::AddrNotAvailable, format!("DNS: {e}"))
})?
.collect();
if addrs.is_empty() {
return Err(std::io::Error::new(
std::io::ErrorKind::AddrNotAvailable,
"no addresses resolved",
));
}
let mut last_err = std::io::Error::other("no addresses tried");
for addr in addrs {
match TcpStream::connect_timeout(&addr, timeout) {
Ok(stream) => {
stream.set_read_timeout(Some(timeout))?;
stream.set_write_timeout(Some(timeout))?;
return Ok(stream);
}
Err(e) => last_err = e,
}
}
Err(last_err)
}
fn probe_tcp_open(host: &str, port: u16, timeout: Duration) -> std::io::Result<()> {
let _stream = connect(host, port, timeout)?;
Ok(())
}
fn probe_ssh_banner(host: &str, port: u16, timeout: Duration) -> std::io::Result<Option<String>> {
let mut stream = connect(host, port, timeout)?;
let mut buf = [0u8; 256];
let mut total = 0;
while total < buf.len() {
let n = match stream.read(&mut buf[total..]) {
Ok(0) => break,
Ok(n) => n,
Err(ref e) if e.kind() == std::io::ErrorKind::Interrupted => continue,
Err(e) => return Err(e),
};
total += n;
if buf[..total].contains(&b'\n') {
break;
}
}
if total == 0 {
return Ok(None);
}
let line: String = buf[..total]
.iter()
.take_while(|&&b| b != b'\n' && b != b'\r')
.map(|&b| b as char)
.collect();
if line.starts_with("SSH-") {
Ok(Some(line))
} else {
Ok(None)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn config_skip_is_case_insensitive() {
let cfg = ProbeConfig {
skip: vec!["SSH".into(), "Netconf".into()],
..Default::default()
};
assert!(cfg.skipped("ssh"));
assert!(cfg.skipped("netconf"));
assert!(!cfg.skipped("gnmi"));
}
#[test]
fn report_default_has_no_results() {
let r = ProbeReport::default();
assert!(r.netconf_available.is_none());
assert!(r.gnmi_available.is_none());
assert!(r.restconf_available.is_none());
assert!(r.ssh_banner.is_none());
assert_eq!(r.elapsed_ms, 0);
}
#[test]
fn probe_unreachable_host_gracefully_reports() {
let cfg = ProbeConfig {
timeout: Duration::from_millis(250),
..Default::default()
};
let report = probe_device("198.51.100.1", "cisco_ios", &cfg)
.expect("probe never errors at lib level");
assert_eq!(report.host, "198.51.100.1");
assert_eq!(report.vendor, "cisco_ios");
assert!(report.netconf_available != Some(true));
assert!(report.gnmi_available != Some(true));
assert!(report.restconf_available != Some(true));
assert!(report.ssh_banner.is_none());
assert!(
!report.diagnostics.is_empty(),
"expected at least one diagnostic for unreachable host"
);
}
#[test]
fn probe_respects_skip_list() {
let cfg = ProbeConfig {
timeout: Duration::from_millis(250),
skip: vec![
"ssh".into(),
"netconf".into(),
"gnmi".into(),
"restconf".into(),
],
};
let report = probe_device("198.51.100.1", "cisco_ios", &cfg).unwrap();
assert_eq!(report.netconf_available, None);
assert_eq!(report.gnmi_available, None);
assert_eq!(report.restconf_available, None);
assert_eq!(report.ssh_banner, None);
assert!(
report.diagnostics.is_empty(),
"no probes ran, no diagnostics expected"
);
}
}