Skip to main content

psrp_rs/runspace/
state.rs

1//! Pure state machine for a runspace pool.
2//!
3//! This module is fully sync and does no I/O: it just maps `(state, event)`
4//! to `(new state, actions)`. The async driver (`pool.rs`) owns the
5//! transport and executes the actions.
6//!
7//! Modelling the lifecycle this way makes every transition trivially
8//! exhaustible from a unit test (no `tokio::test`, no mock transport).
9
10use uuid::Uuid;
11
12use crate::clixml::{PsObject, PsValue, parse_clixml, to_clixml};
13use crate::error::{PsrpError, Result};
14use crate::message::{MessageType, PsrpMessage};
15
16/// Minimum PSRP protocol version we advertise (`2.3` — matches Windows
17/// PowerShell 5.1 and PowerShell 7+).
18pub const PROTOCOL_VERSION: &str = "2.3";
19pub(crate) const PS_VERSION: &str = "2.0";
20pub(crate) const SERIALIZATION_VERSION: &str = "1.1.0.1";
21
22/// Lifecycle states of a runspace pool (MS-PSRP §2.2.3.4).
23#[derive(Debug, Clone, Copy, PartialEq, Eq)]
24pub enum RunspacePoolState {
25    BeforeOpen = 0,
26    Opening = 1,
27    Opened = 2,
28    Closed = 3,
29    Closing = 4,
30    Broken = 5,
31    Disconnecting = 6,
32    Disconnected = 7,
33    Connecting = 8,
34    NegotiationSent = 9,
35    NegotiationSucceeded = 10,
36}
37
38impl RunspacePoolState {
39    pub(crate) fn from_i32(v: i32) -> Self {
40        match v {
41            0 => Self::BeforeOpen,
42            1 => Self::Opening,
43            2 => Self::Opened,
44            3 => Self::Closed,
45            4 => Self::Closing,
46            5 => Self::Broken,
47            6 => Self::Disconnecting,
48            7 => Self::Disconnected,
49            8 => Self::Connecting,
50            9 => Self::NegotiationSent,
51            10 => Self::NegotiationSucceeded,
52            _ => Self::Broken,
53        }
54    }
55}
56
57/// An action that the state machine wants the async driver to perform.
58///
59/// Currently only `SendMessage` is used; this stays open for future
60/// primitives like `SignalStop` or `CloseTransport`.
61#[derive(Debug, Clone)]
62pub enum Action {
63    SendMessage {
64        message_type: MessageType,
65        body: String,
66    },
67}
68
69/// Pure state machine for the runspace pool lifecycle.
70///
71/// The driver calls [`open`](Self::open) once to start the handshake, then
72/// feeds every server-originated message through
73/// [`on_message`](Self::on_message). The machine produces a list of
74/// [`Action`]s to execute and transitions its internal state accordingly.
75#[derive(Debug)]
76pub struct RunspacePoolStateMachine {
77    state: RunspacePoolState,
78    rpid: Uuid,
79    min_runspaces: i32,
80    max_runspaces: i32,
81}
82
83impl RunspacePoolStateMachine {
84    /// Build a new machine. `rpid` should typically be a freshly-generated v4 UUID.
85    pub fn new(rpid: Uuid, min_runspaces: i32, max_runspaces: i32) -> Result<Self> {
86        if min_runspaces < 1 || max_runspaces < min_runspaces {
87            return Err(PsrpError::protocol(format!(
88                "invalid runspace bounds: min={min_runspaces} max={max_runspaces}"
89            )));
90        }
91        Ok(Self {
92            state: RunspacePoolState::BeforeOpen,
93            rpid,
94            min_runspaces,
95            max_runspaces,
96        })
97    }
98
99    /// Current lifecycle state.
100    #[must_use]
101    pub fn state(&self) -> RunspacePoolState {
102        self.state
103    }
104
105    /// Runspace pool identifier.
106    #[must_use]
107    pub fn rpid(&self) -> Uuid {
108        self.rpid
109    }
110
111    /// Configured minimum runspaces.
112    #[must_use]
113    pub fn min_runspaces(&self) -> i32 {
114        self.min_runspaces
115    }
116
117    /// Configured maximum runspaces.
118    #[must_use]
119    pub fn max_runspaces(&self) -> i32 {
120        self.max_runspaces
121    }
122
123    /// Produce the actions required to start the opening handshake.
124    ///
125    /// Transitions the state from `BeforeOpen` to `NegotiationSent`.
126    pub fn open(&mut self) -> Vec<Action> {
127        self.state = RunspacePoolState::Opening;
128        let actions = vec![
129            Action::SendMessage {
130                message_type: MessageType::SessionCapability,
131                body: session_capability_xml(),
132            },
133            Action::SendMessage {
134                message_type: MessageType::InitRunspacePool,
135                body: init_runspace_pool_xml(self.min_runspaces, self.max_runspaces),
136            },
137        ];
138        self.state = RunspacePoolState::NegotiationSent;
139        actions
140    }
141
142    /// Produce the actions required to **reconnect** to a previously
143    /// disconnected runspace pool.
144    ///
145    /// PSRP §3.1.4.1 — the client sends a `ConnectRunspacePool` (`0x0002_100B`)
146    /// message and the server responds with the current pool state. We
147    /// re-emit the `SessionCapability` first to renegotiate protocol
148    /// versions.
149    pub fn connect(&mut self) -> Vec<Action> {
150        self.state = RunspacePoolState::Connecting;
151        let actions = vec![
152            Action::SendMessage {
153                message_type: MessageType::SessionCapability,
154                body: session_capability_xml(),
155            },
156            Action::SendMessage {
157                message_type: MessageType::ConnectRunspacePool,
158                body: "<Obj RefId=\"0\"><MS/></Obj>".into(),
159            },
160        ];
161        self.state = RunspacePoolState::NegotiationSent;
162        actions
163    }
164
165    /// Feed a server-originated message into the machine.
166    ///
167    /// Returns `Ok(())` on a valid transition. Unknown message types are
168    /// silently ignored. A `RunspacePoolState=Broken/Closed` received
169    /// during the opening handshake is reported as a protocol error.
170    pub fn on_message(&mut self, msg: &PsrpMessage) -> Result<()> {
171        match msg.message_type {
172            MessageType::RunspacePoolState => {
173                let new_state = extract_runspace_state(&msg.data)?;
174                self.state = new_state;
175                match new_state {
176                    RunspacePoolState::Broken | RunspacePoolState::Closed => {
177                        return Err(PsrpError::protocol(format!(
178                            "runspace pool entered terminal state {new_state:?}"
179                        )));
180                    }
181                    _ => {}
182                }
183            }
184            // These are informational during handshake — ignored silently.
185            MessageType::SessionCapability
186            | MessageType::ApplicationPrivateData
187            | MessageType::RunspacePoolInitData
188            | MessageType::EncryptedSessionKey
189            | MessageType::PublicKeyRequest => {}
190            _ => {}
191        }
192        Ok(())
193    }
194
195    /// True once the machine has reached [`RunspacePoolState::Opened`].
196    #[must_use]
197    pub fn is_opened(&self) -> bool {
198        self.state == RunspacePoolState::Opened
199    }
200
201    /// Produce the actions required to close the pool.
202    pub fn close(&mut self) -> Vec<Action> {
203        self.state = RunspacePoolState::Closing;
204        vec![Action::SendMessage {
205            message_type: MessageType::CloseRunspacePool,
206            body: "<Obj RefId=\"0\"><MS/></Obj>".into(),
207        }]
208    }
209
210    /// Mark the machine as fully closed (after the transport has torn down).
211    pub fn mark_closed(&mut self) {
212        self.state = RunspacePoolState::Closed;
213    }
214}
215
216pub(crate) fn session_capability_xml() -> String {
217    // PSRP §2.2.2.1: PSVersion / protocolversion / SerializationVersion
218    // are typed `Version` on the wire, NOT plain strings. Some servers
219    // accept the looser <S> form but the strict ones reject it.
220    let obj = PsValue::Object(
221        PsObject::new()
222            .with("PSVersion", PsValue::Version(PS_VERSION.into()))
223            .with("protocolversion", PsValue::Version(PROTOCOL_VERSION.into()))
224            .with(
225                "SerializationVersion",
226                PsValue::Version(SERIALIZATION_VERSION.into()),
227            ),
228    );
229    to_clixml(&obj)
230}
231
232pub(crate) fn init_runspace_pool_xml(min: i32, max: i32) -> String {
233    // PSRP §2.2.2.2 — InitRunspacePool requires the following member set:
234    //   MinRunspaces (I32), MaxRunspaces (I32),
235    //   PSThreadOptions (enum), ApartmentState (enum),
236    //   HostInfo (Obj), ApplicationArguments (DCT or Nil)
237    //
238    // Strict server-side deserialisers reject messages that omit any
239    // of these. We emit each enum with its full .NET type hierarchy via
240    // `crate::clixml::encode::ps_enum`.
241    use crate::clixml::encode::{ps_enum, ps_host_info_null};
242
243    let obj = PsValue::Object(
244        PsObject::new()
245            .with("MinRunspaces", PsValue::I32(min))
246            .with("MaxRunspaces", PsValue::I32(max))
247            .with(
248                "PSThreadOptions",
249                ps_enum(
250                    "System.Management.Automation.Runspaces.PSThreadOptions",
251                    "Default",
252                    0,
253                ),
254            )
255            .with(
256                "ApartmentState",
257                ps_enum(
258                    "System.Management.Automation.Runspaces.ApartmentState",
259                    "UNKNOWN",
260                    2,
261                ),
262            )
263            .with("HostInfo", ps_host_info_null())
264            .with("ApplicationArguments", PsValue::Null),
265    );
266    to_clixml(&obj)
267}
268
269pub(crate) fn extract_runspace_state(xml: &str) -> Result<RunspacePoolState> {
270    let parsed = parse_clixml(xml)?;
271    for value in parsed {
272        if let PsValue::Object(obj) = value
273            && let Some(PsValue::I32(code)) = obj.get("RunspaceState")
274        {
275            return Ok(RunspacePoolState::from_i32(*code));
276        }
277    }
278    Err(PsrpError::protocol(
279        "RunspacePoolState message missing RunspaceState property",
280    ))
281}
282
283#[cfg(test)]
284mod tests {
285    use super::*;
286    use crate::message::Destination;
287
288    fn state_msg(state: RunspacePoolState) -> PsrpMessage {
289        let body = to_clixml(&PsValue::Object(
290            PsObject::new().with("RunspaceState", PsValue::I32(state as i32)),
291        ));
292        PsrpMessage {
293            destination: Destination::Client,
294            message_type: MessageType::RunspacePoolState,
295            rpid: Uuid::nil(),
296            pid: Uuid::nil(),
297            data: body,
298        }
299    }
300
301    #[test]
302    fn new_rejects_bad_bounds() {
303        assert!(RunspacePoolStateMachine::new(Uuid::nil(), 0, 0).is_err());
304        assert!(RunspacePoolStateMachine::new(Uuid::nil(), 5, 3).is_err());
305    }
306
307    #[test]
308    fn new_accepts_valid_bounds() {
309        let m = RunspacePoolStateMachine::new(Uuid::nil(), 2, 5).unwrap();
310        assert_eq!(m.min_runspaces(), 2);
311        assert_eq!(m.max_runspaces(), 5);
312        assert_eq!(m.state(), RunspacePoolState::BeforeOpen);
313    }
314
315    #[test]
316    fn open_produces_two_messages_and_transitions() {
317        let mut m = RunspacePoolStateMachine::new(Uuid::nil(), 1, 1).unwrap();
318        let actions = m.open();
319        assert_eq!(actions.len(), 2);
320        match &actions[0] {
321            Action::SendMessage { message_type, body } => {
322                assert_eq!(*message_type, MessageType::SessionCapability);
323                assert!(body.contains(PROTOCOL_VERSION));
324            }
325        }
326        match &actions[1] {
327            Action::SendMessage { message_type, body } => {
328                assert_eq!(*message_type, MessageType::InitRunspacePool);
329                assert!(body.contains("MinRunspaces"));
330                assert!(body.contains("MaxRunspaces"));
331            }
332        }
333        assert_eq!(m.state(), RunspacePoolState::NegotiationSent);
334    }
335
336    #[test]
337    fn on_message_runspace_opened_sets_state() {
338        let mut m = RunspacePoolStateMachine::new(Uuid::nil(), 1, 1).unwrap();
339        m.open();
340        m.on_message(&state_msg(RunspacePoolState::Opened)).unwrap();
341        assert!(m.is_opened());
342        assert_eq!(m.state(), RunspacePoolState::Opened);
343    }
344
345    #[test]
346    fn on_message_broken_is_error() {
347        let mut m = RunspacePoolStateMachine::new(Uuid::nil(), 1, 1).unwrap();
348        m.open();
349        let err = m
350            .on_message(&state_msg(RunspacePoolState::Broken))
351            .unwrap_err();
352        assert!(matches!(err, PsrpError::Protocol(_)));
353        assert_eq!(m.state(), RunspacePoolState::Broken);
354    }
355
356    #[test]
357    fn on_message_closed_is_error() {
358        let mut m = RunspacePoolStateMachine::new(Uuid::nil(), 1, 1).unwrap();
359        let err = m
360            .on_message(&state_msg(RunspacePoolState::Closed))
361            .unwrap_err();
362        assert!(matches!(err, PsrpError::Protocol(_)));
363    }
364
365    #[test]
366    fn on_message_ignores_informational_types() {
367        let mut m = RunspacePoolStateMachine::new(Uuid::nil(), 1, 1).unwrap();
368        m.open();
369        for mt in [
370            MessageType::SessionCapability,
371            MessageType::ApplicationPrivateData,
372            MessageType::RunspacePoolInitData,
373            MessageType::EncryptedSessionKey,
374            MessageType::PublicKeyRequest,
375            MessageType::PipelineOutput,
376        ] {
377            let msg = PsrpMessage {
378                destination: Destination::Client,
379                message_type: mt,
380                rpid: Uuid::nil(),
381                pid: Uuid::nil(),
382                data: "<Nil/>".into(),
383            };
384            m.on_message(&msg).unwrap();
385        }
386        // Still in NegotiationSent, none of the above transition.
387        assert_eq!(m.state(), RunspacePoolState::NegotiationSent);
388    }
389
390    #[test]
391    fn on_message_intermediate_state_keeps_machine_alive() {
392        let mut m = RunspacePoolStateMachine::new(Uuid::nil(), 1, 1).unwrap();
393        m.open();
394        m.on_message(&state_msg(RunspacePoolState::NegotiationSucceeded))
395            .unwrap();
396        assert_eq!(m.state(), RunspacePoolState::NegotiationSucceeded);
397        assert!(!m.is_opened());
398    }
399
400    #[test]
401    fn close_produces_action_and_mark_closed() {
402        let mut m = RunspacePoolStateMachine::new(Uuid::nil(), 1, 1).unwrap();
403        let actions = m.close();
404        assert_eq!(actions.len(), 1);
405        assert_eq!(m.state(), RunspacePoolState::Closing);
406        m.mark_closed();
407        assert_eq!(m.state(), RunspacePoolState::Closed);
408    }
409
410    #[test]
411    fn rpid_is_preserved() {
412        let id = Uuid::parse_str("11112222-3333-4444-5555-666677778888").unwrap();
413        let m = RunspacePoolStateMachine::new(id, 1, 1).unwrap();
414        assert_eq!(m.rpid(), id);
415    }
416
417    #[test]
418    fn state_from_i32_covers_all_known() {
419        for (code, expected) in [
420            (0, RunspacePoolState::BeforeOpen),
421            (1, RunspacePoolState::Opening),
422            (2, RunspacePoolState::Opened),
423            (3, RunspacePoolState::Closed),
424            (4, RunspacePoolState::Closing),
425            (5, RunspacePoolState::Broken),
426            (6, RunspacePoolState::Disconnecting),
427            (7, RunspacePoolState::Disconnected),
428            (8, RunspacePoolState::Connecting),
429            (9, RunspacePoolState::NegotiationSent),
430            (10, RunspacePoolState::NegotiationSucceeded),
431            (99, RunspacePoolState::Broken),
432        ] {
433            assert_eq!(RunspacePoolState::from_i32(code), expected);
434        }
435    }
436
437    #[test]
438    fn extract_runspace_state_missing_property() {
439        assert!(extract_runspace_state("<Obj RefId=\"0\"><MS/></Obj>").is_err());
440    }
441
442    #[test]
443    fn extract_runspace_state_ok() {
444        let xml = to_clixml(&PsValue::Object(
445            PsObject::new().with("RunspaceState", PsValue::I32(2)),
446        ));
447        assert_eq!(
448            extract_runspace_state(&xml).unwrap(),
449            RunspacePoolState::Opened
450        );
451    }
452
453    #[test]
454    fn session_capability_xml_has_protocol_version() {
455        let xml = session_capability_xml();
456        assert!(xml.contains(PROTOCOL_VERSION));
457        assert!(xml.contains(SERIALIZATION_VERSION));
458    }
459
460    #[test]
461    fn init_runspace_pool_xml_has_counts() {
462        let xml = init_runspace_pool_xml(2, 7);
463        assert!(xml.contains("<I32 N=\"MinRunspaces\">2</I32>"));
464        assert!(xml.contains("<I32 N=\"MaxRunspaces\">7</I32>"));
465    }
466
467    #[test]
468    fn init_runspace_pool_xml_has_full_enum_hierarchy() {
469        let xml = init_runspace_pool_xml(1, 1);
470        // PSThreadOptions enum with full type chain
471        assert!(xml.contains("System.Management.Automation.Runspaces.PSThreadOptions"));
472        assert!(xml.contains("System.Enum"));
473        assert!(xml.contains("System.ValueType"));
474        assert!(xml.contains("System.Object"));
475        assert!(xml.contains("<ToString>Default</ToString>"));
476        // ApartmentState enum value (Unknown == 2)
477        assert!(xml.contains("System.Management.Automation.Runspaces.ApartmentState"));
478        assert!(xml.contains("<ToString>UNKNOWN</ToString>"));
479    }
480
481    #[test]
482    fn init_runspace_pool_xml_has_host_info_and_application_args() {
483        let xml = init_runspace_pool_xml(1, 1);
484        assert!(xml.contains("N=\"HostInfo\""));
485        assert!(xml.contains("_isHostNull"));
486        assert!(xml.contains("N=\"ApplicationArguments\""));
487    }
488
489    #[test]
490    fn session_capability_emits_version_tags() {
491        let xml = session_capability_xml();
492        // Plain `<Version>` rather than `<S>`.
493        assert!(xml.contains("<Version N=\"PSVersion\">"));
494        assert!(xml.contains("<Version N=\"protocolversion\">"));
495        assert!(xml.contains("<Version N=\"SerializationVersion\">"));
496        assert!(!xml.contains("<S N=\"PSVersion\">"));
497    }
498}