use uuid::Uuid;
use crate::clixml::{PsObject, PsValue, parse_clixml, to_clixml};
use crate::error::{PsrpError, Result};
use crate::message::{MessageType, PsrpMessage};
pub const PROTOCOL_VERSION: &str = "2.3";
pub(crate) const PS_VERSION: &str = "2.0";
pub(crate) const SERIALIZATION_VERSION: &str = "1.1.0.1";
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum RunspacePoolState {
BeforeOpen = 0,
Opening = 1,
Opened = 2,
Closed = 3,
Closing = 4,
Broken = 5,
Disconnecting = 6,
Disconnected = 7,
Connecting = 8,
NegotiationSent = 9,
NegotiationSucceeded = 10,
}
impl RunspacePoolState {
pub(crate) fn from_i32(v: i32) -> Self {
match v {
0 => Self::BeforeOpen,
1 => Self::Opening,
2 => Self::Opened,
3 => Self::Closed,
4 => Self::Closing,
5 => Self::Broken,
6 => Self::Disconnecting,
7 => Self::Disconnected,
8 => Self::Connecting,
9 => Self::NegotiationSent,
10 => Self::NegotiationSucceeded,
_ => Self::Broken,
}
}
}
#[derive(Debug, Clone)]
pub enum Action {
SendMessage {
message_type: MessageType,
body: String,
},
}
#[derive(Debug)]
pub struct RunspacePoolStateMachine {
state: RunspacePoolState,
rpid: Uuid,
min_runspaces: i32,
max_runspaces: i32,
}
impl RunspacePoolStateMachine {
pub fn new(rpid: Uuid, min_runspaces: i32, max_runspaces: i32) -> Result<Self> {
if min_runspaces < 1 || max_runspaces < min_runspaces {
return Err(PsrpError::protocol(format!(
"invalid runspace bounds: min={min_runspaces} max={max_runspaces}"
)));
}
Ok(Self {
state: RunspacePoolState::BeforeOpen,
rpid,
min_runspaces,
max_runspaces,
})
}
#[must_use]
pub fn state(&self) -> RunspacePoolState {
self.state
}
#[must_use]
pub fn rpid(&self) -> Uuid {
self.rpid
}
#[must_use]
pub fn min_runspaces(&self) -> i32 {
self.min_runspaces
}
#[must_use]
pub fn max_runspaces(&self) -> i32 {
self.max_runspaces
}
pub fn open(&mut self) -> Vec<Action> {
self.state = RunspacePoolState::Opening;
let actions = vec![
Action::SendMessage {
message_type: MessageType::SessionCapability,
body: session_capability_xml(),
},
Action::SendMessage {
message_type: MessageType::InitRunspacePool,
body: init_runspace_pool_xml(self.min_runspaces, self.max_runspaces),
},
];
self.state = RunspacePoolState::NegotiationSent;
actions
}
pub fn connect(&mut self) -> Vec<Action> {
self.state = RunspacePoolState::Connecting;
let actions = vec![
Action::SendMessage {
message_type: MessageType::SessionCapability,
body: session_capability_xml(),
},
Action::SendMessage {
message_type: MessageType::ConnectRunspacePool,
body: "<Obj RefId=\"0\"><MS/></Obj>".into(),
},
];
self.state = RunspacePoolState::NegotiationSent;
actions
}
pub fn on_message(&mut self, msg: &PsrpMessage) -> Result<()> {
match msg.message_type {
MessageType::RunspacePoolState => {
let new_state = extract_runspace_state(&msg.data)?;
self.state = new_state;
match new_state {
RunspacePoolState::Broken | RunspacePoolState::Closed => {
return Err(PsrpError::protocol(format!(
"runspace pool entered terminal state {new_state:?}"
)));
}
_ => {}
}
}
MessageType::SessionCapability
| MessageType::ApplicationPrivateData
| MessageType::RunspacePoolInitData
| MessageType::EncryptedSessionKey
| MessageType::PublicKeyRequest => {}
_ => {}
}
Ok(())
}
#[must_use]
pub fn is_opened(&self) -> bool {
self.state == RunspacePoolState::Opened
}
pub fn close(&mut self) -> Vec<Action> {
self.state = RunspacePoolState::Closing;
vec![Action::SendMessage {
message_type: MessageType::CloseRunspacePool,
body: "<Obj RefId=\"0\"><MS/></Obj>".into(),
}]
}
pub fn mark_closed(&mut self) {
self.state = RunspacePoolState::Closed;
}
}
pub(crate) fn session_capability_xml() -> String {
let obj = PsValue::Object(
PsObject::new()
.with("PSVersion", PsValue::Version(PS_VERSION.into()))
.with("protocolversion", PsValue::Version(PROTOCOL_VERSION.into()))
.with(
"SerializationVersion",
PsValue::Version(SERIALIZATION_VERSION.into()),
),
);
to_clixml(&obj)
}
pub(crate) fn init_runspace_pool_xml(min: i32, max: i32) -> String {
use crate::clixml::encode::{ps_enum, ps_host_info_null};
let obj = PsValue::Object(
PsObject::new()
.with("MinRunspaces", PsValue::I32(min))
.with("MaxRunspaces", PsValue::I32(max))
.with(
"PSThreadOptions",
ps_enum(
"System.Management.Automation.Runspaces.PSThreadOptions",
"Default",
0,
),
)
.with(
"ApartmentState",
ps_enum(
"System.Management.Automation.Runspaces.ApartmentState",
"UNKNOWN",
2,
),
)
.with("HostInfo", ps_host_info_null())
.with("ApplicationArguments", PsValue::Null),
);
to_clixml(&obj)
}
pub(crate) fn extract_runspace_state(xml: &str) -> Result<RunspacePoolState> {
let parsed = parse_clixml(xml)?;
for value in parsed {
if let PsValue::Object(obj) = value
&& let Some(PsValue::I32(code)) = obj.get("RunspaceState")
{
return Ok(RunspacePoolState::from_i32(*code));
}
}
Err(PsrpError::protocol(
"RunspacePoolState message missing RunspaceState property",
))
}
#[cfg(test)]
mod tests {
use super::*;
use crate::message::Destination;
fn state_msg(state: RunspacePoolState) -> PsrpMessage {
let body = to_clixml(&PsValue::Object(
PsObject::new().with("RunspaceState", PsValue::I32(state as i32)),
));
PsrpMessage {
destination: Destination::Client,
message_type: MessageType::RunspacePoolState,
rpid: Uuid::nil(),
pid: Uuid::nil(),
data: body,
}
}
#[test]
fn new_rejects_bad_bounds() {
assert!(RunspacePoolStateMachine::new(Uuid::nil(), 0, 0).is_err());
assert!(RunspacePoolStateMachine::new(Uuid::nil(), 5, 3).is_err());
}
#[test]
fn new_accepts_valid_bounds() {
let m = RunspacePoolStateMachine::new(Uuid::nil(), 2, 5).unwrap();
assert_eq!(m.min_runspaces(), 2);
assert_eq!(m.max_runspaces(), 5);
assert_eq!(m.state(), RunspacePoolState::BeforeOpen);
}
#[test]
fn open_produces_two_messages_and_transitions() {
let mut m = RunspacePoolStateMachine::new(Uuid::nil(), 1, 1).unwrap();
let actions = m.open();
assert_eq!(actions.len(), 2);
match &actions[0] {
Action::SendMessage { message_type, body } => {
assert_eq!(*message_type, MessageType::SessionCapability);
assert!(body.contains(PROTOCOL_VERSION));
}
}
match &actions[1] {
Action::SendMessage { message_type, body } => {
assert_eq!(*message_type, MessageType::InitRunspacePool);
assert!(body.contains("MinRunspaces"));
assert!(body.contains("MaxRunspaces"));
}
}
assert_eq!(m.state(), RunspacePoolState::NegotiationSent);
}
#[test]
fn on_message_runspace_opened_sets_state() {
let mut m = RunspacePoolStateMachine::new(Uuid::nil(), 1, 1).unwrap();
m.open();
m.on_message(&state_msg(RunspacePoolState::Opened)).unwrap();
assert!(m.is_opened());
assert_eq!(m.state(), RunspacePoolState::Opened);
}
#[test]
fn on_message_broken_is_error() {
let mut m = RunspacePoolStateMachine::new(Uuid::nil(), 1, 1).unwrap();
m.open();
let err = m
.on_message(&state_msg(RunspacePoolState::Broken))
.unwrap_err();
assert!(matches!(err, PsrpError::Protocol(_)));
assert_eq!(m.state(), RunspacePoolState::Broken);
}
#[test]
fn on_message_closed_is_error() {
let mut m = RunspacePoolStateMachine::new(Uuid::nil(), 1, 1).unwrap();
let err = m
.on_message(&state_msg(RunspacePoolState::Closed))
.unwrap_err();
assert!(matches!(err, PsrpError::Protocol(_)));
}
#[test]
fn on_message_ignores_informational_types() {
let mut m = RunspacePoolStateMachine::new(Uuid::nil(), 1, 1).unwrap();
m.open();
for mt in [
MessageType::SessionCapability,
MessageType::ApplicationPrivateData,
MessageType::RunspacePoolInitData,
MessageType::EncryptedSessionKey,
MessageType::PublicKeyRequest,
MessageType::PipelineOutput,
] {
let msg = PsrpMessage {
destination: Destination::Client,
message_type: mt,
rpid: Uuid::nil(),
pid: Uuid::nil(),
data: "<Nil/>".into(),
};
m.on_message(&msg).unwrap();
}
assert_eq!(m.state(), RunspacePoolState::NegotiationSent);
}
#[test]
fn on_message_intermediate_state_keeps_machine_alive() {
let mut m = RunspacePoolStateMachine::new(Uuid::nil(), 1, 1).unwrap();
m.open();
m.on_message(&state_msg(RunspacePoolState::NegotiationSucceeded))
.unwrap();
assert_eq!(m.state(), RunspacePoolState::NegotiationSucceeded);
assert!(!m.is_opened());
}
#[test]
fn close_produces_action_and_mark_closed() {
let mut m = RunspacePoolStateMachine::new(Uuid::nil(), 1, 1).unwrap();
let actions = m.close();
assert_eq!(actions.len(), 1);
assert_eq!(m.state(), RunspacePoolState::Closing);
m.mark_closed();
assert_eq!(m.state(), RunspacePoolState::Closed);
}
#[test]
fn rpid_is_preserved() {
let id = Uuid::parse_str("11112222-3333-4444-5555-666677778888").unwrap();
let m = RunspacePoolStateMachine::new(id, 1, 1).unwrap();
assert_eq!(m.rpid(), id);
}
#[test]
fn state_from_i32_covers_all_known() {
for (code, expected) in [
(0, RunspacePoolState::BeforeOpen),
(1, RunspacePoolState::Opening),
(2, RunspacePoolState::Opened),
(3, RunspacePoolState::Closed),
(4, RunspacePoolState::Closing),
(5, RunspacePoolState::Broken),
(6, RunspacePoolState::Disconnecting),
(7, RunspacePoolState::Disconnected),
(8, RunspacePoolState::Connecting),
(9, RunspacePoolState::NegotiationSent),
(10, RunspacePoolState::NegotiationSucceeded),
(99, RunspacePoolState::Broken),
] {
assert_eq!(RunspacePoolState::from_i32(code), expected);
}
}
#[test]
fn extract_runspace_state_missing_property() {
assert!(extract_runspace_state("<Obj RefId=\"0\"><MS/></Obj>").is_err());
}
#[test]
fn extract_runspace_state_ok() {
let xml = to_clixml(&PsValue::Object(
PsObject::new().with("RunspaceState", PsValue::I32(2)),
));
assert_eq!(
extract_runspace_state(&xml).unwrap(),
RunspacePoolState::Opened
);
}
#[test]
fn session_capability_xml_has_protocol_version() {
let xml = session_capability_xml();
assert!(xml.contains(PROTOCOL_VERSION));
assert!(xml.contains(SERIALIZATION_VERSION));
}
#[test]
fn init_runspace_pool_xml_has_counts() {
let xml = init_runspace_pool_xml(2, 7);
assert!(xml.contains("<I32 N=\"MinRunspaces\">2</I32>"));
assert!(xml.contains("<I32 N=\"MaxRunspaces\">7</I32>"));
}
#[test]
fn init_runspace_pool_xml_has_full_enum_hierarchy() {
let xml = init_runspace_pool_xml(1, 1);
assert!(xml.contains("System.Management.Automation.Runspaces.PSThreadOptions"));
assert!(xml.contains("System.Enum"));
assert!(xml.contains("System.ValueType"));
assert!(xml.contains("System.Object"));
assert!(xml.contains("<ToString>Default</ToString>"));
assert!(xml.contains("System.Management.Automation.Runspaces.ApartmentState"));
assert!(xml.contains("<ToString>UNKNOWN</ToString>"));
}
#[test]
fn init_runspace_pool_xml_has_host_info_and_application_args() {
let xml = init_runspace_pool_xml(1, 1);
assert!(xml.contains("N=\"HostInfo\""));
assert!(xml.contains("_isHostNull"));
assert!(xml.contains("N=\"ApplicationArguments\""));
}
#[test]
fn session_capability_emits_version_tags() {
let xml = session_capability_xml();
assert!(xml.contains("<Version N=\"PSVersion\">"));
assert!(xml.contains("<Version N=\"protocolversion\">"));
assert!(xml.contains("<Version N=\"SerializationVersion\">"));
assert!(!xml.contains("<S N=\"PSVersion\">"));
}
}