krpc-client 0.6.1

A stand-alone client for the Kerbal Space Program kRPC mod.
Documentation
use std::{collections::HashMap, marker::PhantomData, sync::Arc};
#[cfg(not(feature = "tokio"))]
use std::{
    sync::{Condvar, Mutex},
    time::Duration,
};

#[cfg(feature = "tokio")]
use tokio::sync::{Mutex, Notify};

use crate::{
    client::Client,
    error::RpcError,
    schema::{DecodeUntagged, ProcedureCall, ProcedureResult},
    services::krpc::KRPC,
    RpcType,
};

/// A streaming procedure call.
///
/// `Stream<T>` is created by calling any procedure with the
/// `_stream()` suffix. This will start the stream
/// automatically.
///
/// This type provides access to the procedure's
/// results of type `T` via [`get`][get]. Results are pushed
/// by the server at the rate selected by
/// [`set_rate`][set_rate]. And consumers may block until a
/// stream's value has changed with [`wait`][wait].
///
/// The stream will attempt to remove itself when dropped.
/// Otherwise, the server will remove remaining streams when
/// the client disconnects.
///
/// [wait]: Stream::wait
/// [set_rate]: Stream::set_rate
/// [get]: Stream::get
pub struct Stream<T: RpcType + Send> {
    pub(crate) id: u64,
    krpc: KRPC,
    client: Arc<Client>,
    phantom: PhantomData<T>,
}

#[cfg(not(feature = "tokio"))]
type StreamEntry = Arc<(Mutex<ProcedureResult>, Condvar)>;
#[cfg(feature = "tokio")]
type StreamEntry = Arc<(Mutex<ProcedureResult>, Notify)>;
#[derive(Default)]
pub(crate) struct StreamWrangler {
    streams: Mutex<HashMap<u64, StreamEntry>>,
    #[cfg(feature = "tokio")]
    refcounts: std::sync::Mutex<HashMap<u64, u32>>,
}

impl StreamWrangler {
    #[cfg(feature = "tokio")]
    pub fn increment_refcount(&self, id: u64) -> u32 {
        let mut guard = self.refcounts.lock().unwrap();
        let entry = guard.entry(id).or_insert(0);
        *entry += 1;
        *entry
    }

    #[cfg(feature = "tokio")]
    pub fn decrement_refcount(&self, id: u64) -> u32 {
        let mut guard = self.refcounts.lock().unwrap();
        let Some(entry) = guard.get_mut(&id) else {
            return 0;
        };
        *entry -= 1;

        let result = *entry;
        if result == 0 {
            guard.remove(&id);
        }

        result
    }

    #[cfg(not(feature = "tokio"))]
    pub fn insert(
        &self,
        id: u64,
        procedure_result: ProcedureResult,
    ) -> Result<(), RpcError> {
        let mut map = self.streams.lock().unwrap();
        let (lock, cvar) = { &*map.entry(id).or_default().clone() };

        *lock.lock().unwrap() = procedure_result;
        cvar.notify_one();

        Ok(())
    }

    #[cfg(feature = "tokio")]
    pub async fn insert(
        &self,
        id: u64,
        procedure_result: ProcedureResult,
    ) -> Result<(), RpcError> {
        let mut map = self.streams.lock().await;
        let (lock, cvar) =
            { &*map.entry(id).or_insert_with(Default::default).clone() };

        *lock.lock().await = procedure_result;
        cvar.notify_one();

        Ok(())
    }

    #[cfg(not(feature = "tokio"))]
    pub fn wait(&self, id: u64) {
        let (lock, cvar) = {
            let mut map = self.streams.lock().unwrap();
            &*map.entry(id).or_default().clone()
        };
        let result = lock.lock().unwrap();
        let _result = cvar.wait(result).unwrap();
    }

    #[cfg(not(feature = "tokio"))]
    pub fn wait_timeout(&self, id: u64, dur: Duration) {
        let (lock, cvar) = {
            let mut map = self.streams.lock().unwrap();
            &*map.entry(id).or_default().clone()
        };
        let result = lock.lock().unwrap();
        let _result = cvar.wait_timeout(result, dur).unwrap();
    }

    #[cfg(feature = "tokio")]
    pub async fn wait(&self, id: u64) {
        let (_lock, cvar) = {
            let mut map = self.streams.lock().await;
            &*map.entry(id).or_insert_with(Default::default).clone()
        };
        cvar.notified().await;
    }

    #[cfg(not(feature = "tokio"))]
    pub fn remove(&self, id: u64) {
        let mut map = self.streams.lock().unwrap();
        map.remove(&id);
    }

    #[cfg(feature = "tokio")]
    pub async fn remove(&self, id: u64) {
        let mut map = self.streams.lock().await;
        map.remove(&id);
    }

    #[cfg(not(feature = "tokio"))]
    pub fn get<T: DecodeUntagged>(
        &self,
        client: Arc<Client>,
        id: u64,
    ) -> Result<T, RpcError> {
        let mut map = self.streams.lock().unwrap();
        let (lock, _) = { &*map.entry(id).or_default().clone() };
        let result = lock.lock().unwrap();
        T::decode_untagged(client, &result.value)
    }

    #[cfg(feature = "tokio")]
    pub async fn get<T: DecodeUntagged>(
        &self,
        client: Arc<Client>,
        id: u64,
    ) -> Result<T, RpcError> {
        let mut map = self.streams.lock().await;
        let (lock, _) =
            { &*map.entry(id).or_insert_with(Default::default).clone() };
        let result = lock.lock().await;
        T::decode_untagged(client, &result.value)
    }
}

impl<T: RpcType + Send> Stream<T> {
    #[cfg(not(feature = "tokio"))]
    pub(crate) fn new(
        client: Arc<Client>,
        call: ProcedureCall,
    ) -> Result<Self, RpcError> {
        let krpc = KRPC::new(client.clone());
        let stream = krpc.add_stream(call, true)?;
        client.await_stream(stream.id);

        Ok(Self {
            id: stream.id,
            krpc,
            client,
            phantom: PhantomData,
        })
    }

    #[cfg(feature = "tokio")]
    pub(crate) async fn new(
        client: Arc<Client>,
        call: ProcedureCall,
    ) -> Result<Self, RpcError> {
        let krpc = KRPC::new(client.clone());
        let stream = krpc.add_stream(call, true).await?;
        client.register_stream(stream.id);
        client.await_stream(stream.id).await;

        Ok(Self {
            id: stream.id,
            krpc,
            client,
            phantom: PhantomData,
        })
    }

    /// Set the update rate for this streaming procedure.
    #[cfg(not(feature = "tokio"))]
    pub fn set_rate(&self, hz: f32) -> Result<(), RpcError> {
        self.krpc.set_stream_rate(self.id, hz)
    }

    /// Set the update rate for this streaming procedure.
    #[cfg(feature = "tokio")]
    pub async fn set_rate(&self, hz: f32) -> Result<(), RpcError> {
        self.krpc.set_stream_rate(self.id, hz).await
    }

    /// Retrieve the current result received for this
    /// procedure. This value is not guaranteed to have
    /// changed since the last call to [`get`][get]. Use
    /// [`wait`][wait] to block until the value has changed.
    ///
    /// [wait]: Stream::wait
    /// [get]: Stream::get
    #[cfg(not(feature = "tokio"))]
    pub fn get(&self) -> Result<T, RpcError> {
        self.client.read_stream(self.id)
    }

    /// Retrieve the current result received for this
    /// procedure. This value is not guaranteed to have
    /// changed since the last call to [`get`][get]. Use
    /// [`wait`][wait] to block until the value has changed.
    ///
    /// [wait]: Stream::wait
    /// [get]: Stream::get
    #[cfg(feature = "tokio")]
    pub async fn get(&self) -> Result<T, RpcError> {
        self.client.read_stream(self.id).await
    }

    /// Block the current thread of execution until this
    /// stream receives an update from the server.
    #[cfg(not(feature = "tokio"))]
    pub fn wait(&self) {
        self.client.await_stream(self.id);
    }

    /// Block the current thread of execution until this
    /// stream receives an update from the server or the
    /// timeout is reached.
    #[cfg(not(feature = "tokio"))]
    pub fn wait_timeout(&self, dur: Duration) {
        self.client.await_stream_timeout(self.id, dur);
    }

    /// Block the current thread of execution until this
    /// stream receives an update from the server.
    #[cfg(feature = "tokio")]
    pub async fn wait(&self) {
        self.client.await_stream(self.id).await;
    }
}

impl<T: RpcType + Send> Drop for Stream<T> {
    // Try to remove the stream if it's dropped, but don't panic
    // if unable.
    #[cfg(not(feature = "tokio"))]
    fn drop(&mut self) {
        self.krpc.remove_stream(self.id).ok();
        self.client.remove_stream(self.id).ok();
    }

    #[cfg(feature = "tokio")]
    fn drop(&mut self) {
        let krpc = self.krpc.clone();
        let client = self.client.clone();
        let id = self.id;
        let refcount = client.release_stream(id);
        if refcount == 0 {
            tokio::task::spawn(async move {
                krpc.remove_stream(id).await.ok();
                client.remove_stream(id).await.ok();
            });
        }
    }
}