1use anyhow::{anyhow, Result};
2use bytes::BytesMut;
3use etherparse::{EtherType, Ethernet2Header, IpNumber, Ipv4Header, Ipv6Header, TcpHeader};
4use log::{debug, trace, warn};
5use smoltcp::wire::EthernetAddress;
6use std::{
7 collections::{hash_map::Entry, HashMap},
8 sync::Arc,
9};
10use tokio::sync::broadcast::{
11 channel as broadcast_channel, Receiver as BroadcastReceiver, Sender as BroadcastSender,
12};
13use tokio::{
14 select,
15 sync::{
16 mpsc::{channel, Receiver, Sender},
17 Mutex,
18 },
19 task::JoinHandle,
20};
21
22const TO_BRIDGE_QUEUE_LEN: usize = 3000;
23const FROM_BRIDGE_QUEUE_LEN: usize = 3000;
24const BROADCAST_QUEUE_LEN: usize = 3000;
25const MEMBER_LEAVE_QUEUE_LEN: usize = 30;
26
27#[derive(Debug)]
28struct BridgeMember {
29 pub from_bridge_sender: Sender<BytesMut>,
30}
31
32pub struct BridgeJoinHandle {
33 mac: EthernetAddress,
34 pub to_bridge_sender: Sender<BytesMut>,
35 pub from_bridge_receiver: Receiver<BytesMut>,
36 pub from_broadcast_receiver: BroadcastReceiver<BytesMut>,
37 member_leave_sender: Sender<EthernetAddress>,
38}
39
40impl Drop for BridgeJoinHandle {
41 fn drop(&mut self) {
42 if let Err(error) = self.member_leave_sender.try_send(self.mac) {
43 warn!(
44 "virtual bridge member {} failed to leave: {}",
45 self.mac, error
46 );
47 }
48 }
49}
50
51type VirtualBridgeMemberMap = Arc<Mutex<HashMap<EthernetAddress, BridgeMember>>>;
52
53#[derive(Clone)]
54pub struct VirtualBridge {
55 to_bridge_sender: Sender<BytesMut>,
56 from_broadcast_sender: BroadcastSender<BytesMut>,
57 member_leave_sender: Sender<EthernetAddress>,
58 members: VirtualBridgeMemberMap,
59 _task: Arc<JoinHandle<()>>,
60}
61
62enum VirtualBridgeSelect {
63 BroadcastSent,
64 PacketReceived(Option<BytesMut>),
65 MemberLeave(Option<EthernetAddress>),
66}
67
68impl VirtualBridge {
69 pub fn new() -> Result<VirtualBridge> {
70 let (to_bridge_sender, to_bridge_receiver) = channel::<BytesMut>(TO_BRIDGE_QUEUE_LEN);
71 let (member_leave_sender, member_leave_reciever) =
72 channel::<EthernetAddress>(MEMBER_LEAVE_QUEUE_LEN);
73 let (from_broadcast_sender, from_broadcast_receiver) =
74 broadcast_channel(BROADCAST_QUEUE_LEN);
75
76 let members = Arc::new(Mutex::new(HashMap::new()));
77 let handle = {
78 let members = members.clone();
79 let broadcast_rx_sender = from_broadcast_sender.clone();
80 tokio::task::spawn(async move {
81 if let Err(error) = VirtualBridge::process(
82 members,
83 member_leave_reciever,
84 to_bridge_receiver,
85 broadcast_rx_sender,
86 from_broadcast_receiver,
87 )
88 .await
89 {
90 warn!("virtual bridge processing task failed: {}", error);
91 }
92 })
93 };
94
95 Ok(VirtualBridge {
96 to_bridge_sender,
97 from_broadcast_sender,
98 member_leave_sender,
99 members,
100 _task: Arc::new(handle),
101 })
102 }
103
104 pub async fn join(&self, mac: EthernetAddress) -> Result<BridgeJoinHandle> {
105 let (from_bridge_sender, from_bridge_receiver) = channel::<BytesMut>(FROM_BRIDGE_QUEUE_LEN);
106 let member = BridgeMember { from_bridge_sender };
107
108 match self.members.lock().await.entry(mac) {
109 Entry::Occupied(_) => {
110 return Err(anyhow!("virtual bridge member {} already exists", mac));
111 }
112 Entry::Vacant(entry) => {
113 entry.insert(member);
114 }
115 };
116 debug!("virtual bridge member {} has joined", mac);
117 Ok(BridgeJoinHandle {
118 mac,
119 member_leave_sender: self.member_leave_sender.clone(),
120 from_bridge_receiver,
121 from_broadcast_receiver: self.from_broadcast_sender.subscribe(),
122 to_bridge_sender: self.to_bridge_sender.clone(),
123 })
124 }
125
126 async fn process(
127 members: VirtualBridgeMemberMap,
128 mut member_leave_reciever: Receiver<EthernetAddress>,
129 mut to_bridge_receiver: Receiver<BytesMut>,
130 broadcast_rx_sender: BroadcastSender<BytesMut>,
131 mut from_broadcast_receiver: BroadcastReceiver<BytesMut>,
132 ) -> Result<()> {
133 loop {
134 let selection = select! {
135 biased;
136 x = to_bridge_receiver.recv() => VirtualBridgeSelect::PacketReceived(x),
137 _ = from_broadcast_receiver.recv() => VirtualBridgeSelect::BroadcastSent,
138 x = member_leave_reciever.recv() => VirtualBridgeSelect::MemberLeave(x),
139 };
140
141 match selection {
142 VirtualBridgeSelect::PacketReceived(Some(mut packet)) => {
143 let (header, payload) = match Ethernet2Header::from_slice(&packet) {
144 Ok(data) => data,
145 Err(error) => {
146 debug!("virtual bridge failed to parse ethernet header: {}", error);
147 continue;
148 }
149 };
150
151 if header.ether_type == EtherType::IPV4 {
156 let (ipv4, payload) = Ipv4Header::from_slice(payload)?;
157 if ipv4.protocol == IpNumber::TCP {
158 let (mut tcp, payload) = TcpHeader::from_slice(payload)?;
159 tcp.checksum = tcp.calc_checksum_ipv4(&ipv4, payload)?;
160 let tcp_header_offset = Ethernet2Header::LEN + ipv4.header_len();
161 let mut header = &mut packet[tcp_header_offset..];
162 tcp.write(&mut header)?;
163 }
164 } else if header.ether_type == EtherType::IPV6 {
165 let (ipv6, payload) = Ipv6Header::from_slice(payload)?;
166 if ipv6.next_header == IpNumber::TCP {
167 let (mut tcp, payload) = TcpHeader::from_slice(payload)?;
168 tcp.checksum = tcp.calc_checksum_ipv6(&ipv6, payload)?;
169 let tcp_header_offset = Ethernet2Header::LEN + ipv6.header_len();
170 let mut header = &mut packet[tcp_header_offset..];
171 tcp.write(&mut header)?;
172 }
173 }
174
175 let destination = EthernetAddress(header.destination);
176 if destination.is_multicast() {
177 broadcast_rx_sender.send(packet)?;
178 continue;
179 }
180 match members.lock().await.get(&destination) {
181 Some(member) => {
182 member.from_bridge_sender.try_send(packet)?;
183 trace!(
184 "sending bridged packet from {} to {}",
185 EthernetAddress(header.source),
186 EthernetAddress(header.destination)
187 );
188 }
189 None => {
190 trace!("no bridge member with address: {}", destination);
191 }
192 }
193 }
194
195 VirtualBridgeSelect::MemberLeave(Some(mac)) => {
196 if members.lock().await.remove(&mac).is_some() {
197 debug!("virtual bridge member {} has left", mac);
198 }
199 }
200
201 VirtualBridgeSelect::PacketReceived(None) => break,
202 VirtualBridgeSelect::MemberLeave(None) => {}
203 VirtualBridgeSelect::BroadcastSent => {}
204 }
205 }
206 Ok(())
207 }
208}