Skip to main content

pingora_core/protocols/l4/
virt.rs

1//! Provides [`VirtualSocketStream`].
2
3use std::{
4    pin::Pin,
5    task::{Context, Poll},
6};
7
8use tokio::io::{AsyncRead, AsyncWrite};
9
10use super::ext::TcpKeepalive;
11
12/// A limited set of socket options that can be set on a [`VirtualSocket`].
13#[non_exhaustive]
14#[derive(Debug, Clone)]
15pub enum VirtualSockOpt {
16    NoDelay,
17    KeepAlive(TcpKeepalive),
18}
19
20/// A "virtual" socket that supports async read and write operations.
21pub trait VirtualSocket: AsyncRead + AsyncWrite + Unpin + Send + Sync + std::fmt::Debug {
22    /// Set a socket option.
23    fn set_socket_option(&self, opt: VirtualSockOpt) -> std::io::Result<()>;
24}
25
26/// Wrapper around any type implementing  [`VirtualSocket`].
27#[derive(Debug)]
28pub struct VirtualSocketStream {
29    pub(crate) socket: Box<dyn VirtualSocket>,
30}
31
32impl VirtualSocketStream {
33    pub fn new(socket: Box<dyn VirtualSocket>) -> Self {
34        Self { socket }
35    }
36
37    #[inline]
38    pub fn set_socket_option(&self, opt: VirtualSockOpt) -> std::io::Result<()> {
39        self.socket.set_socket_option(opt)
40    }
41}
42
43impl AsyncRead for VirtualSocketStream {
44    #[inline]
45    fn poll_read(
46        self: Pin<&mut Self>,
47        cx: &mut Context<'_>,
48        buf: &mut tokio::io::ReadBuf<'_>,
49    ) -> Poll<std::io::Result<()>> {
50        Pin::new(&mut *self.get_mut().socket).poll_read(cx, buf)
51    }
52}
53
54impl AsyncWrite for VirtualSocketStream {
55    #[inline]
56    fn poll_write(
57        self: Pin<&mut Self>,
58        cx: &mut Context<'_>,
59        buf: &[u8],
60    ) -> Poll<std::io::Result<usize>> {
61        Pin::new(&mut *self.get_mut().socket).poll_write(cx, buf)
62    }
63
64    #[inline]
65    fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
66        Pin::new(&mut *self.get_mut().socket).poll_flush(cx)
67    }
68
69    #[inline]
70    fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
71        Pin::new(&mut *self.get_mut().socket).poll_shutdown(cx)
72    }
73}
74
75#[cfg(test)]
76mod tests {
77    use std::sync::{Arc, Mutex};
78
79    use tokio::io::{AsyncReadExt, AsyncWriteExt as _};
80
81    use crate::protocols::l4::stream::Stream;
82
83    use super::*;
84
85    #[derive(Debug)]
86    struct StaticVirtualSocket {
87        content: Vec<u8>,
88        read_pos: usize,
89        write_buf: Arc<Mutex<Vec<u8>>>,
90    }
91
92    impl AsyncRead for StaticVirtualSocket {
93        fn poll_read(
94            mut self: Pin<&mut Self>,
95            _cx: &mut Context<'_>,
96            buf: &mut tokio::io::ReadBuf<'_>,
97        ) -> Poll<std::io::Result<()>> {
98            debug_assert!(self.read_pos <= self.content.len());
99
100            let remaining = self.content.len() - self.read_pos;
101            if remaining == 0 {
102                return Poll::Ready(Ok(()));
103            }
104
105            let to_read = std::cmp::min(remaining, buf.remaining());
106            buf.put_slice(&self.content[self.read_pos..self.read_pos + to_read]);
107            self.read_pos += to_read;
108
109            Poll::Ready(Ok(()))
110        }
111    }
112
113    impl AsyncWrite for StaticVirtualSocket {
114        fn poll_write(
115            self: Pin<&mut Self>,
116            _cx: &mut Context<'_>,
117            buf: &[u8],
118        ) -> Poll<std::io::Result<usize>> {
119            // write to internal buffer
120            let this = self.get_mut();
121            this.write_buf.lock().unwrap().extend_from_slice(buf);
122            Poll::Ready(Ok(buf.len()))
123        }
124
125        fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
126            Poll::Ready(Ok(()))
127        }
128
129        fn poll_shutdown(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
130            Poll::Ready(Ok(()))
131        }
132    }
133
134    impl VirtualSocket for StaticVirtualSocket {
135        fn set_socket_option(&self, _opt: VirtualSockOpt) -> std::io::Result<()> {
136            Ok(())
137        }
138    }
139
140    /// Basic test that ensures reading and writing works with a virtual socket.
141    //
142    /// Mostly just ensures that construction works and the plumbing is correct.
143    #[tokio::test]
144    async fn test_stream_virtual() {
145        let content = b"hello virtual world";
146        let write_buf = Arc::new(Mutex::new(Vec::new()));
147        let mut stream = Stream::from(VirtualSocketStream::new(Box::new(StaticVirtualSocket {
148            content: content.to_vec(),
149            read_pos: 0,
150            write_buf: write_buf.clone(),
151        })));
152
153        let mut buf = Vec::new();
154        let out = stream.read_to_end(&mut buf).await.unwrap();
155        assert_eq!(out, content.len());
156        assert_eq!(buf, content);
157
158        stream.write_all(content).await.unwrap();
159        assert_eq!(write_buf.lock().unwrap().as_slice(), content);
160    }
161}