Skip to main content

rs_netty/context/
stream.rs

1use std::{
2    collections::VecDeque,
3    future::{ready, Future, IntoFuture, Ready},
4    net::SocketAddr,
5    pin::Pin,
6    sync::{
7        atomic::{AtomicBool, AtomicU64, Ordering},
8        Arc, Mutex,
9    },
10};
11
12#[cfg(feature = "tls")]
13use crate::tls::TlsInfo;
14use crate::{
15    channel::Channel,
16    context::{
17        info::{ConnInfo, DatagramInfo},
18        ConnectionStats,
19    },
20    Result,
21};
22
23/// Context passed to TCP/UDP inbound transformation stages.
24///
25/// It exposes connection/datagram identity but does not allow writes.
26pub struct InboundContext {
27    info: DatagramInfo,
28}
29
30impl InboundContext {
31    pub(crate) fn new(info: ConnInfo) -> Self {
32        Self {
33            info: DatagramInfo::from_conn(info),
34        }
35    }
36
37    pub(crate) fn new_datagram(info: DatagramInfo) -> Self {
38        Self { info }
39    }
40
41    pub fn id(&self) -> u64 {
42        self.info.id()
43    }
44
45    pub fn peer_addr(&self) -> SocketAddr {
46        self.info.peer_addr()
47    }
48
49    pub fn local_addr(&self) -> SocketAddr {
50        self.info.local_addr()
51    }
52
53    /// TLS metadata for this TCP connection, when TLS is enabled and negotiated.
54    #[cfg(feature = "tls")]
55    pub fn tls(&self) -> Option<&TlsInfo> {
56        self.info.tls()
57    }
58}
59
60/// Context passed to business transformation stages.
61pub struct BusinessContext {
62    info: DatagramInfo,
63}
64
65impl BusinessContext {
66    pub(crate) fn new(info: ConnInfo) -> Self {
67        Self {
68            info: DatagramInfo::from_conn(info),
69        }
70    }
71
72    pub(crate) fn new_datagram(info: DatagramInfo) -> Self {
73        Self { info }
74    }
75
76    pub fn id(&self) -> u64 {
77        self.info.id()
78    }
79
80    pub fn peer_addr(&self) -> SocketAddr {
81        self.info.peer_addr()
82    }
83
84    pub fn local_addr(&self) -> SocketAddr {
85        self.info.local_addr()
86    }
87
88    /// TLS metadata for this TCP connection, when TLS is enabled and negotiated.
89    #[cfg(feature = "tls")]
90    pub fn tls(&self) -> Option<&TlsInfo> {
91        self.info.tls()
92    }
93}
94
95/// Context passed to outbound transformation stages.
96pub struct OutboundContext {
97    info: DatagramInfo,
98}
99
100impl OutboundContext {
101    pub(crate) fn new(info: ConnInfo) -> Self {
102        Self {
103            info: DatagramInfo::from_conn(info),
104        }
105    }
106
107    pub(crate) fn new_datagram(info: DatagramInfo) -> Self {
108        Self { info }
109    }
110
111    pub fn id(&self) -> u64 {
112        self.info.id()
113    }
114
115    pub fn peer_addr(&self) -> SocketAddr {
116        self.info.peer_addr()
117    }
118
119    pub fn local_addr(&self) -> SocketAddr {
120        self.info.local_addr()
121    }
122
123    /// TLS metadata for this TCP connection, when TLS is enabled and negotiated.
124    #[cfg(feature = "tls")]
125    pub fn tls(&self) -> Option<&TlsInfo> {
126        self.info.tls()
127    }
128}
129
130/// Context passed to a TCP [`crate::Handler`].
131///
132/// Writes through this context are staged in a handler-local outbox. They are
133/// encoded into the connection write buffer when the handler returns or when a
134/// flush is requested. Flush handles may be dropped for fire-and-forget
135/// behavior or awaited to wait until the local socket write completes.
136pub struct Context<W> {
137    info: ConnInfo,
138    channel: Channel<W>,
139    outbox: StreamOutboxHandle<W>,
140    close_requested: bool,
141}
142
143impl<W: Send + 'static> Context<W> {
144    pub(crate) fn new(info: ConnInfo, channel: Channel<W>) -> Self {
145        Self {
146            info,
147            channel,
148            outbox: StreamOutboxHandle::new(),
149            close_requested: false,
150        }
151    }
152
153    pub fn id(&self) -> u64 {
154        self.info.id()
155    }
156
157    /// Remote peer address for this connection.
158    pub fn peer_addr(&self) -> SocketAddr {
159        self.info.peer_addr()
160    }
161
162    /// Local socket address for this connection.
163    pub fn local_addr(&self) -> SocketAddr {
164        self.info.local_addr()
165    }
166
167    /// TLS metadata for this connection, when TLS is enabled and negotiated.
168    #[cfg(feature = "tls")]
169    pub fn tls(&self) -> Option<&TlsInfo> {
170        self.info.tls()
171    }
172
173    /// Returns a cloneable channel for writing from outside the current handler.
174    pub fn channel(&self) -> Channel<W> {
175        self.channel.clone()
176    }
177
178    /// Connection stats when tracking was enabled on the server/client.
179    pub fn stats(&self) -> Option<ConnectionStats> {
180        self.channel.stats()
181    }
182
183    /// Stages a message for outbound processing.
184    ///
185    /// The returned handle is ready immediately; awaiting it is supported for
186    /// source compatibility with earlier async-style handlers.
187    #[inline]
188    pub fn write(&mut self, msg: W) -> WriteHandle {
189        self.outbox.push_write(msg);
190        WriteHandle { _private: () }
191    }
192
193    /// Requests a flush of messages staged by this handler so far.
194    ///
195    /// Dropping the returned handle is fire-and-forget. Awaiting it waits until
196    /// the connection runtime has completed the local socket write.
197    #[inline]
198    pub fn flush(&mut self) -> FlushHandle<'_, W> {
199        self.outbox.push_flush()
200    }
201
202    /// Stages a message and requests an outbound flush.
203    ///
204    /// Dropping the returned handle is fire-and-forget. Awaiting it waits until
205    /// the connection runtime has completed the local socket write for this
206    /// flush boundary.
207    #[inline]
208    pub fn write_and_flush(&mut self, msg: W) -> FlushHandle<'_, W> {
209        self.outbox.push_write_and_flush(msg)
210    }
211
212    /// Requests that the connection close after the current handler returns.
213    pub async fn close(&mut self) -> Result<()> {
214        self.close_requested = true;
215        Ok(())
216    }
217
218    pub(crate) fn outbox(&self) -> StreamOutboxHandle<W> {
219        self.outbox.clone()
220    }
221
222    pub(crate) fn close_requested(&self) -> bool {
223        self.close_requested
224    }
225
226    #[inline]
227    pub(crate) fn has_external_channel(&self) -> bool {
228        self.channel.strong_count() > 1
229    }
230}
231
232pub struct WriteHandle {
233    _private: (),
234}
235
236impl IntoFuture for WriteHandle {
237    type Output = Result<()>;
238    type IntoFuture = Ready<Result<()>>;
239
240    #[inline]
241    fn into_future(self) -> Self::IntoFuture {
242        ready(Ok(()))
243    }
244}
245
246pub struct FlushHandle<'a, W> {
247    outbox: &'a StreamOutboxHandle<W>,
248}
249
250impl<'a, W> IntoFuture for FlushHandle<'a, W> {
251    type Output = Result<()>;
252    type IntoFuture = Pin<Box<dyn Future<Output = Result<()>> + Send + 'a>>;
253
254    #[inline]
255    fn into_future(self) -> Self::IntoFuture {
256        let id = self.outbox.push_flush_completion();
257        let state = &self.outbox.core.flush_state;
258
259        Box::pin(async move {
260            state.mark_awaited(id);
261
262            loop {
263                let notified = state.notify.notified();
264                tokio::pin!(notified);
265                notified.as_mut().enable();
266
267                if state.completed_flush_id.load(Ordering::Acquire) >= id {
268                    return Ok(());
269                }
270
271                notified.await;
272            }
273        })
274    }
275}
276
277pub(crate) enum StreamOutboxCommand<W> {
278    Write(W),
279    Flush { completion: Option<u64> },
280    WriteAndFlush { msg: W },
281}
282
283struct StreamOutboxState<W> {
284    head: Option<StreamOutboxCommand<W>>,
285    tail: VecDeque<StreamOutboxCommand<W>>,
286}
287
288impl<W> StreamOutboxState<W> {
289    fn new() -> Self {
290        Self {
291            head: None,
292            tail: VecDeque::new(),
293        }
294    }
295
296    #[inline]
297    fn push(&mut self, command: StreamOutboxCommand<W>) {
298        if self.head.is_none() {
299            self.head = Some(command);
300        } else {
301            self.tail.push_back(command);
302        }
303    }
304
305    #[inline]
306    fn take_batch(&mut self) -> StreamOutboxBatch<W> {
307        StreamOutboxBatch {
308            head: self.head.take(),
309            tail: std::mem::take(&mut self.tail),
310        }
311    }
312}
313
314pub(crate) struct StreamOutboxBatch<W> {
315    head: Option<StreamOutboxCommand<W>>,
316    tail: VecDeque<StreamOutboxCommand<W>>,
317}
318
319impl<W> Iterator for StreamOutboxBatch<W> {
320    type Item = StreamOutboxCommand<W>;
321
322    #[inline]
323    fn next(&mut self) -> Option<Self::Item> {
324        self.head.take().or_else(|| self.tail.pop_front())
325    }
326}
327
328struct StreamFlushState {
329    next_flush_id: AtomicU64,
330    completed_flush_id: AtomicU64,
331    awaited_flush_id: AtomicU64,
332    notify: tokio::sync::Notify,
333}
334
335impl StreamFlushState {
336    fn new() -> Self {
337        Self {
338            next_flush_id: AtomicU64::new(0),
339            completed_flush_id: AtomicU64::new(0),
340            awaited_flush_id: AtomicU64::new(0),
341            notify: tokio::sync::Notify::new(),
342        }
343    }
344
345    #[inline]
346    fn next_id(&self) -> u64 {
347        self.next_flush_id.fetch_add(1, Ordering::Relaxed) + 1
348    }
349
350    #[inline]
351    fn mark_awaited(&self, id: u64) {
352        self.awaited_flush_id.fetch_max(id, Ordering::Release);
353    }
354
355    #[inline]
356    fn complete(&self, id: u64) {
357        self.completed_flush_id.store(id, Ordering::Release);
358        if self.awaited_flush_id.load(Ordering::Acquire) >= id {
359            self.notify.notify_waiters();
360        }
361    }
362}
363
364struct StreamOutboxCore<W> {
365    commands: Mutex<StreamOutboxState<W>>,
366    flush_requested: AtomicBool,
367    flush_state: StreamFlushState,
368}
369
370pub(crate) struct StreamOutboxHandle<W> {
371    core: Arc<StreamOutboxCore<W>>,
372}
373
374impl<W> Clone for StreamOutboxHandle<W> {
375    fn clone(&self) -> Self {
376        Self {
377            core: self.core.clone(),
378        }
379    }
380}
381
382impl<W> StreamOutboxHandle<W> {
383    fn new() -> Self {
384        Self {
385            core: Arc::new(StreamOutboxCore {
386                commands: Mutex::new(StreamOutboxState::new()),
387                flush_requested: AtomicBool::new(false),
388                flush_state: StreamFlushState::new(),
389            }),
390        }
391    }
392
393    #[inline]
394    fn push_write(&self, msg: W) {
395        self.core
396            .commands
397            .lock()
398            .expect("stream outbox lock poisoned")
399            .push(StreamOutboxCommand::Write(msg));
400    }
401
402    #[inline]
403    fn push_flush(&self) -> FlushHandle<'_, W> {
404        self.core
405            .commands
406            .lock()
407            .expect("stream outbox lock poisoned")
408            .push(StreamOutboxCommand::Flush { completion: None });
409        self.core.flush_requested.store(true, Ordering::Release);
410        FlushHandle { outbox: self }
411    }
412
413    #[inline]
414    fn push_write_and_flush(&self, msg: W) -> FlushHandle<'_, W> {
415        self.core
416            .commands
417            .lock()
418            .expect("stream outbox lock poisoned")
419            .push(StreamOutboxCommand::WriteAndFlush { msg });
420        self.core.flush_requested.store(true, Ordering::Release);
421        FlushHandle { outbox: self }
422    }
423
424    #[inline]
425    fn push_flush_completion(&self) -> u64 {
426        let id = self.core.flush_state.next_id();
427        self.core
428            .commands
429            .lock()
430            .expect("stream outbox lock poisoned")
431            .push(StreamOutboxCommand::Flush {
432                completion: Some(id),
433            });
434        self.core.flush_requested.store(true, Ordering::Release);
435        id
436    }
437
438    #[inline]
439    pub(crate) fn has_flush_command(&self) -> bool {
440        self.core.flush_requested.load(Ordering::Acquire)
441    }
442
443    #[inline]
444    pub(crate) fn take_commands(&self) -> StreamOutboxBatch<W> {
445        self.core.flush_requested.store(false, Ordering::Release);
446        self.core
447            .commands
448            .lock()
449            .expect("stream outbox lock poisoned")
450            .take_batch()
451    }
452
453    #[inline]
454    pub(crate) fn complete_flush(&self, id: u64) {
455        self.core.flush_state.complete(id);
456    }
457}