kratanet/
vbridge.rs

1use anyhow::{anyhow, Result};
2use bytes::BytesMut;
3use etherparse::{EtherType, Ethernet2Header, IpNumber, Ipv4Header, Ipv6Header, TcpHeader};
4use log::{debug, trace, warn};
5use smoltcp::wire::EthernetAddress;
6use std::{
7    collections::{hash_map::Entry, HashMap},
8    sync::Arc,
9};
10use tokio::sync::broadcast::{
11    channel as broadcast_channel, Receiver as BroadcastReceiver, Sender as BroadcastSender,
12};
13use tokio::{
14    select,
15    sync::{
16        mpsc::{channel, Receiver, Sender},
17        Mutex,
18    },
19    task::JoinHandle,
20};
21
22const TO_BRIDGE_QUEUE_LEN: usize = 3000;
23const FROM_BRIDGE_QUEUE_LEN: usize = 3000;
24const BROADCAST_QUEUE_LEN: usize = 3000;
25const MEMBER_LEAVE_QUEUE_LEN: usize = 30;
26
27#[derive(Debug)]
28struct BridgeMember {
29    pub from_bridge_sender: Sender<BytesMut>,
30}
31
32pub struct BridgeJoinHandle {
33    mac: EthernetAddress,
34    pub to_bridge_sender: Sender<BytesMut>,
35    pub from_bridge_receiver: Receiver<BytesMut>,
36    pub from_broadcast_receiver: BroadcastReceiver<BytesMut>,
37    member_leave_sender: Sender<EthernetAddress>,
38}
39
40impl Drop for BridgeJoinHandle {
41    fn drop(&mut self) {
42        if let Err(error) = self.member_leave_sender.try_send(self.mac) {
43            warn!(
44                "virtual bridge member {} failed to leave: {}",
45                self.mac, error
46            );
47        }
48    }
49}
50
51type VirtualBridgeMemberMap = Arc<Mutex<HashMap<EthernetAddress, BridgeMember>>>;
52
53#[derive(Clone)]
54pub struct VirtualBridge {
55    to_bridge_sender: Sender<BytesMut>,
56    from_broadcast_sender: BroadcastSender<BytesMut>,
57    member_leave_sender: Sender<EthernetAddress>,
58    members: VirtualBridgeMemberMap,
59    _task: Arc<JoinHandle<()>>,
60}
61
62enum VirtualBridgeSelect {
63    BroadcastSent,
64    PacketReceived(Option<BytesMut>),
65    MemberLeave(Option<EthernetAddress>),
66}
67
68impl VirtualBridge {
69    pub fn new() -> Result<VirtualBridge> {
70        let (to_bridge_sender, to_bridge_receiver) = channel::<BytesMut>(TO_BRIDGE_QUEUE_LEN);
71        let (member_leave_sender, member_leave_reciever) =
72            channel::<EthernetAddress>(MEMBER_LEAVE_QUEUE_LEN);
73        let (from_broadcast_sender, from_broadcast_receiver) =
74            broadcast_channel(BROADCAST_QUEUE_LEN);
75
76        let members = Arc::new(Mutex::new(HashMap::new()));
77        let handle = {
78            let members = members.clone();
79            let broadcast_rx_sender = from_broadcast_sender.clone();
80            tokio::task::spawn(async move {
81                if let Err(error) = VirtualBridge::process(
82                    members,
83                    member_leave_reciever,
84                    to_bridge_receiver,
85                    broadcast_rx_sender,
86                    from_broadcast_receiver,
87                )
88                .await
89                {
90                    warn!("virtual bridge processing task failed: {}", error);
91                }
92            })
93        };
94
95        Ok(VirtualBridge {
96            to_bridge_sender,
97            from_broadcast_sender,
98            member_leave_sender,
99            members,
100            _task: Arc::new(handle),
101        })
102    }
103
104    pub async fn join(&self, mac: EthernetAddress) -> Result<BridgeJoinHandle> {
105        let (from_bridge_sender, from_bridge_receiver) = channel::<BytesMut>(FROM_BRIDGE_QUEUE_LEN);
106        let member = BridgeMember { from_bridge_sender };
107
108        match self.members.lock().await.entry(mac) {
109            Entry::Occupied(_) => {
110                return Err(anyhow!("virtual bridge member {} already exists", mac));
111            }
112            Entry::Vacant(entry) => {
113                entry.insert(member);
114            }
115        };
116        debug!("virtual bridge member {} has joined", mac);
117        Ok(BridgeJoinHandle {
118            mac,
119            member_leave_sender: self.member_leave_sender.clone(),
120            from_bridge_receiver,
121            from_broadcast_receiver: self.from_broadcast_sender.subscribe(),
122            to_bridge_sender: self.to_bridge_sender.clone(),
123        })
124    }
125
126    async fn process(
127        members: VirtualBridgeMemberMap,
128        mut member_leave_reciever: Receiver<EthernetAddress>,
129        mut to_bridge_receiver: Receiver<BytesMut>,
130        broadcast_rx_sender: BroadcastSender<BytesMut>,
131        mut from_broadcast_receiver: BroadcastReceiver<BytesMut>,
132    ) -> Result<()> {
133        loop {
134            let selection = select! {
135                biased;
136                x = to_bridge_receiver.recv() => VirtualBridgeSelect::PacketReceived(x),
137                _ = from_broadcast_receiver.recv() => VirtualBridgeSelect::BroadcastSent,
138                x = member_leave_reciever.recv() => VirtualBridgeSelect::MemberLeave(x),
139            };
140
141            match selection {
142                VirtualBridgeSelect::PacketReceived(Some(mut packet)) => {
143                    let (header, payload) = match Ethernet2Header::from_slice(&packet) {
144                        Ok(data) => data,
145                        Err(error) => {
146                            debug!("virtual bridge failed to parse ethernet header: {}", error);
147                            continue;
148                        }
149                    };
150
151                    // recalculate TCP checksums when routing packets.
152                    // the xen network backend / frontend drivers for linux
153                    // use checksum offloading but since we bypass some layers
154                    // of the kernel we have to do it ourselves.
155                    if header.ether_type == EtherType::IPV4 {
156                        let (ipv4, payload) = Ipv4Header::from_slice(payload)?;
157                        if ipv4.protocol == IpNumber::TCP {
158                            let (mut tcp, payload) = TcpHeader::from_slice(payload)?;
159                            tcp.checksum = tcp.calc_checksum_ipv4(&ipv4, payload)?;
160                            let tcp_header_offset = Ethernet2Header::LEN + ipv4.header_len();
161                            let mut header = &mut packet[tcp_header_offset..];
162                            tcp.write(&mut header)?;
163                        }
164                    } else if header.ether_type == EtherType::IPV6 {
165                        let (ipv6, payload) = Ipv6Header::from_slice(payload)?;
166                        if ipv6.next_header == IpNumber::TCP {
167                            let (mut tcp, payload) = TcpHeader::from_slice(payload)?;
168                            tcp.checksum = tcp.calc_checksum_ipv6(&ipv6, payload)?;
169                            let tcp_header_offset = Ethernet2Header::LEN + ipv6.header_len();
170                            let mut header = &mut packet[tcp_header_offset..];
171                            tcp.write(&mut header)?;
172                        }
173                    }
174
175                    let destination = EthernetAddress(header.destination);
176                    if destination.is_multicast() {
177                        broadcast_rx_sender.send(packet)?;
178                        continue;
179                    }
180                    match members.lock().await.get(&destination) {
181                        Some(member) => {
182                            member.from_bridge_sender.try_send(packet)?;
183                            trace!(
184                                "sending bridged packet from {} to {}",
185                                EthernetAddress(header.source),
186                                EthernetAddress(header.destination)
187                            );
188                        }
189                        None => {
190                            trace!("no bridge member with address: {}", destination);
191                        }
192                    }
193                }
194
195                VirtualBridgeSelect::MemberLeave(Some(mac)) => {
196                    if members.lock().await.remove(&mac).is_some() {
197                        debug!("virtual bridge member {} has left", mac);
198                    }
199                }
200
201                VirtualBridgeSelect::PacketReceived(None) => break,
202                VirtualBridgeSelect::MemberLeave(None) => {}
203                VirtualBridgeSelect::BroadcastSent => {}
204            }
205        }
206        Ok(())
207    }
208}