1use std::sync::Arc;
25
26use tokio::sync::Mutex;
27
28use crate::clixml::PsValue;
29use crate::error::Result;
30use crate::pipeline::{Pipeline, PipelineResult};
31use crate::runspace::RunspacePool;
32use crate::transport::PsrpTransport;
33
34pub struct SharedRunspacePool<T: PsrpTransport> {
40 inner: Arc<Mutex<RunspacePool<T>>>,
41}
42
43impl<T: PsrpTransport> Clone for SharedRunspacePool<T> {
44 fn clone(&self) -> Self {
45 Self {
46 inner: Arc::clone(&self.inner),
47 }
48 }
49}
50
51impl<T: PsrpTransport> std::fmt::Debug for SharedRunspacePool<T> {
52 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
53 f.debug_struct("SharedRunspacePool")
54 .field("strong_count", &Arc::strong_count(&self.inner))
55 .finish()
56 }
57}
58
59impl<T: PsrpTransport> SharedRunspacePool<T> {
60 #[must_use]
62 pub fn new(pool: RunspacePool<T>) -> Self {
63 Self {
64 inner: Arc::new(Mutex::new(pool)),
65 }
66 }
67
68 #[must_use]
70 pub fn handle_count(&self) -> usize {
71 Arc::strong_count(&self.inner)
72 }
73
74 pub async fn run_script(&self, script: &str) -> Result<Vec<PsValue>> {
76 let mut guard = self.inner.lock().await;
77 guard.run_script(script).await
78 }
79
80 pub async fn run_pipeline(&self, pipeline: Pipeline) -> Result<PipelineResult> {
82 let mut guard = self.inner.lock().await;
83 pipeline.run_all_streams(&mut guard).await
84 }
85
86 pub async fn run_script_with_cancel(
88 &self,
89 script: &str,
90 cancel: tokio_util::sync::CancellationToken,
91 ) -> Result<Vec<PsValue>> {
92 let mut guard = self.inner.lock().await;
93 guard.run_script_with_cancel(script, cancel).await
94 }
95
96 pub async fn request_session_key(&self) -> Result<()> {
98 let mut guard = self.inner.lock().await;
99 guard.request_session_key().await
100 }
101
102 pub async fn close(self) -> Result<()> {
106 match Arc::try_unwrap(self.inner) {
107 Ok(mutex) => mutex.into_inner().close().await,
108 Err(arc) => Err(crate::error::PsrpError::protocol(format!(
109 "cannot close SharedRunspacePool: {} handles still outstanding",
110 Arc::strong_count(&arc)
111 ))),
112 }
113 }
114
115 pub async fn with_pool<F, R>(&self, f: F) -> R
119 where
120 F: for<'a> FnOnce(
121 &'a mut RunspacePool<T>,
122 )
123 -> std::pin::Pin<Box<dyn std::future::Future<Output = R> + Send + 'a>>,
124 {
125 let mut guard = self.inner.lock().await;
126 f(&mut guard).await
127 }
128}
129
130#[cfg(test)]
131mod tests {
132 use super::*;
133 use crate::clixml::{PsObject, to_clixml};
134 use crate::fragment::encode_message;
135 use crate::message::{Destination, MessageType, PsrpMessage};
136 use crate::pipeline::PipelineState;
137 use crate::runspace::RunspacePoolState;
138 use crate::transport::mock::MockTransport;
139 use uuid::Uuid;
140
141 fn state_message(state: RunspacePoolState) -> Vec<u8> {
142 let body = to_clixml(&PsValue::Object(
143 PsObject::new().with("RunspaceState", PsValue::I32(state as i32)),
144 ));
145 PsrpMessage {
146 destination: Destination::Client,
147 message_type: MessageType::RunspacePoolState,
148 rpid: Uuid::nil(),
149 pid: Uuid::nil(),
150 data: body,
151 }
152 .encode()
153 }
154
155 fn pipeline_state_message(state: PipelineState) -> Vec<u8> {
156 let body = to_clixml(&PsValue::Object(
157 PsObject::new().with("PipelineState", PsValue::I32(state as i32)),
158 ));
159 PsrpMessage {
160 destination: Destination::Client,
161 message_type: MessageType::PipelineState,
162 rpid: Uuid::nil(),
163 pid: Uuid::nil(),
164 data: body,
165 }
166 .encode()
167 }
168
169 async fn opened_shared() -> (MockTransport, SharedRunspacePool<MockTransport>) {
170 let t = MockTransport::new();
171 t.push_incoming(encode_message(1, &state_message(RunspacePoolState::Opened)));
172 let pool = RunspacePool::open_with_transport(t.clone()).await.unwrap();
173 (t, SharedRunspacePool::new(pool))
174 }
175
176 #[tokio::test]
177 async fn shared_run_script_serialises_access() {
178 let (t, shared) = opened_shared().await;
179 t.push_incoming(encode_message(
180 10,
181 &PsrpMessage {
182 destination: Destination::Client,
183 message_type: MessageType::PipelineOutput,
184 rpid: Uuid::nil(),
185 pid: Uuid::nil(),
186 data: "<I32>42</I32>".into(),
187 }
188 .encode(),
189 ));
190 t.push_incoming(encode_message(
191 11,
192 &pipeline_state_message(PipelineState::Completed),
193 ));
194 let out = shared.run_script("whatever").await.unwrap();
195 assert_eq!(out, vec![PsValue::I32(42)]);
196 assert_eq!(shared.handle_count(), 1);
198 shared.close().await.unwrap();
199 }
200
201 #[tokio::test]
202 async fn shared_close_errors_with_outstanding_clones() {
203 let (_t, shared) = opened_shared().await;
204 let clone = shared.clone();
205 assert_eq!(shared.handle_count(), 2);
206 let err = shared.close().await.unwrap_err();
207 assert!(matches!(err, crate::error::PsrpError::Protocol(_)));
208 clone.close().await.unwrap();
210 }
211
212 #[tokio::test]
213 async fn shared_with_pool_direct_access() {
214 let (_t, shared) = opened_shared().await;
215 let state = shared
216 .with_pool(|p| Box::pin(async move { p.state() }))
217 .await;
218 assert_eq!(state, RunspacePoolState::Opened);
219 shared.close().await.unwrap();
220 }
221
222 #[tokio::test]
223 async fn shared_debug_format_includes_strong_count() {
224 let (_t, shared) = opened_shared().await;
225 let s = format!("{shared:?}");
226 assert!(s.contains("SharedRunspacePool"));
227 assert!(s.contains("strong_count"));
228 shared.close().await.unwrap();
229 }
230
231 #[tokio::test]
234 async fn shared_run_pipeline_with_builder() {
235 let (t, shared) = opened_shared().await;
236 t.push_incoming(encode_message(
237 10,
238 &PsrpMessage {
239 destination: Destination::Client,
240 message_type: MessageType::PipelineOutput,
241 rpid: Uuid::nil(),
242 pid: Uuid::nil(),
243 data: "<S>ok</S>".into(),
244 }
245 .encode(),
246 ));
247 t.push_incoming(encode_message(
248 11,
249 &pipeline_state_message(PipelineState::Completed),
250 ));
251 let result = shared
252 .run_pipeline(crate::pipeline::Pipeline::new("dummy"))
253 .await
254 .unwrap();
255 assert_eq!(result.output, vec![PsValue::String("ok".into())]);
256 shared.close().await.unwrap();
257 }
258
259 #[tokio::test]
260 async fn shared_run_script_with_cancel_token() {
261 let (t, shared) = opened_shared().await;
262 t.push_incoming(encode_message(
263 10,
264 &PsrpMessage {
265 destination: Destination::Client,
266 message_type: MessageType::PipelineOutput,
267 rpid: Uuid::nil(),
268 pid: Uuid::nil(),
269 data: "<I32>7</I32>".into(),
270 }
271 .encode(),
272 ));
273 t.push_incoming(encode_message(
274 11,
275 &pipeline_state_message(PipelineState::Completed),
276 ));
277 let token = tokio_util::sync::CancellationToken::new();
278 let out = shared.run_script_with_cancel("x", token).await.unwrap();
279 assert_eq!(out, vec![PsValue::I32(7)]);
280 shared.close().await.unwrap();
281 }
282
283 #[tokio::test]
284 async fn shared_request_session_key_delegates_and_fails() {
285 let (t, shared) = opened_shared().await;
291 t.push_incoming(encode_message(
294 9,
295 &PsrpMessage {
296 destination: Destination::Client,
297 message_type: MessageType::EncryptedSessionKey,
298 rpid: Uuid::nil(),
299 pid: Uuid::nil(),
300 data: to_clixml(&PsValue::Object(
301 PsObject::new().with("EncryptedSessionKey", PsValue::String("deadbeef".into())),
302 )),
303 }
304 .encode(),
305 ));
306 let err = shared.request_session_key().await.unwrap_err();
307 assert!(matches!(err, crate::error::PsrpError::Protocol(_)));
308 shared.close().await.unwrap();
309 }
310
311 #[tokio::test]
312 async fn shared_handle_count_scales() {
313 let (_t, shared) = opened_shared().await;
314 assert_eq!(shared.handle_count(), 1);
315 let h2 = shared.clone();
316 let h3 = shared.clone();
317 assert_eq!(shared.handle_count(), 3);
318 drop(h3);
319 drop(h2);
320 assert_eq!(shared.handle_count(), 1);
321 shared.close().await.unwrap();
322 }
323}