use std::{collections::HashMap, marker::PhantomData, sync::Arc};
#[cfg(not(feature = "tokio"))]
use std::{
sync::{Condvar, Mutex},
time::Duration,
};
#[cfg(feature = "tokio")]
use tokio::sync::{Mutex, Notify};
use crate::{
client::Client,
error::RpcError,
schema::{DecodeUntagged, ProcedureCall, ProcedureResult},
services::krpc::KRPC,
RpcType,
};
pub struct Stream<T: RpcType + Send> {
pub(crate) id: u64,
krpc: KRPC,
client: Arc<Client>,
phantom: PhantomData<T>,
}
#[cfg(not(feature = "tokio"))]
type StreamEntry = Arc<(Mutex<ProcedureResult>, Condvar)>;
#[cfg(feature = "tokio")]
type StreamEntry = Arc<(Mutex<ProcedureResult>, Notify)>;
#[derive(Default)]
pub(crate) struct StreamWrangler {
streams: Mutex<HashMap<u64, StreamEntry>>,
#[cfg(feature = "tokio")]
refcounts: std::sync::Mutex<HashMap<u64, u32>>,
}
impl StreamWrangler {
#[cfg(feature = "tokio")]
pub fn increment_refcount(&self, id: u64) -> u32 {
let mut guard = self.refcounts.lock().unwrap();
let entry = guard.entry(id).or_insert(0);
*entry += 1;
*entry
}
#[cfg(feature = "tokio")]
pub fn decrement_refcount(&self, id: u64) -> u32 {
let mut guard = self.refcounts.lock().unwrap();
let Some(entry) = guard.get_mut(&id) else {
return 0;
};
*entry -= 1;
let result = *entry;
if result == 0 {
guard.remove(&id);
}
result
}
#[cfg(not(feature = "tokio"))]
pub fn insert(
&self,
id: u64,
procedure_result: ProcedureResult,
) -> Result<(), RpcError> {
let mut map = self.streams.lock().unwrap();
let (lock, cvar) = { &*map.entry(id).or_default().clone() };
*lock.lock().unwrap() = procedure_result;
cvar.notify_one();
Ok(())
}
#[cfg(feature = "tokio")]
pub async fn insert(
&self,
id: u64,
procedure_result: ProcedureResult,
) -> Result<(), RpcError> {
let mut map = self.streams.lock().await;
let (lock, cvar) =
{ &*map.entry(id).or_insert_with(Default::default).clone() };
*lock.lock().await = procedure_result;
cvar.notify_one();
Ok(())
}
#[cfg(not(feature = "tokio"))]
pub fn wait(&self, id: u64) {
let (lock, cvar) = {
let mut map = self.streams.lock().unwrap();
&*map.entry(id).or_default().clone()
};
let result = lock.lock().unwrap();
let _result = cvar.wait(result).unwrap();
}
#[cfg(not(feature = "tokio"))]
pub fn wait_timeout(&self, id: u64, dur: Duration) {
let (lock, cvar) = {
let mut map = self.streams.lock().unwrap();
&*map.entry(id).or_default().clone()
};
let result = lock.lock().unwrap();
let _result = cvar.wait_timeout(result, dur).unwrap();
}
#[cfg(feature = "tokio")]
pub async fn wait(&self, id: u64) {
let (_lock, cvar) = {
let mut map = self.streams.lock().await;
&*map.entry(id).or_insert_with(Default::default).clone()
};
cvar.notified().await;
}
#[cfg(not(feature = "tokio"))]
pub fn remove(&self, id: u64) {
let mut map = self.streams.lock().unwrap();
map.remove(&id);
}
#[cfg(feature = "tokio")]
pub async fn remove(&self, id: u64) {
let mut map = self.streams.lock().await;
map.remove(&id);
}
#[cfg(not(feature = "tokio"))]
pub fn get<T: DecodeUntagged>(
&self,
client: Arc<Client>,
id: u64,
) -> Result<T, RpcError> {
let mut map = self.streams.lock().unwrap();
let (lock, _) = { &*map.entry(id).or_default().clone() };
let result = lock.lock().unwrap();
T::decode_untagged(client, &result.value)
}
#[cfg(feature = "tokio")]
pub async fn get<T: DecodeUntagged>(
&self,
client: Arc<Client>,
id: u64,
) -> Result<T, RpcError> {
let mut map = self.streams.lock().await;
let (lock, _) =
{ &*map.entry(id).or_insert_with(Default::default).clone() };
let result = lock.lock().await;
T::decode_untagged(client, &result.value)
}
}
impl<T: RpcType + Send> Stream<T> {
#[cfg(not(feature = "tokio"))]
pub(crate) fn new(
client: Arc<Client>,
call: ProcedureCall,
) -> Result<Self, RpcError> {
let krpc = KRPC::new(client.clone());
let stream = krpc.add_stream(call, true)?;
client.await_stream(stream.id);
Ok(Self {
id: stream.id,
krpc,
client,
phantom: PhantomData,
})
}
#[cfg(feature = "tokio")]
pub(crate) async fn new(
client: Arc<Client>,
call: ProcedureCall,
) -> Result<Self, RpcError> {
let krpc = KRPC::new(client.clone());
let stream = krpc.add_stream(call, true).await?;
client.register_stream(stream.id);
client.await_stream(stream.id).await;
Ok(Self {
id: stream.id,
krpc,
client,
phantom: PhantomData,
})
}
#[cfg(not(feature = "tokio"))]
pub fn set_rate(&self, hz: f32) -> Result<(), RpcError> {
self.krpc.set_stream_rate(self.id, hz)
}
#[cfg(feature = "tokio")]
pub async fn set_rate(&self, hz: f32) -> Result<(), RpcError> {
self.krpc.set_stream_rate(self.id, hz).await
}
#[cfg(not(feature = "tokio"))]
pub fn get(&self) -> Result<T, RpcError> {
self.client.read_stream(self.id)
}
#[cfg(feature = "tokio")]
pub async fn get(&self) -> Result<T, RpcError> {
self.client.read_stream(self.id).await
}
#[cfg(not(feature = "tokio"))]
pub fn wait(&self) {
self.client.await_stream(self.id);
}
#[cfg(not(feature = "tokio"))]
pub fn wait_timeout(&self, dur: Duration) {
self.client.await_stream_timeout(self.id, dur);
}
#[cfg(feature = "tokio")]
pub async fn wait(&self) {
self.client.await_stream(self.id).await;
}
}
impl<T: RpcType + Send> Drop for Stream<T> {
#[cfg(not(feature = "tokio"))]
fn drop(&mut self) {
self.krpc.remove_stream(self.id).ok();
self.client.remove_stream(self.id).ok();
}
#[cfg(feature = "tokio")]
fn drop(&mut self) {
let krpc = self.krpc.clone();
let client = self.client.clone();
let id = self.id;
let refcount = client.release_stream(id);
if refcount == 0 {
tokio::task::spawn(async move {
krpc.remove_stream(id).await.ok();
client.remove_stream(id).await.ok();
});
}
}
}