pingora_core/protocols/l4/
virt.rs1use std::{
4 pin::Pin,
5 task::{Context, Poll},
6};
7
8use tokio::io::{AsyncRead, AsyncWrite};
9
10use super::ext::TcpKeepalive;
11
12#[non_exhaustive]
14#[derive(Debug, Clone)]
15pub enum VirtualSockOpt {
16 NoDelay,
17 KeepAlive(TcpKeepalive),
18}
19
20pub trait VirtualSocket: AsyncRead + AsyncWrite + Unpin + Send + Sync + std::fmt::Debug {
22 fn set_socket_option(&self, opt: VirtualSockOpt) -> std::io::Result<()>;
24}
25
26#[derive(Debug)]
28pub struct VirtualSocketStream {
29 pub(crate) socket: Box<dyn VirtualSocket>,
30}
31
32impl VirtualSocketStream {
33 pub fn new(socket: Box<dyn VirtualSocket>) -> Self {
34 Self { socket }
35 }
36
37 #[inline]
38 pub fn set_socket_option(&self, opt: VirtualSockOpt) -> std::io::Result<()> {
39 self.socket.set_socket_option(opt)
40 }
41}
42
43impl AsyncRead for VirtualSocketStream {
44 #[inline]
45 fn poll_read(
46 self: Pin<&mut Self>,
47 cx: &mut Context<'_>,
48 buf: &mut tokio::io::ReadBuf<'_>,
49 ) -> Poll<std::io::Result<()>> {
50 Pin::new(&mut *self.get_mut().socket).poll_read(cx, buf)
51 }
52}
53
54impl AsyncWrite for VirtualSocketStream {
55 #[inline]
56 fn poll_write(
57 self: Pin<&mut Self>,
58 cx: &mut Context<'_>,
59 buf: &[u8],
60 ) -> Poll<std::io::Result<usize>> {
61 Pin::new(&mut *self.get_mut().socket).poll_write(cx, buf)
62 }
63
64 #[inline]
65 fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
66 Pin::new(&mut *self.get_mut().socket).poll_flush(cx)
67 }
68
69 #[inline]
70 fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
71 Pin::new(&mut *self.get_mut().socket).poll_shutdown(cx)
72 }
73}
74
75#[cfg(test)]
76mod tests {
77 use std::sync::{Arc, Mutex};
78
79 use tokio::io::{AsyncReadExt, AsyncWriteExt as _};
80
81 use crate::protocols::l4::stream::Stream;
82
83 use super::*;
84
85 #[derive(Debug)]
86 struct StaticVirtualSocket {
87 content: Vec<u8>,
88 read_pos: usize,
89 write_buf: Arc<Mutex<Vec<u8>>>,
90 }
91
92 impl AsyncRead for StaticVirtualSocket {
93 fn poll_read(
94 mut self: Pin<&mut Self>,
95 _cx: &mut Context<'_>,
96 buf: &mut tokio::io::ReadBuf<'_>,
97 ) -> Poll<std::io::Result<()>> {
98 debug_assert!(self.read_pos <= self.content.len());
99
100 let remaining = self.content.len() - self.read_pos;
101 if remaining == 0 {
102 return Poll::Ready(Ok(()));
103 }
104
105 let to_read = std::cmp::min(remaining, buf.remaining());
106 buf.put_slice(&self.content[self.read_pos..self.read_pos + to_read]);
107 self.read_pos += to_read;
108
109 Poll::Ready(Ok(()))
110 }
111 }
112
113 impl AsyncWrite for StaticVirtualSocket {
114 fn poll_write(
115 self: Pin<&mut Self>,
116 _cx: &mut Context<'_>,
117 buf: &[u8],
118 ) -> Poll<std::io::Result<usize>> {
119 let this = self.get_mut();
121 this.write_buf.lock().unwrap().extend_from_slice(buf);
122 Poll::Ready(Ok(buf.len()))
123 }
124
125 fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
126 Poll::Ready(Ok(()))
127 }
128
129 fn poll_shutdown(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
130 Poll::Ready(Ok(()))
131 }
132 }
133
134 impl VirtualSocket for StaticVirtualSocket {
135 fn set_socket_option(&self, _opt: VirtualSockOpt) -> std::io::Result<()> {
136 Ok(())
137 }
138 }
139
140 #[tokio::test]
144 async fn test_stream_virtual() {
145 let content = b"hello virtual world";
146 let write_buf = Arc::new(Mutex::new(Vec::new()));
147 let mut stream = Stream::from(VirtualSocketStream::new(Box::new(StaticVirtualSocket {
148 content: content.to_vec(),
149 read_pos: 0,
150 write_buf: write_buf.clone(),
151 })));
152
153 let mut buf = Vec::new();
154 let out = stream.read_to_end(&mut buf).await.unwrap();
155 assert_eq!(out, content.len());
156 assert_eq!(buf, content);
157
158 stream.write_all(content).await.unwrap();
159 assert_eq!(write_buf.lock().unwrap().as_slice(), content);
160 }
161}