retty/bootstrap/bootstrap_udp/
mod.rs

1use super::*;
2use crate::transport::Protocol;
3use async_transport::{AsyncUdpSocket, Capabilities, RecvMeta, Transmit, UdpSocket, BATCH_SIZE};
4use std::mem::MaybeUninit;
5
6pub(crate) mod bootstrap_udp_client;
7pub(crate) mod bootstrap_udp_server;
8
9struct BootstrapUdp<W> {
10    boostrap: Bootstrap<W>,
11
12    socket: Option<UdpSocket>,
13}
14
15impl<W: 'static> Default for BootstrapUdp<W> {
16    fn default() -> Self {
17        Self::new()
18    }
19}
20
21impl<W: 'static> BootstrapUdp<W> {
22    fn new() -> Self {
23        Self {
24            boostrap: Bootstrap::new(),
25
26            socket: None,
27        }
28    }
29
30    fn max_payload_size(&mut self, max_payload_size: usize) -> &mut Self {
31        self.boostrap.max_payload_size(max_payload_size);
32        self
33    }
34
35    fn pipeline(&mut self, pipeline_factory_fn: PipelineFactoryFn<TaggedBytesMut, W>) -> &mut Self {
36        self.boostrap.pipeline(pipeline_factory_fn);
37        self
38    }
39
40    async fn bind<A: AsyncToSocketAddrs>(&mut self, addr: A) -> Result<SocketAddr, Error> {
41        let socket = UdpSocket::bind(addr).await?;
42        let local_addr = socket.local_addr()?;
43        self.socket = Some(socket);
44        Ok(local_addr)
45    }
46
47    async fn connect(
48        &mut self,
49        _peer_addr: Option<SocketAddr>,
50    ) -> Result<Rc<dyn OutboundPipeline<TaggedBytesMut, W>>, Error> {
51        let socket = self.socket.take().unwrap();
52        let local_addr = socket.local_addr()?;
53
54        let pipeline_factory_fn = Rc::clone(self.boostrap.pipeline_factory_fn.as_ref().unwrap());
55        let pipeline = (pipeline_factory_fn)();
56        let pipeline_wr = Rc::clone(&pipeline);
57
58        let (close_tx, mut close_rx) = async_broadcast::broadcast(1);
59        {
60            let mut tx = self.boostrap.close_tx.borrow_mut();
61            *tx = Some(close_tx);
62        }
63
64        let worker = {
65            let workgroup = WaitGroup::new();
66            let worker = workgroup.worker();
67            {
68                let mut wg = self.boostrap.wg.borrow_mut();
69                *wg = Some(workgroup);
70            }
71            worker
72        };
73
74        let max_payload_size = self.boostrap.max_payload_size;
75
76        spawn_local(async move {
77            let _w = worker;
78
79            let capabilities = Capabilities::new();
80            let buf = vec![0u8; max_payload_size * capabilities.gro_segments() * BATCH_SIZE];
81            let buf_len = buf.len();
82            let mut recv_buf: Box<[u8]> = buf.into();
83            let mut metas = [RecvMeta::default(); BATCH_SIZE];
84            let mut iovs = MaybeUninit::<[std::io::IoSliceMut<'_>; BATCH_SIZE]>::uninit();
85            recv_buf
86                .chunks_mut(buf_len / BATCH_SIZE)
87                .enumerate()
88                .for_each(|(i, buf)| unsafe {
89                    iovs.as_mut_ptr()
90                        .cast::<std::io::IoSliceMut<'_>>()
91                        .add(i)
92                        .write(std::io::IoSliceMut::<'_>::new(buf));
93                });
94            let mut iovs = unsafe { iovs.assume_init() };
95
96            pipeline.transport_active();
97            loop {
98                // prioritize socket.write than socket.read
99                while let Some(msg) = pipeline.poll_transmit() {
100                    let transmit = Transmit {
101                        destination: msg.transport.peer_addr,
102                        ecn: msg.transport.ecn,
103                        contents: msg.message.to_vec(),
104                        segment_size: None,
105                        src_ip: Some(msg.transport.local_addr.ip()),
106                    };
107                    match socket.send(&capabilities, &[transmit]).await {
108                        Ok(_) => {
109                            trace!("socket write {} bytes", msg.message.len());
110                        }
111                        Err(err) => {
112                            warn!("socket write error {}", err);
113                            break;
114                        }
115                    }
116                }
117
118                let mut eto = Instant::now() + Duration::from_secs(MAX_DURATION_IN_SECS);
119                pipeline.poll_timeout(&mut eto);
120
121                let delay_from_now = eto
122                    .checked_duration_since(Instant::now())
123                    .unwrap_or(Duration::from_secs(0));
124                if delay_from_now.is_zero() {
125                    pipeline.handle_timeout(Instant::now());
126                    continue;
127                }
128
129                let timeout = Timer::after(delay_from_now);
130
131                tokio::select! {
132                    _ = close_rx.recv() => {
133                        trace!("pipeline socket exit loop");
134                        break;
135                    }
136                    _ = timeout => {
137                        pipeline.handle_timeout(Instant::now());
138                    }
139                    res = socket.recv(&mut iovs, &mut metas) => {
140                        match res {
141                            Ok(n) => {
142                                if n == 0 {
143                                    pipeline.handle_read_eof();
144                                    break;
145                                }
146
147                                for (meta, buf) in metas.iter().zip(iovs.iter()).take(n) {
148                                    let message: BytesMut = buf[0..meta.len].into();
149                                    if !message.is_empty() {
150                                        trace!("socket read {} bytes", message.len());
151                                        pipeline
152                                            .read(TaggedBytesMut {
153                                                now: Instant::now(),
154                                                transport: TransportContext {
155                                                    local_addr,
156                                                    peer_addr: meta.addr,
157                                                    ecn: meta.ecn,
158                                                    protocol: Protocol::UDP,
159                                                },
160                                                message,
161                                            });
162                                    }
163                                }
164                            }
165                            Err(err) => {
166                                warn!("socket read error {}", err);
167                                break;
168                            }
169                        }
170                    }
171                }
172            }
173            pipeline.transport_inactive();
174        })
175        .detach();
176
177        Ok(pipeline_wr)
178    }
179
180    async fn stop(&self) {
181        self.boostrap.stop().await
182    }
183
184    async fn wait_for_stop(&self) {
185        self.boostrap.wait_for_stop().await
186    }
187
188    async fn graceful_stop(&self) {
189        self.boostrap.graceful_stop().await
190    }
191}