use crate::dedup::ConcurrentDedup;
use crate::error::{Result, RouterError};
use crate::filter::EndpointFilters;
use crate::framing::{MavlinkFrame, StreamParser};
use crate::mavlink_utils::extract_target;
use crate::router::{EndpointId, RoutedMessage};
use crate::routing::RoutingTable;
use mavlink::Message;
use parking_lot::RwLock;
use std::sync::Arc;
use std::time::{Duration, Instant, SystemTime, UNIX_EPOCH};
use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt, BufWriter};
use tokio::sync::broadcast::{self, error::RecvError, error::TryRecvError};
use tokio_util::sync::CancellationToken;
use tracing::{trace, warn};
#[derive(Debug)]
pub struct ExponentialBackoff {
current: Duration,
min: Duration,
max: Duration,
multiplier: f64,
}
impl ExponentialBackoff {
pub fn new(min: Duration, max: Duration, multiplier: f64) -> Self {
Self {
current: min,
min,
max,
multiplier,
}
}
pub fn next_backoff(&mut self) -> Duration {
let wait = self.current;
self.current = std::cmp::min(
self.max,
Duration::from_secs_f64(self.current.as_secs_f64() * self.multiplier),
);
wait
}
pub fn reset(&mut self) {
self.current = self.min;
}
}
#[derive(Clone)]
pub struct EndpointCore {
pub id: EndpointId,
pub bus_tx: broadcast::Sender<RoutedMessage>,
pub routing_table: Arc<RwLock<RoutingTable>>,
pub dedup: ConcurrentDedup,
pub filters: EndpointFilters,
pub update_routing: bool,
}
#[inline(always)]
fn timestamp_us_fast() -> u64 {
static BASE: std::sync::OnceLock<(Instant, Duration)> = std::sync::OnceLock::new();
let (start_instant, start_unix) = BASE.get_or_init(|| {
let now = SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap_or_default();
(Instant::now(), now)
});
let total = start_unix.saturating_add(start_instant.elapsed());
total.as_micros().min(u64::MAX as u128) as u64
}
impl EndpointCore {
pub fn handle_incoming_frame(&self, frame: MavlinkFrame) {
let message_id = frame.message.message_id();
if frame.header.system_id == 0 {
trace!(
"Dropping message with sysid 0 (msg_id {}) from endpoint {}",
message_id,
self.id
);
return;
}
if !self.filters.check_incoming(&frame.header, message_id) {
return; }
let serialized_bytes = frame.raw_bytes;
if self.dedup.check_and_insert(&serialized_bytes) {
return; }
if self.update_routing {
let now = Instant::now();
let needs_update = {
let rt = self.routing_table.read();
rt.needs_update_for_endpoint(
self.id,
frame.header.system_id,
frame.header.component_id,
now,
)
};
if needs_update {
if let Some(mut rt) = self.routing_table.try_write() {
rt.update(
self.id,
frame.header.system_id,
frame.header.component_id,
now,
);
} else {
let mut rt = self.routing_table.write();
rt.update(
self.id,
frame.header.system_id,
frame.header.component_id,
now,
);
}
}
}
let timestamp_us = timestamp_us_fast();
let target = extract_target(&frame.message);
if let Err(e) = self.bus_tx.send(RoutedMessage {
source_id: self.id,
header: frame.header,
message_id,
version: frame.version,
timestamp_us,
serialized_bytes,
target,
}) {
warn!("Bus send error: {:?}", e);
}
}
pub fn check_outgoing(&self, msg: &RoutedMessage) -> bool {
if msg.source_id == self.id {
return false;
}
if !self.filters.check_outgoing(&msg.header, msg.message_id) {
return false;
}
let target = msg.target;
if target.system_id == 0 {
return true;
}
let rt = self.routing_table.read();
let should_send = rt.should_send(self.id, target.system_id, target.component_id);
if !should_send {
trace!(
endpoint_id = %self.id,
source_id = %msg.source_id,
target_sys = target.system_id,
target_comp = target.component_id,
msg_id = msg.message_id,
"Routing decision: DROP (no route to target)"
);
} else {
trace!(
endpoint_id = %self.id,
source_id = %msg.source_id,
target_sys = target.system_id,
target_comp = target.component_id,
msg_id = msg.message_id,
"Routing decision: FORWARD"
);
}
should_send
}
}
pub async fn run_stream_loop<R, W>(
mut reader: R,
writer: W,
mut bus_rx: broadcast::Receiver<RoutedMessage>,
core: EndpointCore,
cancel_token: CancellationToken,
name: String,
) -> Result<()>
where
R: AsyncRead + Unpin + Send + 'static,
W: AsyncWrite + Unpin + Send + 'static,
{
let core_read = core.clone();
let name_read = name.clone();
let cancel_token_for_reader_loop = cancel_token.clone();
let cancel_token_for_writer_loop = cancel_token.clone();
let cancel_token_for_final_select = cancel_token.clone();
let reader_loop = async move {
let mut parser = StreamParser::new();
let mut buf = [0u8; 4096];
loop {
if cancel_token_for_reader_loop.is_cancelled() {
return Ok(());
}
match reader.read(&mut buf).await {
Ok(0) => return Ok(()), Ok(n) => {
parser.push(&buf[..n]);
while let Some(frame) = parser.parse_next() {
core_read.handle_incoming_frame(frame);
}
}
Err(e) => {
return Err(RouterError::network(&name_read, e));
}
}
}
};
let writer_loop = async move {
let mut writer = BufWriter::with_capacity(65536, writer);
loop {
let msg = match bus_rx.recv().await {
Ok(msg) => msg,
Err(RecvError::Lagged(n)) => {
warn!("{} Sender lagged: missed {} messages", name, n);
continue;
}
Err(RecvError::Closed) => break,
};
if cancel_token_for_writer_loop.is_cancelled() {
break;
}
if !core.check_outgoing(&msg) {
continue;
}
if let Err(e) = writer.write_all(&msg.serialized_bytes).await {
return Err(RouterError::network(&name, e));
}
const BATCH_SIZE: usize = 1024;
for _ in 0..BATCH_SIZE {
match bus_rx.try_recv() {
Ok(m) => {
if core.check_outgoing(&m) {
if let Err(e) = writer.write_all(&m.serialized_bytes).await {
return Err(RouterError::network(&name, e));
}
}
}
Err(TryRecvError::Empty) => break,
Err(TryRecvError::Lagged(n)) => {
warn!("{} Sender lagged: missed {} messages", name, n);
}
Err(TryRecvError::Closed) => return Ok(()),
}
}
if let Err(e) = writer.flush().await {
return Err(RouterError::network(&name, e));
}
}
Ok(())
};
tokio::select! {
result = reader_loop => result,
result = writer_loop => result,
_ = cancel_token_for_final_select.cancelled() => Ok(()),
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::time::SystemTime;
#[test]
fn test_exponential_backoff_initial() {
let mut backoff =
ExponentialBackoff::new(Duration::from_secs(1), Duration::from_secs(60), 2.0);
assert_eq!(backoff.next_backoff(), Duration::from_secs(1));
}
#[test]
fn test_exponential_backoff_doubles() {
let mut backoff =
ExponentialBackoff::new(Duration::from_secs(1), Duration::from_secs(60), 2.0);
assert_eq!(backoff.next_backoff(), Duration::from_secs(1));
assert_eq!(backoff.next_backoff(), Duration::from_secs(2));
assert_eq!(backoff.next_backoff(), Duration::from_secs(4));
assert_eq!(backoff.next_backoff(), Duration::from_secs(8));
}
#[test]
fn test_exponential_backoff_caps_at_max() {
let mut backoff =
ExponentialBackoff::new(Duration::from_secs(10), Duration::from_secs(30), 2.0);
assert_eq!(backoff.next_backoff(), Duration::from_secs(10));
assert_eq!(backoff.next_backoff(), Duration::from_secs(20));
assert_eq!(backoff.next_backoff(), Duration::from_secs(30));
assert_eq!(backoff.next_backoff(), Duration::from_secs(30));
}
#[test]
fn test_exponential_backoff_reset() {
let mut backoff =
ExponentialBackoff::new(Duration::from_secs(1), Duration::from_secs(60), 2.0);
backoff.next_backoff(); backoff.next_backoff(); backoff.next_backoff();
backoff.reset();
assert_eq!(backoff.next_backoff(), Duration::from_secs(1));
}
#[test]
fn test_exponential_backoff_custom_multiplier() {
let mut backoff =
ExponentialBackoff::new(Duration::from_secs(1), Duration::from_secs(100), 3.0);
assert_eq!(backoff.next_backoff(), Duration::from_secs(1));
assert_eq!(backoff.next_backoff(), Duration::from_secs(3));
assert_eq!(backoff.next_backoff(), Duration::from_secs(9));
assert_eq!(backoff.next_backoff(), Duration::from_secs(27));
}
#[test]
fn test_timestamp_us_fast_monotonic_walltime() {
let t1 = timestamp_us_fast();
let t2 = timestamp_us_fast();
assert!(t2 >= t1);
let now = SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap_or_default()
.as_micros() as u64;
let tolerance = 5_000_000; assert!(t2 + tolerance >= now);
assert!(t2 <= now + tolerance);
}
}