use async_trait::async_trait;
use tracing::{debug, warn};
use winrm_rs::{RESOURCE_URI_PSRP, Shell, SoapError, WinrmClient, WinrmError};
use crate::error::{PsrpError, Result};
#[async_trait]
pub trait PsrpTransport: Send {
async fn send_fragment(&self, bytes: &[u8]) -> Result<()>;
async fn recv_chunk(&mut self) -> Result<Vec<u8>>;
async fn signal_stop(&self) -> Result<()>;
async fn close_shell(&mut self) -> Result<()>;
async fn execute_pipeline(
&mut self,
fragment_bytes: &[u8],
_pipeline_id: uuid::Uuid,
) -> Result<()> {
self.send_fragment(fragment_bytes).await
}
async fn disconnect_shell(&mut self) -> Result<String> {
Err(PsrpError::protocol(
"this transport does not implement disconnect_shell",
))
}
}
pub struct WinrmPsrpTransport<'c> {
shell: Option<Shell<'c>>,
command_id: String,
done: bool,
has_command: bool,
}
impl std::fmt::Debug for WinrmPsrpTransport<'_> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("WinrmPsrpTransport")
.field("command_id", &self.command_id)
.field("has_shell", &self.shell.is_some())
.field("done", &self.done)
.finish()
}
}
impl<'c> WinrmPsrpTransport<'c> {
pub async fn open(
client: &'c WinrmClient,
host: &str,
creation_fragments: &[u8],
) -> Result<Self> {
let creation_b64 = crate::clixml::encode::base64_encode(creation_fragments);
let shell = client
.open_psrp_shell(host, &creation_b64, RESOURCE_URI_PSRP)
.await?;
let command_id = shell.shell_id().to_string();
debug!(command_id, "PSRP transport started (no command yet)");
Ok(Self {
shell: Some(shell),
command_id,
done: false,
has_command: false,
})
}
pub async fn reconnect(client: &'c WinrmClient, host: &str, shell_id: &str) -> Result<Self> {
let shell = client
.reconnect_shell(host, shell_id, RESOURCE_URI_PSRP)
.await?;
let command_id = shell.shell_id().to_string();
debug!(command_id, "PSRP transport reconnected");
Ok(Self {
shell: Some(shell),
command_id,
done: false,
has_command: false,
})
}
pub async fn start_pipeline_command(&mut self) -> Result<()> {
let shell = self.shell()?;
let cmd_id = shell.start_command("", &[]).await?;
debug!(cmd_id, "PSRP pipeline command started");
self.command_id = cmd_id;
self.has_command = true;
Ok(())
}
fn shell(&self) -> Result<&Shell<'c>> {
self.shell
.as_ref()
.ok_or_else(|| PsrpError::protocol("transport closed"))
}
}
#[async_trait]
impl PsrpTransport for WinrmPsrpTransport<'_> {
async fn send_fragment(&self, bytes: &[u8]) -> Result<()> {
self.shell()?
.send_input(&self.command_id, bytes, false)
.await?;
Ok(())
}
async fn recv_chunk(&mut self) -> Result<Vec<u8>> {
loop {
let shell = self.shell()?;
match shell.receive_next(&self.command_id).await {
Ok(out) => {
if out.done {
self.done = true;
}
if out.stdout.is_empty() && !self.done {
continue;
}
return Ok(out.stdout);
}
Err(WinrmError::Timeout(_)) => continue,
Err(WinrmError::Soap(SoapError::Fault {
ref code,
ref reason,
})) if code.contains("TimedOut") => {
debug!(%code, %reason, "PSRP recv_chunk: w:TimedOut — retrying");
continue;
}
Err(WinrmError::Soap(SoapError::Fault { code, reason })) => {
warn!(%code, %reason, "PSRP transport SOAP fault — shell likely dead");
return Err(PsrpError::Winrm(WinrmError::Soap(SoapError::Fault {
code,
reason,
})));
}
Err(e) => return Err(PsrpError::Winrm(e)),
}
}
}
async fn execute_pipeline(
&mut self,
fragment_bytes: &[u8],
pipeline_id: uuid::Uuid,
) -> Result<()> {
let shell = self.shell()?;
let b64 = crate::clixml::encode::base64_encode(fragment_bytes);
let cmd_id = shell
.start_command_with_id("", &[&b64], &pipeline_id.hyphenated().to_string())
.await?;
debug!(cmd_id, "PSRP pipeline Execute started");
self.command_id = cmd_id;
self.has_command = true;
Ok(())
}
async fn signal_stop(&self) -> Result<()> {
self.shell()?.signal_ctrl_c(&self.command_id).await?;
Ok(())
}
async fn close_shell(&mut self) -> Result<()> {
if let Some(shell) = self.shell.take() {
shell.close().await?;
}
Ok(())
}
async fn disconnect_shell(&mut self) -> Result<String> {
let shell = self
.shell
.take()
.ok_or_else(|| PsrpError::protocol("transport closed"))?;
let id = shell.disconnect().await?;
Ok(id)
}
}
impl Drop for WinrmPsrpTransport<'_> {
fn drop(&mut self) {
if self.shell.is_some() {
warn!("WinrmPsrpTransport dropped without close — shell leaked server-side");
}
}
}
#[cfg(test)]
pub(crate) mod mock {
use super::*;
use std::collections::VecDeque;
use std::sync::{Arc, Mutex};
#[derive(Clone, Default)]
pub struct MockTransport {
pub inbox: Arc<Mutex<VecDeque<Vec<u8>>>>, pub outbox: Arc<Mutex<Vec<Vec<u8>>>>, pub stopped: Arc<Mutex<bool>>,
pub closed: Arc<Mutex<bool>>,
pub fail_send: Arc<Mutex<bool>>,
pub fail_recv: Arc<Mutex<Option<PsrpError>>>,
}
impl MockTransport {
pub fn new() -> Self {
Self::default()
}
pub fn push_incoming(&self, bytes: Vec<u8>) {
self.inbox.lock().unwrap().push_back(bytes);
}
pub fn sent(&self) -> Vec<Vec<u8>> {
self.outbox.lock().unwrap().clone()
}
}
#[async_trait]
impl PsrpTransport for MockTransport {
async fn send_fragment(&self, bytes: &[u8]) -> Result<()> {
if *self.fail_send.lock().unwrap() {
return Err(PsrpError::protocol("mock send failure"));
}
self.outbox.lock().unwrap().push(bytes.to_vec());
Ok(())
}
async fn recv_chunk(&mut self) -> Result<Vec<u8>> {
if let Some(e) = self.fail_recv.lock().unwrap().take() {
return Err(e);
}
let mut inbox = self.inbox.lock().unwrap();
if let Some(bytes) = inbox.pop_front() {
Ok(bytes)
} else {
Err(PsrpError::protocol("mock inbox empty"))
}
}
async fn signal_stop(&self) -> Result<()> {
*self.stopped.lock().unwrap() = true;
Ok(())
}
async fn close_shell(&mut self) -> Result<()> {
*self.closed.lock().unwrap() = true;
Ok(())
}
async fn disconnect_shell(&mut self) -> Result<String> {
*self.closed.lock().unwrap() = true;
Ok("MOCK-SHELL-ID".into())
}
}
#[tokio::test]
async fn mock_roundtrip() {
let mut t = MockTransport::new();
t.send_fragment(b"hello").await.unwrap();
assert_eq!(t.sent(), vec![b"hello".to_vec()]);
t.push_incoming(b"world".to_vec());
let got = t.recv_chunk().await.unwrap();
assert_eq!(got, b"world");
t.signal_stop().await.unwrap();
t.close_shell().await.unwrap();
assert!(*t.stopped.lock().unwrap());
assert!(*t.closed.lock().unwrap());
}
#[tokio::test]
async fn mock_recv_failure() {
let mut t = MockTransport::new();
*t.fail_recv.lock().unwrap() = Some(PsrpError::protocol("boom"));
assert!(t.recv_chunk().await.is_err());
}
#[tokio::test]
async fn mock_send_failure() {
let t = MockTransport::new();
*t.fail_send.lock().unwrap() = true;
assert!(t.send_fragment(b"x").await.is_err());
}
}