lb-rs 26.5.22

The rust library for interacting with your lockbook.
Documentation
use std::collections::HashMap;
use std::sync::Arc;
use std::sync::OnceLock;
use std::sync::atomic::AtomicU64;

use serde::de::DeserializeOwned;
#[cfg(unix)]
use tokio::net::unix;
use tokio::sync::broadcast;
use tokio::sync::{Mutex, oneshot};
use tokio::task::JoinHandle;

use crate::ipc::protocol::Request;
use crate::model::account::Account;
use crate::model::errors::{LbErrKind, LbResult};
use crate::service::events::Event;

#[cfg(unix)]
use {
    crate::ipc::protocol::Frame, std::io, std::path::Path, std::sync::atomic::Ordering,
    tokio::net::UnixStream, tokio::net::unix::OwnedWriteHalf,
};

type InFlight = Arc<Mutex<HashMap<u64, oneshot::Sender<Vec<u8>>>>>;

const EVENT_CHANNEL_CAPACITY: usize = 10_000;

#[cfg_attr(not(unix), allow(dead_code))]
pub struct RemoteLb {
    account: OnceLock<Account>,
    events: Arc<OnceLock<broadcast::Sender<Event>>>,
    #[cfg(unix)]
    writer: Mutex<OwnedWriteHalf>,
    seq: AtomicU64,
    in_flight: InFlight,
    reader_task: JoinHandle<()>,
}

impl Drop for RemoteLb {
    fn drop(&mut self) {
        self.reader_task.abort();
    }
}

impl RemoteLb {
    #[cfg(unix)]
    pub async fn connect(socket: &Path) -> io::Result<Arc<Self>> {
        let stream = UnixStream::connect(socket).await?;
        let (read_half, write_half) = stream.into_split();
        let in_flight: InFlight = Arc::new(Mutex::new(HashMap::new()));
        let events: Arc<OnceLock<broadcast::Sender<Event>>> = Arc::new(OnceLock::new());
        let reader_task =
            tokio::spawn(reader_loop(read_half, Arc::clone(&in_flight), Arc::clone(&events)));

        let me = Arc::new(Self {
            account: OnceLock::new(),
            events,
            writer: Mutex::new(write_half),
            seq: AtomicU64::new(0),
            in_flight,
            reader_task,
        });

        if let Ok(account) = me.try_call::<Account>(Request::GetAccount).await {
            me.cache_account(account);
        }

        Ok(me)
    }

    pub fn get_account(&self) -> LbResult<&Account> {
        self.account
            .get()
            .ok_or_else(|| LbErrKind::AccountNonexistent.into())
    }

    pub fn cache_account(&self, account: Account) {
        let _ = self.account.set(account);
    }

    pub fn subscribe(self: &Arc<Self>) -> broadcast::Receiver<Event> {
        let tx = self.events.get_or_init(|| {
            let (tx, _) = broadcast::channel::<Event>(EVENT_CHANNEL_CAPACITY);
            let me = Arc::clone(self);
            tokio::spawn(async move {
                let _ = me.try_call::<()>(Request::Subscribe).await;
            });
            tx
        });
        tx.subscribe()
    }

    pub(crate) async fn try_call<Out>(&self, req: Request) -> Result<Out, RemoteCallError>
    where
        Out: DeserializeOwned,
    {
        #[cfg(not(unix))]
        {
            let _ = req;
            unreachable!("RemoteLb cannot be constructed on non-unix targets")
        }
        #[cfg(unix)]
        {
            let seq = self.seq.fetch_add(1, Ordering::Relaxed);
            let (tx, rx) = oneshot::channel();
            self.in_flight.lock().await.insert(seq, tx);

            let frame = Frame::Request { seq, body: req };
            {
                let mut writer = self.writer.lock().await;
                frame
                    .write(&mut *writer)
                    .await
                    .map_err(|_| RemoteCallError::HostUnavailable)?;
            }

            let output_bytes = rx.await.map_err(|_| RemoteCallError::HostUnavailable)?;

            let result: LbResult<Out> = bincode::deserialize(&output_bytes).map_err(|e| {
                RemoteCallError::Other(
                    LbErrKind::Unexpected(format!("ipc: deserialize response: {e}")).into(),
                )
            })?;
            result.map_err(RemoteCallError::Other)
        }
    }
}

#[cfg_attr(not(unix), allow(dead_code))]
pub(crate) enum RemoteCallError {
    HostUnavailable,
    Other(crate::model::errors::LbErr),
}

#[cfg(unix)]
async fn reader_loop(
    mut reader: unix::OwnedReadHalf, in_flight: InFlight,
    events: Arc<OnceLock<broadcast::Sender<Event>>>,
) {
    loop {
        let frame = match Frame::read(&mut reader).await {
            Ok(f) => f,
            Err(err) => {
                if err.kind() != io::ErrorKind::UnexpectedEof {
                    tracing::warn!(?err, "ipc reader: read failed");
                }
                break;
            }
        };
        match frame {
            Frame::Response { seq, output } => {
                if let Some(tx) = in_flight.lock().await.remove(&seq) {
                    let _ = tx.send(output);
                } else {
                    tracing::warn!(seq, "ipc reader: response for unknown seq");
                }
            }
            Frame::Event { stream_seq: _, body } => {
                if let Some(tx) = events.get() {
                    let _ = tx.send(body);
                }
            }
            Frame::EventEnd { stream_seq } => {
                tracing::debug!(stream_seq, "ipc: host closed event stream");
            }
            Frame::Request { .. } => {
                tracing::warn!("ipc reader: host sent a Request frame; protocol violation");
                break;
            }
        }
    }

    let mut map = in_flight.lock().await;
    map.clear();
}