use std::sync::Arc;
use tokio::sync::Mutex;
use crate::clixml::PsValue;
use crate::error::Result;
use crate::pipeline::{Pipeline, PipelineResult};
use crate::runspace::RunspacePool;
use crate::transport::PsrpTransport;
pub struct SharedRunspacePool<T: PsrpTransport> {
inner: Arc<Mutex<RunspacePool<T>>>,
}
impl<T: PsrpTransport> Clone for SharedRunspacePool<T> {
fn clone(&self) -> Self {
Self {
inner: Arc::clone(&self.inner),
}
}
}
impl<T: PsrpTransport> std::fmt::Debug for SharedRunspacePool<T> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("SharedRunspacePool")
.field("strong_count", &Arc::strong_count(&self.inner))
.finish()
}
}
impl<T: PsrpTransport> SharedRunspacePool<T> {
#[must_use]
pub fn new(pool: RunspacePool<T>) -> Self {
Self {
inner: Arc::new(Mutex::new(pool)),
}
}
#[must_use]
pub fn handle_count(&self) -> usize {
Arc::strong_count(&self.inner)
}
pub async fn run_script(&self, script: &str) -> Result<Vec<PsValue>> {
let mut guard = self.inner.lock().await;
guard.run_script(script).await
}
pub async fn run_pipeline(&self, pipeline: Pipeline) -> Result<PipelineResult> {
let mut guard = self.inner.lock().await;
pipeline.run_all_streams(&mut guard).await
}
pub async fn run_script_with_cancel(
&self,
script: &str,
cancel: tokio_util::sync::CancellationToken,
) -> Result<Vec<PsValue>> {
let mut guard = self.inner.lock().await;
guard.run_script_with_cancel(script, cancel).await
}
pub async fn request_session_key(&self) -> Result<()> {
let mut guard = self.inner.lock().await;
guard.request_session_key().await
}
pub async fn close(self) -> Result<()> {
match Arc::try_unwrap(self.inner) {
Ok(mutex) => mutex.into_inner().close().await,
Err(arc) => Err(crate::error::PsrpError::protocol(format!(
"cannot close SharedRunspacePool: {} handles still outstanding",
Arc::strong_count(&arc)
))),
}
}
pub async fn with_pool<F, R>(&self, f: F) -> R
where
F: for<'a> FnOnce(
&'a mut RunspacePool<T>,
)
-> std::pin::Pin<Box<dyn std::future::Future<Output = R> + Send + 'a>>,
{
let mut guard = self.inner.lock().await;
f(&mut guard).await
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::clixml::{PsObject, to_clixml};
use crate::fragment::encode_message;
use crate::message::{Destination, MessageType, PsrpMessage};
use crate::pipeline::PipelineState;
use crate::runspace::RunspacePoolState;
use crate::transport::mock::MockTransport;
use uuid::Uuid;
fn state_message(state: RunspacePoolState) -> Vec<u8> {
let body = to_clixml(&PsValue::Object(
PsObject::new().with("RunspaceState", PsValue::I32(state as i32)),
));
PsrpMessage {
destination: Destination::Client,
message_type: MessageType::RunspacePoolState,
rpid: Uuid::nil(),
pid: Uuid::nil(),
data: body,
}
.encode()
}
fn pipeline_state_message(state: PipelineState) -> Vec<u8> {
let body = to_clixml(&PsValue::Object(
PsObject::new().with("PipelineState", PsValue::I32(state as i32)),
));
PsrpMessage {
destination: Destination::Client,
message_type: MessageType::PipelineState,
rpid: Uuid::nil(),
pid: Uuid::nil(),
data: body,
}
.encode()
}
async fn opened_shared() -> (MockTransport, SharedRunspacePool<MockTransport>) {
let t = MockTransport::new();
t.push_incoming(encode_message(1, &state_message(RunspacePoolState::Opened)));
let pool = RunspacePool::open_with_transport(t.clone()).await.unwrap();
(t, SharedRunspacePool::new(pool))
}
#[tokio::test]
async fn shared_run_script_serialises_access() {
let (t, shared) = opened_shared().await;
t.push_incoming(encode_message(
10,
&PsrpMessage {
destination: Destination::Client,
message_type: MessageType::PipelineOutput,
rpid: Uuid::nil(),
pid: Uuid::nil(),
data: "<I32>42</I32>".into(),
}
.encode(),
));
t.push_incoming(encode_message(
11,
&pipeline_state_message(PipelineState::Completed),
));
let out = shared.run_script("whatever").await.unwrap();
assert_eq!(out, vec![PsValue::I32(42)]);
assert_eq!(shared.handle_count(), 1);
shared.close().await.unwrap();
}
#[tokio::test]
async fn shared_close_errors_with_outstanding_clones() {
let (_t, shared) = opened_shared().await;
let clone = shared.clone();
assert_eq!(shared.handle_count(), 2);
let err = shared.close().await.unwrap_err();
assert!(matches!(err, crate::error::PsrpError::Protocol(_)));
clone.close().await.unwrap();
}
#[tokio::test]
async fn shared_with_pool_direct_access() {
let (_t, shared) = opened_shared().await;
let state = shared
.with_pool(|p| Box::pin(async move { p.state() }))
.await;
assert_eq!(state, RunspacePoolState::Opened);
shared.close().await.unwrap();
}
#[tokio::test]
async fn shared_debug_format_includes_strong_count() {
let (_t, shared) = opened_shared().await;
let s = format!("{shared:?}");
assert!(s.contains("SharedRunspacePool"));
assert!(s.contains("strong_count"));
shared.close().await.unwrap();
}
#[tokio::test]
async fn shared_run_pipeline_with_builder() {
let (t, shared) = opened_shared().await;
t.push_incoming(encode_message(
10,
&PsrpMessage {
destination: Destination::Client,
message_type: MessageType::PipelineOutput,
rpid: Uuid::nil(),
pid: Uuid::nil(),
data: "<S>ok</S>".into(),
}
.encode(),
));
t.push_incoming(encode_message(
11,
&pipeline_state_message(PipelineState::Completed),
));
let result = shared
.run_pipeline(crate::pipeline::Pipeline::new("dummy"))
.await
.unwrap();
assert_eq!(result.output, vec![PsValue::String("ok".into())]);
shared.close().await.unwrap();
}
#[tokio::test]
async fn shared_run_script_with_cancel_token() {
let (t, shared) = opened_shared().await;
t.push_incoming(encode_message(
10,
&PsrpMessage {
destination: Destination::Client,
message_type: MessageType::PipelineOutput,
rpid: Uuid::nil(),
pid: Uuid::nil(),
data: "<I32>7</I32>".into(),
}
.encode(),
));
t.push_incoming(encode_message(
11,
&pipeline_state_message(PipelineState::Completed),
));
let token = tokio_util::sync::CancellationToken::new();
let out = shared.run_script_with_cancel("x", token).await.unwrap();
assert_eq!(out, vec![PsValue::I32(7)]);
shared.close().await.unwrap();
}
#[tokio::test]
async fn shared_request_session_key_delegates_and_fails() {
let (t, shared) = opened_shared().await;
t.push_incoming(encode_message(
9,
&PsrpMessage {
destination: Destination::Client,
message_type: MessageType::EncryptedSessionKey,
rpid: Uuid::nil(),
pid: Uuid::nil(),
data: to_clixml(&PsValue::Object(
PsObject::new().with("EncryptedSessionKey", PsValue::String("deadbeef".into())),
)),
}
.encode(),
));
let err = shared.request_session_key().await.unwrap_err();
assert!(matches!(err, crate::error::PsrpError::Protocol(_)));
shared.close().await.unwrap();
}
#[tokio::test]
async fn shared_handle_count_scales() {
let (_t, shared) = opened_shared().await;
assert_eq!(shared.handle_count(), 1);
let h2 = shared.clone();
let h3 = shared.clone();
assert_eq!(shared.handle_count(), 3);
drop(h3);
drop(h2);
assert_eq!(shared.handle_count(), 1);
shared.close().await.unwrap();
}
}