use super::*;
impl SshConnectionManager {
pub fn new() -> Self {
let cache = Cache::builder()
.max_capacity(100)
.time_to_idle(Duration::from_secs(5 * 60)) .build();
Self { cache }
}
pub async fn get_with_context(
&self,
request: ConnectionRequest,
context: ExecutionContext,
) -> Result<mpsc::Sender<CmdJob>, ConnectError> {
self.get_with_request_and_recording(request, context.security_options, None)
.await
}
pub async fn execute_command_with_context(
&self,
request: ConnectionRequest,
command: Command,
context: ExecutionContext,
) -> Result<Output, ConnectError> {
let result = self
.execute_operation_with_context(request, SessionOperation::from(command), context)
.await
.map_err(|err| err.into_parts().0)?;
match result.steps.len() {
1 => Ok(result
.steps
.into_iter()
.next()
.expect("single step output should exist")
.into_output()),
count => Err(ConnectError::InternalServerError(format!(
"expected one output for command execution, got {count}"
))),
}
}
pub async fn execute_operation_with_context(
&self,
request: ConnectionRequest,
operation: SessionOperation,
context: ExecutionContext,
) -> Result<SessionOperationOutput, SessionOperationExecutionError> {
let device_addr = request.device_addr();
let sys = context.sys.clone();
self.get_with_request_and_recording(request, context.security_options, None)
.await
.map_err(|err| {
SessionOperationExecutionError::new(
err,
SessionOperationOutput {
success: false,
steps: Vec::new(),
},
)
})?;
let (_sender, client) = self.cache.get(&device_addr).await.ok_or_else(|| {
SessionOperationExecutionError::new(
ConnectError::InternalServerError("connection cache miss".to_string()),
SessionOperationOutput {
success: false,
steps: Vec::new(),
},
)
})?;
let mut client_guard = client.write().await;
client_guard
.execute_operation_detailed(&operation, sys.as_ref())
.await
.map_err(|err| {
let (error, partial_output) = err.into_parts();
SessionOperationExecutionError::new(error, partial_output)
})
}
pub async fn execute_command_flow_with_context(
&self,
request: ConnectionRequest,
flow: CommandFlow,
context: ExecutionContext,
) -> Result<CommandFlowOutput, ConnectError> {
self.execute_operation_with_context(request, SessionOperation::from(flow), context)
.await
.map(|output| output.into_command_flow_output())
.map_err(|err| err.into_parts().0)
}
pub async fn execute_tx_block_with_context(
&self,
request: ConnectionRequest,
block: TxBlock,
context: ExecutionContext,
) -> Result<TxResult, ConnectError> {
let device_addr = request.device_addr();
let sys = context.sys.clone();
self.get_with_request_and_recording(request, context.security_options, None)
.await?;
let (_sender, client) = self.cache.get(&device_addr).await.ok_or_else(|| {
ConnectError::InternalServerError("connection cache miss".to_string())
})?;
let mut client_guard = client.write().await;
client_guard.execute_tx_block(&block, sys.as_ref()).await
}
pub async fn execute_tx_workflow_with_context(
&self,
request: ConnectionRequest,
workflow: TxWorkflow,
context: ExecutionContext,
) -> Result<TxWorkflowResult, ConnectError> {
let device_addr = request.device_addr();
let sys = context.sys.clone();
self.get_with_request_and_recording(request, context.security_options, None)
.await?;
let (_sender, client) = self.cache.get(&device_addr).await.ok_or_else(|| {
ConnectError::InternalServerError("connection cache miss".to_string())
})?;
let mut client_guard = client.write().await;
client_guard
.execute_tx_workflow(&workflow, sys.as_ref())
.await
}
pub async fn upload_file_with_context(
&self,
request: ConnectionRequest,
upload: FileUploadRequest,
context: ExecutionContext,
) -> Result<(), ConnectError> {
let device_addr = request.device_addr();
self.get_with_request_and_recording(request, context.security_options, None)
.await?;
let (_sender, client) = self.cache.get(&device_addr).await.ok_or_else(|| {
ConnectError::InternalServerError("connection cache miss".to_string())
})?;
let mut client_guard = client.write().await;
client_guard.upload_file(&upload).await
}
pub async fn get_with_recording_and_context(
&self,
request: ConnectionRequest,
context: ExecutionContext,
) -> Result<(mpsc::Sender<CmdJob>, SessionRecorder), ConnectError> {
self.get_with_recording_level_and_context(request, context, SessionRecordLevel::Full)
.await
}
pub async fn get_with_recording_level_and_context(
&self,
request: ConnectionRequest,
context: ExecutionContext,
level: SessionRecordLevel,
) -> Result<(mpsc::Sender<CmdJob>, SessionRecorder), ConnectError> {
let recorder = SessionRecorder::new(level);
let sender = self
.get_with_request_and_recording(
request,
context.security_options,
Some(recorder.clone()),
)
.await?;
Ok((sender, recorder))
}
async fn get_with_request_and_recording(
&self,
request: ConnectionRequest,
security_options: ConnectionSecurityOptions,
recorder: Option<SessionRecorder>,
) -> Result<mpsc::Sender<CmdJob>, ConnectError> {
let device_addr = request.device_addr();
let ConnectionRequest {
user,
addr,
port,
password,
enable_password,
handler,
} = request;
if let Some((sender, client)) = self.cache.get(&device_addr).await {
debug!("Cache hit: {}", device_addr);
let client_guard = client.read().await;
if client_guard.is_connected() {
if client_guard.matches_connection_params(
&password,
&enable_password,
&handler,
&security_options,
) {
debug!("Cached connection params match, reusing: {}", device_addr);
if recorder.is_some() {
drop(client_guard);
let mut client_guard = client.write().await;
client_guard.recorder = recorder.clone();
}
return Ok(sender);
} else {
debug!(
"Cached connection params mismatch, recreating: {}",
device_addr
);
drop(client_guard);
match self
.safely_disconnect_cached_connection(&device_addr, client.clone())
.await
{
Ok(_) => debug!("Old connection safely disconnected: {}", device_addr),
Err(e) => debug!(
"Error disconnecting old connection: {} - {}",
device_addr, e
),
}
self.cache.invalidate(&device_addr).await;
}
} else {
debug!("Cached connection {} is closed. Removing.", device_addr);
self.cache.invalidate(&device_addr).await;
}
} else {
debug!("Cache miss, creating new connection for {}...", device_addr);
}
let ssh_client = SharedSshClient::new(
user,
addr,
port,
password,
enable_password,
handler,
security_options,
recorder,
)
.await?;
let client_arc = Arc::new(RwLock::new(ssh_client));
let (tx, mut rx) = mpsc::channel::<CmdJob>(32);
let client_clone = client_arc.clone();
let worker_device_addr = device_addr.clone();
tokio::spawn(async move {
loop {
if let Some(job) = rx.recv().await {
if !client_clone.read().await.is_connected() {
let _ = job.responder.send(Err(ConnectError::ConnectClosedError));
break;
}
let res = {
let mut client_guard = client_clone.write().await;
let Command {
mode,
command,
timeout,
dyn_params,
interaction,
} = job.data;
let timeout = Duration::from_secs(timeout.unwrap_or(60));
client_guard
.write_with_mode_and_timeout_using_command(
&command,
&mode,
job.sys.as_ref(),
timeout,
&dyn_params,
&interaction,
)
.await
};
let _ = job.responder.send(res);
} else {
debug!(
"Command channel closed for {}, stopping worker.",
worker_device_addr
);
break;
}
}
});
self.cache
.insert(device_addr.clone(), (tx.clone(), client_arc))
.await;
debug!("New connection for {} has been cached.", device_addr);
Ok(tx)
}
async fn safely_disconnect_cached_connection(
&self,
device_addr: &str,
client_arc: Arc<RwLock<SharedSshClient>>,
) -> Result<(), ConnectError> {
debug!("Safely disconnecting cached connection: {}", device_addr);
let mut client_guard = client_arc.write().await;
if !client_guard.is_connected() {
debug!("Connection {} already disconnected, skipping", device_addr);
return Ok(());
}
match client_guard.close().await {
Ok(_) => {
debug!("Connection {} safely closed", device_addr);
Ok(())
}
Err(e) => {
debug!("Error closing connection {}: {}", device_addr, e);
Ok(())
}
}
}
}
impl Default for SshConnectionManager {
fn default() -> Self {
Self::new()
}
}