1use std::{collections::HashMap, marker::PhantomData, sync::Arc};
2#[cfg(not(feature = "tokio"))]
3use std::{
4 sync::{Condvar, Mutex},
5 time::Duration,
6};
7
8#[cfg(feature = "tokio")]
9use tokio::sync::{Mutex, Notify};
10
11use crate::{
12 client::Client,
13 error::RpcError,
14 schema::{DecodeUntagged, ProcedureCall, ProcedureResult},
15 services::krpc::KRPC,
16 RpcType,
17};
18
19pub struct Stream<T: RpcType + Send> {
39 pub(crate) id: u64,
40 krpc: KRPC,
41 client: Arc<Client>,
42 phantom: PhantomData<T>,
43}
44
45#[cfg(not(feature = "tokio"))]
46type StreamEntry = Arc<(Mutex<ProcedureResult>, Condvar)>;
47#[cfg(feature = "tokio")]
48type StreamEntry = Arc<(Mutex<ProcedureResult>, Notify)>;
49#[derive(Default)]
50pub(crate) struct StreamWrangler {
51 streams: Mutex<HashMap<u64, StreamEntry>>,
52 #[cfg(feature = "tokio")]
53 refcounts: std::sync::Mutex<HashMap<u64, u32>>,
54}
55
56impl StreamWrangler {
57 #[cfg(feature = "tokio")]
58 pub fn increment_refcount(&self, id: u64) -> u32 {
59 let mut guard = self.refcounts.lock().unwrap();
60 let entry = guard.entry(id).or_insert(0);
61 *entry += 1;
62 *entry
63 }
64
65 #[cfg(feature = "tokio")]
66 pub fn decrement_refcount(&self, id: u64) -> u32 {
67 let mut guard = self.refcounts.lock().unwrap();
68 let Some(entry) = guard.get_mut(&id) else {
69 return 0;
70 };
71 *entry -= 1;
72
73 let result = *entry;
74 if result == 0 {
75 guard.remove(&id);
76 }
77
78 result
79 }
80
81 #[cfg(not(feature = "tokio"))]
82 pub fn insert(
83 &self,
84 id: u64,
85 procedure_result: ProcedureResult,
86 ) -> Result<(), RpcError> {
87 let mut map = self.streams.lock().unwrap();
88 let (lock, cvar) = { &*map.entry(id).or_default().clone() };
89
90 *lock.lock().unwrap() = procedure_result;
91 cvar.notify_one();
92
93 Ok(())
94 }
95
96 #[cfg(feature = "tokio")]
97 pub async fn insert(
98 &self,
99 id: u64,
100 procedure_result: ProcedureResult,
101 ) -> Result<(), RpcError> {
102 let mut map = self.streams.lock().await;
103 let (lock, cvar) =
104 { &*map.entry(id).or_insert_with(Default::default).clone() };
105
106 *lock.lock().await = procedure_result;
107 cvar.notify_one();
108
109 Ok(())
110 }
111
112 #[cfg(not(feature = "tokio"))]
113 pub fn wait(&self, id: u64) {
114 let (lock, cvar) = {
115 let mut map = self.streams.lock().unwrap();
116 &*map.entry(id).or_default().clone()
117 };
118 let result = lock.lock().unwrap();
119 let _result = cvar.wait(result).unwrap();
120 }
121
122 #[cfg(not(feature = "tokio"))]
123 pub fn wait_timeout(&self, id: u64, dur: Duration) {
124 let (lock, cvar) = {
125 let mut map = self.streams.lock().unwrap();
126 &*map.entry(id).or_default().clone()
127 };
128 let result = lock.lock().unwrap();
129 let _result = cvar.wait_timeout(result, dur).unwrap();
130 }
131
132 #[cfg(feature = "tokio")]
133 pub async fn wait(&self, id: u64) {
134 let (_lock, cvar) = {
135 let mut map = self.streams.lock().await;
136 &*map.entry(id).or_insert_with(Default::default).clone()
137 };
138 cvar.notified().await;
139 }
140
141 #[cfg(not(feature = "tokio"))]
142 pub fn remove(&self, id: u64) {
143 let mut map = self.streams.lock().unwrap();
144 map.remove(&id);
145 }
146
147 #[cfg(feature = "tokio")]
148 pub async fn remove(&self, id: u64) {
149 let mut map = self.streams.lock().await;
150 map.remove(&id);
151 }
152
153 #[cfg(not(feature = "tokio"))]
154 pub fn get<T: DecodeUntagged>(
155 &self,
156 client: Arc<Client>,
157 id: u64,
158 ) -> Result<T, RpcError> {
159 let mut map = self.streams.lock().unwrap();
160 let (lock, _) = { &*map.entry(id).or_default().clone() };
161 let result = lock.lock().unwrap();
162 T::decode_untagged(client, &result.value)
163 }
164
165 #[cfg(feature = "tokio")]
166 pub async fn get<T: DecodeUntagged>(
167 &self,
168 client: Arc<Client>,
169 id: u64,
170 ) -> Result<T, RpcError> {
171 let mut map = self.streams.lock().await;
172 let (lock, _) =
173 { &*map.entry(id).or_insert_with(Default::default).clone() };
174 let result = lock.lock().await;
175 T::decode_untagged(client, &result.value)
176 }
177}
178
179impl<T: RpcType + Send> Stream<T> {
180 #[cfg(not(feature = "tokio"))]
181 pub(crate) fn new(
182 client: Arc<Client>,
183 call: ProcedureCall,
184 ) -> Result<Self, RpcError> {
185 let krpc = KRPC::new(client.clone());
186 let stream = krpc.add_stream(call, true)?;
187 client.await_stream(stream.id);
188
189 Ok(Self {
190 id: stream.id,
191 krpc,
192 client,
193 phantom: PhantomData,
194 })
195 }
196
197 #[cfg(feature = "tokio")]
198 pub(crate) async fn new(
199 client: Arc<Client>,
200 call: ProcedureCall,
201 ) -> Result<Self, RpcError> {
202 let krpc = KRPC::new(client.clone());
203 let stream = krpc.add_stream(call, true).await?;
204 client.register_stream(stream.id);
205 client.await_stream(stream.id).await;
206
207 Ok(Self {
208 id: stream.id,
209 krpc,
210 client,
211 phantom: PhantomData,
212 })
213 }
214
215 #[cfg(not(feature = "tokio"))]
217 pub fn set_rate(&self, hz: f32) -> Result<(), RpcError> {
218 self.krpc.set_stream_rate(self.id, hz)
219 }
220
221 #[cfg(feature = "tokio")]
223 pub async fn set_rate(&self, hz: f32) -> Result<(), RpcError> {
224 self.krpc.set_stream_rate(self.id, hz).await
225 }
226
227 #[cfg(not(feature = "tokio"))]
235 pub fn get(&self) -> Result<T, RpcError> {
236 self.client.read_stream(self.id)
237 }
238
239 #[cfg(feature = "tokio")]
247 pub async fn get(&self) -> Result<T, RpcError> {
248 self.client.read_stream(self.id).await
249 }
250
251 #[cfg(not(feature = "tokio"))]
254 pub fn wait(&self) {
255 self.client.await_stream(self.id);
256 }
257
258 #[cfg(not(feature = "tokio"))]
262 pub fn wait_timeout(&self, dur: Duration) {
263 self.client.await_stream_timeout(self.id, dur);
264 }
265
266 #[cfg(feature = "tokio")]
269 pub async fn wait(&self) {
270 self.client.await_stream(self.id).await;
271 }
272}
273
274impl<T: RpcType + Send> Drop for Stream<T> {
275 #[cfg(not(feature = "tokio"))]
278 fn drop(&mut self) {
279 self.krpc.remove_stream(self.id).ok();
280 self.client.remove_stream(self.id).ok();
281 }
282
283 #[cfg(feature = "tokio")]
284 fn drop(&mut self) {
285 let krpc = self.krpc.clone();
286 let client = self.client.clone();
287 let id = self.id;
288 let refcount = client.release_stream(id);
289 if refcount == 0 {
290 tokio::task::spawn(async move {
291 krpc.remove_stream(id).await.ok();
292 client.remove_stream(id).await.ok();
293 });
294 }
295 }
296}