Skip to main content

lb_rs/ipc/
client.rs

1use std::collections::HashMap;
2use std::sync::Arc;
3use std::sync::OnceLock;
4use std::sync::atomic::AtomicU64;
5
6use serde::de::DeserializeOwned;
7#[cfg(unix)]
8use tokio::net::unix;
9use tokio::sync::broadcast;
10use tokio::sync::{Mutex, oneshot};
11use tokio::task::JoinHandle;
12
13use crate::ipc::protocol::Request;
14use crate::model::account::Account;
15use crate::model::errors::{LbErrKind, LbResult};
16use crate::service::events::Event;
17
18#[cfg(unix)]
19use {
20    crate::ipc::protocol::Frame, std::io, std::path::Path, std::sync::atomic::Ordering,
21    tokio::net::UnixStream, tokio::net::unix::OwnedWriteHalf,
22};
23
24type InFlight = Arc<Mutex<HashMap<u64, oneshot::Sender<Vec<u8>>>>>;
25
26const EVENT_CHANNEL_CAPACITY: usize = 10_000;
27
28#[cfg_attr(not(unix), allow(dead_code))]
29pub struct RemoteLb {
30    account: OnceLock<Account>,
31    events: Arc<OnceLock<broadcast::Sender<Event>>>,
32    #[cfg(unix)]
33    writer: Mutex<OwnedWriteHalf>,
34    seq: AtomicU64,
35    in_flight: InFlight,
36    reader_task: JoinHandle<()>,
37}
38
39impl Drop for RemoteLb {
40    fn drop(&mut self) {
41        self.reader_task.abort();
42    }
43}
44
45impl RemoteLb {
46    #[cfg(unix)]
47    pub async fn connect(socket: &Path) -> io::Result<Arc<Self>> {
48        let stream = UnixStream::connect(socket).await?;
49        let (read_half, write_half) = stream.into_split();
50        let in_flight: InFlight = Arc::new(Mutex::new(HashMap::new()));
51        let events: Arc<OnceLock<broadcast::Sender<Event>>> = Arc::new(OnceLock::new());
52        let reader_task =
53            tokio::spawn(reader_loop(read_half, Arc::clone(&in_flight), Arc::clone(&events)));
54
55        let me = Arc::new(Self {
56            account: OnceLock::new(),
57            events,
58            writer: Mutex::new(write_half),
59            seq: AtomicU64::new(0),
60            in_flight,
61            reader_task,
62        });
63
64        if let Ok(account) = me.try_call::<Account>(Request::GetAccount).await {
65            me.cache_account(account);
66        }
67
68        Ok(me)
69    }
70
71    pub fn get_account(&self) -> LbResult<&Account> {
72        self.account
73            .get()
74            .ok_or_else(|| LbErrKind::AccountNonexistent.into())
75    }
76
77    pub fn cache_account(&self, account: Account) {
78        let _ = self.account.set(account);
79    }
80
81    pub fn subscribe(self: &Arc<Self>) -> broadcast::Receiver<Event> {
82        let tx = self.events.get_or_init(|| {
83            let (tx, _) = broadcast::channel::<Event>(EVENT_CHANNEL_CAPACITY);
84            let me = Arc::clone(self);
85            tokio::spawn(async move {
86                let _ = me.try_call::<()>(Request::Subscribe).await;
87            });
88            tx
89        });
90        tx.subscribe()
91    }
92
93    pub(crate) async fn try_call<Out>(&self, req: Request) -> Result<Out, RemoteCallError>
94    where
95        Out: DeserializeOwned,
96    {
97        #[cfg(not(unix))]
98        {
99            let _ = req;
100            unreachable!("RemoteLb cannot be constructed on non-unix targets")
101        }
102        #[cfg(unix)]
103        {
104            let seq = self.seq.fetch_add(1, Ordering::Relaxed);
105            let (tx, rx) = oneshot::channel();
106            self.in_flight.lock().await.insert(seq, tx);
107
108            let frame = Frame::Request { seq, body: req };
109            {
110                let mut writer = self.writer.lock().await;
111                frame
112                    .write(&mut *writer)
113                    .await
114                    .map_err(|_| RemoteCallError::HostUnavailable)?;
115            }
116
117            let output_bytes = rx.await.map_err(|_| RemoteCallError::HostUnavailable)?;
118
119            let result: LbResult<Out> = bincode::deserialize(&output_bytes).map_err(|e| {
120                RemoteCallError::Other(
121                    LbErrKind::Unexpected(format!("ipc: deserialize response: {e}")).into(),
122                )
123            })?;
124            result.map_err(RemoteCallError::Other)
125        }
126    }
127}
128
129#[cfg_attr(not(unix), allow(dead_code))]
130pub(crate) enum RemoteCallError {
131    HostUnavailable,
132    Other(crate::model::errors::LbErr),
133}
134
135#[cfg(unix)]
136async fn reader_loop(
137    mut reader: unix::OwnedReadHalf, in_flight: InFlight,
138    events: Arc<OnceLock<broadcast::Sender<Event>>>,
139) {
140    loop {
141        let frame = match Frame::read(&mut reader).await {
142            Ok(f) => f,
143            Err(err) => {
144                if err.kind() != io::ErrorKind::UnexpectedEof {
145                    tracing::warn!(?err, "ipc reader: read failed");
146                }
147                break;
148            }
149        };
150        match frame {
151            Frame::Response { seq, output } => {
152                if let Some(tx) = in_flight.lock().await.remove(&seq) {
153                    let _ = tx.send(output);
154                } else {
155                    tracing::warn!(seq, "ipc reader: response for unknown seq");
156                }
157            }
158            Frame::Event { stream_seq: _, body } => {
159                if let Some(tx) = events.get() {
160                    let _ = tx.send(body);
161                }
162            }
163            Frame::EventEnd { stream_seq } => {
164                tracing::debug!(stream_seq, "ipc: host closed event stream");
165            }
166            Frame::Request { .. } => {
167                tracing::warn!("ipc reader: host sent a Request frame; protocol violation");
168                break;
169            }
170        }
171    }
172
173    let mut map = in_flight.lock().await;
174    map.clear();
175}