retty/bootstrap/bootstrap_udp/
mod.rs1use super::*;
2use crate::transport::Protocol;
3use async_transport::{AsyncUdpSocket, Capabilities, RecvMeta, Transmit, UdpSocket, BATCH_SIZE};
4use std::mem::MaybeUninit;
5
6pub(crate) mod bootstrap_udp_client;
7pub(crate) mod bootstrap_udp_server;
8
9struct BootstrapUdp<W> {
10 boostrap: Bootstrap<W>,
11
12 socket: Option<UdpSocket>,
13}
14
15impl<W: 'static> Default for BootstrapUdp<W> {
16 fn default() -> Self {
17 Self::new()
18 }
19}
20
21impl<W: 'static> BootstrapUdp<W> {
22 fn new() -> Self {
23 Self {
24 boostrap: Bootstrap::new(),
25
26 socket: None,
27 }
28 }
29
30 fn max_payload_size(&mut self, max_payload_size: usize) -> &mut Self {
31 self.boostrap.max_payload_size(max_payload_size);
32 self
33 }
34
35 fn pipeline(&mut self, pipeline_factory_fn: PipelineFactoryFn<TaggedBytesMut, W>) -> &mut Self {
36 self.boostrap.pipeline(pipeline_factory_fn);
37 self
38 }
39
40 async fn bind<A: AsyncToSocketAddrs>(&mut self, addr: A) -> Result<SocketAddr, Error> {
41 let socket = UdpSocket::bind(addr).await?;
42 let local_addr = socket.local_addr()?;
43 self.socket = Some(socket);
44 Ok(local_addr)
45 }
46
47 async fn connect(
48 &mut self,
49 _peer_addr: Option<SocketAddr>,
50 ) -> Result<Rc<dyn OutboundPipeline<TaggedBytesMut, W>>, Error> {
51 let socket = self.socket.take().unwrap();
52 let local_addr = socket.local_addr()?;
53
54 let pipeline_factory_fn = Rc::clone(self.boostrap.pipeline_factory_fn.as_ref().unwrap());
55 let pipeline = (pipeline_factory_fn)();
56 let pipeline_wr = Rc::clone(&pipeline);
57
58 let (close_tx, mut close_rx) = async_broadcast::broadcast(1);
59 {
60 let mut tx = self.boostrap.close_tx.borrow_mut();
61 *tx = Some(close_tx);
62 }
63
64 let worker = {
65 let workgroup = WaitGroup::new();
66 let worker = workgroup.worker();
67 {
68 let mut wg = self.boostrap.wg.borrow_mut();
69 *wg = Some(workgroup);
70 }
71 worker
72 };
73
74 let max_payload_size = self.boostrap.max_payload_size;
75
76 spawn_local(async move {
77 let _w = worker;
78
79 let capabilities = Capabilities::new();
80 let buf = vec![0u8; max_payload_size * capabilities.gro_segments() * BATCH_SIZE];
81 let buf_len = buf.len();
82 let mut recv_buf: Box<[u8]> = buf.into();
83 let mut metas = [RecvMeta::default(); BATCH_SIZE];
84 let mut iovs = MaybeUninit::<[std::io::IoSliceMut<'_>; BATCH_SIZE]>::uninit();
85 recv_buf
86 .chunks_mut(buf_len / BATCH_SIZE)
87 .enumerate()
88 .for_each(|(i, buf)| unsafe {
89 iovs.as_mut_ptr()
90 .cast::<std::io::IoSliceMut<'_>>()
91 .add(i)
92 .write(std::io::IoSliceMut::<'_>::new(buf));
93 });
94 let mut iovs = unsafe { iovs.assume_init() };
95
96 pipeline.transport_active();
97 loop {
98 while let Some(msg) = pipeline.poll_transmit() {
100 let transmit = Transmit {
101 destination: msg.transport.peer_addr,
102 ecn: msg.transport.ecn,
103 contents: msg.message.to_vec(),
104 segment_size: None,
105 src_ip: Some(msg.transport.local_addr.ip()),
106 };
107 match socket.send(&capabilities, &[transmit]).await {
108 Ok(_) => {
109 trace!("socket write {} bytes", msg.message.len());
110 }
111 Err(err) => {
112 warn!("socket write error {}", err);
113 break;
114 }
115 }
116 }
117
118 let mut eto = Instant::now() + Duration::from_secs(MAX_DURATION_IN_SECS);
119 pipeline.poll_timeout(&mut eto);
120
121 let delay_from_now = eto
122 .checked_duration_since(Instant::now())
123 .unwrap_or(Duration::from_secs(0));
124 if delay_from_now.is_zero() {
125 pipeline.handle_timeout(Instant::now());
126 continue;
127 }
128
129 let timeout = Timer::after(delay_from_now);
130
131 tokio::select! {
132 _ = close_rx.recv() => {
133 trace!("pipeline socket exit loop");
134 break;
135 }
136 _ = timeout => {
137 pipeline.handle_timeout(Instant::now());
138 }
139 res = socket.recv(&mut iovs, &mut metas) => {
140 match res {
141 Ok(n) => {
142 if n == 0 {
143 pipeline.handle_read_eof();
144 break;
145 }
146
147 for (meta, buf) in metas.iter().zip(iovs.iter()).take(n) {
148 let message: BytesMut = buf[0..meta.len].into();
149 if !message.is_empty() {
150 trace!("socket read {} bytes", message.len());
151 pipeline
152 .read(TaggedBytesMut {
153 now: Instant::now(),
154 transport: TransportContext {
155 local_addr,
156 peer_addr: meta.addr,
157 ecn: meta.ecn,
158 protocol: Protocol::UDP,
159 },
160 message,
161 });
162 }
163 }
164 }
165 Err(err) => {
166 warn!("socket read error {}", err);
167 break;
168 }
169 }
170 }
171 }
172 }
173 pipeline.transport_inactive();
174 })
175 .detach();
176
177 Ok(pipeline_wr)
178 }
179
180 async fn stop(&self) {
181 self.boostrap.stop().await
182 }
183
184 async fn wait_for_stop(&self) {
185 self.boostrap.wait_for_stop().await
186 }
187
188 async fn graceful_stop(&self) {
189 self.boostrap.graceful_stop().await
190 }
191}