1use crate::borrow::Cow;
2use crate::comm::{Flush, TtlBufWriter};
3use crate::Error;
4use crate::EventChannel;
5use crate::IntoElbusResult;
6use crate::OpConfirm;
7use crate::QoS;
8use crate::GREETINGS;
9use crate::PING_FRAME;
10use crate::PROTOCOL_VERSION;
11use crate::RESPONSE_OK;
12use crate::SECONDARY_SEP;
13use crate::{Frame, FrameData, FrameKind, FrameOp};
14use std::collections::BTreeMap;
15use std::marker::Unpin;
16use std::sync::atomic;
17use std::sync::{Arc, Mutex};
18use std::time::Duration;
19use tokio::io::{AsyncReadExt, AsyncWriteExt, BufReader};
20#[cfg(not(target_os = "windows"))]
21use tokio::net::unix;
22#[cfg(not(target_os = "windows"))]
23use tokio::net::UnixStream;
24use tokio::net::{tcp, TcpStream};
25use tokio::sync::oneshot;
26use tokio::task::JoinHandle;
27
28use crate::client::AsyncClient;
29
30use log::{error, trace, warn};
31
32use async_trait::async_trait;
33
34type ResponseMap = Arc<Mutex<BTreeMap<u32, oneshot::Sender<Result<(), Error>>>>>;
35
36enum Writer {
37 #[cfg(not(target_os = "windows"))]
38 Unix(TtlBufWriter<unix::OwnedWriteHalf>),
39 Tcp(TtlBufWriter<tcp::OwnedWriteHalf>),
40}
41
42impl Writer {
43 pub async fn write(&mut self, buf: &[u8], flush: Flush) -> Result<(), Error> {
44 match self {
45 #[cfg(not(target_os = "windows"))]
46 Writer::Unix(w) => w.write(buf, flush).await.map_err(Into::into),
47 Writer::Tcp(w) => w.write(buf, flush).await.map_err(Into::into),
48 }
49 }
50}
51
52#[derive(Debug, Clone)]
53pub struct Config {
54 path: String,
55 name: String,
56 buf_size: usize,
57 buf_ttl: Duration,
58 queue_size: usize,
59 timeout: Duration,
60}
61
62impl Config {
63 pub fn new(path: &str, name: &str) -> Self {
66 Self {
67 path: path.to_owned(),
68 name: name.to_owned(),
69 buf_size: crate::DEFAULT_BUF_SIZE,
70 buf_ttl: crate::DEFAULT_BUF_TTL,
71 queue_size: crate::DEFAULT_QUEUE_SIZE,
72 timeout: crate::DEFAULT_TIMEOUT,
73 }
74 }
75 pub fn buf_size(mut self, size: usize) -> Self {
76 self.buf_size = size;
77 self
78 }
79 pub fn buf_ttl(mut self, ttl: Duration) -> Self {
80 self.buf_ttl = ttl;
81 self
82 }
83 pub fn queue_size(mut self, size: usize) -> Self {
84 self.queue_size = size;
85 self
86 }
87 pub fn timeout(mut self, timeout: Duration) -> Self {
88 self.timeout = timeout;
89 self
90 }
91}
92
93pub struct Client {
94 name: String,
95 writer: Writer,
96 reader_fut: JoinHandle<()>,
97 frame_id: u32,
98 responses: ResponseMap,
99 rx: Option<EventChannel>,
100 connected: Arc<atomic::AtomicBool>,
101 timeout: Duration,
102 config: Config,
103 secondary_counter: atomic::AtomicUsize,
104}
105
106macro_rules! prepare_frame_buf {
109 ($self: expr, $op: expr, $qos: expr) => {{
110 $self.increment_frame_id();
111 let mut buf = $self.frame_id.to_le_bytes().to_vec();
112 buf.push($op as u8 | ($qos as u8) << 6);
113 buf
114 }};
115}
116
117macro_rules! send_data_or_mark_disconnected {
118 ($self: expr, $data: expr, $flush: expr) => {
119 match tokio::time::timeout($self.timeout, $self.writer.write($data, $flush)).await {
120 Ok(result) => {
121 if let Err(e) = result {
122 $self.reader_fut.abort();
123 $self.connected.store(false, atomic::Ordering::SeqCst);
124 return Err(e.into());
125 }
126 }
127 Err(e) => {
128 return Err(e.into());
129 }
130 }
131 };
132}
133
134macro_rules! send_frame_and_confirm {
135 ($self: expr, $buf: expr, $payload: expr, $qos: expr) => {{
136 let rx = if $qos.needs_ack() {
137 let (tx, rx) = oneshot::channel();
138 {
139 $self.responses.lock().unwrap().insert($self.frame_id, tx);
140 }
141 Some(rx)
142 } else {
143 None
144 };
145 send_data_or_mark_disconnected!($self, $buf, Flush::No);
146 send_data_or_mark_disconnected!($self, $payload, $qos.is_realtime().into());
147 Ok(rx)
148 }};
149}
150
151macro_rules! send_frame {
152 ($self: expr, $target: expr, $payload: expr, $op: expr, $qos: expr) => {{
154 let mut buf = prepare_frame_buf!($self, $op, $qos);
155 let t = $target.as_bytes();
156 buf.extend_from_slice(&((t.len() + $payload.len() + 1) as u32).to_le_bytes());
157 buf.extend_from_slice(t);
158 buf.push(0x00);
159 trace!("sending elbus {:?} to {} QoS={:?}", $op, $target, $qos);
160 send_frame_and_confirm!($self, &buf, $payload, $qos)
161 }};
162 ($self: expr, $target: expr, $header: expr, $payload: expr, $op: expr, $qos: expr) => {{
164 let mut buf = prepare_frame_buf!($self, $op, $qos);
165 let t = $target.as_bytes();
166 buf.extend_from_slice(
167 &((t.len() + $payload.len() + $header.len() + 1) as u32).to_le_bytes(),
168 );
169 buf.extend_from_slice(t);
170 buf.push(0x00);
171 buf.extend_from_slice($header);
172 trace!("sending elbus {:?} to {} QoS={:?}", $op, $target, $qos);
173 send_frame_and_confirm!($self, &buf, $payload, $qos)
174 }};
175 ($self: expr, $payload: expr, $op: expr, $qos: expr) => {{
177 let mut buf = prepare_frame_buf!($self, $op, $qos);
178 buf.extend_from_slice(&($payload.len() as u32).to_le_bytes());
179 send_frame_and_confirm!($self, &buf, $payload, $qos)
180 }};
181}
182
183macro_rules! connect_broker {
184 ($name: expr, $reader: expr, $writer: expr,
185 $responses: expr, $connected: expr, $timeout: expr, $queue_size: expr) => {{
186 chat($name, &mut $reader, &mut $writer).await?;
187 let (tx, rx) = async_channel::bounded($queue_size);
188 let reader_responses = $responses.clone();
189 let rconn = $connected.clone();
190 let timeout = $timeout.clone();
191 let reader_fut = tokio::spawn(async move {
192 if let Err(e) = handle_read($reader, tx, timeout, reader_responses).await {
193 error!("elbus client reader error: {}", e);
194 }
195 rconn.store(false, atomic::Ordering::SeqCst);
196 });
197 (reader_fut, rx)
198 }};
199}
200
201impl Client {
202 pub async fn connect(config: &Config) -> Result<Self, Error> {
203 let responses: ResponseMap = <_>::default();
204 let connected = Arc::new(atomic::AtomicBool::new(true));
205 #[allow(clippy::case_sensitive_file_extension_comparisons)]
206 let (writer, reader_fut, rx) = if config.path.ends_with(".sock")
207 || config.path.ends_with(".socket")
208 || config.path.ends_with(".ipc")
209 || config.path.starts_with('/')
210 {
211 #[cfg(target_os = "windows")]
212 {
213 return Err(Error::not_supported("unix sockets"));
214 }
215 #[cfg(not(target_os = "windows"))]
216 {
217 let stream = UnixStream::connect(&config.path).await?;
218 let (r, mut writer) = stream.into_split();
219 let mut reader = BufReader::with_capacity(config.buf_size, r);
220 let (reader_fut, rx) = connect_broker!(
221 &config.name,
222 reader,
223 writer,
224 responses,
225 connected,
226 config.timeout,
227 config.queue_size
228 );
229 (
230 Writer::Unix(TtlBufWriter::new(
231 writer,
232 config.buf_size,
233 config.buf_ttl,
234 config.timeout,
235 )),
236 reader_fut,
237 rx,
238 )
239 }
240 } else {
241 let stream = TcpStream::connect(&config.path).await?;
242 stream.set_nodelay(true)?;
243 let (r, mut writer) = stream.into_split();
244 let mut reader = BufReader::with_capacity(config.buf_size, r);
245 let (reader_fut, rx) = connect_broker!(
246 &config.name,
247 reader,
248 writer,
249 responses,
250 connected,
251 config.timeout,
252 config.queue_size
253 );
254 (
255 Writer::Tcp(TtlBufWriter::new(
256 writer,
257 config.buf_size,
258 config.buf_ttl,
259 config.timeout,
260 )),
261 reader_fut,
262 rx,
263 )
264 };
265 Ok(Self {
266 name: config.name.clone(),
267 writer,
268 reader_fut,
269 frame_id: 0,
270 responses,
271 rx: Some(rx),
272 connected,
273 timeout: config.timeout,
274 config: config.clone(),
275 secondary_counter: atomic::AtomicUsize::new(0),
276 })
277 }
278 pub async fn register_secondary(&self) -> Result<Self, Error> {
279 if self.name.contains(SECONDARY_SEP) {
280 Err(Error::not_supported("not a primary client"))
281 } else {
282 let secondary_id = self
283 .secondary_counter
284 .fetch_add(1, atomic::Ordering::SeqCst);
285 let secondary_name = format!("{}{}{}", self.name, SECONDARY_SEP, secondary_id);
286 let mut config = self.config.clone();
287 config.name = secondary_name;
288 Self::connect(&config).await
289 }
290 }
291 #[inline]
292 fn increment_frame_id(&mut self) {
293 if self.frame_id == u32::MAX {
294 self.frame_id = 1;
295 } else {
296 self.frame_id += 1;
297 }
298 }
299 #[inline]
300 pub fn get_timeout(&self) -> Duration {
301 self.timeout
302 }
303}
304#[async_trait]
305impl AsyncClient for Client {
306 #[inline]
307 fn take_event_channel(&mut self) -> Option<EventChannel> {
308 self.rx.take()
309 }
310 #[inline]
311 fn get_connected_beacon(&self) -> Option<Arc<atomic::AtomicBool>> {
312 Some(self.connected.clone())
313 }
314 async fn send(
315 &mut self,
316 target: &str,
317 payload: Cow<'async_trait>,
318 qos: QoS,
319 ) -> Result<OpConfirm, Error> {
320 send_frame!(self, target, payload.as_slice(), FrameOp::Message, qos)
321 }
322 async fn zc_send(
323 &mut self,
324 target: &str,
325 header: Cow<'async_trait>,
326 payload: Cow<'async_trait>,
327 qos: QoS,
328 ) -> Result<OpConfirm, Error> {
329 send_frame!(
330 self,
331 target,
332 header.as_slice(),
333 payload.as_slice(),
334 FrameOp::Message,
335 qos
336 )
337 }
338 async fn send_broadcast(
339 &mut self,
340 target: &str,
341 payload: Cow<'async_trait>,
342 qos: QoS,
343 ) -> Result<OpConfirm, Error> {
344 send_frame!(self, target, payload.as_slice(), FrameOp::Broadcast, qos)
345 }
346 async fn publish(
347 &mut self,
348 target: &str,
349 payload: Cow<'async_trait>,
350 qos: QoS,
351 ) -> Result<OpConfirm, Error> {
352 send_frame!(self, target, payload.as_slice(), FrameOp::PublishTopic, qos)
353 }
354 async fn subscribe(&mut self, topic: &str, qos: QoS) -> Result<OpConfirm, Error> {
355 send_frame!(self, topic.as_bytes(), FrameOp::SubscribeTopic, qos)
356 }
357 async fn unsubscribe(&mut self, topic: &str, qos: QoS) -> Result<OpConfirm, Error> {
358 send_frame!(self, topic.as_bytes(), FrameOp::UnsubscribeTopic, qos)
359 }
360 async fn subscribe_bulk(&mut self, topics: &[&str], qos: QoS) -> Result<OpConfirm, Error> {
361 let mut payload = Vec::new();
362 for topic in topics {
363 if !payload.is_empty() {
364 payload.push(0x00);
365 }
366 payload.extend(topic.as_bytes());
367 }
368 send_frame!(self, &payload, FrameOp::SubscribeTopic, qos)
369 }
370 async fn unsubscribe_bulk(&mut self, topics: &[&str], qos: QoS) -> Result<OpConfirm, Error> {
371 let mut payload = Vec::new();
372 for topic in topics {
373 if !payload.is_empty() {
374 payload.push(0x00);
375 }
376 payload.extend(topic.as_bytes());
377 }
378 send_frame!(self, &payload, FrameOp::UnsubscribeTopic, qos)
379 }
380 #[inline]
381 async fn ping(&mut self) -> Result<(), Error> {
382 send_data_or_mark_disconnected!(self, PING_FRAME, Flush::Instant);
383 Ok(())
384 }
385 #[inline]
386 fn is_connected(&self) -> bool {
387 self.connected.load(atomic::Ordering::SeqCst)
388 }
389 #[inline]
390 fn get_timeout(&self) -> Option<Duration> {
391 Some(self.timeout)
392 }
393 #[inline]
394 fn get_name(&self) -> &str {
395 self.name.as_str()
396 }
397}
398
399impl Drop for Client {
400 fn drop(&mut self) {
401 self.reader_fut.abort();
402 }
403}
404
405async fn handle_read<R>(
406 mut reader: R,
407 tx: async_channel::Sender<Frame>,
408 timeout: Duration,
409 responses: ResponseMap,
410) -> Result<(), Error>
411where
412 R: AsyncReadExt + Unpin,
413{
414 loop {
415 let mut buf = vec![0; 6];
416 reader.read_exact(&mut buf).await?;
417 let frame_type: FrameKind = buf[0].try_into()?;
418 let realtime = buf[5] != 0;
419 match frame_type {
420 FrameKind::Nop => {}
421 FrameKind::Acknowledge => {
422 let ack_id = u32::from_le_bytes(buf[1..5].try_into().unwrap());
423 let tx_channel = { responses.lock().unwrap().remove(&ack_id) };
424 if let Some(tx) = tx_channel {
425 let _r = tx.send(buf[5].to_elbus_result());
426 } else {
427 warn!("orphaned elbus op ack {}", ack_id);
428 }
429 }
430 _ => {
431 let frame_len = u32::from_le_bytes(buf[1..5].try_into().unwrap());
432 let mut buf = vec![0; frame_len as usize];
433 tokio::time::timeout(timeout, reader.read_exact(&mut buf)).await??;
434 let (sender, topic, payload_pos) = {
435 if frame_type == FrameKind::Publish {
436 let mut sp = buf.splitn(3, |c| *c == 0);
437 let s = sp.next().ok_or_else(|| Error::data("broken frame"))?;
438 let sender = std::str::from_utf8(s)?.to_owned();
439 let t = sp.next().ok_or_else(|| Error::data("broken frame"))?;
440 let topic = std::str::from_utf8(t)?.to_owned();
441 sp.next().ok_or_else(|| Error::data("broken frame"))?;
442 let payload_pos = s.len() + t.len() + 2;
443 (Some(sender), Some(topic), payload_pos)
444 } else {
445 let mut sp = buf.splitn(2, |c| *c == 0);
446 let s = sp.next().ok_or_else(|| Error::data("broken frame"))?;
447 let sender = std::str::from_utf8(s)?.to_owned();
448 sp.next().ok_or_else(|| Error::data("broken frame"))?;
449 let payload_pos = s.len() + 1;
450 (Some(sender), None, payload_pos)
451 }
452 };
453 let frame = Arc::new(FrameData::new(
454 frame_type,
455 sender,
456 topic,
457 None,
458 buf,
459 payload_pos,
460 realtime,
461 ));
462 tx.send(frame).await.map_err(Error::io)?;
463 }
464 }
465 }
466}
467
468async fn chat<R, W>(name: &str, reader: &mut R, writer: &mut W) -> Result<(), Error>
469where
470 R: AsyncReadExt + Unpin,
471 W: AsyncWriteExt + Unpin,
472{
473 if name.len() > u16::MAX as usize {
474 return Err(Error::data("name too long"));
475 }
476 let mut buf = vec![0; 3];
477 reader.read_exact(&mut buf).await?;
478 if buf[0] != GREETINGS[0] {
479 return Err(Error::not_supported("Invalid greetings"));
480 }
481 if u16::from_le_bytes(buf[1..3].try_into().unwrap()) != PROTOCOL_VERSION {
482 return Err(Error::not_supported("Unsupported protocol version"));
483 }
484 writer.write_all(&buf).await?;
485 let mut buf = vec![0; 1];
486 reader.read_exact(&mut buf).await?;
487 if buf[0] != RESPONSE_OK {
488 return Err(Error::new(
489 buf[0].into(),
490 Some(format!("Server greetings response: {:?}", buf[0])),
491 ));
492 }
493 let n = name.as_bytes().to_vec();
494 #[allow(clippy::cast_possible_truncation)]
495 writer.write_all(&(name.len() as u16).to_le_bytes()).await?;
496 writer.write_all(&n).await?;
497 let mut buf = vec![0; 1];
498 reader.read_exact(&mut buf).await?;
499 if buf[0] != RESPONSE_OK {
500 return Err(Error::new(
501 buf[0].into(),
502 Some(format!("Server registration response: {:?}", buf[0])),
503 ));
504 }
505 Ok(())
506}