bluerobotics_ping/
device.rs

1use std::convert::TryFrom;
2
3use futures::{
4    stream::{SplitSink, SplitStream},
5    SinkExt, StreamExt,
6};
7use tokio::sync::{
8    broadcast::{self, Sender},
9    mpsc::{self, Receiver},
10};
11use tokio::{
12    io::{AsyncRead, AsyncWrite},
13    task::JoinHandle,
14};
15use tokio_util::codec::{Decoder, Framed};
16use tracing::{error, info, trace};
17
18use crate::{
19    codec::PingCodec,
20    common,
21    error::PingError,
22    message::{self, MessageInfo, ProtocolMessage},
23    Messages,
24};
25
26// Make devices available, each device uses Common and PingDevice.
27pub use crate::ping1d::Device as Ping1D;
28pub use crate::ping360::Device as Ping360;
29
30#[derive(Debug)]
31pub struct Common {
32    tx: mpsc::Sender<ProtocolMessage>,
33    rx: broadcast::Receiver<ProtocolMessage>,
34    task_handles: TaskHandles,
35}
36#[derive(Debug)]
37
38struct TaskHandles {
39    stream_handle: JoinHandle<()>,
40    sink_handle: JoinHandle<()>,
41}
42
43impl Common {
44    pub fn new<T>(io: T) -> Self
45    where
46        T: AsyncRead + AsyncWrite + Unpin + Send + 'static,
47    {
48        // Prepare Serial sink and stream modules
49        let serial: Framed<T, PingCodec> = PingCodec::new().framed(io);
50        let (serial_sink, serial_stream) = serial.split();
51
52        // Prepare Serial receiver broadcast and sender
53        let (broadcast_tx, broadcast_rx) = broadcast::channel::<ProtocolMessage>(100);
54        let stream_handle = tokio::spawn(Self::stream(serial_stream, broadcast_tx));
55        let (sender, sender_rx) = mpsc::channel::<ProtocolMessage>(100);
56        let sink_handle = tokio::spawn(Self::sink(serial_sink, sender_rx));
57
58        Common {
59            tx: sender,
60            rx: broadcast_rx,
61            task_handles: TaskHandles {
62                stream_handle,
63                sink_handle,
64            },
65        }
66    }
67
68    async fn sink<T: AsyncRead + AsyncWrite + Unpin + Send + 'static>(
69        mut sink: SplitSink<Framed<T, PingCodec>, ProtocolMessage>,
70        mut sender_rx: Receiver<ProtocolMessage>,
71    ) {
72        while let Some(item) = sender_rx.recv().await {
73            if let Err(e) = sink.send(item).await {
74                error!("{e:?}");
75            }
76        }
77    }
78
79    async fn stream<T: AsyncRead + AsyncWrite + Unpin + Send + 'static>(
80        mut serial_stream: SplitStream<Framed<T, PingCodec>>,
81        broadcast_tx: Sender<ProtocolMessage>,
82    ) {
83        'outside_loop: loop {
84            while let Some(msg) = serial_stream.next().await {
85                match msg {
86                    Ok(msg) => {
87                        if let Err(e) = broadcast_tx.send(msg) {
88                            error!("{e:?}");
89                            break 'outside_loop;
90                        };
91                    }
92                    Err(e) => {
93                        trace!("{e:?}");
94                    }
95                }
96            }
97        }
98    }
99
100    pub async fn send_message(&self, message: ProtocolMessage) -> Result<(), PingError> {
101        self.tx.send(message).await.map_err(|err| err.into())
102    }
103
104    fn subscribe(&self) -> tokio::sync::broadcast::Receiver<ProtocolMessage> {
105        self.rx.resubscribe()
106    }
107}
108
109impl Drop for Common {
110    fn drop(&mut self) {
111        self.task_handles.stream_handle.abort();
112        self.task_handles.sink_handle.abort();
113        info!("TaskHandles sink and stream dropped, tasks aborted");
114    }
115}
116
117pub trait PingDevice {
118    fn get_common(&self) -> &Common;
119
120    fn get_mut_common(&mut self) -> &mut Common;
121
122    fn subscribe(&self) -> tokio::sync::broadcast::Receiver<ProtocolMessage> {
123        self.get_common().subscribe()
124    }
125
126    async fn send_general_request(&self, requested_id: u16) -> Result<(), PingError> {
127        let request =
128            common::Messages::GeneralRequest(common::GeneralRequestStruct { requested_id });
129        let mut package = message::ProtocolMessage::new();
130        package.set_message(&request);
131
132        if let Err(e) = self.get_common().send_message(package).await {
133            return Err(e);
134        };
135
136        Ok(())
137    }
138
139    async fn wait_for_message<T: 'static>(
140        &self,
141        mut receiver: tokio::sync::broadcast::Receiver<ProtocolMessage>,
142    ) -> Result<T, PingError>
143    where
144        T: crate::message::MessageInfo + std::marker::Sync + Clone + std::marker::Send,
145    {
146        let future = async move {
147            loop {
148                match receiver.recv().await {
149                    Ok(answer) => {
150                        if T::id() != answer.message_id {
151                            continue;
152                        };
153                        let message = Messages::try_from(&answer)
154                            .map_err(|_e| PingError::TryFromError(answer))?;
155                        match message.inner::<T>() {
156                            Some(message) => return Ok(message.clone()),
157                            None => {
158                                error!(
159                                    "Received message is not of type `{}` ({}), receiving: {:?}",
160                                    T::name(),
161                                    T::id(),
162                                    message,
163                                );
164                            }
165                        }
166                    }
167                    Err(broadcast::error::RecvError::Lagged(_)) => continue,
168                    Err(e) => return Err(e.into()),
169                };
170            }
171        };
172
173        match tokio::time::timeout(tokio::time::Duration::from_secs(15), future).await {
174            Ok(result) => result,
175            Err(_) => Err(PingError::TimeoutError),
176        }
177    }
178
179    async fn wait_for_ack(
180        &self,
181        mut receiver: tokio::sync::broadcast::Receiver<ProtocolMessage>,
182        message_id: u16,
183    ) -> Result<(), PingError> {
184        let future = async move {
185            loop {
186                match receiver.recv().await {
187                    Ok(answer) => {
188                        if common::AckStruct::id() != answer.message_id
189                            && common::NackStruct::id() != answer.message_id
190                        {
191                            continue;
192                        }
193                        match Messages::try_from(&answer) {
194                            Ok(Messages::Common(common::Messages::Ack(answer))) => {
195                                if answer.acked_id != message_id {
196                                    continue;
197                                };
198                                return Ok(());
199                            }
200                            Ok(Messages::Common(common::Messages::Nack(answer))) => {
201                                if answer.nacked_id != message_id {
202                                    continue;
203                                };
204                                return Err(PingError::NackError(answer.nack_message));
205                            }
206                            _ => return Err(PingError::TryFromError(answer)), // Almost unreachable, but raises error ProtocolMessage
207                        };
208                    }
209                    Err(broadcast::error::RecvError::Lagged(_)) => continue,
210                    Err(e) => return Err(e.into()),
211                };
212            }
213        };
214
215        match tokio::time::timeout(tokio::time::Duration::from_secs(15), future).await {
216            Ok(result) => result,
217            Err(_) => Err(PingError::TimeoutError),
218        }
219    }
220
221    async fn request<T: 'static>(&self) -> Result<T, PingError>
222    where
223        T: crate::message::MessageInfo + std::marker::Sync + Clone + std::marker::Send,
224    {
225        let receiver = self.subscribe();
226
227        self.send_general_request(T::id()).await?;
228
229        self.wait_for_message(receiver).await
230    }
231
232    #[doc = "Device information"]
233    async fn device_information(&self) -> Result<common::DeviceInformationStruct, PingError> {
234        self.request().await
235    }
236    #[doc = "The protocol version"]
237    async fn protocol_version(&self) -> Result<common::ProtocolVersionStruct, PingError> {
238        self.request().await
239    }
240    #[doc = "Set the device ID."]
241    #[doc = "# Arguments"]
242    #[doc = "* `device_id` - Device ID (1-254). 0 is unknown and 255 is reserved for broadcast messages."]
243    async fn set_device_id(&self, device_id: u8) -> Result<(), PingError> {
244        let request = common::Messages::SetDeviceId(common::SetDeviceIdStruct { device_id });
245        let mut package = ProtocolMessage::new();
246        package.set_message(&request);
247        let receiver = self.subscribe();
248        self.get_common().send_message(package).await?;
249        self.wait_for_ack(receiver, common::SetDeviceIdStruct::id())
250            .await
251    }
252}