bluerobotics_ping/
device.rs1use std::convert::TryFrom;
2
3use futures::{
4 stream::{SplitSink, SplitStream},
5 SinkExt, StreamExt,
6};
7use tokio::sync::{
8 broadcast::{self, Sender},
9 mpsc::{self, Receiver},
10};
11use tokio::{
12 io::{AsyncRead, AsyncWrite},
13 task::JoinHandle,
14};
15use tokio_util::codec::{Decoder, Framed};
16use tracing::{error, info, trace};
17
18use crate::{
19 codec::PingCodec,
20 common,
21 error::PingError,
22 message::{self, MessageInfo, ProtocolMessage},
23 Messages,
24};
25
26pub use crate::ping1d::Device as Ping1D;
28pub use crate::ping360::Device as Ping360;
29
30#[derive(Debug)]
31pub struct Common {
32 tx: mpsc::Sender<ProtocolMessage>,
33 rx: broadcast::Receiver<ProtocolMessage>,
34 task_handles: TaskHandles,
35}
36#[derive(Debug)]
37
38struct TaskHandles {
39 stream_handle: JoinHandle<()>,
40 sink_handle: JoinHandle<()>,
41}
42
43impl Common {
44 pub fn new<T>(io: T) -> Self
45 where
46 T: AsyncRead + AsyncWrite + Unpin + Send + 'static,
47 {
48 let serial: Framed<T, PingCodec> = PingCodec::new().framed(io);
50 let (serial_sink, serial_stream) = serial.split();
51
52 let (broadcast_tx, broadcast_rx) = broadcast::channel::<ProtocolMessage>(100);
54 let stream_handle = tokio::spawn(Self::stream(serial_stream, broadcast_tx));
55 let (sender, sender_rx) = mpsc::channel::<ProtocolMessage>(100);
56 let sink_handle = tokio::spawn(Self::sink(serial_sink, sender_rx));
57
58 Common {
59 tx: sender,
60 rx: broadcast_rx,
61 task_handles: TaskHandles {
62 stream_handle,
63 sink_handle,
64 },
65 }
66 }
67
68 async fn sink<T: AsyncRead + AsyncWrite + Unpin + Send + 'static>(
69 mut sink: SplitSink<Framed<T, PingCodec>, ProtocolMessage>,
70 mut sender_rx: Receiver<ProtocolMessage>,
71 ) {
72 while let Some(item) = sender_rx.recv().await {
73 if let Err(e) = sink.send(item).await {
74 error!("{e:?}");
75 }
76 }
77 }
78
79 async fn stream<T: AsyncRead + AsyncWrite + Unpin + Send + 'static>(
80 mut serial_stream: SplitStream<Framed<T, PingCodec>>,
81 broadcast_tx: Sender<ProtocolMessage>,
82 ) {
83 'outside_loop: loop {
84 while let Some(msg) = serial_stream.next().await {
85 match msg {
86 Ok(msg) => {
87 if let Err(e) = broadcast_tx.send(msg) {
88 error!("{e:?}");
89 break 'outside_loop;
90 };
91 }
92 Err(e) => {
93 trace!("{e:?}");
94 }
95 }
96 }
97 }
98 }
99
100 pub async fn send_message(&self, message: ProtocolMessage) -> Result<(), PingError> {
101 self.tx.send(message).await.map_err(|err| err.into())
102 }
103
104 fn subscribe(&self) -> tokio::sync::broadcast::Receiver<ProtocolMessage> {
105 self.rx.resubscribe()
106 }
107}
108
109impl Drop for Common {
110 fn drop(&mut self) {
111 self.task_handles.stream_handle.abort();
112 self.task_handles.sink_handle.abort();
113 info!("TaskHandles sink and stream dropped, tasks aborted");
114 }
115}
116
117pub trait PingDevice {
118 fn get_common(&self) -> &Common;
119
120 fn get_mut_common(&mut self) -> &mut Common;
121
122 fn subscribe(&self) -> tokio::sync::broadcast::Receiver<ProtocolMessage> {
123 self.get_common().subscribe()
124 }
125
126 async fn send_general_request(&self, requested_id: u16) -> Result<(), PingError> {
127 let request =
128 common::Messages::GeneralRequest(common::GeneralRequestStruct { requested_id });
129 let mut package = message::ProtocolMessage::new();
130 package.set_message(&request);
131
132 if let Err(e) = self.get_common().send_message(package).await {
133 return Err(e);
134 };
135
136 Ok(())
137 }
138
139 async fn wait_for_message<T: 'static>(
140 &self,
141 mut receiver: tokio::sync::broadcast::Receiver<ProtocolMessage>,
142 ) -> Result<T, PingError>
143 where
144 T: crate::message::MessageInfo + std::marker::Sync + Clone + std::marker::Send,
145 {
146 let future = async move {
147 loop {
148 match receiver.recv().await {
149 Ok(answer) => {
150 if T::id() != answer.message_id {
151 continue;
152 };
153 let message = Messages::try_from(&answer)
154 .map_err(|_e| PingError::TryFromError(answer))?;
155 match message.inner::<T>() {
156 Some(message) => return Ok(message.clone()),
157 None => {
158 error!(
159 "Received message is not of type `{}` ({}), receiving: {:?}",
160 T::name(),
161 T::id(),
162 message,
163 );
164 }
165 }
166 }
167 Err(broadcast::error::RecvError::Lagged(_)) => continue,
168 Err(e) => return Err(e.into()),
169 };
170 }
171 };
172
173 match tokio::time::timeout(tokio::time::Duration::from_secs(15), future).await {
174 Ok(result) => result,
175 Err(_) => Err(PingError::TimeoutError),
176 }
177 }
178
179 async fn wait_for_ack(
180 &self,
181 mut receiver: tokio::sync::broadcast::Receiver<ProtocolMessage>,
182 message_id: u16,
183 ) -> Result<(), PingError> {
184 let future = async move {
185 loop {
186 match receiver.recv().await {
187 Ok(answer) => {
188 if common::AckStruct::id() != answer.message_id
189 && common::NackStruct::id() != answer.message_id
190 {
191 continue;
192 }
193 match Messages::try_from(&answer) {
194 Ok(Messages::Common(common::Messages::Ack(answer))) => {
195 if answer.acked_id != message_id {
196 continue;
197 };
198 return Ok(());
199 }
200 Ok(Messages::Common(common::Messages::Nack(answer))) => {
201 if answer.nacked_id != message_id {
202 continue;
203 };
204 return Err(PingError::NackError(answer.nack_message));
205 }
206 _ => return Err(PingError::TryFromError(answer)), };
208 }
209 Err(broadcast::error::RecvError::Lagged(_)) => continue,
210 Err(e) => return Err(e.into()),
211 };
212 }
213 };
214
215 match tokio::time::timeout(tokio::time::Duration::from_secs(15), future).await {
216 Ok(result) => result,
217 Err(_) => Err(PingError::TimeoutError),
218 }
219 }
220
221 async fn request<T: 'static>(&self) -> Result<T, PingError>
222 where
223 T: crate::message::MessageInfo + std::marker::Sync + Clone + std::marker::Send,
224 {
225 let receiver = self.subscribe();
226
227 self.send_general_request(T::id()).await?;
228
229 self.wait_for_message(receiver).await
230 }
231
232 #[doc = "Device information"]
233 async fn device_information(&self) -> Result<common::DeviceInformationStruct, PingError> {
234 self.request().await
235 }
236 #[doc = "The protocol version"]
237 async fn protocol_version(&self) -> Result<common::ProtocolVersionStruct, PingError> {
238 self.request().await
239 }
240 #[doc = "Set the device ID."]
241 #[doc = "# Arguments"]
242 #[doc = "* `device_id` - Device ID (1-254). 0 is unknown and 255 is reserved for broadcast messages."]
243 async fn set_device_id(&self, device_id: u8) -> Result<(), PingError> {
244 let request = common::Messages::SetDeviceId(common::SetDeviceIdStruct { device_id });
245 let mut package = ProtocolMessage::new();
246 package.set_message(&request);
247 let receiver = self.subscribe();
248 self.get_common().send_message(package).await?;
249 self.wait_for_ack(receiver, common::SetDeviceIdStruct::id())
250 .await
251 }
252}