use std::sync::Arc;
use tokio::sync::Mutex;
pub trait Shareable<T>: Clone + Send + Sync + 'static {
fn new(inner: T) -> Self;
}
#[derive(Debug)]
pub struct Shared<T> {
inner: Arc<Mutex<T>>,
}
impl<T> Shared<T>
where
T: Send + 'static,
{
pub async fn with<F, R>(&self, f: F) -> R
where
F: FnOnce(&T) -> R + Send,
{
let guard = self.inner.lock().await;
f(&*guard)
}
pub async fn with_mut<F, R>(&self, f: F) -> R
where
F: FnOnce(&mut T) -> R + Send,
{
let mut guard = self.inner.lock().await;
f(&mut *guard)
}
pub async fn with_async<F, Fut, R>(&self, f: F) -> R
where
F: FnOnce(&T) -> Fut + Send,
Fut: std::future::Future<Output = R> + Send,
{
let guard = self.inner.lock().await;
f(&*guard).await
}
pub async fn with_mut_async<F, Fut, R>(&self, f: F) -> R
where
F: FnOnce(&mut T) -> Fut + Send,
Fut: std::future::Future<Output = R> + Send,
{
let mut guard = self.inner.lock().await;
f(&mut *guard).await
}
pub fn try_with<F, R>(&self, f: F) -> Option<R>
where
F: FnOnce(&T) -> R + Send,
{
let guard = self.inner.try_lock().ok()?;
Some(f(&*guard))
}
pub fn try_with_mut<F, R>(&self, f: F) -> Option<R>
where
F: FnOnce(&mut T) -> R + Send,
{
let mut guard = self.inner.try_lock().ok()?;
Some(f(&mut *guard))
}
}
impl<T> Shareable<T> for Shared<T>
where
T: Send + 'static,
{
fn new(inner: T) -> Self {
Self {
inner: Arc::new(Mutex::new(inner)),
}
}
}
impl<T> Clone for Shared<T>
where
T: Send + 'static,
{
fn clone(&self) -> Self {
Self {
inner: Arc::clone(&self.inner),
}
}
}
#[derive(Debug)]
pub struct ConsumableShared<T> {
inner: Arc<Mutex<Option<T>>>,
}
impl<T> ConsumableShared<T>
where
T: Send + 'static,
{
pub async fn with<F, R>(&self, f: F) -> Result<R, SharedError>
where
F: FnOnce(&T) -> R + Send,
{
let guard = self.inner.lock().await;
match guard.as_ref() {
Some(value) => Ok(f(value)),
None => Err(SharedError::Consumed),
}
}
pub async fn with_mut<F, R>(&self, f: F) -> Result<R, SharedError>
where
F: FnOnce(&mut T) -> R + Send,
{
let mut guard = self.inner.lock().await;
match guard.as_mut() {
Some(value) => Ok(f(value)),
None => Err(SharedError::Consumed),
}
}
pub async fn with_async<F, Fut, R>(&self, f: F) -> Result<R, SharedError>
where
F: FnOnce(&T) -> Fut + Send,
Fut: std::future::Future<Output = R> + Send,
{
let guard = self.inner.lock().await;
match guard.as_ref() {
Some(value) => Ok(f(value).await),
None => Err(SharedError::Consumed),
}
}
pub async fn with_mut_async<F, Fut, R>(&self, f: F) -> Result<R, SharedError>
where
F: FnOnce(&mut T) -> Fut + Send,
Fut: std::future::Future<Output = R> + Send,
{
let mut guard = self.inner.lock().await;
match guard.as_mut() {
Some(value) => Ok(f(value).await),
None => Err(SharedError::Consumed),
}
}
pub async fn consume(self) -> Result<T, SharedError> {
let mut guard = self.inner.lock().await;
guard.take().ok_or(SharedError::Consumed)
}
pub async fn is_available(&self) -> bool {
self.inner.lock().await.is_some()
}
}
impl<T> Shareable<T> for ConsumableShared<T>
where
T: Send + 'static,
{
fn new(inner: T) -> Self {
Self {
inner: Arc::new(Mutex::new(Some(inner))),
}
}
}
impl<T> Clone for ConsumableShared<T>
where
T: Send + 'static,
{
fn clone(&self) -> Self {
Self {
inner: Arc::clone(&self.inner),
}
}
}
#[derive(Debug, Clone, thiserror::Error)]
pub enum SharedError {
#[error("The shared value has been consumed")]
Consumed,
}
impl From<SharedError> for crate::McpError {
fn from(err: SharedError) -> Self {
match err {
SharedError::Consumed => {
crate::McpError::invalid_params("Shared value has already been consumed")
.with_component("shared_wrapper")
}
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[derive(Debug)]
struct TestCounter {
value: u64,
}
impl TestCounter {
fn new() -> Self {
Self { value: 0 }
}
fn increment(&mut self) {
self.value += 1;
}
fn get(&self) -> u64 {
self.value
}
#[allow(dead_code)]
async fn get_async(&self) -> u64 {
self.value
}
}
#[tokio::test]
async fn test_shared_basic_operations() {
let counter = TestCounter::new();
let shared = Shared::new(counter);
let value = shared.with(|c| c.get()).await;
assert_eq!(value, 0);
shared.with_mut(|c| c.increment()).await;
let value = shared.with(|c| c.get()).await;
assert_eq!(value, 1);
}
#[tokio::test]
async fn test_shared_async_operations() {
let counter = TestCounter::new();
let shared = Shared::new(counter);
let value = shared.with(|c| c.get()).await;
assert_eq!(value, 0);
}
#[tokio::test]
async fn test_shared_cloning() {
let counter = TestCounter::new();
let shared = Shared::new(counter);
let clones: Vec<_> = (0..10).map(|_| shared.clone()).collect();
assert_eq!(clones.len(), 10);
for (i, shared_clone) in clones.into_iter().enumerate() {
shared_clone.with_mut(|c| c.increment()).await;
let value = shared_clone.with(|c| c.get()).await;
assert_eq!(value, i as u64 + 1);
}
}
#[tokio::test]
async fn test_shared_concurrent_access() {
let counter = TestCounter::new();
let shared = Shared::new(counter);
let handles: Vec<_> = (0..10)
.map(|_| {
let shared_clone = shared.clone();
tokio::spawn(async move {
shared_clone.with_mut(|c| c.increment()).await;
})
})
.collect();
for handle in handles {
handle.await.unwrap();
}
let value = shared.with(|c| c.get()).await;
assert_eq!(value, 10);
}
#[tokio::test]
async fn test_consumable_shared() {
let counter = TestCounter::new();
let shared = ConsumableShared::new(counter);
let shared_clone = shared.clone();
assert!(shared.is_available().await);
let value = shared.with(|c| c.get()).await.unwrap();
assert_eq!(value, 0);
shared.with_mut(|c| c.increment()).await.unwrap();
let value = shared.with(|c| c.get()).await.unwrap();
assert_eq!(value, 1);
let counter = shared.consume().await.unwrap();
assert_eq!(counter.get(), 1);
assert!(!shared_clone.is_available().await);
assert!(matches!(
shared_clone.with(|c| c.get()).await,
Err(SharedError::Consumed)
));
}
#[tokio::test]
async fn test_consumable_shared_cloning() {
let counter = TestCounter::new();
let shared = ConsumableShared::new(counter);
let shared_clone = shared.clone();
assert!(shared.is_available().await);
assert!(shared_clone.is_available().await);
let _counter = shared.consume().await.unwrap();
assert!(!shared_clone.is_available().await);
assert!(matches!(
shared_clone.with(|c| c.get()).await,
Err(SharedError::Consumed)
));
}
#[tokio::test]
async fn test_try_operations() {
let counter = TestCounter::new();
let shared = Shared::new(counter);
let value = shared.try_with(|c| c.get()).unwrap();
assert_eq!(value, 0);
shared.try_with_mut(|c| c.increment()).unwrap();
let value = shared.try_with(|c| c.get()).unwrap();
assert_eq!(value, 1);
}
}