use crate::PluginManager;
use crate::plugin_types::{PluginConnection, Plugins};
use async_trait::async_trait;
use genja_core::inventory::{
Connection, ConnectionFactory, ConnectionKey, ResolvedConnectionParams,
};
use std::sync::Arc;
use tokio::sync::Mutex;
#[derive(Debug)]
#[doc(hidden)]
pub struct PluginConnectionAdapter {
inner: Box<dyn PluginConnection>,
alive: bool,
}
impl PluginConnectionAdapter {
fn new(inner: Box<dyn PluginConnection>) -> Self {
Self {
inner,
alive: false,
}
}
#[doc(hidden)]
pub fn inner_plugin_connection(&self) -> &dyn PluginConnection {
self.inner.as_ref()
}
}
#[async_trait]
impl Connection for PluginConnectionAdapter {
fn create(&self, key: &ConnectionKey) -> Box<dyn Connection> {
let instance = self.inner.create(key);
Box::new(PluginConnectionAdapter::new(instance))
}
fn is_alive(&self) -> bool {
self.alive
}
async fn open(&mut self, params: &ResolvedConnectionParams) -> Result<(), String> {
let result = self.inner.open(params).await;
if result.is_ok() {
self.alive = true;
}
result
}
async fn execute_command(&mut self, command: &str) -> Result<String, String> {
self.inner.execute_command(command).await
}
fn close(&mut self) -> ConnectionKey {
let key = self.inner.close();
self.alive = false;
key
}
}
pub fn build_connection_factory(plugins: Arc<PluginManager>) -> Arc<ConnectionFactory> {
Arc::new(move |key: &ConnectionKey| {
let plugin = plugins.get_plugin(&key.plugin_name)?;
match plugin {
Plugins::Connection(connection) => {
let instance = connection.create(key);
let adapter = PluginConnectionAdapter::new(instance);
Some(Arc::new(Mutex::new(adapter)) as Arc<Mutex<dyn Connection>>)
}
_ => None,
}
})
}
#[cfg(test)]
mod tests {
use super::*;
use crate::plugin_types::{Plugin, PluginConnection, PluginRunner};
use genja_core::inventory::Connection;
use genja_core::inventory::ConnectionManager;
use genja_core::task::Tasks;
use std::future::Future;
use tokio::runtime::Builder;
fn run_async<F: Future>(future: F) -> F::Output {
Builder::new_current_thread()
.enable_all()
.build()
.expect("test runtime should build")
.block_on(future)
}
#[derive(Debug)]
struct TestConnection {
name: &'static str,
key: ConnectionKey,
alive: bool,
}
impl TestConnection {
fn new(name: &'static str, key: ConnectionKey) -> Self {
Self {
name,
key,
alive: false,
}
}
}
impl Plugin for TestConnection {
fn name(&self) -> String {
self.name.to_string()
}
}
#[async_trait]
impl PluginConnection for TestConnection {
fn create(&self, key: &ConnectionKey) -> Box<dyn PluginConnection> {
Box::new(Self::new(self.name, key.clone()))
}
async fn open(&mut self, _params: &ResolvedConnectionParams) -> Result<(), String> {
self.alive = true;
Ok(())
}
fn close(&mut self) -> ConnectionKey {
self.alive = false;
self.key.clone()
}
fn is_alive(&self) -> bool {
self.alive
}
}
#[derive(Debug)]
struct DummyRunner {
name: &'static str,
}
impl Plugin for DummyRunner {
fn name(&self) -> String {
self.name.to_string()
}
}
#[async_trait]
impl PluginRunner for DummyRunner {
async fn run_task(
&self,
_task: &genja_core::task::TaskDefinition,
_hosts: &genja_core::inventory::Hosts,
_connection_resolver: Option<
std::sync::Arc<dyn genja_core::task::TaskConnectionResolver>,
>,
_runner_config: &genja_core::settings::RunnerConfig,
_max_depth: usize,
) -> Result<genja_core::task::TaskResults, genja_core::GenjaError> {
Ok(genja_core::task::TaskResults::new("runner"))
}
async fn run_tasks(
&self,
_tasks: &Tasks,
_hosts: &genja_core::inventory::Hosts,
_connection_resolver: Option<
std::sync::Arc<dyn genja_core::task::TaskConnectionResolver>,
>,
_runner_config: &genja_core::settings::RunnerConfig,
_max_depth: usize,
) -> Result<Vec<genja_core::task::TaskResults>, genja_core::GenjaError> {
Ok(Vec::new())
}
}
fn default_params() -> ResolvedConnectionParams {
ResolvedConnectionParams {
hostname: "host1".to_string(),
port: Some(22),
username: Some("user".to_string()),
password: Some("pass".to_string()),
platform: Some("linux".to_string()),
extras: None,
}
}
#[test]
fn adapter_open_close_updates_alive_and_returns_key() {
let key = ConnectionKey::new("host1", "ssh");
let plugin = TestConnection::new("ssh", key.clone());
let mut adapter = PluginConnectionAdapter::new(Box::new(plugin));
assert!(!adapter.is_alive());
run_async(adapter.open(&default_params())).unwrap();
assert!(adapter.is_alive());
let closed_key = adapter.close();
assert_eq!(closed_key, key);
assert!(!adapter.is_alive());
}
#[test]
fn adapter_create_uses_plugin_create_and_starts_dead() {
let key = ConnectionKey::new("host1", "ssh");
let plugin = TestConnection::new("ssh", key.clone());
let adapter = PluginConnectionAdapter::new(Box::new(plugin));
let new_key = ConnectionKey::new("host2", "ssh");
let new_conn = adapter.create(&new_key);
assert!(!new_conn.is_alive());
}
#[test]
fn factory_returns_none_for_missing_or_non_connection_plugins() {
let manager = Arc::new(PluginManager::new());
let factory = build_connection_factory(Arc::clone(&manager));
let key = ConnectionKey::new("host1", "ssh");
assert!(factory(&key).is_none());
let mut manager = PluginManager::new();
manager.register_plugin(Plugins::Runner(Box::new(DummyRunner { name: "runner" })));
let factory = build_connection_factory(Arc::new(manager));
let key = ConnectionKey::new("host1", "runner");
assert!(factory(&key).is_none());
}
#[test]
fn factory_returns_adapter_for_connection_plugins() {
let key = ConnectionKey::new("host1", "ssh");
let plugin = TestConnection::new("ssh", key.clone());
let mut manager = PluginManager::new();
manager.register_plugin(Plugins::Connection(Box::new(plugin)));
let factory = build_connection_factory(Arc::new(manager));
let connection = factory(&key).expect("expected connection plugin");
{
let mut guard = run_async(connection.lock());
assert!(!guard.is_alive());
run_async(guard.open(&default_params())).unwrap();
assert!(guard.is_alive());
let closed_key = guard.close();
assert_eq!(closed_key, key);
assert!(!guard.is_alive());
}
}
#[test]
fn manager_counters_increment_on_open_and_close() {
let key = ConnectionKey::new("host1", "ssh");
let plugin = TestConnection::new("ssh", key.clone());
let mut manager = PluginManager::new();
manager.register_plugin(Plugins::Connection(Box::new(plugin)));
let factory = build_connection_factory(Arc::new(manager));
let connection_manager = ConnectionManager::with_connection_factory(factory);
let params = default_params();
let connection = run_async(connection_manager.open_connection(&key, ¶ms))
.unwrap()
.unwrap();
let counters = connection_manager.connection_counters_for("ssh").unwrap();
assert_eq!(counters.create_calls, 1);
assert_eq!(counters.open_calls, 1);
assert_eq!(counters.close_calls, 0);
drop(connection);
connection_manager.close_connection(&key);
let counters = connection_manager.connection_counters_for("ssh").unwrap();
assert_eq!(counters.create_calls, 1);
assert_eq!(counters.open_calls, 1);
assert_eq!(counters.close_calls, 1);
}
#[test]
fn open_connection_twice_does_not_double_count_open() {
let key = ConnectionKey::new("host1", "ssh");
let plugin = TestConnection::new("ssh", key.clone());
let mut manager = PluginManager::new();
manager.register_plugin(Plugins::Connection(Box::new(plugin)));
let factory = build_connection_factory(Arc::new(manager));
let connection_manager = ConnectionManager::with_connection_factory(factory);
let params = default_params();
run_async(connection_manager.open_connection(&key, ¶ms))
.unwrap()
.unwrap();
let counters_after_first = connection_manager.connection_counters_for("ssh").unwrap();
assert_eq!(counters_after_first.create_calls, 1);
assert_eq!(counters_after_first.open_calls, 1);
run_async(connection_manager.open_connection(&key, ¶ms))
.unwrap()
.unwrap();
let counters_after_second = connection_manager.connection_counters_for("ssh").unwrap();
assert_eq!(counters_after_second.create_calls, 1);
assert_eq!(counters_after_second.open_calls, 1);
}
#[test]
fn open_connection_errors_when_factory_missing() {
let connection_manager = ConnectionManager::default();
let key = ConnectionKey::new("host1", "ssh");
let params = default_params();
let err = run_async(connection_manager.open_connection(&key, ¶ms)).unwrap_err();
assert_eq!(err, "connection factory not set");
}
#[test]
fn connection_counters_snapshot_tracks_multiple_types() {
let key_ssh = ConnectionKey::new("host1", "ssh");
let key_telnet = ConnectionKey::new("host2", "telnet");
let plugin_ssh = TestConnection::new("ssh", key_ssh.clone());
let plugin_telnet = TestConnection::new("telnet", key_telnet.clone());
let mut manager = PluginManager::new();
manager.register_plugin(Plugins::Connection(Box::new(plugin_ssh)));
manager.register_plugin(Plugins::Connection(Box::new(plugin_telnet)));
let factory = build_connection_factory(Arc::new(manager));
let connection_manager = ConnectionManager::with_connection_factory(factory);
let params = default_params();
run_async(connection_manager.open_connection(&key_ssh, ¶ms))
.unwrap()
.unwrap();
run_async(connection_manager.open_connection(&key_telnet, ¶ms))
.unwrap()
.unwrap();
let snapshot = connection_manager.connection_counters_snapshot();
let ssh = snapshot.get("ssh").copied().unwrap();
let telnet = snapshot.get("telnet").copied().unwrap();
assert_eq!(ssh.create_calls, 1);
assert_eq!(ssh.open_calls, 1);
assert_eq!(ssh.close_calls, 0);
assert_eq!(telnet.create_calls, 1);
assert_eq!(telnet.open_calls, 1);
assert_eq!(telnet.close_calls, 0);
}
#[test]
fn factory_connection_is_thread_safe() {
let key = ConnectionKey::new("host1", "ssh");
let plugin = TestConnection::new("ssh", key.clone());
let mut manager = PluginManager::new();
manager.register_plugin(Plugins::Connection(Box::new(plugin)));
let factory = build_connection_factory(Arc::new(manager));
let manager = Arc::new(ConnectionManager::with_connection_factory(factory));
let barrier = Arc::new(std::sync::Barrier::new(3));
let params = Arc::new(default_params());
let barrier_a = Arc::clone(&barrier);
let params_a = Arc::clone(¶ms);
let manager_a = Arc::clone(&manager);
let key_a = key.clone();
let thread_a = std::thread::spawn(move || {
barrier_a.wait();
run_async(manager_a.open_connection(&key_a, ¶ms_a))
.unwrap()
.unwrap();
});
let barrier_b = Arc::clone(&barrier);
let params_b = Arc::clone(¶ms);
let manager_b = Arc::clone(&manager);
let key_b = key.clone();
let thread_b = std::thread::spawn(move || {
barrier_b.wait();
run_async(manager_b.open_connection(&key_b, ¶ms_b))
.unwrap()
.unwrap();
});
barrier.wait();
thread_a.join().unwrap();
thread_b.join().unwrap();
manager.close_connection(&key);
let counters = manager.connection_counters_for("ssh").unwrap();
assert_eq!(counters.create_calls, 1);
assert_eq!(counters.open_calls, 1);
assert_eq!(counters.close_calls, 1);
}
#[test]
fn adapter_open_error_keeps_alive_false() {
#[derive(Debug)]
struct FailingConnection;
impl Plugin for FailingConnection {
fn name(&self) -> String {
"fail".to_string()
}
}
#[async_trait]
impl PluginConnection for FailingConnection {
fn create(&self, _key: &ConnectionKey) -> Box<dyn PluginConnection> {
Box::new(Self)
}
async fn open(&mut self, _params: &ResolvedConnectionParams) -> Result<(), String> {
Err("boom".to_string())
}
fn close(&mut self) -> ConnectionKey {
ConnectionKey::new("host1", "fail")
}
fn is_alive(&self) -> bool {
false
}
}
let mut adapter = PluginConnectionAdapter::new(Box::new(FailingConnection));
assert!(!adapter.is_alive());
let err = run_async(adapter.open(&default_params())).unwrap_err();
assert_eq!(err, "boom");
assert!(!adapter.is_alive());
}
#[test]
fn adapter_create_can_be_called_multiple_times() {
let key = ConnectionKey::new("host1", "ssh");
let plugin = TestConnection::new("ssh", key);
let adapter = PluginConnectionAdapter::new(Box::new(plugin));
let key_a = ConnectionKey::new("host-a", "ssh");
let key_b = ConnectionKey::new("host-b", "ssh");
let conn_a = adapter.create(&key_a);
let conn_b = adapter.create(&key_b);
assert!(!conn_a.is_alive());
assert!(!conn_b.is_alive());
}
}