use std::time::Duration;
use crate::capability::Capabilities;
use crate::error::NetconfError;
use crate::facts::Facts;
use crate::notification::Notification;
use crate::session::Session;
use crate::transport::Transport;
use crate::transport::ssh::{HostKeyVerification, SshAuth, SshConfig, SshTransport};
#[cfg(feature = "tls")]
use crate::transport::tls::{TlsConfig, TlsTransport};
use crate::rpc::RpcErrorInfo;
use crate::types::{Datastore, DefaultOperation, ErrorOption, LoadAction, LoadFormat, OpenConfigurationMode, TestOption};
use crate::vendor::VendorProfile;
#[derive(Clone)]
enum TransportConfig {
Ssh(SshConfig),
#[cfg(feature = "tls")]
Tls(TlsConfig),
}
pub struct ClientBuilder {
host: String,
port: u16,
username: Option<String>,
password: Option<String>,
key_file: Option<String>,
key_passphrase: Option<String>,
use_agent: bool,
vendor_profile: Option<Box<dyn VendorProfile>>,
gather_facts: bool,
keepalive_interval: Option<Duration>,
host_key_verification: HostKeyVerification,
}
impl ClientBuilder {
pub fn username(mut self, username: &str) -> Self {
self.username = Some(username.to_string());
self
}
pub fn password(mut self, password: &str) -> Self {
self.password = Some(password.to_string());
self
}
pub fn key_file(mut self, path: &str) -> Self {
self.key_file = Some(path.to_string());
self
}
pub fn key_passphrase(mut self, passphrase: &str) -> Self {
self.key_passphrase = Some(passphrase.to_string());
self
}
pub fn ssh_agent(mut self) -> Self {
self.use_agent = true;
self
}
pub fn vendor_profile(mut self, profile: Box<dyn VendorProfile>) -> Self {
self.vendor_profile = Some(profile);
self
}
pub fn gather_facts(mut self, gather: bool) -> Self {
self.gather_facts = gather;
self
}
pub fn keepalive_interval(mut self, interval: Duration) -> Self {
self.keepalive_interval = Some(interval);
self
}
pub fn host_key_verification(mut self, policy: HostKeyVerification) -> Self {
self.host_key_verification = policy;
self
}
pub async fn connect(self) -> Result<Client, NetconfError> {
let username = self
.username
.ok_or_else(|| {
crate::error::TransportError::Auth("username is required".to_string())
})?;
let auth = if self.use_agent {
SshAuth::Agent
} else if let Some(key_path) = self.key_file {
SshAuth::KeyFile {
path: key_path,
passphrase: self.key_passphrase,
}
} else if let Some(password) = self.password {
SshAuth::Password(password)
} else {
return Err(crate::error::TransportError::Auth(
"no authentication method specified (password, key_file, or ssh_agent)".to_string(),
)
.into());
};
let config = SshConfig {
host: self.host,
port: self.port,
username,
auth,
host_key_verification: self.host_key_verification,
};
let transport = SshTransport::connect(config.clone()).await?;
let mut session = Session::new(Box::new(transport));
if let Some(interval) = self.keepalive_interval {
session.set_keepalive_interval(interval);
}
if let Some(profile) = self.vendor_profile {
session.set_vendor_profile(profile);
}
session.establish().await?;
if self.gather_facts {
session.gather_facts().await?;
}
Ok(Client {
session,
transport_config: TransportConfig::Ssh(config),
gather_facts: self.gather_facts,
keepalive_interval: self.keepalive_interval,
})
}
}
pub struct Client {
session: Session,
transport_config: TransportConfig,
gather_facts: bool,
keepalive_interval: Option<Duration>,
}
impl Client {
pub fn connect(address: &str) -> ClientBuilder {
let (host, port) = parse_address(address);
ClientBuilder {
host,
port,
username: None,
password: None,
key_file: None,
key_passphrase: None,
use_agent: false,
vendor_profile: None,
gather_facts: true,
keepalive_interval: None,
host_key_verification: HostKeyVerification::default(),
}
}
#[cfg(feature = "tls")]
pub fn connect_tls(config: TlsConfig) -> TlsClientBuilder {
TlsClientBuilder {
tls_config: config,
vendor_profile: None,
gather_facts: true,
keepalive_interval: None,
}
}
pub async fn rpc(&mut self, rpc_content: &str) -> Result<String, NetconfError> {
self.session.rpc(rpc_content).await
}
pub fn supports(&self, capability_uri: &str) -> bool {
self.session.supports(capability_uri)
}
pub fn vendor_name(&self) -> &str {
self.session.vendor_name()
}
pub fn capabilities(&self) -> Option<&Capabilities> {
self.session.capabilities()
}
pub fn facts(&self) -> &Facts {
self.session.facts()
}
pub async fn gather_facts(&mut self) -> Result<(), NetconfError> {
self.session.gather_facts().await
}
pub fn session_alive(&self) -> bool {
self.session.is_alive()
}
pub async fn probe_session(&mut self) -> bool {
self.session.probe().await
}
pub async fn reconnect(&mut self) -> Result<(), NetconfError> {
let _ = self.session.close_session().await;
let transport: Box<dyn Transport> = match &self.transport_config {
TransportConfig::Ssh(config) => {
Box::new(SshTransport::connect(config.clone()).await?)
}
#[cfg(feature = "tls")]
TransportConfig::Tls(config) => {
Box::new(TlsTransport::connect(config).await?)
}
};
let mut session = Session::new(transport);
if let Some(interval) = self.keepalive_interval {
session.set_keepalive_interval(interval);
}
session.establish().await?;
if self.gather_facts {
session.gather_facts().await?;
}
self.session = session;
tracing::info!("NETCONF session reconnected");
Ok(())
}
pub async fn get_config(&mut self, source: Datastore) -> Result<String, NetconfError> {
self.session.get_config(source, None).await
}
pub async fn get_config_filtered(
&mut self,
source: Datastore,
filter: &str,
) -> Result<String, NetconfError> {
self.session.get_config(source, Some(filter)).await
}
pub async fn get(&mut self, filter: Option<&str>) -> Result<String, NetconfError> {
self.session.get(filter).await
}
pub fn edit_config(&mut self, target: Datastore) -> EditConfigBuilder<'_> {
EditConfigBuilder {
session: &mut self.session,
target,
config: None,
default_operation: None,
test_option: None,
error_option: None,
}
}
pub async fn lock(&mut self, target: Datastore) -> Result<(), NetconfError> {
self.session.lock(target).await
}
pub async fn unlock(&mut self, target: Datastore) -> Result<(), NetconfError> {
self.session.unlock(target).await
}
pub async fn discard_changes(&mut self) -> Result<(), NetconfError> {
self.session.discard_changes().await
}
pub async fn commit(&mut self) -> Result<(), NetconfError> {
self.session.commit().await
}
pub async fn validate(&mut self, source: Datastore) -> Result<(), NetconfError> {
self.session.validate(source).await
}
pub async fn close_session(&mut self) -> Result<(), NetconfError> {
self.session.close_session().await
}
pub async fn kill_session(&mut self, session_id: u32) -> Result<(), NetconfError> {
self.session.kill_session(session_id).await
}
pub async fn confirmed_commit(&mut self, confirm_timeout: u32) -> Result<(), NetconfError> {
self.session.confirmed_commit(confirm_timeout).await
}
pub async fn confirming_commit(&mut self) -> Result<(), NetconfError> {
self.session.confirming_commit().await
}
pub async fn lock_or_kill_stale(
&mut self,
target: Datastore,
) -> Result<Option<u32>, NetconfError> {
self.session.lock_or_kill_stale(target).await
}
pub async fn rpc_with_warnings(
&mut self,
rpc_content: &str,
) -> Result<(String, Vec<RpcErrorInfo>), NetconfError> {
self.session.rpc_with_warnings(rpc_content).await
}
pub async fn open_configuration(
&mut self,
mode: OpenConfigurationMode,
) -> Result<(), NetconfError> {
self.session.open_configuration(mode).await
}
pub async fn close_configuration(&mut self) -> Result<(), NetconfError> {
self.session.close_configuration().await
}
pub async fn commit_configuration(&mut self) -> Result<(), NetconfError> {
self.session.commit_configuration().await
}
pub async fn rollback_configuration(&mut self, rollback: u32) -> Result<(), NetconfError> {
self.session.rollback_configuration(rollback).await
}
pub async fn get_configuration_compare(
&mut self,
rollback: u32,
) -> Result<String, NetconfError> {
self.session.get_configuration_compare(rollback).await
}
pub async fn load_configuration(
&mut self,
action: LoadAction,
format: LoadFormat,
config: &str,
) -> Result<String, NetconfError> {
self.session.load_configuration(action, format, config).await
}
pub fn requires_open_configuration(&self) -> bool {
self.session.requires_open_configuration()
}
pub async fn create_subscription(
&mut self,
stream: Option<&str>,
filter: Option<&str>,
start_time: Option<&str>,
stop_time: Option<&str>,
) -> Result<(), NetconfError> {
self.session
.create_subscription(stream, filter, start_time, stop_time)
.await
}
pub fn drain_notifications(&mut self) -> Vec<Notification> {
self.session.drain_notifications()
}
pub async fn recv_notification(&mut self) -> Result<Option<Notification>, NetconfError> {
self.session.recv_notification().await
}
pub fn has_notifications(&self) -> bool {
self.session.has_notifications()
}
pub fn has_subscription(&self) -> bool {
self.session.has_subscription()
}
}
pub struct EditConfigBuilder<'a> {
session: &'a mut Session,
target: Datastore,
config: Option<String>,
default_operation: Option<DefaultOperation>,
test_option: Option<TestOption>,
error_option: Option<ErrorOption>,
}
impl<'a> EditConfigBuilder<'a> {
pub fn config(mut self, config: &str) -> Self {
self.config = Some(config.to_string());
self
}
pub fn default_operation(mut self, op: DefaultOperation) -> Self {
self.default_operation = Some(op);
self
}
pub fn test_option(mut self, opt: TestOption) -> Self {
self.test_option = Some(opt);
self
}
pub fn error_option(mut self, opt: ErrorOption) -> Self {
self.error_option = Some(opt);
self
}
pub async fn send(self) -> Result<(), NetconfError> {
let config = self.config.ok_or_else(|| {
crate::error::ProtocolError::Xml("edit-config requires a config payload".to_string())
})?;
self.session
.edit_config(
self.target,
&config,
self.default_operation,
self.test_option,
self.error_option,
)
.await
}
}
#[cfg(feature = "tls")]
pub struct TlsClientBuilder {
tls_config: TlsConfig,
vendor_profile: Option<Box<dyn VendorProfile>>,
gather_facts: bool,
keepalive_interval: Option<Duration>,
}
#[cfg(feature = "tls")]
impl TlsClientBuilder {
pub fn vendor_profile(mut self, profile: Box<dyn VendorProfile>) -> Self {
self.vendor_profile = Some(profile);
self
}
pub fn gather_facts(mut self, gather: bool) -> Self {
self.gather_facts = gather;
self
}
pub fn keepalive_interval(mut self, interval: Duration) -> Self {
self.keepalive_interval = Some(interval);
self
}
pub async fn connect(self) -> Result<Client, NetconfError> {
let transport = TlsTransport::connect(&self.tls_config).await?;
let mut session = Session::new(Box::new(transport));
if let Some(interval) = self.keepalive_interval {
session.set_keepalive_interval(interval);
}
if let Some(profile) = self.vendor_profile {
session.set_vendor_profile(profile);
}
session.establish().await?;
if self.gather_facts {
session.gather_facts().await?;
}
Ok(Client {
session,
transport_config: TransportConfig::Tls(self.tls_config),
gather_facts: self.gather_facts,
keepalive_interval: self.keepalive_interval,
})
}
}
fn parse_address(address: &str) -> (String, u16) {
if let Some((host, port_str)) = address.rsplit_once(':') {
if let Ok(port) = port_str.parse::<u16>() {
return (host.to_string(), port);
}
}
(address.to_string(), 830)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_parse_address_with_port() {
let (host, port) = parse_address("10.0.0.1:830");
assert_eq!(host, "10.0.0.1");
assert_eq!(port, 830);
}
#[test]
fn test_parse_address_without_port() {
let (host, port) = parse_address("10.0.0.1");
assert_eq!(host, "10.0.0.1");
assert_eq!(port, 830);
}
#[test]
fn test_parse_address_hostname() {
let (host, port) = parse_address("router.example.com:22830");
assert_eq!(host, "router.example.com");
assert_eq!(port, 22830);
}
}