1use futures::{SinkExt, StreamExt};
2use mxr_core::id::*;
3use mxr_core::types::*;
4use mxr_core::MxrError;
5use mxr_protocol::*;
6use std::path::Path;
7use std::sync::atomic::{AtomicU64, Ordering};
8use tokio::net::UnixStream;
9use tokio::sync::mpsc;
10use tokio_util::codec::Framed;
11
12pub struct Client {
13 framed: Framed<UnixStream, IpcCodec>,
14 next_id: AtomicU64,
15 event_tx: Option<mpsc::UnboundedSender<DaemonEvent>>,
16}
17
18impl Client {
19 pub async fn connect(socket_path: &Path) -> std::io::Result<Self> {
20 let stream = UnixStream::connect(socket_path).await?;
21 Ok(Self {
22 framed: Framed::new(stream, IpcCodec::new()),
23 next_id: AtomicU64::new(1),
24 event_tx: None,
25 })
26 }
27
28 pub fn with_event_channel(mut self, tx: mpsc::UnboundedSender<DaemonEvent>) -> Self {
29 self.event_tx = Some(tx);
30 self
31 }
32
33 pub async fn raw_request(&mut self, req: Request) -> Result<Response, MxrError> {
34 self.request(req).await
35 }
36
37 async fn request(&mut self, req: Request) -> Result<Response, MxrError> {
38 let id = self.next_id.fetch_add(1, Ordering::Relaxed);
39 let msg = IpcMessage {
40 id,
41 payload: IpcPayload::Request(req),
42 };
43 self.framed
44 .send(msg)
45 .await
46 .map_err(|e| MxrError::Ipc(e.to_string()))?;
47
48 loop {
49 match self.framed.next().await {
50 Some(Ok(resp_msg)) => match resp_msg.payload {
51 IpcPayload::Response(resp) if resp_msg.id == id => return Ok(resp),
52 IpcPayload::Event(event) => {
53 if let Some(ref tx) = self.event_tx {
54 let _ = tx.send(event);
55 }
56 continue;
57 }
58 _ => continue,
59 },
60 Some(Err(e)) => return Err(MxrError::Ipc(describe_ipc_failure(&e.to_string()))),
61 None => {
62 return Err(MxrError::Ipc(
63 "Connection closed. The running daemon may be using an incompatible protocol. Restart the daemon after upgrading.".into(),
64 ))
65 }
66 }
67 }
68 }
69
70 pub async fn list_envelopes(
71 &mut self,
72 limit: u32,
73 offset: u32,
74 ) -> Result<Vec<Envelope>, MxrError> {
75 let resp = self
76 .request(Request::ListEnvelopes {
77 label_id: None,
78 account_id: None,
79 limit,
80 offset,
81 })
82 .await?;
83
84 match resp {
85 Response::Ok {
86 data: ResponseData::Envelopes { envelopes },
87 } => Ok(envelopes),
88 Response::Error { message } => Err(MxrError::Ipc(message)),
89 _ => Err(MxrError::Ipc("Unexpected response".into())),
90 }
91 }
92
93 pub async fn list_labels(&mut self) -> Result<Vec<Label>, MxrError> {
94 let resp = self
95 .request(Request::ListLabels { account_id: None })
96 .await?;
97 match resp {
98 Response::Ok {
99 data: ResponseData::Labels { labels },
100 } => Ok(labels),
101 Response::Error { message } => Err(MxrError::Ipc(message)),
102 _ => Err(MxrError::Ipc("Unexpected response".into())),
103 }
104 }
105
106 pub async fn search(
107 &mut self,
108 query: &str,
109 limit: u32,
110 ) -> Result<Vec<SearchResultItem>, MxrError> {
111 let resp = self
112 .request(Request::Search {
113 query: query.to_string(),
114 limit,
115 offset: 0,
116 mode: None,
117 sort: Some(mxr_core::types::SortOrder::DateDesc),
118 explain: false,
119 })
120 .await?;
121 match resp {
122 Response::Ok {
123 data: ResponseData::SearchResults { results, .. },
124 } => Ok(results),
125 Response::Error { message } => Err(MxrError::Ipc(message)),
126 _ => Err(MxrError::Ipc("Unexpected response".into())),
127 }
128 }
129
130 pub async fn get_envelope(&mut self, message_id: &MessageId) -> Result<Envelope, MxrError> {
131 let resp = self
132 .request(Request::GetEnvelope {
133 message_id: message_id.clone(),
134 })
135 .await?;
136 match resp {
137 Response::Ok {
138 data: ResponseData::Envelope { envelope },
139 } => Ok(envelope),
140 Response::Error { message } => Err(MxrError::Ipc(message)),
141 _ => Err(MxrError::Ipc("Unexpected response".into())),
142 }
143 }
144
145 pub async fn get_body(&mut self, message_id: &MessageId) -> Result<MessageBody, MxrError> {
146 let resp = self
147 .request(Request::GetBody {
148 message_id: message_id.clone(),
149 })
150 .await?;
151 match resp {
152 Response::Ok {
153 data: ResponseData::Body { body },
154 } => Ok(body),
155 Response::Error { message } => Err(MxrError::Ipc(message)),
156 _ => Err(MxrError::Ipc("Unexpected response".into())),
157 }
158 }
159
160 pub async fn get_thread(
161 &mut self,
162 thread_id: &ThreadId,
163 ) -> Result<(Thread, Vec<Envelope>), MxrError> {
164 let resp = self
165 .request(Request::GetThread {
166 thread_id: thread_id.clone(),
167 })
168 .await?;
169 match resp {
170 Response::Ok {
171 data: ResponseData::Thread { thread, messages },
172 } => Ok((thread, messages)),
173 Response::Error { message } => Err(MxrError::Ipc(message)),
174 _ => Err(MxrError::Ipc("Unexpected response".into())),
175 }
176 }
177
178 pub async fn list_saved_searches(
179 &mut self,
180 ) -> Result<Vec<mxr_core::types::SavedSearch>, MxrError> {
181 let resp = self.request(Request::ListSavedSearches).await?;
182 match resp {
183 Response::Ok {
184 data: ResponseData::SavedSearches { searches },
185 } => Ok(searches),
186 Response::Error { message } => Err(MxrError::Ipc(message)),
187 _ => Err(MxrError::Ipc("Unexpected response".into())),
188 }
189 }
190
191 pub async fn list_subscriptions(
192 &mut self,
193 limit: u32,
194 ) -> Result<Vec<mxr_core::types::SubscriptionSummary>, MxrError> {
195 let resp = self.request(Request::ListSubscriptions { limit }).await?;
196 match resp {
197 Response::Ok {
198 data: ResponseData::Subscriptions { subscriptions },
199 } => Ok(subscriptions),
200 Response::Error { message } => Err(MxrError::Ipc(message)),
201 _ => Err(MxrError::Ipc("Unexpected response".into())),
202 }
203 }
204
205 pub async fn ping(&mut self) -> Result<(), MxrError> {
206 let resp = self.request(Request::Ping).await?;
207 match resp {
208 Response::Ok {
209 data: ResponseData::Pong,
210 } => Ok(()),
211 _ => Err(MxrError::Ipc("Unexpected response".into())),
212 }
213 }
214}
215
216fn describe_ipc_failure(message: &str) -> String {
217 if message.contains("unknown variant") || message.contains("missing field") {
218 format!("IPC protocol mismatch: {message}. Restart the daemon after upgrading.")
219 } else {
220 message.to_string()
221 }
222}