1use std::{
2 collections::VecDeque,
3 future::{ready, Future, IntoFuture, Ready},
4 net::SocketAddr,
5 pin::Pin,
6 sync::{
7 atomic::{AtomicBool, AtomicU64, Ordering},
8 Arc, Mutex,
9 },
10};
11
12#[cfg(feature = "tls")]
13use crate::tls::TlsInfo;
14use crate::{
15 channel::Channel,
16 context::{
17 info::{ConnInfo, DatagramInfo},
18 ConnectionStats,
19 },
20 Result,
21};
22
23pub struct InboundContext {
27 info: DatagramInfo,
28}
29
30impl InboundContext {
31 pub(crate) fn new(info: ConnInfo) -> Self {
32 Self {
33 info: DatagramInfo::from_conn(info),
34 }
35 }
36
37 pub(crate) fn new_datagram(info: DatagramInfo) -> Self {
38 Self { info }
39 }
40
41 pub fn id(&self) -> u64 {
42 self.info.id()
43 }
44
45 pub fn peer_addr(&self) -> SocketAddr {
46 self.info.peer_addr()
47 }
48
49 pub fn local_addr(&self) -> SocketAddr {
50 self.info.local_addr()
51 }
52
53 #[cfg(feature = "tls")]
55 pub fn tls(&self) -> Option<&TlsInfo> {
56 self.info.tls()
57 }
58}
59
60pub struct BusinessContext {
62 info: DatagramInfo,
63}
64
65impl BusinessContext {
66 pub(crate) fn new(info: ConnInfo) -> Self {
67 Self {
68 info: DatagramInfo::from_conn(info),
69 }
70 }
71
72 pub(crate) fn new_datagram(info: DatagramInfo) -> Self {
73 Self { info }
74 }
75
76 pub fn id(&self) -> u64 {
77 self.info.id()
78 }
79
80 pub fn peer_addr(&self) -> SocketAddr {
81 self.info.peer_addr()
82 }
83
84 pub fn local_addr(&self) -> SocketAddr {
85 self.info.local_addr()
86 }
87
88 #[cfg(feature = "tls")]
90 pub fn tls(&self) -> Option<&TlsInfo> {
91 self.info.tls()
92 }
93}
94
95pub struct OutboundContext {
97 info: DatagramInfo,
98}
99
100impl OutboundContext {
101 pub(crate) fn new(info: ConnInfo) -> Self {
102 Self {
103 info: DatagramInfo::from_conn(info),
104 }
105 }
106
107 pub(crate) fn new_datagram(info: DatagramInfo) -> Self {
108 Self { info }
109 }
110
111 pub fn id(&self) -> u64 {
112 self.info.id()
113 }
114
115 pub fn peer_addr(&self) -> SocketAddr {
116 self.info.peer_addr()
117 }
118
119 pub fn local_addr(&self) -> SocketAddr {
120 self.info.local_addr()
121 }
122
123 #[cfg(feature = "tls")]
125 pub fn tls(&self) -> Option<&TlsInfo> {
126 self.info.tls()
127 }
128}
129
130pub struct Context<W> {
137 info: ConnInfo,
138 channel: Channel<W>,
139 outbox: StreamOutboxHandle<W>,
140 close_requested: bool,
141}
142
143impl<W: Send + 'static> Context<W> {
144 pub(crate) fn new(info: ConnInfo, channel: Channel<W>) -> Self {
145 Self {
146 info,
147 channel,
148 outbox: StreamOutboxHandle::new(),
149 close_requested: false,
150 }
151 }
152
153 pub fn id(&self) -> u64 {
154 self.info.id()
155 }
156
157 pub fn peer_addr(&self) -> SocketAddr {
159 self.info.peer_addr()
160 }
161
162 pub fn local_addr(&self) -> SocketAddr {
164 self.info.local_addr()
165 }
166
167 #[cfg(feature = "tls")]
169 pub fn tls(&self) -> Option<&TlsInfo> {
170 self.info.tls()
171 }
172
173 pub fn channel(&self) -> Channel<W> {
175 self.channel.clone()
176 }
177
178 pub fn stats(&self) -> Option<ConnectionStats> {
180 self.channel.stats()
181 }
182
183 #[inline]
188 pub fn write(&mut self, msg: W) -> WriteHandle {
189 self.outbox.push_write(msg);
190 WriteHandle { _private: () }
191 }
192
193 #[inline]
198 pub fn flush(&mut self) -> FlushHandle<'_, W> {
199 self.outbox.push_flush()
200 }
201
202 #[inline]
208 pub fn write_and_flush(&mut self, msg: W) -> FlushHandle<'_, W> {
209 self.outbox.push_write_and_flush(msg)
210 }
211
212 pub async fn close(&mut self) -> Result<()> {
214 self.close_requested = true;
215 Ok(())
216 }
217
218 pub(crate) fn outbox(&self) -> StreamOutboxHandle<W> {
219 self.outbox.clone()
220 }
221
222 pub(crate) fn close_requested(&self) -> bool {
223 self.close_requested
224 }
225
226 #[inline]
227 pub(crate) fn has_external_channel(&self) -> bool {
228 self.channel.strong_count() > 1
229 }
230}
231
232pub struct WriteHandle {
233 _private: (),
234}
235
236impl IntoFuture for WriteHandle {
237 type Output = Result<()>;
238 type IntoFuture = Ready<Result<()>>;
239
240 #[inline]
241 fn into_future(self) -> Self::IntoFuture {
242 ready(Ok(()))
243 }
244}
245
246pub struct FlushHandle<'a, W> {
247 outbox: &'a StreamOutboxHandle<W>,
248}
249
250impl<'a, W> IntoFuture for FlushHandle<'a, W> {
251 type Output = Result<()>;
252 type IntoFuture = Pin<Box<dyn Future<Output = Result<()>> + Send + 'a>>;
253
254 #[inline]
255 fn into_future(self) -> Self::IntoFuture {
256 let id = self.outbox.push_flush_completion();
257 let state = &self.outbox.core.flush_state;
258
259 Box::pin(async move {
260 state.mark_awaited(id);
261
262 loop {
263 let notified = state.notify.notified();
264 tokio::pin!(notified);
265 notified.as_mut().enable();
266
267 if state.completed_flush_id.load(Ordering::Acquire) >= id {
268 return Ok(());
269 }
270
271 notified.await;
272 }
273 })
274 }
275}
276
277pub(crate) enum StreamOutboxCommand<W> {
278 Write(W),
279 Flush { completion: Option<u64> },
280 WriteAndFlush { msg: W },
281}
282
283struct StreamOutboxState<W> {
284 head: Option<StreamOutboxCommand<W>>,
285 tail: VecDeque<StreamOutboxCommand<W>>,
286}
287
288impl<W> StreamOutboxState<W> {
289 fn new() -> Self {
290 Self {
291 head: None,
292 tail: VecDeque::new(),
293 }
294 }
295
296 #[inline]
297 fn push(&mut self, command: StreamOutboxCommand<W>) {
298 if self.head.is_none() {
299 self.head = Some(command);
300 } else {
301 self.tail.push_back(command);
302 }
303 }
304
305 #[inline]
306 fn take_batch(&mut self) -> StreamOutboxBatch<W> {
307 StreamOutboxBatch {
308 head: self.head.take(),
309 tail: std::mem::take(&mut self.tail),
310 }
311 }
312}
313
314pub(crate) struct StreamOutboxBatch<W> {
315 head: Option<StreamOutboxCommand<W>>,
316 tail: VecDeque<StreamOutboxCommand<W>>,
317}
318
319impl<W> Iterator for StreamOutboxBatch<W> {
320 type Item = StreamOutboxCommand<W>;
321
322 #[inline]
323 fn next(&mut self) -> Option<Self::Item> {
324 self.head.take().or_else(|| self.tail.pop_front())
325 }
326}
327
328struct StreamFlushState {
329 next_flush_id: AtomicU64,
330 completed_flush_id: AtomicU64,
331 awaited_flush_id: AtomicU64,
332 notify: tokio::sync::Notify,
333}
334
335impl StreamFlushState {
336 fn new() -> Self {
337 Self {
338 next_flush_id: AtomicU64::new(0),
339 completed_flush_id: AtomicU64::new(0),
340 awaited_flush_id: AtomicU64::new(0),
341 notify: tokio::sync::Notify::new(),
342 }
343 }
344
345 #[inline]
346 fn next_id(&self) -> u64 {
347 self.next_flush_id.fetch_add(1, Ordering::Relaxed) + 1
348 }
349
350 #[inline]
351 fn mark_awaited(&self, id: u64) {
352 self.awaited_flush_id.fetch_max(id, Ordering::Release);
353 }
354
355 #[inline]
356 fn complete(&self, id: u64) {
357 self.completed_flush_id.store(id, Ordering::Release);
358 if self.awaited_flush_id.load(Ordering::Acquire) >= id {
359 self.notify.notify_waiters();
360 }
361 }
362}
363
364struct StreamOutboxCore<W> {
365 commands: Mutex<StreamOutboxState<W>>,
366 flush_requested: AtomicBool,
367 flush_state: StreamFlushState,
368}
369
370pub(crate) struct StreamOutboxHandle<W> {
371 core: Arc<StreamOutboxCore<W>>,
372}
373
374impl<W> Clone for StreamOutboxHandle<W> {
375 fn clone(&self) -> Self {
376 Self {
377 core: self.core.clone(),
378 }
379 }
380}
381
382impl<W> StreamOutboxHandle<W> {
383 fn new() -> Self {
384 Self {
385 core: Arc::new(StreamOutboxCore {
386 commands: Mutex::new(StreamOutboxState::new()),
387 flush_requested: AtomicBool::new(false),
388 flush_state: StreamFlushState::new(),
389 }),
390 }
391 }
392
393 #[inline]
394 fn push_write(&self, msg: W) {
395 self.core
396 .commands
397 .lock()
398 .expect("stream outbox lock poisoned")
399 .push(StreamOutboxCommand::Write(msg));
400 }
401
402 #[inline]
403 fn push_flush(&self) -> FlushHandle<'_, W> {
404 self.core
405 .commands
406 .lock()
407 .expect("stream outbox lock poisoned")
408 .push(StreamOutboxCommand::Flush { completion: None });
409 self.core.flush_requested.store(true, Ordering::Release);
410 FlushHandle { outbox: self }
411 }
412
413 #[inline]
414 fn push_write_and_flush(&self, msg: W) -> FlushHandle<'_, W> {
415 self.core
416 .commands
417 .lock()
418 .expect("stream outbox lock poisoned")
419 .push(StreamOutboxCommand::WriteAndFlush { msg });
420 self.core.flush_requested.store(true, Ordering::Release);
421 FlushHandle { outbox: self }
422 }
423
424 #[inline]
425 fn push_flush_completion(&self) -> u64 {
426 let id = self.core.flush_state.next_id();
427 self.core
428 .commands
429 .lock()
430 .expect("stream outbox lock poisoned")
431 .push(StreamOutboxCommand::Flush {
432 completion: Some(id),
433 });
434 self.core.flush_requested.store(true, Ordering::Release);
435 id
436 }
437
438 #[inline]
439 pub(crate) fn has_flush_command(&self) -> bool {
440 self.core.flush_requested.load(Ordering::Acquire)
441 }
442
443 #[inline]
444 pub(crate) fn take_commands(&self) -> StreamOutboxBatch<W> {
445 self.core.flush_requested.store(false, Ordering::Release);
446 self.core
447 .commands
448 .lock()
449 .expect("stream outbox lock poisoned")
450 .take_batch()
451 }
452
453 #[inline]
454 pub(crate) fn complete_flush(&self, id: u64) {
455 self.core.flush_state.complete(id);
456 }
457}