rustnetconf 0.11.0

An async-first NETCONF 1.0/1.1 client library for Rust
Documentation
//! Async connection pool for managing NETCONF sessions to multiple devices.
//!
//! ```text
//! DevicePool
//!   ├── Semaphore(max_connections) — global concurrency limit
//!   ├── devices: HashMap<name, DeviceConfig>
//!   └── connections: HashMap<name, Vec<Client>> — idle pool per device
//!
//!   checkout("spine-01") → PoolGuard(Client) — auto-checkin on drop
//! ```
//!
//! # Example
//!
//! ```rust,no_run
//! use rustnetconf::pool::{DevicePool, DeviceConfig};
//! use rustnetconf::transport::ssh::SshAuth;
//! use rustnetconf::Datastore;
//!
//! # async fn example() -> Result<(), Box<dyn std::error::Error>> {
//! let pool = DevicePool::builder()
//!     .max_connections(50)
//!     .add_device("spine-01", DeviceConfig {
//!         host: "10.0.0.1:830".into(),
//!         username: "admin".into(),
//!         auth: SshAuth::Password(zeroize::Zeroizing::new("secret".into())),
//!         vendor: None,
//!     })
//!     .build();
//!
//! let mut conn = pool.checkout("spine-01").await?;
//! let config = conn.get_config(Datastore::Running).await?;
//! // conn auto-returned to pool when dropped
//! # Ok(())
//! # }
//! ```

use std::collections::HashMap;
use std::ops::{Deref, DerefMut};
use std::sync::Arc;
use std::time::Duration;
use tokio::sync::{Mutex, Semaphore, SemaphorePermit};

use crate::client::Client;
use crate::error::NetconfError;
use crate::transport::ssh::SshAuth;
use crate::vendor::VendorProfile;

/// Configuration for a single device in the pool.
pub struct DeviceConfig {
    /// Device address in `host:port` or `host` format (default port 830).
    pub host: String,
    /// SSH username.
    pub username: String,
    /// SSH authentication method.
    pub auth: SshAuth,
    /// Optional explicit vendor profile. `None` = auto-detect.
    ///
    /// TODO: This field is not yet wired into connection setup. `connect_device()`
    /// takes `&DeviceConfig` (shared reference via `Arc`), so the `Box<dyn VendorProfile>`
    /// cannot be moved out. To wire this, either change the field to
    /// `Option<Arc<dyn VendorProfile>>` and add a corresponding `Session::set_vendor_profile_arc()`
    /// method, or restructure the pool to use per-device mutexes so the config
    /// can be consumed on first connect.
    pub vendor: Option<Box<dyn VendorProfile>>,
}

/// Builder for constructing a `DevicePool`.
pub struct DevicePoolBuilder {
    devices: HashMap<String, DeviceConfig>,
    max_connections: usize,
    checkout_timeout: Duration,
}

impl DevicePoolBuilder {
    /// Set the maximum number of concurrent connections across all devices.
    pub fn max_connections(mut self, max: usize) -> Self {
        self.max_connections = max;
        self
    }

    /// Set the timeout for checking out a connection.
    ///
    /// If all connections are in use and none become available within this
    /// duration, `checkout()` returns an error instead of blocking forever.
    pub fn checkout_timeout(mut self, timeout: Duration) -> Self {
        self.checkout_timeout = timeout;
        self
    }

    /// Add a device to the pool.
    pub fn add_device(mut self, name: &str, config: DeviceConfig) -> Self {
        self.devices.insert(name.to_string(), config);
        self
    }

    /// Build the pool.
    pub fn build(self) -> DevicePool {
        DevicePool {
            devices: Arc::new(self.devices),
            semaphore: Arc::new(Semaphore::new(self.max_connections)),
            connections: Arc::new(Mutex::new(HashMap::new())),
            checkout_timeout: self.checkout_timeout,
        }
    }
}

/// Async connection pool for multi-device NETCONF operations.
pub struct DevicePool {
    devices: Arc<HashMap<String, DeviceConfig>>,
    semaphore: Arc<Semaphore>,
    connections: Arc<Mutex<HashMap<String, Vec<Client>>>>,
    checkout_timeout: Duration,
}

impl DevicePool {
    /// Create a new pool builder.
    pub fn builder() -> DevicePoolBuilder {
        DevicePoolBuilder {
            devices: HashMap::new(),
            max_connections: 10,
            checkout_timeout: Duration::from_secs(30),
        }
    }

    /// Check out a connection to a named device.
    ///
    /// Returns a `PoolGuard` that dereferences to `Client` and automatically
    /// returns the connection to the pool when dropped (if still healthy).
    ///
    /// Blocks until a semaphore permit is available, up to `checkout_timeout`.
    pub async fn checkout(&self, device_name: &str) -> Result<PoolGuard<'_>, NetconfError> {
        let config = self.devices.get(device_name).ok_or_else(|| {
            crate::error::TransportError::Connect(format!(
                "unknown device in pool: '{device_name}'"
            ))
        })?;

        // Acquire semaphore permit with timeout
        let permit = tokio::time::timeout(self.checkout_timeout, self.semaphore.acquire())
            .await
            .map_err(|_| {
                crate::error::TransportError::Connect(format!(
                    "pool checkout timeout after {:?} — all connections in use",
                    self.checkout_timeout,
                ))
            })?
            .map_err(|_| {
                crate::error::TransportError::Connect("pool semaphore closed".to_string())
            })?;

        // Try to reuse an idle connection, skipping any dead sessions
        {
            let mut conns = self.connections.lock().await;
            if let Some(pool) = conns.get_mut(device_name) {
                while let Some(client) = pool.pop() {
                    if client.session_alive() {
                        return Ok(PoolGuard {
                            client: Some(client),
                            device_name: device_name.to_string(),
                            connections: Arc::clone(&self.connections),
                            _permit: permit,
                        });
                    }
                    tracing::warn!(
                        device = device_name,
                        "discarding dead idle connection from pool"
                    );
                }
            }
        }

        // No idle connection — create a new one
        let client = connect_device(config).await?;

        Ok(PoolGuard {
            client: Some(client),
            device_name: device_name.to_string(),
            connections: Arc::clone(&self.connections),
            _permit: permit,
        })
    }

    /// Get the names of all devices in the pool.
    pub fn device_names(&self) -> Vec<String> {
        self.devices.keys().cloned().collect()
    }

    /// Get the number of available semaphore permits.
    pub fn available_connections(&self) -> usize {
        self.semaphore.available_permits()
    }
}

/// A checked-out connection from the pool.
///
/// Dereferences to `Client` for direct use. Automatically returns the
/// connection to the pool when dropped (if the session is still alive).
/// Broken sessions are discarded.
pub struct PoolGuard<'a> {
    client: Option<Client>,
    device_name: String,
    connections: Arc<Mutex<HashMap<String, Vec<Client>>>>,
    _permit: SemaphorePermit<'a>,
}

impl<'a> PoolGuard<'a> {
    /// Mark the connection as broken — it will be discarded on drop.
    ///
    /// # Warning
    ///
    /// After calling `discard()` the guard's `client` field is `None`. Any
    /// subsequent dereference (via `Deref`/`DerefMut`) will panic. The guard
    /// should be dropped immediately after calling this method and must not be
    /// used to issue further RPCs.
    pub fn discard(&mut self) {
        self.client = None;
    }
}

impl<'a> Deref for PoolGuard<'a> {
    type Target = Client;

    fn deref(&self) -> &Client {
        self.client
            .as_ref()
            .expect("PoolGuard client already taken")
    }
}

impl<'a> DerefMut for PoolGuard<'a> {
    fn deref_mut(&mut self) -> &mut Client {
        self.client
            .as_mut()
            .expect("PoolGuard client already taken")
    }
}

impl<'a> Drop for PoolGuard<'a> {
    fn drop(&mut self) {
        if let Some(client) = self.client.take() {
            if !client.session_alive() {
                tracing::warn!(
                    device = %self.device_name,
                    "dropping dead connection — not returning to pool"
                );
                return;
            }

            let connections = Arc::clone(&self.connections);
            let device_name = self.device_name.clone();

            // Only spawn if a Tokio runtime is still available; during shutdown
            // (or in sync contexts) the connection is simply dropped.
            if let Ok(handle) = tokio::runtime::Handle::try_current() {
                handle.spawn(async move {
                    let mut conns = connections.lock().await;
                    conns.entry(device_name).or_default().push(client);
                });
            } else {
                tracing::warn!(
                    device = %device_name,
                    "no Tokio runtime on drop — connection not returned to pool"
                );
            }
        }
        // Semaphore permit released when _permit drops
    }
}

/// Connect to a device using its config.
async fn connect_device(config: &DeviceConfig) -> Result<Client, NetconfError> {
    let mut builder = Client::connect(&config.host).username(&config.username);

    match &config.auth {
        SshAuth::Password(pass) => {
            builder = builder.password(pass.as_str());
        }
        SshAuth::KeyFile { path, passphrase } => {
            builder = builder.key_file(path);
            if let Some(pass) = passphrase {
                builder = builder.key_passphrase(pass.as_str());
            }
        }
        SshAuth::Agent => {
            builder = builder.ssh_agent();
        }
    }

    builder.connect().await
}