bluerobotics_ping/
device.rs

1use std::convert::TryFrom;
2
3use futures::{
4    stream::{SplitSink, SplitStream},
5    SinkExt, StreamExt,
6};
7use tokio::sync::{
8    broadcast::{self, Sender},
9    mpsc::{self, Receiver},
10};
11use tokio::{
12    io::{AsyncRead, AsyncWrite},
13    task::JoinHandle,
14};
15use tokio_util::codec::{Decoder, Framed};
16use tracing::{error, info, trace};
17
18use crate::{
19    codec::PingCodec,
20    common,
21    error::PingError,
22    message::{self, MessageInfo, ProtocolMessage},
23    Messages,
24};
25
26// Make devices available, each device uses Common and PingDevice.
27pub use crate::ping1d::Device as Ping1D;
28pub use crate::ping360::Device as Ping360;
29
30#[derive(Debug)]
31pub struct Common {
32    tx: mpsc::Sender<ProtocolMessage>,
33    rx: broadcast::Receiver<ProtocolMessage>,
34    task_handles: TaskHandles,
35}
36#[derive(Debug)]
37
38struct TaskHandles {
39    stream_handle: JoinHandle<()>,
40    sink_handle: JoinHandle<()>,
41}
42
43impl Common {
44    pub fn new<T>(io: T) -> Self
45    where
46        T: AsyncRead + AsyncWrite + Unpin + Send + 'static,
47    {
48        // Prepare Serial sink and stream modules
49        let serial: Framed<T, PingCodec> = PingCodec::new().framed(io);
50        let (serial_sink, serial_stream) = serial.split();
51
52        // Prepare Serial receiver broadcast and sender
53        let (broadcast_tx, broadcast_rx) = broadcast::channel::<ProtocolMessage>(100);
54        let stream_handle = tokio::spawn(Self::stream(serial_stream, broadcast_tx));
55        let (sender, sender_rx) = mpsc::channel::<ProtocolMessage>(100);
56        let sink_handle = tokio::spawn(Self::sink(serial_sink, sender_rx));
57
58        Common {
59            tx: sender,
60            rx: broadcast_rx,
61            task_handles: TaskHandles {
62                stream_handle,
63                sink_handle,
64            },
65        }
66    }
67
68    async fn sink<T: AsyncRead + AsyncWrite + Unpin + Send + 'static>(
69        mut sink: SplitSink<Framed<T, PingCodec>, ProtocolMessage>,
70        mut sender_rx: Receiver<ProtocolMessage>,
71    ) {
72        while let Some(item) = sender_rx.recv().await {
73            if let Err(e) = sink.send(item).await {
74                error!("{e:?}");
75            }
76        }
77    }
78
79    async fn stream<T: AsyncRead + AsyncWrite + Unpin + Send + 'static>(
80        mut serial_stream: SplitStream<Framed<T, PingCodec>>,
81        broadcast_tx: Sender<ProtocolMessage>,
82    ) {
83        'outside_loop: loop {
84            while let Some(msg) = serial_stream.next().await {
85                match msg {
86                    Ok(msg) => {
87                        if let Err(e) = broadcast_tx.send(msg) {
88                            error!("{e:?}");
89                            break 'outside_loop;
90                        };
91                    }
92                    Err(e) => {
93                        trace!("{e:?}");
94                    }
95                }
96            }
97        }
98    }
99
100    pub async fn send_message(&self, message: ProtocolMessage) -> Result<(), PingError> {
101        self.tx.send(message).await.map_err(|err| err.into())
102    }
103
104    fn subscribe(&self) -> tokio::sync::broadcast::Receiver<ProtocolMessage> {
105        self.rx.resubscribe()
106    }
107}
108
109impl Drop for Common {
110    fn drop(&mut self) {
111        self.task_handles.stream_handle.abort();
112        self.task_handles.sink_handle.abort();
113        info!("TaskHandles sink and stream dropped, tasks aborted");
114    }
115}
116
117pub trait PingDevice {
118    fn get_common(&self) -> &Common;
119
120    fn get_mut_common(&mut self) -> &mut Common;
121
122    fn subscribe(&self) -> tokio::sync::broadcast::Receiver<ProtocolMessage> {
123        self.get_common().subscribe()
124    }
125
126    async fn send_general_request(&self, requested_id: u16) -> Result<(), PingError> {
127        let request =
128            common::Messages::GeneralRequest(common::GeneralRequestStruct { requested_id });
129        let mut package = message::ProtocolMessage::new();
130        package.set_message(&request);
131
132        if let Err(e) = self.get_common().send_message(package).await {
133            return Err(e);
134        };
135
136        Ok(())
137    }
138
139    async fn wait_for_message<T: 'static>(
140        &self,
141        mut receiver: tokio::sync::broadcast::Receiver<ProtocolMessage>,
142    ) -> Result<T, PingError>
143    where
144        T: crate::message::MessageInfo + std::marker::Sync + Clone + std::marker::Send,
145    {
146        let future = async move {
147            loop {
148                match receiver.recv().await {
149                    Ok(answer) => {
150                        if T::id() != answer.message_id {
151                            continue;
152                        };
153                        let message = Messages::try_from(&answer)
154                            .map_err(|_e| PingError::TryFromError(answer))?;
155                        return Ok(message.inner::<T>().unwrap().clone());
156                    }
157                    Err(broadcast::error::RecvError::Lagged(_)) => continue,
158                    Err(e) => return Err(e.into()),
159                };
160            }
161        };
162
163        match tokio::time::timeout(tokio::time::Duration::from_secs(15), future).await {
164            Ok(result) => result,
165            Err(_) => Err(PingError::TimeoutError),
166        }
167    }
168
169    async fn wait_for_ack(
170        &self,
171        mut receiver: tokio::sync::broadcast::Receiver<ProtocolMessage>,
172        message_id: u16,
173    ) -> Result<(), PingError> {
174        let future = async move {
175            loop {
176                match receiver.recv().await {
177                    Ok(answer) => {
178                        if common::AckStruct::id() != answer.message_id
179                            && common::NackStruct::id() != answer.message_id
180                        {
181                            continue;
182                        }
183                        match Messages::try_from(&answer) {
184                            Ok(Messages::Common(common::Messages::Ack(answer))) => {
185                                if answer.acked_id != message_id {
186                                    continue;
187                                };
188                                return Ok(());
189                            }
190                            Ok(Messages::Common(common::Messages::Nack(answer))) => {
191                                if answer.nacked_id != message_id {
192                                    continue;
193                                };
194                                return Err(PingError::NackError(answer.nack_message));
195                            }
196                            _ => return Err(PingError::TryFromError(answer)), // Almost unreachable, but raises error ProtocolMessage
197                        };
198                    }
199                    Err(broadcast::error::RecvError::Lagged(_)) => continue,
200                    Err(e) => return Err(e.into()),
201                };
202            }
203        };
204
205        match tokio::time::timeout(tokio::time::Duration::from_secs(15), future).await {
206            Ok(result) => result,
207            Err(_) => Err(PingError::TimeoutError),
208        }
209    }
210
211    async fn request<T: 'static>(&self) -> Result<T, PingError>
212    where
213        T: crate::message::MessageInfo + std::marker::Sync + Clone + std::marker::Send,
214    {
215        let receiver = self.subscribe();
216
217        self.send_general_request(T::id()).await?;
218
219        self.wait_for_message(receiver).await
220    }
221
222    #[doc = "Device information"]
223    async fn device_information(&self) -> Result<common::DeviceInformationStruct, PingError> {
224        self.request().await
225    }
226    #[doc = "The protocol version"]
227    async fn protocol_version(&self) -> Result<common::ProtocolVersionStruct, PingError> {
228        self.request().await
229    }
230    #[doc = "Set the device ID."]
231    #[doc = "# Arguments"]
232    #[doc = "* `device_id` - Device ID (1-254). 0 is unknown and 255 is reserved for broadcast messages."]
233    async fn set_device_id(&self, device_id: u8) -> Result<(), PingError> {
234        let request = common::Messages::SetDeviceId(common::SetDeviceIdStruct { device_id });
235        let mut package = ProtocolMessage::new();
236        package.set_message(&request);
237        let receiver = self.subscribe();
238        self.get_common().send_message(package).await?;
239        self.wait_for_ack(receiver, common::SetDeviceIdStruct::id())
240            .await
241    }
242}