krpc_client/
stream.rs

1use std::{collections::HashMap, marker::PhantomData, sync::Arc};
2#[cfg(not(feature = "tokio"))]
3use std::{
4    sync::{Condvar, Mutex},
5    time::Duration,
6};
7
8#[cfg(feature = "tokio")]
9use tokio::sync::{Mutex, Notify};
10
11use crate::{
12    client::Client,
13    error::RpcError,
14    schema::{DecodeUntagged, ProcedureCall, ProcedureResult},
15    services::krpc::KRPC,
16    RpcType,
17};
18
19/// A streaming procedure call.
20///
21/// `Stream<T>` is created by calling any procedure with the
22/// `_stream()` suffix. This will start the stream
23/// automatically.
24///
25/// This type provides access to the procedure's
26/// results of type `T` via [`get`][get]. Results are pushed
27/// by the server at the rate selected by
28/// [`set_rate`][set_rate]. And consumers may block until a
29/// stream's value has changed with [`wait`][wait].
30///
31/// The stream will attempt to remove itself when dropped.
32/// Otherwise, the server will remove remaining streams when
33/// the client disconnects.
34///
35/// [wait]: Stream::wait
36/// [set_rate]: Stream::set_rate
37/// [get]: Stream::get
38pub struct Stream<T: RpcType + Send> {
39    pub(crate) id: u64,
40    krpc: KRPC,
41    client: Arc<Client>,
42    phantom: PhantomData<T>,
43}
44
45#[cfg(not(feature = "tokio"))]
46type StreamEntry = Arc<(Mutex<ProcedureResult>, Condvar)>;
47#[cfg(feature = "tokio")]
48type StreamEntry = Arc<(Mutex<ProcedureResult>, Notify)>;
49#[derive(Default)]
50pub(crate) struct StreamWrangler {
51    streams: Mutex<HashMap<u64, StreamEntry>>,
52    #[cfg(feature = "tokio")]
53    refcounts: std::sync::Mutex<HashMap<u64, u32>>,
54}
55
56impl StreamWrangler {
57    #[cfg(feature = "tokio")]
58    pub fn increment_refcount(&self, id: u64) -> u32 {
59        let mut guard = self.refcounts.lock().unwrap();
60        let entry = guard.entry(id).or_insert(0);
61        *entry += 1;
62        *entry
63    }
64
65    #[cfg(feature = "tokio")]
66    pub fn decrement_refcount(&self, id: u64) -> u32 {
67        let mut guard = self.refcounts.lock().unwrap();
68        let Some(entry) = guard.get_mut(&id) else {
69            return 0;
70        };
71        *entry -= 1;
72
73        let result = *entry;
74        if result == 0 {
75            guard.remove(&id);
76        }
77
78        result
79    }
80
81    #[cfg(not(feature = "tokio"))]
82    pub fn insert(
83        &self,
84        id: u64,
85        procedure_result: ProcedureResult,
86    ) -> Result<(), RpcError> {
87        let mut map = self.streams.lock().unwrap();
88        let (lock, cvar) = { &*map.entry(id).or_default().clone() };
89
90        *lock.lock().unwrap() = procedure_result;
91        cvar.notify_one();
92
93        Ok(())
94    }
95
96    #[cfg(feature = "tokio")]
97    pub async fn insert(
98        &self,
99        id: u64,
100        procedure_result: ProcedureResult,
101    ) -> Result<(), RpcError> {
102        let mut map = self.streams.lock().await;
103        let (lock, cvar) =
104            { &*map.entry(id).or_insert_with(Default::default).clone() };
105
106        *lock.lock().await = procedure_result;
107        cvar.notify_one();
108
109        Ok(())
110    }
111
112    #[cfg(not(feature = "tokio"))]
113    pub fn wait(&self, id: u64) {
114        let (lock, cvar) = {
115            let mut map = self.streams.lock().unwrap();
116            &*map.entry(id).or_default().clone()
117        };
118        let result = lock.lock().unwrap();
119        let _result = cvar.wait(result).unwrap();
120    }
121
122    #[cfg(not(feature = "tokio"))]
123    pub fn wait_timeout(&self, id: u64, dur: Duration) {
124        let (lock, cvar) = {
125            let mut map = self.streams.lock().unwrap();
126            &*map.entry(id).or_default().clone()
127        };
128        let result = lock.lock().unwrap();
129        let _result = cvar.wait_timeout(result, dur).unwrap();
130    }
131
132    #[cfg(feature = "tokio")]
133    pub async fn wait(&self, id: u64) {
134        let (_lock, cvar) = {
135            let mut map = self.streams.lock().await;
136            &*map.entry(id).or_insert_with(Default::default).clone()
137        };
138        cvar.notified().await;
139    }
140
141    #[cfg(not(feature = "tokio"))]
142    pub fn remove(&self, id: u64) {
143        let mut map = self.streams.lock().unwrap();
144        map.remove(&id);
145    }
146
147    #[cfg(feature = "tokio")]
148    pub async fn remove(&self, id: u64) {
149        let mut map = self.streams.lock().await;
150        map.remove(&id);
151    }
152
153    #[cfg(not(feature = "tokio"))]
154    pub fn get<T: DecodeUntagged>(
155        &self,
156        client: Arc<Client>,
157        id: u64,
158    ) -> Result<T, RpcError> {
159        let mut map = self.streams.lock().unwrap();
160        let (lock, _) = { &*map.entry(id).or_default().clone() };
161        let result = lock.lock().unwrap();
162        T::decode_untagged(client, &result.value)
163    }
164
165    #[cfg(feature = "tokio")]
166    pub async fn get<T: DecodeUntagged>(
167        &self,
168        client: Arc<Client>,
169        id: u64,
170    ) -> Result<T, RpcError> {
171        let mut map = self.streams.lock().await;
172        let (lock, _) =
173            { &*map.entry(id).or_insert_with(Default::default).clone() };
174        let result = lock.lock().await;
175        T::decode_untagged(client, &result.value)
176    }
177}
178
179impl<T: RpcType + Send> Stream<T> {
180    #[cfg(not(feature = "tokio"))]
181    pub(crate) fn new(
182        client: Arc<Client>,
183        call: ProcedureCall,
184    ) -> Result<Self, RpcError> {
185        let krpc = KRPC::new(client.clone());
186        let stream = krpc.add_stream(call, true)?;
187        client.await_stream(stream.id);
188
189        Ok(Self {
190            id: stream.id,
191            krpc,
192            client,
193            phantom: PhantomData,
194        })
195    }
196
197    #[cfg(feature = "tokio")]
198    pub(crate) async fn new(
199        client: Arc<Client>,
200        call: ProcedureCall,
201    ) -> Result<Self, RpcError> {
202        let krpc = KRPC::new(client.clone());
203        let stream = krpc.add_stream(call, true).await?;
204        client.register_stream(stream.id);
205        client.await_stream(stream.id).await;
206
207        Ok(Self {
208            id: stream.id,
209            krpc,
210            client,
211            phantom: PhantomData,
212        })
213    }
214
215    /// Set the update rate for this streaming procedure.
216    #[cfg(not(feature = "tokio"))]
217    pub fn set_rate(&self, hz: f32) -> Result<(), RpcError> {
218        self.krpc.set_stream_rate(self.id, hz)
219    }
220
221    /// Set the update rate for this streaming procedure.
222    #[cfg(feature = "tokio")]
223    pub async fn set_rate(&self, hz: f32) -> Result<(), RpcError> {
224        self.krpc.set_stream_rate(self.id, hz).await
225    }
226
227    /// Retrieve the current result received for this
228    /// procedure. This value is not guaranteed to have
229    /// changed since the last call to [`get`][get]. Use
230    /// [`wait`][wait] to block until the value has changed.
231    ///
232    /// [wait]: Stream::wait
233    /// [get]: Stream::get
234    #[cfg(not(feature = "tokio"))]
235    pub fn get(&self) -> Result<T, RpcError> {
236        self.client.read_stream(self.id)
237    }
238
239    /// Retrieve the current result received for this
240    /// procedure. This value is not guaranteed to have
241    /// changed since the last call to [`get`][get]. Use
242    /// [`wait`][wait] to block until the value has changed.
243    ///
244    /// [wait]: Stream::wait
245    /// [get]: Stream::get
246    #[cfg(feature = "tokio")]
247    pub async fn get(&self) -> Result<T, RpcError> {
248        self.client.read_stream(self.id).await
249    }
250
251    /// Block the current thread of execution until this
252    /// stream receives an update from the server.
253    #[cfg(not(feature = "tokio"))]
254    pub fn wait(&self) {
255        self.client.await_stream(self.id);
256    }
257
258    /// Block the current thread of execution until this
259    /// stream receives an update from the server or the
260    /// timeout is reached.
261    #[cfg(not(feature = "tokio"))]
262    pub fn wait_timeout(&self, dur: Duration) {
263        self.client.await_stream_timeout(self.id, dur);
264    }
265
266    /// Block the current thread of execution until this
267    /// stream receives an update from the server.
268    #[cfg(feature = "tokio")]
269    pub async fn wait(&self) {
270        self.client.await_stream(self.id).await;
271    }
272}
273
274impl<T: RpcType + Send> Drop for Stream<T> {
275    // Try to remove the stream if it's dropped, but don't panic
276    // if unable.
277    #[cfg(not(feature = "tokio"))]
278    fn drop(&mut self) {
279        self.krpc.remove_stream(self.id).ok();
280        self.client.remove_stream(self.id).ok();
281    }
282
283    #[cfg(feature = "tokio")]
284    fn drop(&mut self) {
285        let krpc = self.krpc.clone();
286        let client = self.client.clone();
287        let id = self.id;
288        let refcount = client.release_stream(id);
289        if refcount == 0 {
290            tokio::task::spawn(async move {
291                krpc.remove_stream(id).await.ok();
292                client.remove_stream(id).await.ok();
293            });
294        }
295    }
296}