use std::collections::HashMap;
use std::ops::{Deref, DerefMut};
use std::sync::Arc;
use std::time::Duration;
use tokio::sync::{Mutex, Semaphore, SemaphorePermit};
use crate::client::Client;
use crate::error::NetconfError;
use crate::transport::ssh::SshAuth;
use crate::vendor::VendorProfile;
pub struct DeviceConfig {
pub host: String,
pub username: String,
pub auth: SshAuth,
pub vendor: Option<Box<dyn VendorProfile>>,
}
pub struct DevicePoolBuilder {
devices: HashMap<String, DeviceConfig>,
max_connections: usize,
checkout_timeout: Duration,
}
impl DevicePoolBuilder {
pub fn max_connections(mut self, max: usize) -> Self {
self.max_connections = max;
self
}
pub fn checkout_timeout(mut self, timeout: Duration) -> Self {
self.checkout_timeout = timeout;
self
}
pub fn add_device(mut self, name: &str, config: DeviceConfig) -> Self {
self.devices.insert(name.to_string(), config);
self
}
pub fn build(self) -> DevicePool {
DevicePool {
devices: Arc::new(self.devices),
semaphore: Arc::new(Semaphore::new(self.max_connections)),
connections: Arc::new(Mutex::new(HashMap::new())),
checkout_timeout: self.checkout_timeout,
}
}
}
pub struct DevicePool {
devices: Arc<HashMap<String, DeviceConfig>>,
semaphore: Arc<Semaphore>,
connections: Arc<Mutex<HashMap<String, Vec<Client>>>>,
checkout_timeout: Duration,
}
impl DevicePool {
pub fn builder() -> DevicePoolBuilder {
DevicePoolBuilder {
devices: HashMap::new(),
max_connections: 10,
checkout_timeout: Duration::from_secs(30),
}
}
pub async fn checkout(&self, device_name: &str) -> Result<PoolGuard<'_>, NetconfError> {
let config = self.devices.get(device_name).ok_or_else(|| {
crate::error::TransportError::Connect(format!(
"unknown device in pool: '{device_name}'"
))
})?;
let permit = tokio::time::timeout(self.checkout_timeout, self.semaphore.acquire())
.await
.map_err(|_| {
crate::error::TransportError::Connect(format!(
"pool checkout timeout after {:?} — all connections in use",
self.checkout_timeout,
))
})?
.map_err(|_| {
crate::error::TransportError::Connect("pool semaphore closed".to_string())
})?;
{
let mut conns = self.connections.lock().await;
if let Some(pool) = conns.get_mut(device_name) {
if let Some(client) = pool.pop() {
return Ok(PoolGuard {
client: Some(client),
device_name: device_name.to_string(),
connections: Arc::clone(&self.connections),
_permit: permit,
});
}
}
}
let client = connect_device(config).await?;
Ok(PoolGuard {
client: Some(client),
device_name: device_name.to_string(),
connections: Arc::clone(&self.connections),
_permit: permit,
})
}
pub fn device_names(&self) -> Vec<String> {
self.devices.keys().cloned().collect()
}
pub fn available_connections(&self) -> usize {
self.semaphore.available_permits()
}
}
pub struct PoolGuard<'a> {
client: Option<Client>,
device_name: String,
connections: Arc<Mutex<HashMap<String, Vec<Client>>>>,
_permit: SemaphorePermit<'a>,
}
impl<'a> PoolGuard<'a> {
pub fn discard(&mut self) {
self.client = None;
}
}
impl<'a> Deref for PoolGuard<'a> {
type Target = Client;
fn deref(&self) -> &Client {
self.client.as_ref().expect("PoolGuard client already taken")
}
}
impl<'a> DerefMut for PoolGuard<'a> {
fn deref_mut(&mut self) -> &mut Client {
self.client.as_mut().expect("PoolGuard client already taken")
}
}
impl<'a> Drop for PoolGuard<'a> {
fn drop(&mut self) {
if let Some(client) = self.client.take() {
let connections = Arc::clone(&self.connections);
let device_name = self.device_name.clone();
tokio::spawn(async move {
let mut conns = connections.lock().await;
conns.entry(device_name).or_default().push(client);
});
}
}
}
async fn connect_device(config: &DeviceConfig) -> Result<Client, NetconfError> {
let mut builder = Client::connect(&config.host).username(&config.username);
match &config.auth {
SshAuth::Password(pass) => {
builder = builder.password(pass);
}
SshAuth::KeyFile { path, passphrase } => {
builder = builder.key_file(path);
if let Some(pass) = passphrase {
builder = builder.key_passphrase(pass);
}
}
SshAuth::Agent => {
builder = builder.ssh_agent();
}
}
builder.connect().await
}