1use std::collections::HashMap;
2use std::sync::Arc;
3use std::sync::OnceLock;
4use std::sync::atomic::AtomicU64;
5
6use serde::de::DeserializeOwned;
7#[cfg(unix)]
8use tokio::net::unix;
9use tokio::sync::broadcast;
10use tokio::sync::{Mutex, oneshot};
11use tokio::task::JoinHandle;
12
13use crate::ipc::protocol::Request;
14use crate::model::account::Account;
15use crate::model::errors::{LbErrKind, LbResult};
16use crate::service::events::Event;
17
18#[cfg(unix)]
19use {
20 crate::ipc::protocol::Frame, std::io, std::path::Path, std::sync::atomic::Ordering,
21 tokio::net::UnixStream, tokio::net::unix::OwnedWriteHalf,
22};
23
24type InFlight = Arc<Mutex<HashMap<u64, oneshot::Sender<Vec<u8>>>>>;
25
26const EVENT_CHANNEL_CAPACITY: usize = 10_000;
27
28#[cfg_attr(not(unix), allow(dead_code))]
29pub struct RemoteLb {
30 account: OnceLock<Account>,
31 events: Arc<OnceLock<broadcast::Sender<Event>>>,
32 #[cfg(unix)]
33 writer: Mutex<OwnedWriteHalf>,
34 seq: AtomicU64,
35 in_flight: InFlight,
36 reader_task: JoinHandle<()>,
37}
38
39impl Drop for RemoteLb {
40 fn drop(&mut self) {
41 self.reader_task.abort();
42 }
43}
44
45impl RemoteLb {
46 #[cfg(unix)]
47 pub async fn connect(socket: &Path) -> io::Result<Arc<Self>> {
48 let stream = UnixStream::connect(socket).await?;
49 let (read_half, write_half) = stream.into_split();
50 let in_flight: InFlight = Arc::new(Mutex::new(HashMap::new()));
51 let events: Arc<OnceLock<broadcast::Sender<Event>>> = Arc::new(OnceLock::new());
52 let reader_task =
53 tokio::spawn(reader_loop(read_half, Arc::clone(&in_flight), Arc::clone(&events)));
54
55 let me = Arc::new(Self {
56 account: OnceLock::new(),
57 events,
58 writer: Mutex::new(write_half),
59 seq: AtomicU64::new(0),
60 in_flight,
61 reader_task,
62 });
63
64 if let Ok(account) = me.try_call::<Account>(Request::GetAccount).await {
65 me.cache_account(account);
66 }
67
68 Ok(me)
69 }
70
71 pub fn get_account(&self) -> LbResult<&Account> {
72 self.account
73 .get()
74 .ok_or_else(|| LbErrKind::AccountNonexistent.into())
75 }
76
77 pub fn cache_account(&self, account: Account) {
78 let _ = self.account.set(account);
79 }
80
81 pub fn subscribe(self: &Arc<Self>) -> broadcast::Receiver<Event> {
82 let tx = self.events.get_or_init(|| {
83 let (tx, _) = broadcast::channel::<Event>(EVENT_CHANNEL_CAPACITY);
84 let me = Arc::clone(self);
85 tokio::spawn(async move {
86 let _ = me.try_call::<()>(Request::Subscribe).await;
87 });
88 tx
89 });
90 tx.subscribe()
91 }
92
93 pub(crate) async fn try_call<Out>(&self, req: Request) -> Result<Out, RemoteCallError>
94 where
95 Out: DeserializeOwned,
96 {
97 #[cfg(not(unix))]
98 {
99 let _ = req;
100 unreachable!("RemoteLb cannot be constructed on non-unix targets")
101 }
102 #[cfg(unix)]
103 {
104 let seq = self.seq.fetch_add(1, Ordering::Relaxed);
105 let (tx, rx) = oneshot::channel();
106 self.in_flight.lock().await.insert(seq, tx);
107
108 let frame = Frame::Request { seq, body: req };
109 {
110 let mut writer = self.writer.lock().await;
111 frame
112 .write(&mut *writer)
113 .await
114 .map_err(|_| RemoteCallError::HostUnavailable)?;
115 }
116
117 let output_bytes = rx.await.map_err(|_| RemoteCallError::HostUnavailable)?;
118
119 let result: LbResult<Out> = bincode::deserialize(&output_bytes).map_err(|e| {
120 RemoteCallError::Other(
121 LbErrKind::Unexpected(format!("ipc: deserialize response: {e}")).into(),
122 )
123 })?;
124 result.map_err(RemoteCallError::Other)
125 }
126 }
127}
128
129#[cfg_attr(not(unix), allow(dead_code))]
130pub(crate) enum RemoteCallError {
131 HostUnavailable,
132 Other(crate::model::errors::LbErr),
133}
134
135#[cfg(unix)]
136async fn reader_loop(
137 mut reader: unix::OwnedReadHalf, in_flight: InFlight,
138 events: Arc<OnceLock<broadcast::Sender<Event>>>,
139) {
140 loop {
141 let frame = match Frame::read(&mut reader).await {
142 Ok(f) => f,
143 Err(err) => {
144 if err.kind() != io::ErrorKind::UnexpectedEof {
145 tracing::warn!(?err, "ipc reader: read failed");
146 }
147 break;
148 }
149 };
150 match frame {
151 Frame::Response { seq, output } => {
152 if let Some(tx) = in_flight.lock().await.remove(&seq) {
153 let _ = tx.send(output);
154 } else {
155 tracing::warn!(seq, "ipc reader: response for unknown seq");
156 }
157 }
158 Frame::Event { stream_seq: _, body } => {
159 if let Some(tx) = events.get() {
160 let _ = tx.send(body);
161 }
162 }
163 Frame::EventEnd { stream_seq } => {
164 tracing::debug!(stream_seq, "ipc: host closed event stream");
165 }
166 Frame::Request { .. } => {
167 tracing::warn!("ipc reader: host sent a Request frame; protocol violation");
168 break;
169 }
170 }
171 }
172
173 let mut map = in_flight.lock().await;
174 map.clear();
175}