use std::net::SocketAddr;
use std::sync::atomic::{AtomicBool, AtomicU64, Ordering};
use std::sync::Arc;
use std::time::Instant;
use tokio::net::UdpSocket;
use tokio::sync::mpsc;
use tracing::{debug, error, info, warn};
use crate::error::{BacnetError, BacnetResult};
use super::bvlc::{BvlcMessage, BACNET_IP_PORT};
pub const DEFAULT_PORT: u16 = BACNET_IP_PORT;
pub const DEFAULT_BUFFER_SIZE: usize = 1500;
pub const MAX_APDU_SIZE: u16 = 1476;
#[derive(Debug, Clone)]
pub struct NetworkConfig {
pub bind_addr: SocketAddr,
pub broadcast_addr: SocketAddr,
pub recv_buffer_size: usize,
pub send_buffer_size: usize,
pub max_apdu_size: u16,
pub channel_buffer_size: usize,
}
impl Default for NetworkConfig {
fn default() -> Self {
Self {
bind_addr: format!("0.0.0.0:{}", DEFAULT_PORT).parse().unwrap(),
broadcast_addr: format!("255.255.255.255:{}", DEFAULT_PORT).parse().unwrap(),
recv_buffer_size: DEFAULT_BUFFER_SIZE,
send_buffer_size: DEFAULT_BUFFER_SIZE,
max_apdu_size: MAX_APDU_SIZE,
channel_buffer_size: 10_000,
}
}
}
impl NetworkConfig {
pub fn with_bind_addr(mut self, addr: SocketAddr) -> Self {
self.bind_addr = addr;
self
}
pub fn with_broadcast_addr(mut self, addr: SocketAddr) -> Self {
self.broadcast_addr = addr;
self
}
pub fn with_port(mut self, port: u16) -> Self {
self.bind_addr.set_port(port);
self.broadcast_addr.set_port(port);
self
}
}
#[derive(Debug, Clone)]
pub struct IncomingPacket {
pub data: Vec<u8>,
pub source: SocketAddr,
pub timestamp: Instant,
pub bvlc: Option<BvlcMessage>,
}
impl IncomingPacket {
pub fn new(data: Vec<u8>, source: SocketAddr) -> Self {
let bvlc = BvlcMessage::decode(&data).ok();
Self {
data,
source,
timestamp: Instant::now(),
bvlc,
}
}
pub fn message(&self) -> Option<&BvlcMessage> {
self.bvlc.as_ref()
}
pub fn effective_source(&self) -> SocketAddr {
self.bvlc
.as_ref()
.and_then(|m| m.effective_source())
.unwrap_or(self.source)
}
}
#[derive(Debug, Default)]
pub struct NetworkStats {
pub packets_received: AtomicU64,
pub packets_sent: AtomicU64,
pub bytes_received: AtomicU64,
pub bytes_sent: AtomicU64,
pub receive_errors: AtomicU64,
pub send_errors: AtomicU64,
pub dropped_packets: AtomicU64,
}
impl NetworkStats {
pub fn record_received(&self, bytes: u64) {
self.packets_received.fetch_add(1, Ordering::Relaxed);
self.bytes_received.fetch_add(bytes, Ordering::Relaxed);
}
pub fn record_sent(&self, bytes: u64) {
self.packets_sent.fetch_add(1, Ordering::Relaxed);
self.bytes_sent.fetch_add(bytes, Ordering::Relaxed);
}
pub fn record_receive_error(&self) {
self.receive_errors.fetch_add(1, Ordering::Relaxed);
}
pub fn record_send_error(&self) {
self.send_errors.fetch_add(1, Ordering::Relaxed);
}
pub fn record_dropped(&self) {
self.dropped_packets.fetch_add(1, Ordering::Relaxed);
}
pub fn snapshot(&self) -> NetworkStatsSnapshot {
NetworkStatsSnapshot {
packets_received: self.packets_received.load(Ordering::Relaxed),
packets_sent: self.packets_sent.load(Ordering::Relaxed),
bytes_received: self.bytes_received.load(Ordering::Relaxed),
bytes_sent: self.bytes_sent.load(Ordering::Relaxed),
receive_errors: self.receive_errors.load(Ordering::Relaxed),
send_errors: self.send_errors.load(Ordering::Relaxed),
dropped_packets: self.dropped_packets.load(Ordering::Relaxed),
}
}
}
#[derive(Debug, Clone, Default)]
pub struct NetworkStatsSnapshot {
pub packets_received: u64,
pub packets_sent: u64,
pub bytes_received: u64,
pub bytes_sent: u64,
pub receive_errors: u64,
pub send_errors: u64,
pub dropped_packets: u64,
}
#[derive(Clone)]
pub struct NetworkHandle {
socket: Arc<UdpSocket>,
config: NetworkConfig,
stats: Arc<NetworkStats>,
}
impl NetworkHandle {
pub async fn send_to(&self, data: &[u8], dest: SocketAddr) -> BacnetResult<()> {
match self.socket.send_to(data, dest).await {
Ok(len) => {
self.stats.record_sent(len as u64);
debug!(dest = %dest, len, "Sent packet");
Ok(())
}
Err(e) => {
self.stats.record_send_error();
Err(BacnetError::Io(e))
}
}
}
pub async fn send_message(&self, msg: &BvlcMessage, dest: SocketAddr) -> BacnetResult<()> {
let data = msg.encode();
self.send_to(&data, dest).await
}
pub async fn broadcast(&self, data: &[u8]) -> BacnetResult<()> {
self.send_to(data, self.config.broadcast_addr).await
}
pub async fn broadcast_message(&self, msg: &BvlcMessage) -> BacnetResult<()> {
let data = msg.encode();
self.broadcast(&data).await
}
pub fn stats(&self) -> NetworkStatsSnapshot {
self.stats.snapshot()
}
pub fn local_addr(&self) -> BacnetResult<SocketAddr> {
self.socket.local_addr().map_err(BacnetError::Io)
}
}
pub struct BACnetNetwork {
config: NetworkConfig,
socket: Arc<UdpSocket>,
recv_tx: mpsc::Sender<IncomingPacket>,
stats: Arc<NetworkStats>,
shutdown: Arc<AtomicBool>,
}
impl BACnetNetwork {
pub async fn bind(
config: NetworkConfig,
) -> BacnetResult<(Self, mpsc::Receiver<IncomingPacket>)> {
let socket = UdpSocket::bind(&config.bind_addr).await.map_err(|e| {
BacnetError::Server(format!("Failed to bind to {}: {}", config.bind_addr, e))
})?;
socket
.set_broadcast(true)
.map_err(|e| BacnetError::Server(format!("Failed to enable broadcast: {}", e)))?;
let (recv_tx, recv_rx) = mpsc::channel(config.channel_buffer_size);
info!(addr = %config.bind_addr, "BACnet/IP network bound");
Ok((
Self {
socket: Arc::new(socket),
recv_tx,
stats: Arc::new(NetworkStats::default()),
shutdown: Arc::new(AtomicBool::new(false)),
config,
},
recv_rx,
))
}
pub fn handle(&self) -> NetworkHandle {
NetworkHandle {
socket: self.socket.clone(),
config: self.config.clone(),
stats: self.stats.clone(),
}
}
pub fn stats(&self) -> &Arc<NetworkStats> {
&self.stats
}
pub fn shutdown(&self) {
self.shutdown.store(true, Ordering::SeqCst);
}
pub fn is_shutdown(&self) -> bool {
self.shutdown.load(Ordering::SeqCst)
}
pub async fn run_receive_loop(&self) -> BacnetResult<()> {
let mut buf = vec![0u8; self.config.recv_buffer_size];
info!("Starting BACnet/IP receive loop");
loop {
if self.shutdown.load(Ordering::SeqCst) {
info!("Receive loop shutdown requested");
break;
}
let recv_result = tokio::time::timeout(
std::time::Duration::from_millis(100),
self.socket.recv_from(&mut buf),
)
.await;
match recv_result {
Ok(Ok((len, source))) => {
self.stats.record_received(len as u64);
let packet = IncomingPacket::new(buf[..len].to_vec(), source);
if packet.bvlc.is_some() {
debug!(source = %source, len, "Received valid BVLC packet");
} else {
debug!(source = %source, len, "Received non-BVLC packet");
}
if self.recv_tx.try_send(packet).is_err() {
self.stats.record_dropped();
warn!("Receive channel full, dropping packet");
}
}
Ok(Err(e)) => {
self.stats.record_receive_error();
error!(error = %e, "UDP receive error");
}
Err(_) => {
continue;
}
}
}
info!("BACnet/IP receive loop stopped");
Ok(())
}
pub fn local_addr(&self) -> BacnetResult<SocketAddr> {
self.socket.local_addr().map_err(BacnetError::Io)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_network_config_default() {
let config = NetworkConfig::default();
assert_eq!(config.bind_addr.port(), DEFAULT_PORT);
assert_eq!(config.max_apdu_size, MAX_APDU_SIZE);
}
#[test]
fn test_network_config_with_port() {
let config = NetworkConfig::default().with_port(12345);
assert_eq!(config.bind_addr.port(), 12345);
assert_eq!(config.broadcast_addr.port(), 12345);
}
#[test]
fn test_network_stats() {
let stats = NetworkStats::default();
stats.record_received(100);
stats.record_received(200);
stats.record_sent(50);
stats.record_receive_error();
let snapshot = stats.snapshot();
assert_eq!(snapshot.packets_received, 2);
assert_eq!(snapshot.bytes_received, 300);
assert_eq!(snapshot.packets_sent, 1);
assert_eq!(snapshot.bytes_sent, 50);
assert_eq!(snapshot.receive_errors, 1);
}
#[test]
fn test_incoming_packet() {
let bvlc = BvlcMessage::original_unicast(vec![0x01, 0x04]);
let data = bvlc.encode().to_vec();
let source: SocketAddr = "192.168.1.100:47808".parse().unwrap();
let packet = IncomingPacket::new(data, source);
assert!(packet.message().is_some());
assert_eq!(packet.effective_source(), source);
}
}