use std::sync::Arc;
use uuid::Uuid;
use crate::clixml::{PsObject, PsValue, parse_clixml, to_clixml};
use crate::error::{PsrpError, Result};
use crate::fragment::{Reassembler, encode_message};
use crate::host::{HostMethodId, NoInteractionHost, PsHost, dispatch_host_call};
use crate::message::{Destination, MessageType, PsrpMessage};
use crate::transport::PsrpTransport;
use super::state::{Action, RunspacePoolState, RunspacePoolStateMachine};
pub struct RunspacePool<T: PsrpTransport> {
transport: T,
reassembler: Reassembler,
machine: RunspacePoolStateMachine,
next_object_id: u64,
closed: bool,
host: Arc<dyn PsHost>,
session_key: Option<crate::crypto::SessionKey>,
pending_messages: Vec<PsrpMessage>,
}
impl<T: PsrpTransport> RunspacePool<T> {
pub async fn open_with_transport(transport: T) -> Result<Self> {
Self::open_with_options(transport, 1, 1).await
}
pub async fn open_with_options(
transport: T,
min_runspaces: i32,
max_runspaces: i32,
) -> Result<Self> {
Self::open_with_options_and_host(
transport,
min_runspaces,
max_runspaces,
Arc::new(NoInteractionHost),
)
.await
}
pub async fn open_with_options_and_host(
transport: T,
min_runspaces: i32,
max_runspaces: i32,
host: Arc<dyn PsHost>,
) -> Result<Self> {
let machine = RunspacePoolStateMachine::new(Uuid::new_v4(), min_runspaces, max_runspaces)?;
let mut pool = Self {
transport,
reassembler: Reassembler::new(),
machine,
next_object_id: 1,
closed: false,
host,
session_key: None,
pending_messages: Vec::new(),
};
let actions = pool.machine.open();
pool.execute(actions).await?;
pool.drain_until_opened().await?;
Ok(pool)
}
pub fn build_creation_fragments(
min_runspaces: i32,
max_runspaces: i32,
) -> Result<(Uuid, Vec<u8>)> {
let mut machine =
RunspacePoolStateMachine::new(Uuid::new_v4(), min_runspaces, max_runspaces)?;
let actions = machine.open();
let rpid = machine.rpid();
let mut next_oid = 1u64;
let mut buf = Vec::new();
for action in actions {
match action {
Action::SendMessage { message_type, body } => {
let msg = PsrpMessage {
destination: Destination::Server,
message_type,
rpid,
pid: Uuid::nil(),
data: body,
};
let encoded = msg.encode();
buf.extend_from_slice(&encode_message(next_oid, &encoded));
next_oid += 1;
}
}
}
Ok((rpid, buf))
}
pub async fn open_from_transport(
transport: T,
rpid: Uuid,
min_runspaces: i32,
max_runspaces: i32,
) -> Result<Self> {
Self::open_from_transport_with_host(
transport,
rpid,
min_runspaces,
max_runspaces,
Arc::new(NoInteractionHost),
)
.await
}
pub async fn open_from_transport_with_host(
transport: T,
rpid: Uuid,
min_runspaces: i32,
max_runspaces: i32,
host: Arc<dyn PsHost>,
) -> Result<Self> {
let mut machine = RunspacePoolStateMachine::new(rpid, min_runspaces, max_runspaces)?;
let _ = machine.open(); let mut pool = Self {
transport,
reassembler: Reassembler::new(),
machine,
next_object_id: 3, closed: false,
host,
session_key: None,
pending_messages: Vec::new(),
};
pool.drain_until_opened().await?;
Ok(pool)
}
#[must_use]
pub fn state(&self) -> RunspacePoolState {
self.machine.state()
}
#[must_use]
pub fn min_runspaces(&self) -> i32 {
self.machine.min_runspaces()
}
#[must_use]
pub fn max_runspaces(&self) -> i32 {
self.machine.max_runspaces()
}
#[must_use]
pub fn id(&self) -> Uuid {
self.machine.rpid()
}
pub fn ensure_state(&self, expected: RunspacePoolState) -> Result<()> {
if self.machine.state() != expected {
return Err(PsrpError::BadState {
expected: format!("{expected:?}"),
actual: format!("{:?}", self.machine.state()),
});
}
Ok(())
}
pub async fn disconnect(mut self) -> Result<DisconnectedPool> {
if self.closed {
return Err(PsrpError::protocol("pool already closed"));
}
self.closed = true;
let shell_id = self.transport.disconnect_shell().await?;
Ok(DisconnectedPool {
shell_id,
rpid: self.machine.rpid(),
min_runspaces: self.machine.min_runspaces(),
max_runspaces: self.machine.max_runspaces(),
host: self.host.clone(),
session_key: self.session_key.clone(),
})
}
pub async fn close(mut self) -> Result<()> {
if self.closed {
return Ok(());
}
self.closed = true;
let actions = self.machine.close();
let send_result = self.execute(actions).await;
if let Err(e) = &send_result {
tracing::debug!("CloseRunspacePool send failed: {e}");
}
self.transport.close_shell().await?;
self.machine.mark_closed();
send_result
}
pub async fn run_script(&mut self, script: &str) -> Result<Vec<PsValue>> {
crate::pipeline::Pipeline::new(script).run(self).await
}
#[must_use]
pub fn session_key(&self) -> Option<&crate::crypto::SessionKey> {
self.session_key.as_ref()
}
pub fn set_session_key(&mut self, key: crate::crypto::SessionKey) {
self.session_key = Some(key);
}
pub fn encrypt_secure_string(&self, plaintext: &str) -> Result<String> {
let key = self
.session_key
.as_ref()
.ok_or_else(|| PsrpError::protocol("session key not negotiated yet"))?;
let bytes = key.encrypt_secure_string(plaintext);
Ok(crate::clixml::encode::base64_encode(&bytes))
}
pub fn decrypt_secure_string(&self, base64_body: &str) -> Result<String> {
let key = self
.session_key
.as_ref()
.ok_or_else(|| PsrpError::protocol("session key not negotiated yet"))?;
let bytes = crate::clixml::encode::base64_decode(base64_body)
.ok_or_else(|| PsrpError::protocol("secure string: invalid base64"))?;
key.decrypt_secure_string(&bytes)
}
pub async fn request_session_key(&mut self) -> Result<()> {
if self.session_key.is_some() {
return Ok(());
}
let client_key = crate::crypto::ClientSessionKey::generate()?;
let blob_hex = client_key.public_blob_hex();
let body = crate::clixml::to_clixml(&PsValue::Object(
PsObject::new().with("PublicKey", PsValue::String(blob_hex)),
));
self.send_psrp_message(MessageType::PublicKey, body).await?;
loop {
let msg = self.next_message().await?;
if msg.message_type == MessageType::EncryptedSessionKey {
let parsed = parse_clixml(&msg.data)?;
let hex = parsed
.into_iter()
.find_map(|v| match v {
PsValue::Object(obj) => obj
.get("EncryptedSessionKey")
.and_then(PsValue::as_str)
.map(str::to_string),
_ => None,
})
.ok_or_else(|| {
PsrpError::protocol(
"EncryptedSessionKey message missing EncryptedSessionKey property",
)
})?;
let bytes = hex_decode(&hex)?;
let raw = client_key.decrypt_session_key(&bytes)?;
self.session_key = Some(crate::crypto::SessionKey::from_bytes(raw));
return Ok(());
}
}
}
pub async fn run_script_with_cancel(
&mut self,
script: &str,
cancel: tokio_util::sync::CancellationToken,
) -> Result<Vec<PsValue>> {
crate::pipeline::Pipeline::new(script)
.run_with_cancel(self, cancel)
.await
}
pub(crate) async fn send_psrp_message(&mut self, mt: MessageType, body: String) -> Result<()> {
let msg = PsrpMessage {
destination: Destination::Server,
message_type: mt,
rpid: self.machine.rpid(),
pid: Uuid::nil(),
data: body,
};
let encoded = msg.encode();
let oid = self.next_object_id;
self.next_object_id += 1;
let frag_bytes = encode_message(oid, &encoded);
self.transport.send_fragment(&frag_bytes).await
}
pub(crate) async fn next_message(&mut self) -> Result<PsrpMessage> {
loop {
if let Some(msg) = self.pending_messages.pop() {
match msg.message_type {
MessageType::RunspacePoolHostCall | MessageType::PipelineHostCall => {
self.handle_host_call(&msg).await?;
continue;
}
_ => return Ok(msg),
}
}
let chunk = self.transport.recv_chunk().await?;
if chunk.is_empty() {
if self.reassembler.is_idle() && self.pending_messages.is_empty() {
return Ok(PsrpMessage {
destination: Destination::Client,
message_type: MessageType::PipelineState,
rpid: self.machine.rpid(),
pid: Uuid::nil(),
data: crate::clixml::to_clixml(&PsValue::Object(
PsObject::new().with("PipelineState", PsValue::I32(4)),
)),
});
}
}
let completed = self.reassembler.feed(&chunk)?;
let mut decoded: Vec<PsrpMessage> = Vec::new();
for payload in completed {
decoded.push(PsrpMessage::decode(&payload)?);
}
decoded.reverse();
self.pending_messages.extend(decoded);
}
}
async fn handle_host_call(&mut self, msg: &PsrpMessage) -> Result<()> {
let parsed = parse_clixml(&msg.data)?;
let root = parsed
.into_iter()
.next()
.ok_or_else(|| PsrpError::protocol("host call: empty body"))?;
let obj = match &root {
PsValue::Object(o) => o,
_ => return Err(PsrpError::protocol("host call: body is not an object")),
};
let ci = obj
.get("ci")
.and_then(PsValue::as_i64)
.ok_or_else(|| PsrpError::protocol("host call: missing ci"))?;
let mi_raw = obj
.get("mi")
.and_then(|v| v.as_i64().or_else(|| v.as_i32().map(i64::from)))
.ok_or_else(|| PsrpError::protocol("host call: missing mi"))?;
let mi = HostMethodId::from_i64(mi_raw);
let mp: Vec<PsValue> = match obj.get("mp") {
Some(PsValue::List(items)) => items.clone(),
Some(PsValue::Object(inner)) => match inner.get("_value") {
Some(PsValue::List(items)) => items.clone(),
_ => Vec::new(),
},
_ => Vec::new(),
};
let result = dispatch_host_call(self.host.as_ref(), mi, &mp).await;
let response_value = match result {
Ok(None) => return Ok(()),
Ok(Some(v)) => v,
Err(e) => {
let body = build_host_response_error_body(ci, mi_raw, &e.to_string());
let response_type = match msg.message_type {
MessageType::RunspacePoolHostCall => MessageType::RunspacePoolHostResponse,
_ => MessageType::PipelineHostResponse,
};
return self
.send_pipeline_message(response_type, msg.pid, body)
.await;
}
};
let body = build_host_response_body(ci, mi_raw, &response_value);
let response_type = match msg.message_type {
MessageType::RunspacePoolHostCall => MessageType::RunspacePoolHostResponse,
_ => MessageType::PipelineHostResponse,
};
self.send_pipeline_message(response_type, msg.pid, body)
.await
}
pub(crate) fn allocate_object_id(&mut self) -> u64 {
let id = self.next_object_id;
self.next_object_id += 1;
id
}
pub(crate) async fn signal_transport_stop(&mut self) -> Result<()> {
self.transport.signal_stop().await
}
pub(crate) async fn send_pipeline_message(
&mut self,
mt: MessageType,
pid: Uuid,
body: String,
) -> Result<()> {
let msg = PsrpMessage {
destination: Destination::Server,
message_type: mt,
rpid: self.machine.rpid(),
pid,
data: body,
};
let encoded = msg.encode();
let oid = self.allocate_object_id();
let frag_bytes = encode_message(oid, &encoded);
if mt == MessageType::CreatePipeline {
self.transport.execute_pipeline(&frag_bytes, pid).await
} else {
self.transport.send_fragment(&frag_bytes).await
}
}
async fn execute(&mut self, actions: Vec<Action>) -> Result<()> {
for action in actions {
match action {
Action::SendMessage { message_type, body } => {
self.send_psrp_message(message_type, body).await?;
}
}
}
Ok(())
}
async fn drain_until_opened(&mut self) -> Result<()> {
loop {
let msg = self.next_message().await?;
self.machine.on_message(&msg)?;
if self.machine.is_opened() {
return Ok(());
}
}
}
}
pub struct DisconnectedPool {
shell_id: String,
rpid: Uuid,
min_runspaces: i32,
max_runspaces: i32,
host: Arc<dyn PsHost>,
session_key: Option<crate::crypto::SessionKey>,
}
impl std::fmt::Debug for DisconnectedPool {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("DisconnectedPool")
.field("rpid", &self.rpid)
.field("shell_id", &self.shell_id)
.field("min_runspaces", &self.min_runspaces)
.field("max_runspaces", &self.max_runspaces)
.finish()
}
}
impl DisconnectedPool {
#[must_use]
pub fn rpid(&self) -> Uuid {
self.rpid
}
#[must_use]
pub fn shell_id(&self) -> &str {
&self.shell_id
}
pub async fn reconnect<T: PsrpTransport>(self, transport: T) -> Result<RunspacePool<T>> {
let mut machine =
RunspacePoolStateMachine::new(self.rpid, self.min_runspaces, self.max_runspaces)?;
let actions = machine.connect();
let mut pool = RunspacePool {
transport,
reassembler: Reassembler::new(),
machine,
next_object_id: 1,
closed: false,
host: self.host,
session_key: self.session_key,
pending_messages: Vec::new(),
};
pool.execute(actions).await?;
pool.drain_until_opened().await?;
Ok(pool)
}
}
fn hex_decode(s: &str) -> Result<Vec<u8>> {
let clean: String = s.chars().filter(|c| !c.is_whitespace()).collect();
if clean.len() % 2 != 0 {
return Err(PsrpError::protocol("hex: odd length"));
}
let mut out = Vec::with_capacity(clean.len() / 2);
for pair in clean.as_bytes().chunks_exact(2) {
let hi = hex_digit(pair[0])?;
let lo = hex_digit(pair[1])?;
out.push((hi << 4) | lo);
}
Ok(out)
}
fn hex_digit(b: u8) -> Result<u8> {
match b {
b'0'..=b'9' => Ok(b - b'0'),
b'a'..=b'f' => Ok(b - b'a' + 10),
b'A'..=b'F' => Ok(b - b'A' + 10),
_ => Err(PsrpError::protocol(format!("hex: invalid digit 0x{b:02X}"))),
}
}
fn build_host_response_body(ci: i64, mi: i64, value: &PsValue) -> String {
let obj = PsObject::new()
.with("ci", PsValue::I64(ci))
.with("mi", PsValue::I64(mi))
.with("mr", value.clone());
to_clixml(&PsValue::Object(obj))
}
fn build_host_response_error_body(ci: i64, mi: i64, message: &str) -> String {
let exception = PsObject::new()
.with("Message", PsValue::String(message.to_string()))
.with_type_names(["System.Exception"]);
let obj = PsObject::new()
.with("ci", PsValue::I64(ci))
.with("mi", PsValue::I64(mi))
.with("me", PsValue::Object(exception));
to_clixml(&PsValue::Object(obj))
}
impl<T: PsrpTransport> std::fmt::Debug for RunspacePool<T> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("RunspacePool")
.field("state", &self.machine.state())
.field("rpid", &self.machine.rpid())
.field("min_runspaces", &self.machine.min_runspaces())
.field("max_runspaces", &self.machine.max_runspaces())
.field("next_object_id", &self.next_object_id)
.field("closed", &self.closed)
.field("reassembler_idle", &self.reassembler.is_idle())
.field("transport", &"..")
.finish()
}
}
impl<T: PsrpTransport> Drop for RunspacePool<T> {
fn drop(&mut self) {
if !self.closed {
tracing::warn!("RunspacePool dropped without close() — server-side shell may leak");
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::clixml::{PsObject, to_clixml};
use crate::transport::mock::MockTransport;
fn state_message_bytes(state: RunspacePoolState) -> Vec<u8> {
let body = to_clixml(&PsValue::Object(
PsObject::new().with("RunspaceState", PsValue::I32(state as i32)),
));
PsrpMessage {
destination: Destination::Client,
message_type: MessageType::RunspacePoolState,
rpid: Uuid::nil(),
pid: Uuid::nil(),
data: body,
}
.encode()
}
fn wire(oid: u64, payload: &[u8]) -> Vec<u8> {
encode_message(oid, payload)
}
#[tokio::test]
async fn open_reaches_opened_state() {
let t = MockTransport::new();
t.push_incoming(wire(100, &state_message_bytes(RunspacePoolState::Opened)));
let pool = RunspacePool::open_with_transport(t.clone()).await.unwrap();
assert_eq!(pool.state(), RunspacePoolState::Opened);
assert_eq!(t.sent().len(), 2); pool.close().await.unwrap();
assert!(*t.closed.lock().unwrap());
}
#[tokio::test]
async fn broken_state_during_open_errors() {
let t = MockTransport::new();
t.push_incoming(wire(1, &state_message_bytes(RunspacePoolState::Broken)));
let err = RunspacePool::open_with_transport(t).await.unwrap_err();
assert!(matches!(err, PsrpError::Protocol(_)));
}
#[tokio::test]
async fn invalid_runspace_bounds() {
let t = MockTransport::new();
let err = RunspacePool::open_with_options(t, 0, 0).await.unwrap_err();
assert!(matches!(err, PsrpError::Protocol(_)));
}
#[tokio::test]
async fn invalid_runspace_bounds_max_lt_min() {
let t = MockTransport::new();
let err = RunspacePool::open_with_options(t, 5, 3).await.unwrap_err();
assert!(matches!(err, PsrpError::Protocol(_)));
}
#[tokio::test]
async fn open_ignores_unrelated_messages() {
let t = MockTransport::new();
let app_data = PsrpMessage {
destination: Destination::Client,
message_type: MessageType::ApplicationPrivateData,
rpid: Uuid::nil(),
pid: Uuid::nil(),
data: "<Obj RefId=\"0\"><MS/></Obj>".into(),
}
.encode();
t.push_incoming(wire(1, &app_data));
t.push_incoming(wire(2, &state_message_bytes(RunspacePoolState::Opened)));
let pool = RunspacePool::open_with_transport(t).await.unwrap();
assert_eq!(pool.state(), RunspacePoolState::Opened);
let _ = pool.close().await;
}
#[tokio::test]
async fn ensure_state_works() {
let t = MockTransport::new();
t.push_incoming(wire(1, &state_message_bytes(RunspacePoolState::Opened)));
let pool = RunspacePool::open_with_transport(t).await.unwrap();
assert!(pool.ensure_state(RunspacePoolState::Opened).is_ok());
assert!(pool.ensure_state(RunspacePoolState::Broken).is_err());
let _ = pool.close().await;
}
#[tokio::test]
async fn close_is_idempotent_and_drop_warns() {
let t = MockTransport::new();
t.push_incoming(wire(1, &state_message_bytes(RunspacePoolState::Opened)));
let pool = RunspacePool::open_with_transport(t).await.unwrap();
pool.close().await.unwrap();
let t = MockTransport::new();
t.push_incoming(wire(1, &state_message_bytes(RunspacePoolState::Opened)));
let pool = RunspacePool::open_with_transport(t).await.unwrap();
drop(pool);
}
#[tokio::test]
async fn debug_impl_does_not_panic() {
let t = MockTransport::new();
t.push_incoming(wire(1, &state_message_bytes(RunspacePoolState::Opened)));
let pool = RunspacePool::open_with_transport(t).await.unwrap();
let s = format!("{pool:?}");
assert!(s.contains("RunspacePool"));
assert!(s.contains("Opened"));
let _ = pool.close().await;
}
#[tokio::test]
async fn min_max_runspaces_exposed() {
let t = MockTransport::new();
t.push_incoming(wire(1, &state_message_bytes(RunspacePoolState::Opened)));
let pool = RunspacePool::open_with_options(t, 2, 5).await.unwrap();
assert_eq!(pool.min_runspaces(), 2);
assert_eq!(pool.max_runspaces(), 5);
assert_ne!(pool.id(), Uuid::nil());
let _ = pool.close().await;
}
fn host_call_bytes(ci: i64, mi: i64, mp: PsValue) -> Vec<u8> {
let obj = PsObject::new()
.with("ci", PsValue::I64(ci))
.with("mi", PsValue::I64(mi))
.with("mp", mp);
let body = to_clixml(&PsValue::Object(obj));
PsrpMessage {
destination: Destination::Client,
message_type: MessageType::RunspacePoolHostCall,
rpid: Uuid::nil(),
pid: Uuid::nil(),
data: body,
}
.encode()
}
async fn opened_pool_with_host(
t: &MockTransport,
host: Arc<dyn PsHost>,
) -> RunspacePool<MockTransport> {
t.inbox
.lock()
.unwrap()
.push_front(wire(1, &state_message_bytes(RunspacePoolState::Opened)));
RunspacePool::open_with_options_and_host(t.clone(), 1, 1, host)
.await
.unwrap()
}
#[tokio::test]
async fn host_call_dispatches_write_line_and_swallows_message() {
let t = MockTransport::new();
let host = crate::host::BufferedHost::new();
t.push_incoming(wire(
2,
&host_call_bytes(
42,
crate::host::HostMethodId::WriteLine1.to_i64(),
PsValue::List(vec![PsValue::String("hello from host".into())]),
),
));
t.push_incoming(wire(
3,
&PsrpMessage {
destination: Destination::Client,
message_type: MessageType::PipelineOutput,
rpid: Uuid::nil(),
pid: Uuid::nil(),
data: "<I32>1</I32>".into(),
}
.encode(),
));
t.push_incoming(wire(
4,
&PsrpMessage {
destination: Destination::Client,
message_type: MessageType::PipelineState,
rpid: Uuid::nil(),
pid: Uuid::nil(),
data: to_clixml(&PsValue::Object(PsObject::new().with(
"PipelineState",
PsValue::I32(crate::pipeline::PipelineState::Completed as i32),
))),
}
.encode(),
));
let mut pool = opened_pool_with_host(&t, Arc::new(host.clone())).await;
let out = pool.run_script("irrelevant").await.unwrap();
assert_eq!(out, vec![PsValue::I32(1)]);
assert_eq!(host.lines(), vec!["hello from host".to_string()]);
assert_eq!(t.sent().len(), 3);
let _ = pool.close().await;
}
#[tokio::test]
async fn host_call_read_line_sends_error_response() {
let t = MockTransport::new();
t.push_incoming(wire(
2,
&host_call_bytes(
7,
crate::host::HostMethodId::ReadLine.to_i64(),
PsValue::List(Vec::new()),
),
));
t.push_incoming(wire(
3,
&PsrpMessage {
destination: Destination::Client,
message_type: MessageType::PipelineState,
rpid: Uuid::nil(),
pid: Uuid::nil(),
data: to_clixml(&PsValue::Object(PsObject::new().with(
"PipelineState",
PsValue::I32(crate::pipeline::PipelineState::Completed as i32),
))),
}
.encode(),
));
let mut pool = opened_pool_with_host(&t, Arc::new(crate::host::NoInteractionHost)).await;
let _ = pool.run_script("irrelevant").await;
assert_eq!(t.sent().len(), 4);
let _ = pool.close().await;
}
#[tokio::test]
async fn host_call_read_line_with_custom_host_sends_value_response() {
use async_trait::async_trait;
struct YesHost;
#[async_trait]
impl crate::host::PsHost for YesHost {
async fn read_line(&self) -> Result<String> {
Ok("alice".into())
}
}
let t = MockTransport::new();
t.push_incoming(wire(
2,
&host_call_bytes(
3,
crate::host::HostMethodId::ReadLine.to_i64(),
PsValue::List(Vec::new()),
),
));
t.push_incoming(wire(
3,
&PsrpMessage {
destination: Destination::Client,
message_type: MessageType::PipelineState,
rpid: Uuid::nil(),
pid: Uuid::nil(),
data: to_clixml(&PsValue::Object(PsObject::new().with(
"PipelineState",
PsValue::I32(crate::pipeline::PipelineState::Completed as i32),
))),
}
.encode(),
));
let mut pool = opened_pool_with_host(&t, Arc::new(YesHost)).await;
let _ = pool.run_script("irrelevant").await;
assert_eq!(t.sent().len(), 4);
let _ = pool.close().await;
}
#[tokio::test]
async fn encrypt_secure_string_without_key_errors() {
let t = MockTransport::new();
t.push_incoming(wire(1, &state_message_bytes(RunspacePoolState::Opened)));
let pool = RunspacePool::open_with_transport(t).await.unwrap();
let err = pool.encrypt_secure_string("x").unwrap_err();
assert!(matches!(err, PsrpError::Protocol(_)));
let err = pool.decrypt_secure_string("AAAA").unwrap_err();
assert!(matches!(err, PsrpError::Protocol(_)));
let _ = pool.close().await;
}
#[tokio::test]
async fn secure_string_roundtrip_via_set_session_key() {
let t = MockTransport::new();
t.push_incoming(wire(1, &state_message_bytes(RunspacePoolState::Opened)));
let mut pool = RunspacePool::open_with_transport(t).await.unwrap();
pool.set_session_key(crate::crypto::SessionKey::from_bytes([7u8; 32]));
assert!(pool.session_key().is_some());
let b64 = pool.encrypt_secure_string("héllo 🌍").unwrap();
let pt = pool.decrypt_secure_string(&b64).unwrap();
assert_eq!(pt, "héllo 🌍");
let _ = pool.close().await;
}
#[tokio::test]
async fn decrypt_secure_string_rejects_bad_base64() {
let t = MockTransport::new();
t.push_incoming(wire(1, &state_message_bytes(RunspacePoolState::Opened)));
let mut pool = RunspacePool::open_with_transport(t).await.unwrap();
pool.set_session_key(crate::crypto::SessionKey::from_bytes([0u8; 32]));
let err = pool.decrypt_secure_string("!!!not base64!!!").unwrap_err();
assert!(matches!(err, PsrpError::Protocol(_)));
let _ = pool.close().await;
}
#[tokio::test]
async fn disconnect_returns_disconnected_pool() {
let t = MockTransport::new();
t.push_incoming(wire(1, &state_message_bytes(RunspacePoolState::Opened)));
let pool = RunspacePool::open_with_options(t, 2, 5).await.unwrap();
let disconnected = pool.disconnect().await.unwrap();
assert_eq!(disconnected.shell_id(), "MOCK-SHELL-ID");
assert_ne!(disconnected.rpid(), Uuid::nil());
let s = format!("{disconnected:?}");
assert!(s.contains("DisconnectedPool"));
assert!(s.contains("MOCK-SHELL-ID"));
}
#[tokio::test]
async fn disconnect_then_reconnect_roundtrip() {
let t = MockTransport::new();
t.push_incoming(wire(1, &state_message_bytes(RunspacePoolState::Opened)));
let pool = RunspacePool::open_with_transport(t).await.unwrap();
let disconnected = pool.disconnect().await.unwrap();
let t2 = MockTransport::new();
t2.inbox
.lock()
.unwrap()
.push_front(wire(1, &state_message_bytes(RunspacePoolState::Opened)));
let pool = disconnected.reconnect(t2.clone()).await.unwrap();
assert_eq!(pool.state(), RunspacePoolState::Opened);
assert_eq!(t2.sent().len(), 2);
let _ = pool.close().await;
}
#[tokio::test]
async fn double_disconnect_errors() {
let t = MockTransport::new();
t.push_incoming(wire(1, &state_message_bytes(RunspacePoolState::Opened)));
let pool = RunspacePool::open_with_transport(t).await.unwrap();
let _disc = pool.disconnect().await.unwrap();
let t2 = MockTransport::new();
t2.push_incoming(wire(1, &state_message_bytes(RunspacePoolState::Opened)));
let pool2 = RunspacePool::open_with_transport(t2).await.unwrap();
pool2.close().await.unwrap();
}
#[tokio::test]
async fn hex_decode_roundtrip_and_errors() {
assert_eq!(
hex_decode("deadBEEF").unwrap(),
vec![0xde, 0xad, 0xbe, 0xef]
);
assert!(hex_decode("ABC").is_err()); assert!(hex_decode("GG").is_err()); }
#[test]
fn build_creation_fragments_produces_nonempty_bytes() {
let (rpid, bytes) = RunspacePool::<MockTransport>::build_creation_fragments(1, 1).unwrap();
assert!(!rpid.is_nil());
assert!(bytes.len() > 42); }
#[test]
fn build_creation_fragments_rejects_bad_bounds() {
let err = RunspacePool::<MockTransport>::build_creation_fragments(0, 0);
assert!(err.is_err());
}
#[tokio::test]
async fn open_from_transport_reaches_opened() {
let t = MockTransport::new();
let (rpid, _bytes) = RunspacePool::<MockTransport>::build_creation_fragments(1, 1).unwrap();
t.push_incoming(wire(100, &state_message_bytes(RunspacePoolState::Opened)));
let pool = RunspacePool::open_from_transport(t.clone(), rpid, 1, 1)
.await
.unwrap();
assert_eq!(pool.state(), RunspacePoolState::Opened);
pool.close().await.unwrap();
}
#[tokio::test]
async fn close_already_closed_is_noop() {
let t = MockTransport::new();
t.push_incoming(wire(100, &state_message_bytes(RunspacePoolState::Opened)));
let mut pool = RunspacePool::open_with_transport(t.clone()).await.unwrap();
pool.closed = true;
pool.close().await.unwrap(); }
#[tokio::test]
async fn disconnect_already_closed_errors() {
let t = MockTransport::new();
t.push_incoming(wire(100, &state_message_bytes(RunspacePoolState::Opened)));
let mut pool = RunspacePool::open_with_transport(t.clone()).await.unwrap();
pool.closed = true;
let err = pool.disconnect().await.unwrap_err();
assert!(err.to_string().contains("already closed"));
}
#[tokio::test]
async fn build_host_response_body_shapes() {
let body = build_host_response_body(3, 11, &PsValue::I32(0));
assert!(body.contains("<I64 N=\"ci\">3</I64>"));
assert!(body.contains("<I64 N=\"mi\">11</I64>"));
assert!(body.contains("N=\"mr\""));
let err = build_host_response_error_body(3, 51, "boom");
assert!(err.contains("<I64 N=\"ci\">3</I64>"));
assert!(err.contains("<I64 N=\"mi\">51</I64>"));
assert!(err.contains("boom"));
}
}