elbus/
ipc.rs

1use crate::borrow::Cow;
2use crate::comm::{Flush, TtlBufWriter};
3use crate::Error;
4use crate::EventChannel;
5use crate::IntoElbusResult;
6use crate::OpConfirm;
7use crate::QoS;
8use crate::GREETINGS;
9use crate::PING_FRAME;
10use crate::PROTOCOL_VERSION;
11use crate::RESPONSE_OK;
12use crate::SECONDARY_SEP;
13use crate::{Frame, FrameData, FrameKind, FrameOp};
14use std::collections::BTreeMap;
15use std::marker::Unpin;
16use std::sync::atomic;
17use std::sync::{Arc, Mutex};
18use std::time::Duration;
19use tokio::io::{AsyncReadExt, AsyncWriteExt, BufReader};
20#[cfg(not(target_os = "windows"))]
21use tokio::net::unix;
22#[cfg(not(target_os = "windows"))]
23use tokio::net::UnixStream;
24use tokio::net::{tcp, TcpStream};
25use tokio::sync::oneshot;
26use tokio::task::JoinHandle;
27
28use crate::client::AsyncClient;
29
30use log::{error, trace, warn};
31
32use async_trait::async_trait;
33
34type ResponseMap = Arc<Mutex<BTreeMap<u32, oneshot::Sender<Result<(), Error>>>>>;
35
36enum Writer {
37    #[cfg(not(target_os = "windows"))]
38    Unix(TtlBufWriter<unix::OwnedWriteHalf>),
39    Tcp(TtlBufWriter<tcp::OwnedWriteHalf>),
40}
41
42impl Writer {
43    pub async fn write(&mut self, buf: &[u8], flush: Flush) -> Result<(), Error> {
44        match self {
45            #[cfg(not(target_os = "windows"))]
46            Writer::Unix(w) => w.write(buf, flush).await.map_err(Into::into),
47            Writer::Tcp(w) => w.write(buf, flush).await.map_err(Into::into),
48        }
49    }
50}
51
52#[derive(Debug, Clone)]
53pub struct Config {
54    path: String,
55    name: String,
56    buf_size: usize,
57    buf_ttl: Duration,
58    queue_size: usize,
59    timeout: Duration,
60}
61
62impl Config {
63    /// path - /path/to/socket (must end with .sock .socket or .ipc) or host:port,
64    /// name - an unique client name
65    pub fn new(path: &str, name: &str) -> Self {
66        Self {
67            path: path.to_owned(),
68            name: name.to_owned(),
69            buf_size: crate::DEFAULT_BUF_SIZE,
70            buf_ttl: crate::DEFAULT_BUF_TTL,
71            queue_size: crate::DEFAULT_QUEUE_SIZE,
72            timeout: crate::DEFAULT_TIMEOUT,
73        }
74    }
75    pub fn buf_size(mut self, size: usize) -> Self {
76        self.buf_size = size;
77        self
78    }
79    pub fn buf_ttl(mut self, ttl: Duration) -> Self {
80        self.buf_ttl = ttl;
81        self
82    }
83    pub fn queue_size(mut self, size: usize) -> Self {
84        self.queue_size = size;
85        self
86    }
87    pub fn timeout(mut self, timeout: Duration) -> Self {
88        self.timeout = timeout;
89        self
90    }
91}
92
93pub struct Client {
94    name: String,
95    writer: Writer,
96    reader_fut: JoinHandle<()>,
97    frame_id: u32,
98    responses: ResponseMap,
99    rx: Option<EventChannel>,
100    connected: Arc<atomic::AtomicBool>,
101    timeout: Duration,
102    config: Config,
103    secondary_counter: atomic::AtomicUsize,
104}
105
106// keep these as macros to insure inline and avoid unecc. futures
107
108macro_rules! prepare_frame_buf {
109    ($self: expr, $op: expr, $qos: expr) => {{
110        $self.increment_frame_id();
111        let mut buf = $self.frame_id.to_le_bytes().to_vec();
112        buf.push($op as u8 | ($qos as u8) << 6);
113        buf
114    }};
115}
116
117macro_rules! send_data_or_mark_disconnected {
118    ($self: expr, $data: expr, $flush: expr) => {
119        match tokio::time::timeout($self.timeout, $self.writer.write($data, $flush)).await {
120            Ok(result) => {
121                if let Err(e) = result {
122                    $self.reader_fut.abort();
123                    $self.connected.store(false, atomic::Ordering::SeqCst);
124                    return Err(e.into());
125                }
126            }
127            Err(e) => {
128                return Err(e.into());
129            }
130        }
131    };
132}
133
134macro_rules! send_frame_and_confirm {
135    ($self: expr, $buf: expr, $payload: expr, $qos: expr) => {{
136        let rx = if $qos.needs_ack() {
137            let (tx, rx) = oneshot::channel();
138            {
139                $self.responses.lock().unwrap().insert($self.frame_id, tx);
140            }
141            Some(rx)
142        } else {
143            None
144        };
145        send_data_or_mark_disconnected!($self, $buf, Flush::No);
146        send_data_or_mark_disconnected!($self, $payload, $qos.is_realtime().into());
147        Ok(rx)
148    }};
149}
150
151macro_rules! send_frame {
152    // send to target or topic
153    ($self: expr, $target: expr, $payload: expr, $op: expr, $qos: expr) => {{
154        let mut buf = prepare_frame_buf!($self, $op, $qos);
155        let t = $target.as_bytes();
156        buf.extend_from_slice(&((t.len() + $payload.len() + 1) as u32).to_le_bytes());
157        buf.extend_from_slice(t);
158        buf.push(0x00);
159        trace!("sending elbus {:?} to {} QoS={:?}", $op, $target, $qos);
160        send_frame_and_confirm!($self, &buf, $payload, $qos)
161    }};
162    // zc-send to target or topic
163    ($self: expr, $target: expr, $header: expr, $payload: expr, $op: expr, $qos: expr) => {{
164        let mut buf = prepare_frame_buf!($self, $op, $qos);
165        let t = $target.as_bytes();
166        buf.extend_from_slice(
167            &((t.len() + $payload.len() + $header.len() + 1) as u32).to_le_bytes(),
168        );
169        buf.extend_from_slice(t);
170        buf.push(0x00);
171        buf.extend_from_slice($header);
172        trace!("sending elbus {:?} to {} QoS={:?}", $op, $target, $qos);
173        send_frame_and_confirm!($self, &buf, $payload, $qos)
174    }};
175    // send w/o a target
176    ($self: expr, $payload: expr, $op: expr, $qos: expr) => {{
177        let mut buf = prepare_frame_buf!($self, $op, $qos);
178        buf.extend_from_slice(&($payload.len() as u32).to_le_bytes());
179        send_frame_and_confirm!($self, &buf, $payload, $qos)
180    }};
181}
182
183macro_rules! connect_broker {
184    ($name: expr, $reader: expr, $writer: expr,
185         $responses: expr, $connected: expr, $timeout: expr, $queue_size: expr) => {{
186        chat($name, &mut $reader, &mut $writer).await?;
187        let (tx, rx) = async_channel::bounded($queue_size);
188        let reader_responses = $responses.clone();
189        let rconn = $connected.clone();
190        let timeout = $timeout.clone();
191        let reader_fut = tokio::spawn(async move {
192            if let Err(e) = handle_read($reader, tx, timeout, reader_responses).await {
193                error!("elbus client reader error: {}", e);
194            }
195            rconn.store(false, atomic::Ordering::SeqCst);
196        });
197        (reader_fut, rx)
198    }};
199}
200
201impl Client {
202    pub async fn connect(config: &Config) -> Result<Self, Error> {
203        let responses: ResponseMap = <_>::default();
204        let connected = Arc::new(atomic::AtomicBool::new(true));
205        #[allow(clippy::case_sensitive_file_extension_comparisons)]
206        let (writer, reader_fut, rx) = if config.path.ends_with(".sock")
207            || config.path.ends_with(".socket")
208            || config.path.ends_with(".ipc")
209            || config.path.starts_with('/')
210        {
211            #[cfg(target_os = "windows")]
212            {
213                return Err(Error::not_supported("unix sockets"));
214            }
215            #[cfg(not(target_os = "windows"))]
216            {
217                let stream = UnixStream::connect(&config.path).await?;
218                let (r, mut writer) = stream.into_split();
219                let mut reader = BufReader::with_capacity(config.buf_size, r);
220                let (reader_fut, rx) = connect_broker!(
221                    &config.name,
222                    reader,
223                    writer,
224                    responses,
225                    connected,
226                    config.timeout,
227                    config.queue_size
228                );
229                (
230                    Writer::Unix(TtlBufWriter::new(
231                        writer,
232                        config.buf_size,
233                        config.buf_ttl,
234                        config.timeout,
235                    )),
236                    reader_fut,
237                    rx,
238                )
239            }
240        } else {
241            let stream = TcpStream::connect(&config.path).await?;
242            stream.set_nodelay(true)?;
243            let (r, mut writer) = stream.into_split();
244            let mut reader = BufReader::with_capacity(config.buf_size, r);
245            let (reader_fut, rx) = connect_broker!(
246                &config.name,
247                reader,
248                writer,
249                responses,
250                connected,
251                config.timeout,
252                config.queue_size
253            );
254            (
255                Writer::Tcp(TtlBufWriter::new(
256                    writer,
257                    config.buf_size,
258                    config.buf_ttl,
259                    config.timeout,
260                )),
261                reader_fut,
262                rx,
263            )
264        };
265        Ok(Self {
266            name: config.name.clone(),
267            writer,
268            reader_fut,
269            frame_id: 0,
270            responses,
271            rx: Some(rx),
272            connected,
273            timeout: config.timeout,
274            config: config.clone(),
275            secondary_counter: atomic::AtomicUsize::new(0),
276        })
277    }
278    pub async fn register_secondary(&self) -> Result<Self, Error> {
279        if self.name.contains(SECONDARY_SEP) {
280            Err(Error::not_supported("not a primary client"))
281        } else {
282            let secondary_id = self
283                .secondary_counter
284                .fetch_add(1, atomic::Ordering::SeqCst);
285            let secondary_name = format!("{}{}{}", self.name, SECONDARY_SEP, secondary_id);
286            let mut config = self.config.clone();
287            config.name = secondary_name;
288            Self::connect(&config).await
289        }
290    }
291    #[inline]
292    fn increment_frame_id(&mut self) {
293        if self.frame_id == u32::MAX {
294            self.frame_id = 1;
295        } else {
296            self.frame_id += 1;
297        }
298    }
299    #[inline]
300    pub fn get_timeout(&self) -> Duration {
301        self.timeout
302    }
303}
304#[async_trait]
305impl AsyncClient for Client {
306    #[inline]
307    fn take_event_channel(&mut self) -> Option<EventChannel> {
308        self.rx.take()
309    }
310    #[inline]
311    fn get_connected_beacon(&self) -> Option<Arc<atomic::AtomicBool>> {
312        Some(self.connected.clone())
313    }
314    async fn send(
315        &mut self,
316        target: &str,
317        payload: Cow<'async_trait>,
318        qos: QoS,
319    ) -> Result<OpConfirm, Error> {
320        send_frame!(self, target, payload.as_slice(), FrameOp::Message, qos)
321    }
322    async fn zc_send(
323        &mut self,
324        target: &str,
325        header: Cow<'async_trait>,
326        payload: Cow<'async_trait>,
327        qos: QoS,
328    ) -> Result<OpConfirm, Error> {
329        send_frame!(
330            self,
331            target,
332            header.as_slice(),
333            payload.as_slice(),
334            FrameOp::Message,
335            qos
336        )
337    }
338    async fn send_broadcast(
339        &mut self,
340        target: &str,
341        payload: Cow<'async_trait>,
342        qos: QoS,
343    ) -> Result<OpConfirm, Error> {
344        send_frame!(self, target, payload.as_slice(), FrameOp::Broadcast, qos)
345    }
346    async fn publish(
347        &mut self,
348        target: &str,
349        payload: Cow<'async_trait>,
350        qos: QoS,
351    ) -> Result<OpConfirm, Error> {
352        send_frame!(self, target, payload.as_slice(), FrameOp::PublishTopic, qos)
353    }
354    async fn subscribe(&mut self, topic: &str, qos: QoS) -> Result<OpConfirm, Error> {
355        send_frame!(self, topic.as_bytes(), FrameOp::SubscribeTopic, qos)
356    }
357    async fn unsubscribe(&mut self, topic: &str, qos: QoS) -> Result<OpConfirm, Error> {
358        send_frame!(self, topic.as_bytes(), FrameOp::UnsubscribeTopic, qos)
359    }
360    async fn subscribe_bulk(&mut self, topics: &[&str], qos: QoS) -> Result<OpConfirm, Error> {
361        let mut payload = Vec::new();
362        for topic in topics {
363            if !payload.is_empty() {
364                payload.push(0x00);
365            }
366            payload.extend(topic.as_bytes());
367        }
368        send_frame!(self, &payload, FrameOp::SubscribeTopic, qos)
369    }
370    async fn unsubscribe_bulk(&mut self, topics: &[&str], qos: QoS) -> Result<OpConfirm, Error> {
371        let mut payload = Vec::new();
372        for topic in topics {
373            if !payload.is_empty() {
374                payload.push(0x00);
375            }
376            payload.extend(topic.as_bytes());
377        }
378        send_frame!(self, &payload, FrameOp::UnsubscribeTopic, qos)
379    }
380    #[inline]
381    async fn ping(&mut self) -> Result<(), Error> {
382        send_data_or_mark_disconnected!(self, PING_FRAME, Flush::Instant);
383        Ok(())
384    }
385    #[inline]
386    fn is_connected(&self) -> bool {
387        self.connected.load(atomic::Ordering::SeqCst)
388    }
389    #[inline]
390    fn get_timeout(&self) -> Option<Duration> {
391        Some(self.timeout)
392    }
393    #[inline]
394    fn get_name(&self) -> &str {
395        self.name.as_str()
396    }
397}
398
399impl Drop for Client {
400    fn drop(&mut self) {
401        self.reader_fut.abort();
402    }
403}
404
405async fn handle_read<R>(
406    mut reader: R,
407    tx: async_channel::Sender<Frame>,
408    timeout: Duration,
409    responses: ResponseMap,
410) -> Result<(), Error>
411where
412    R: AsyncReadExt + Unpin,
413{
414    loop {
415        let mut buf = vec![0; 6];
416        reader.read_exact(&mut buf).await?;
417        let frame_type: FrameKind = buf[0].try_into()?;
418        let realtime = buf[5] != 0;
419        match frame_type {
420            FrameKind::Nop => {}
421            FrameKind::Acknowledge => {
422                let ack_id = u32::from_le_bytes(buf[1..5].try_into().unwrap());
423                let tx_channel = { responses.lock().unwrap().remove(&ack_id) };
424                if let Some(tx) = tx_channel {
425                    let _r = tx.send(buf[5].to_elbus_result());
426                } else {
427                    warn!("orphaned elbus op ack {}", ack_id);
428                }
429            }
430            _ => {
431                let frame_len = u32::from_le_bytes(buf[1..5].try_into().unwrap());
432                let mut buf = vec![0; frame_len as usize];
433                tokio::time::timeout(timeout, reader.read_exact(&mut buf)).await??;
434                let (sender, topic, payload_pos) = {
435                    if frame_type == FrameKind::Publish {
436                        let mut sp = buf.splitn(3, |c| *c == 0);
437                        let s = sp.next().ok_or_else(|| Error::data("broken frame"))?;
438                        let sender = std::str::from_utf8(s)?.to_owned();
439                        let t = sp.next().ok_or_else(|| Error::data("broken frame"))?;
440                        let topic = std::str::from_utf8(t)?.to_owned();
441                        sp.next().ok_or_else(|| Error::data("broken frame"))?;
442                        let payload_pos = s.len() + t.len() + 2;
443                        (Some(sender), Some(topic), payload_pos)
444                    } else {
445                        let mut sp = buf.splitn(2, |c| *c == 0);
446                        let s = sp.next().ok_or_else(|| Error::data("broken frame"))?;
447                        let sender = std::str::from_utf8(s)?.to_owned();
448                        sp.next().ok_or_else(|| Error::data("broken frame"))?;
449                        let payload_pos = s.len() + 1;
450                        (Some(sender), None, payload_pos)
451                    }
452                };
453                let frame = Arc::new(FrameData::new(
454                    frame_type,
455                    sender,
456                    topic,
457                    None,
458                    buf,
459                    payload_pos,
460                    realtime,
461                ));
462                tx.send(frame).await.map_err(Error::io)?;
463            }
464        }
465    }
466}
467
468async fn chat<R, W>(name: &str, reader: &mut R, writer: &mut W) -> Result<(), Error>
469where
470    R: AsyncReadExt + Unpin,
471    W: AsyncWriteExt + Unpin,
472{
473    if name.len() > u16::MAX as usize {
474        return Err(Error::data("name too long"));
475    }
476    let mut buf = vec![0; 3];
477    reader.read_exact(&mut buf).await?;
478    if buf[0] != GREETINGS[0] {
479        return Err(Error::not_supported("Invalid greetings"));
480    }
481    if u16::from_le_bytes(buf[1..3].try_into().unwrap()) != PROTOCOL_VERSION {
482        return Err(Error::not_supported("Unsupported protocol version"));
483    }
484    writer.write_all(&buf).await?;
485    let mut buf = vec![0; 1];
486    reader.read_exact(&mut buf).await?;
487    if buf[0] != RESPONSE_OK {
488        return Err(Error::new(
489            buf[0].into(),
490            Some(format!("Server greetings response: {:?}", buf[0])),
491        ));
492    }
493    let n = name.as_bytes().to_vec();
494    #[allow(clippy::cast_possible_truncation)]
495    writer.write_all(&(name.len() as u16).to_le_bytes()).await?;
496    writer.write_all(&n).await?;
497    let mut buf = vec![0; 1];
498    reader.read_exact(&mut buf).await?;
499    if buf[0] != RESPONSE_OK {
500        return Err(Error::new(
501            buf[0].into(),
502            Some(format!("Server registration response: {:?}", buf[0])),
503        ));
504    }
505    Ok(())
506}