#![deny(missing_docs)]
#![deny(unsafe_code)]
use std::collections::HashMap;
use std::pin::Pin;
use std::sync::Arc;
use std::time::Instant;
use arc_swap::ArcSwapOption;
use blocking::Unblock;
use futures::{
channel::{
mpsc::{self, UnboundedSender},
oneshot,
},
future::Either,
prelude::*,
StreamExt,
};
use mavlink::{MavHeader, Message};
pub mod prelude;
mod types;
mod util;
pub use types::{AsyncMavlinkError, MavMessageType};
pub use util::{to_char_arr, to_string};
pub struct AsyncMavConn<M: mavlink::Message> {
task_dispatcher: UnboundedSender<Task<M>>,
last_heartbeat: Arc<ArcSwapOption<Instant>>,
}
enum Task<M: mavlink::Message> {
Emit {
header: MavHeader,
message: M,
backchannel: oneshot::Sender<Result<(), AsyncMavlinkError>>,
},
Subscribe {
message_type: MavMessageType<M>,
backchannel: UnboundedSender<M>,
},
}
impl<M: 'static + mavlink::Message + Clone + Send + Sync> AsyncMavConn<M> {
pub fn new(
address: &str,
mavlink_version: mavlink::MavlinkVersion,
) -> Result<(Arc<Self>, impl Future<Output = impl Send> + Send), AsyncMavlinkError> {
let mut conn = mavlink::connect::<M>(address)?;
conn.set_protocol_version(mavlink_version);
let (task_dispatcher, incoming_tasks) = mpsc::unbounded();
let last_heartbeat = Arc::new(ArcSwapOption::from(None));
let f = {
let last_heartbeat = last_heartbeat.clone();
async move {
let mut subscriptions: HashMap<_, Vec<UnboundedSender<M>>> = HashMap::new();
let conn = Arc::new(conn);
let messages_iter = std::iter::repeat_with({
let conn = conn.clone();
move || conn.recv()
});
let messages = Unblock::new(messages_iter).map(Either::Right);
let operations = incoming_tasks.map(Either::Left);
let mut combined = stream::select(operations, messages);
let heartbeat_id =
mavlink::common::MavMessage::HEARTBEAT(Default::default()).message_id();
loop {
match combined.next().await.unwrap() {
Either::Left(Task::Subscribe {
message_type,
backchannel,
}) => {
let message_subs = subscriptions
.entry(message_type)
.or_insert_with(|| Vec::with_capacity(1));
message_subs.push(backchannel);
}
Either::Left(Task::Emit {
header,
message,
backchannel,
}) => {
let result = conn.send(&header, &message).map_err(|e| e.into());
let _ = backchannel.send(result); }
Either::Right(Ok((_header, msg))) => {
if msg.message_id() == heartbeat_id {
last_heartbeat.rcu(|_| Some(Arc::new(Instant::now())));
}
subscriptions
.entry(MavMessageType::new(&msg))
.or_insert_with(Vec::new)
.retain(|mut backchannel| match backchannel.is_closed() {
true => false,
false => {
futures::executor::block_on(backchannel.send(msg.clone()))
.expect("unable to do this");
true
}
});
}
_ => {}
}
}
}
};
Ok((
Arc::new(Self {
task_dispatcher,
last_heartbeat,
}),
Box::pin(f),
))
}
pub async fn subscribe(
&self,
message_type: MavMessageType<M>,
) -> Result<Pin<Box<dyn Stream<Item = M>>>, AsyncMavlinkError> {
let (backchannel, rx) = mpsc::unbounded();
self.task_dispatcher
.clone()
.send(Task::Subscribe {
message_type,
backchannel,
})
.await?; Ok(Box::pin(rx))
}
pub async fn request(&self, message_type: MavMessageType<M>) -> M {
let (backchannel, mut rx) = mpsc::unbounded();
self.task_dispatcher
.clone()
.send(Task::Subscribe {
message_type,
backchannel,
})
.await
.unwrap(); rx.next().map(|m| m.expect("Oh no!")).await
}
pub async fn send(&self, header: &MavHeader, message: &M) -> Result<(), AsyncMavlinkError> {
let (backchannel, receiver) = oneshot::channel();
self.task_dispatcher
.clone() .send(Task::Emit {
header: *header,
message: message.clone(),
backchannel,
})
.await
.map_err(AsyncMavlinkError::from)?;
receiver.await.map_err(AsyncMavlinkError::from)?
}
pub async fn send_default(&self, message: &M) -> Result<(), AsyncMavlinkError> {
Ok(self.send(&MavHeader::default(), message).await?)
}
pub async fn last_heartbeat(&self) -> Option<Instant> {
self.last_heartbeat.load_full().map(|arc| *arc)
}
}
#[cfg(test)]
mod test {
use super::*;
use mavlink::common::*;
use std::time::Duration;
async fn new_conn<M: 'static + mavlink::Message + Clone + Send + Sync>(
arg: &str,
) -> Arc<AsyncMavConn<M>> {
let (conn, future) = AsyncMavConn::new(arg, mavlink::MavlinkVersion::V1).unwrap();
smol::spawn(async move { future.await }).detach();
conn
}
#[test]
fn subscribe() -> Result<(), AsyncMavlinkError> {
smol::block_on(async {
let conn = new_conn("udpin:127.0.0.7:14551").await;
smol::spawn(async move {
let mut conn = mavlink::connect("udpout:127.0.0.7:14551").unwrap();
conn.set_protocol_version(mavlink::MavlinkVersion::V1);
loop {
conn.send_default(&MavMessage::HEARTBEAT(Default::default()))
.unwrap();
smol::Timer::after(Duration::from_millis(10)).await;
}
})
.detach();
let message_type = MavMessageType::new(&MavMessage::HEARTBEAT(Default::default()));
let mut stream = conn.subscribe(message_type).await?;
let mut i = 0;
while let Some(MavMessage::HEARTBEAT(_data)) = (stream.next()).await {
i += 1;
if i > 5 {
break;
}
}
Ok(())
})
}
#[test]
fn send() -> Result<(), AsyncMavlinkError> {
smol::block_on(async {
let conn = new_conn("udpout:127.0.0.8:14551").await;
let mut raw_conn = mavlink::connect("udpin:127.0.0.8:14551").unwrap();
raw_conn.set_protocol_version(mavlink::MavlinkVersion::V1);
let received = blocking::unblock(move || raw_conn.recv());
smol::spawn(async move {
let header = mavlink::MavHeader::default();
let message = MavMessage::HEARTBEAT(HEARTBEAT_DATA::default());
for _ in 0usize..100 {
conn.send(&header, &message).await.unwrap();
}
})
.detach();
let received = received.await;
assert!(received.is_ok());
if let Ok((_, MavMessage::HEARTBEAT(_))) = received {
} else {
panic!("received wrong message");
}
Ok(())
})
}
#[test]
fn send_default() -> std::io::Result<()> {
smol::block_on(async {
let conn = new_conn("udpout:127.0.0.9:14551").await;
let mut raw_conn = mavlink::connect("udpin:127.0.0.9:14551").unwrap();
raw_conn.set_protocol_version(mavlink::MavlinkVersion::V1);
let received = blocking::unblock(move || raw_conn.recv());
smol::spawn(async move {
let message = MavMessage::HEARTBEAT(HEARTBEAT_DATA::default());
for _ in 0usize..100 {
conn.send_default(&message).await.unwrap();
}
})
.detach();
let received = received.await;
assert!(received.is_ok());
if let Ok((_, MavMessage::HEARTBEAT(_))) = received {
} else {
panic!("received wrong message");
}
Ok(())
})
}
}