Skip to main content

psrp_rs/
shared.rs

1//! Thread-safe sharing of a [`RunspacePool`] across multiple callers.
2//!
3//! `RunspacePool` is `&mut self`-heavy because the underlying PSRP
4//! protocol serialises messages on a single WS-Management shell. That
5//! means you can't naively hand the same pool to two tasks.
6//!
7//! [`SharedRunspacePool`] wraps a pool in an `Arc<tokio::sync::Mutex<_>>`
8//! so multiple clones can coordinate access while still respecting the
9//! underlying serialization. The public API mirrors the most common
10//! pool methods and acquires the mutex internally for each call.
11//!
12//! True wire-level concurrency (multiple pipelines running in parallel
13//! with messages interleaved) is **not** supported here — that would
14//! require a background dispatcher task owning the transport, which in
15//! turn needs `winrm-rs` to expose an `Arc`-owned `WinrmClient`. The
16//! shared-pool pattern is still useful:
17//!
18//! * Multiple tasks can submit scripts to the same long-lived pool
19//!   without the caller juggling `&mut`.
20//! * The convenience methods take a `Clone`-able handle, friendly to
21//!   `tokio::spawn_local` and actor-style code.
22//! * Integrates cleanly with [`tokio_util::sync::CancellationToken`].
23
24use std::sync::Arc;
25
26use tokio::sync::Mutex;
27
28use crate::clixml::PsValue;
29use crate::error::Result;
30use crate::pipeline::{Pipeline, PipelineResult};
31use crate::runspace::RunspacePool;
32use crate::transport::PsrpTransport;
33
34/// Clone-able handle to a [`RunspacePool`].
35///
36/// Every method acquires an internal mutex before calling into the
37/// underlying pool, so callers can safely share a `SharedRunspacePool`
38/// across tasks.
39pub struct SharedRunspacePool<T: PsrpTransport> {
40    inner: Arc<Mutex<RunspacePool<T>>>,
41}
42
43impl<T: PsrpTransport> Clone for SharedRunspacePool<T> {
44    fn clone(&self) -> Self {
45        Self {
46            inner: Arc::clone(&self.inner),
47        }
48    }
49}
50
51impl<T: PsrpTransport> std::fmt::Debug for SharedRunspacePool<T> {
52    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
53        f.debug_struct("SharedRunspacePool")
54            .field("strong_count", &Arc::strong_count(&self.inner))
55            .finish()
56    }
57}
58
59impl<T: PsrpTransport> SharedRunspacePool<T> {
60    /// Wrap an already-opened [`RunspacePool`].
61    #[must_use]
62    pub fn new(pool: RunspacePool<T>) -> Self {
63        Self {
64            inner: Arc::new(Mutex::new(pool)),
65        }
66    }
67
68    /// Number of outstanding handles to the same pool.
69    #[must_use]
70    pub fn handle_count(&self) -> usize {
71        Arc::strong_count(&self.inner)
72    }
73
74    /// Run a script. Acquires the mutex for the whole call.
75    pub async fn run_script(&self, script: &str) -> Result<Vec<PsValue>> {
76        let mut guard = self.inner.lock().await;
77        guard.run_script(script).await
78    }
79
80    /// Run a pre-built [`Pipeline`] and collect every stream.
81    pub async fn run_pipeline(&self, pipeline: Pipeline) -> Result<PipelineResult> {
82        let mut guard = self.inner.lock().await;
83        pipeline.run_all_streams(&mut guard).await
84    }
85
86    /// Run a script with a cancellation token.
87    pub async fn run_script_with_cancel(
88        &self,
89        script: &str,
90        cancel: tokio_util::sync::CancellationToken,
91    ) -> Result<Vec<PsValue>> {
92        let mut guard = self.inner.lock().await;
93        guard.run_script_with_cancel(script, cancel).await
94    }
95
96    /// Request a session key for `SecureString` transport.
97    pub async fn request_session_key(&self) -> Result<()> {
98        let mut guard = self.inner.lock().await;
99        guard.request_session_key().await
100    }
101
102    /// Close the pool. Only the last outstanding handle can actually
103    /// close — if other clones are still alive, returns an error
104    /// wrapping the still-contended pool.
105    pub async fn close(self) -> Result<()> {
106        match Arc::try_unwrap(self.inner) {
107            Ok(mutex) => mutex.into_inner().close().await,
108            Err(arc) => Err(crate::error::PsrpError::protocol(format!(
109                "cannot close SharedRunspacePool: {} handles still outstanding",
110                Arc::strong_count(&arc)
111            ))),
112        }
113    }
114
115    /// Acquire the pool mutex and run a closure with direct access to
116    /// the underlying [`RunspacePool`]. Useful when you need an API
117    /// that isn't surfaced on the shared wrapper.
118    pub async fn with_pool<F, R>(&self, f: F) -> R
119    where
120        F: for<'a> FnOnce(
121            &'a mut RunspacePool<T>,
122        )
123            -> std::pin::Pin<Box<dyn std::future::Future<Output = R> + Send + 'a>>,
124    {
125        let mut guard = self.inner.lock().await;
126        f(&mut guard).await
127    }
128}
129
130#[cfg(test)]
131mod tests {
132    use super::*;
133    use crate::clixml::{PsObject, to_clixml};
134    use crate::fragment::encode_message;
135    use crate::message::{Destination, MessageType, PsrpMessage};
136    use crate::pipeline::PipelineState;
137    use crate::runspace::RunspacePoolState;
138    use crate::transport::mock::MockTransport;
139    use uuid::Uuid;
140
141    fn state_message(state: RunspacePoolState) -> Vec<u8> {
142        let body = to_clixml(&PsValue::Object(
143            PsObject::new().with("RunspaceState", PsValue::I32(state as i32)),
144        ));
145        PsrpMessage {
146            destination: Destination::Client,
147            message_type: MessageType::RunspacePoolState,
148            rpid: Uuid::nil(),
149            pid: Uuid::nil(),
150            data: body,
151        }
152        .encode()
153    }
154
155    fn pipeline_state_message(state: PipelineState) -> Vec<u8> {
156        let body = to_clixml(&PsValue::Object(
157            PsObject::new().with("PipelineState", PsValue::I32(state as i32)),
158        ));
159        PsrpMessage {
160            destination: Destination::Client,
161            message_type: MessageType::PipelineState,
162            rpid: Uuid::nil(),
163            pid: Uuid::nil(),
164            data: body,
165        }
166        .encode()
167    }
168
169    async fn opened_shared() -> (MockTransport, SharedRunspacePool<MockTransport>) {
170        let t = MockTransport::new();
171        t.push_incoming(encode_message(1, &state_message(RunspacePoolState::Opened)));
172        let pool = RunspacePool::open_with_transport(t.clone()).await.unwrap();
173        (t, SharedRunspacePool::new(pool))
174    }
175
176    #[tokio::test]
177    async fn shared_run_script_serialises_access() {
178        let (t, shared) = opened_shared().await;
179        t.push_incoming(encode_message(
180            10,
181            &PsrpMessage {
182                destination: Destination::Client,
183                message_type: MessageType::PipelineOutput,
184                rpid: Uuid::nil(),
185                pid: Uuid::nil(),
186                data: "<I32>42</I32>".into(),
187            }
188            .encode(),
189        ));
190        t.push_incoming(encode_message(
191            11,
192            &pipeline_state_message(PipelineState::Completed),
193        ));
194        let out = shared.run_script("whatever").await.unwrap();
195        assert_eq!(out, vec![PsValue::I32(42)]);
196        // Count = 1, we can close.
197        assert_eq!(shared.handle_count(), 1);
198        shared.close().await.unwrap();
199    }
200
201    #[tokio::test]
202    async fn shared_close_errors_with_outstanding_clones() {
203        let (_t, shared) = opened_shared().await;
204        let clone = shared.clone();
205        assert_eq!(shared.handle_count(), 2);
206        let err = shared.close().await.unwrap_err();
207        assert!(matches!(err, crate::error::PsrpError::Protocol(_)));
208        // Clean up through the surviving clone.
209        clone.close().await.unwrap();
210    }
211
212    #[tokio::test]
213    async fn shared_with_pool_direct_access() {
214        let (_t, shared) = opened_shared().await;
215        let state = shared
216            .with_pool(|p| Box::pin(async move { p.state() }))
217            .await;
218        assert_eq!(state, RunspacePoolState::Opened);
219        shared.close().await.unwrap();
220    }
221
222    #[tokio::test]
223    async fn shared_debug_format_includes_strong_count() {
224        let (_t, shared) = opened_shared().await;
225        let s = format!("{shared:?}");
226        assert!(s.contains("SharedRunspacePool"));
227        assert!(s.contains("strong_count"));
228        shared.close().await.unwrap();
229    }
230
231    // ---------- Phase D: additional shared pool coverage ----------
232
233    #[tokio::test]
234    async fn shared_run_pipeline_with_builder() {
235        let (t, shared) = opened_shared().await;
236        t.push_incoming(encode_message(
237            10,
238            &PsrpMessage {
239                destination: Destination::Client,
240                message_type: MessageType::PipelineOutput,
241                rpid: Uuid::nil(),
242                pid: Uuid::nil(),
243                data: "<S>ok</S>".into(),
244            }
245            .encode(),
246        ));
247        t.push_incoming(encode_message(
248            11,
249            &pipeline_state_message(PipelineState::Completed),
250        ));
251        let result = shared
252            .run_pipeline(crate::pipeline::Pipeline::new("dummy"))
253            .await
254            .unwrap();
255        assert_eq!(result.output, vec![PsValue::String("ok".into())]);
256        shared.close().await.unwrap();
257    }
258
259    #[tokio::test]
260    async fn shared_run_script_with_cancel_token() {
261        let (t, shared) = opened_shared().await;
262        t.push_incoming(encode_message(
263            10,
264            &PsrpMessage {
265                destination: Destination::Client,
266                message_type: MessageType::PipelineOutput,
267                rpid: Uuid::nil(),
268                pid: Uuid::nil(),
269                data: "<I32>7</I32>".into(),
270            }
271            .encode(),
272        ));
273        t.push_incoming(encode_message(
274            11,
275            &pipeline_state_message(PipelineState::Completed),
276        ));
277        let token = tokio_util::sync::CancellationToken::new();
278        let out = shared.run_script_with_cancel("x", token).await.unwrap();
279        assert_eq!(out, vec![PsValue::I32(7)]);
280        shared.close().await.unwrap();
281    }
282
283    #[tokio::test]
284    async fn shared_request_session_key_delegates_and_fails() {
285        // Without a server to answer, request_session_key hangs trying
286        // to read the next message. We just exercise the delegation
287        // path by pushing an EncryptedSessionKey response that will
288        // fail to decrypt (random bytes) — the error path still
289        // covers the code.
290        let (t, shared) = opened_shared().await;
291        // Seed a fake EncryptedSessionKey with garbage hex — decryption
292        // will fail but the delegation + parse paths run.
293        t.push_incoming(encode_message(
294            9,
295            &PsrpMessage {
296                destination: Destination::Client,
297                message_type: MessageType::EncryptedSessionKey,
298                rpid: Uuid::nil(),
299                pid: Uuid::nil(),
300                data: to_clixml(&PsValue::Object(
301                    PsObject::new().with("EncryptedSessionKey", PsValue::String("deadbeef".into())),
302                )),
303            }
304            .encode(),
305        ));
306        let err = shared.request_session_key().await.unwrap_err();
307        assert!(matches!(err, crate::error::PsrpError::Protocol(_)));
308        shared.close().await.unwrap();
309    }
310
311    #[tokio::test]
312    async fn shared_handle_count_scales() {
313        let (_t, shared) = opened_shared().await;
314        assert_eq!(shared.handle_count(), 1);
315        let h2 = shared.clone();
316        let h3 = shared.clone();
317        assert_eq!(shared.handle_count(), 3);
318        drop(h3);
319        drop(h2);
320        assert_eq!(shared.handle_count(), 1);
321        shared.close().await.unwrap();
322    }
323}