rmp_ipc/ipc/
context.rs

1use crate::error::{Error, Result};
2use crate::event::Event;
3use crate::ipc::stream_emitter::StreamEmitter;
4use std::collections::HashMap;
5use std::mem;
6use std::ops::{Deref, DerefMut};
7use std::sync::atomic::{AtomicUsize, Ordering};
8use std::sync::Arc;
9use tokio::sync::oneshot::Sender;
10use tokio::sync::{oneshot, Mutex, RwLock};
11use typemap_rev::TypeMap;
12
13pub(crate) type ReplyListeners = Arc<Mutex<HashMap<u64, oneshot::Sender<Event>>>>;
14
15/// An object provided to each callback function.
16/// Currently it only holds the event emitter to emit response events in event callbacks.
17/// ```rust
18/// use rmp_ipc::prelude::*;
19///
20/// async fn my_callback(ctx: &Context, _event: Event) -> IPCResult<()> {
21///     // use the emitter on the context object to emit events
22///     // inside callbacks
23///     ctx.emitter.emit("ping", ()).await?;
24///     Ok(())
25/// }
26/// ```
27#[derive(Clone)]
28pub struct Context {
29    /// The event emitter
30    pub emitter: StreamEmitter,
31
32    /// Field to store additional context data
33    pub data: Arc<RwLock<TypeMap>>,
34
35    stop_sender: Arc<Mutex<Option<Sender<()>>>>,
36
37    reply_listeners: ReplyListeners,
38}
39
40impl Context {
41    pub(crate) fn new(
42        emitter: StreamEmitter,
43        data: Arc<RwLock<TypeMap>>,
44        stop_sender: Option<Sender<()>>,
45        reply_listeners: ReplyListeners,
46    ) -> Self {
47        Self {
48            emitter,
49            reply_listeners,
50            data,
51            stop_sender: Arc::new(Mutex::new(stop_sender)),
52        }
53    }
54
55    /// Waits for a reply to the given message ID
56    #[tracing::instrument(level = "debug", skip(self))]
57    pub async fn await_reply(&self, message_id: u64) -> Result<Event> {
58        let (rx, tx) = oneshot::channel();
59        {
60            let mut listeners = self.reply_listeners.lock().await;
61            listeners.insert(message_id, rx);
62        }
63        let event = tx.await?;
64
65        Ok(event)
66    }
67
68    /// Stops the listener and closes the connection
69    #[tracing::instrument(level = "debug", skip(self))]
70    pub async fn stop(self) -> Result<()> {
71        let mut sender = self.stop_sender.lock().await;
72        if let Some(sender) = mem::take(&mut *sender) {
73            sender.send(()).map_err(|_| Error::SendError)?;
74        }
75
76        Ok(())
77    }
78
79    /// Returns the channel for a reply to the given message id
80    pub(crate) async fn get_reply_sender(&self, ref_id: u64) -> Option<oneshot::Sender<Event>> {
81        let mut listeners = self.reply_listeners.lock().await;
82        listeners.remove(&ref_id)
83    }
84}
85
86#[derive(Clone)]
87pub struct PooledContext {
88    contexts: Vec<PoolGuard<Context>>,
89}
90
91pub struct PoolGuard<T>
92where
93    T: Clone,
94{
95    inner: T,
96    count: Arc<AtomicUsize>,
97}
98
99impl<T> Deref for PoolGuard<T>
100where
101    T: Clone,
102{
103    type Target = T;
104
105    fn deref(&self) -> &Self::Target {
106        &self.inner
107    }
108}
109
110impl<T> DerefMut for PoolGuard<T>
111where
112    T: Clone,
113{
114    fn deref_mut(&mut self) -> &mut Self::Target {
115        &mut self.inner
116    }
117}
118
119impl<T> Clone for PoolGuard<T>
120where
121    T: Clone,
122{
123    fn clone(&self) -> Self {
124        self.acquire();
125
126        Self {
127            inner: self.inner.clone(),
128            count: Arc::clone(&self.count),
129        }
130    }
131}
132
133impl<T> Drop for PoolGuard<T>
134where
135    T: Clone,
136{
137    fn drop(&mut self) {
138        self.release();
139    }
140}
141
142impl<T> PoolGuard<T>
143where
144    T: Clone,
145{
146    pub(crate) fn new(inner: T) -> Self {
147        Self {
148            inner,
149            count: Arc::new(AtomicUsize::new(0)),
150        }
151    }
152
153    /// Acquires the context by adding 1 to the count
154    #[tracing::instrument(level = "trace", skip_all)]
155    pub(crate) fn acquire(&self) {
156        let count = self.count.fetch_add(1, Ordering::Relaxed);
157        tracing::trace!(count);
158    }
159
160    /// Releases the connection by subtracting from the stored count
161    #[tracing::instrument(level = "trace", skip_all)]
162    pub(crate) fn release(&self) {
163        let count = self.count.fetch_sub(1, Ordering::Relaxed);
164        tracing::trace!(count);
165    }
166
167    pub(crate) fn count(&self) -> usize {
168        self.count.load(Ordering::Relaxed)
169    }
170}
171
172impl PooledContext {
173    /// Creates a new pooled context from a list of contexts
174    pub(crate) fn new(contexts: Vec<Context>) -> Self {
175        Self {
176            contexts: contexts.into_iter().map(PoolGuard::new).collect(),
177        }
178    }
179
180    /// Acquires a context from the pool
181    /// It always chooses the one that is used the least
182    #[tracing::instrument(level = "trace", skip_all)]
183    pub fn acquire(&self) -> PoolGuard<Context> {
184        self.contexts
185            .iter()
186            .min_by_key(|c| c.count())
187            .unwrap()
188            .clone()
189    }
190}