picodata_plugin/transport/
stream.rs1use openssl::error::ErrorStack;
7use openssl::ssl::{self, SslStream};
8use std::io::{self, Read, Write};
9use std::os::fd::AsRawFd;
10use std::time::Duration;
11use tarantool::coio::CoIOStream;
12use thiserror::Error;
13
14#[derive(Error, Debug)]
20pub enum TlsHandshakeError {
21 #[error("setup failure: {0}")]
23 SetupFailure(ErrorStack),
24
25 #[error("handshake error: {0}")]
27 Failure(ssl::Error),
28}
29
30#[derive(Error, Debug)]
32pub enum PicoStreamError {
33 #[error("configuration error: {0}")]
35 Config(String),
36
37 #[error("io error: {0}")]
39 Io(#[from] io::Error),
40
41 #[error("tls error: {0}")]
43 Tls(#[from] TlsHandshakeError),
44}
45
46pub struct PicoStream {
54 inner: PicoStreamImpl,
55}
56
57enum PicoStreamImpl {
58 Plain(CoIOStream),
59 Tls(SslStream<CoIOStream>),
60}
61
62impl PicoStream {
63 pub fn plain(stream: CoIOStream) -> Self {
65 Self {
66 inner: PicoStreamImpl::Plain(stream),
67 }
68 }
69
70 pub fn tls(stream: SslStream<CoIOStream>) -> Self {
72 Self {
73 inner: PicoStreamImpl::Tls(stream),
74 }
75 }
76
77 pub fn is_tls(&self) -> bool {
79 matches!(self.inner, PicoStreamImpl::Tls(_))
80 }
81
82 pub fn set_nodelay(&self, nodelay: bool) -> io::Result<()> {
87 use std::os::fd::BorrowedFd;
88 let fd = self.as_inner().as_raw_fd();
89 let fd = unsafe { BorrowedFd::borrow_raw(fd) };
91 socket2::SockRef::from(&fd).set_nodelay(nodelay)
92 }
93
94 fn as_inner(&self) -> &CoIOStream {
96 match &self.inner {
97 PicoStreamImpl::Plain(s) => s,
98 PicoStreamImpl::Tls(s) => s.get_ref(),
99 }
100 }
101
102 pub fn read_with_timeout(
118 &mut self,
119 buf: &mut [u8],
120 timeout: Option<Duration>,
121 ) -> io::Result<usize> {
122 match &mut self.inner {
123 PicoStreamImpl::Plain(s) => s.read_with_timeout(buf, timeout),
124 PicoStreamImpl::Tls(s) => {
125 if let Some(timeout_duration) = timeout {
126 read_tls_with_timeout(s, buf, timeout_duration)
127 } else {
128 s.read(buf)
129 }
130 }
131 }
132 }
133}
134
135fn read_tls_with_timeout(
137 ssl_stream: &mut SslStream<CoIOStream>,
138 buf: &mut [u8],
139 timeout: Duration,
140) -> io::Result<usize> {
141 use tarantool::ffi::tarantool as ffi;
142
143 match ssl_stream.read(buf) {
144 Ok(n) => return Ok(n),
145 Err(e) if e.kind() != io::ErrorKind::WouldBlock => return Err(e),
146 _ => {}
147 }
148
149 let fd = ssl_stream.get_ref().as_raw_fd();
150 let timeout_secs = timeout.as_secs_f64();
151
152 match unsafe { ffi::coio_wait(fd, ffi::CoIOFlags::READ.bits(), timeout_secs) } {
153 0 => Err(io::Error::new(io::ErrorKind::TimedOut, "read timeout")),
154 _ => ssl_stream.read(buf),
155 }
156}
157
158impl Read for PicoStream {
159 fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
160 match &mut self.inner {
161 PicoStreamImpl::Plain(s) => s.read(buf),
162 PicoStreamImpl::Tls(s) => s.read(buf),
163 }
164 }
165}
166
167impl Write for PicoStream {
168 fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
169 match &mut self.inner {
170 PicoStreamImpl::Plain(s) => s.write(buf),
171 PicoStreamImpl::Tls(s) => s.write(buf),
172 }
173 }
174
175 fn flush(&mut self) -> io::Result<()> {
176 match &mut self.inner {
177 PicoStreamImpl::Plain(s) => s.flush(),
178 PicoStreamImpl::Tls(s) => s.flush(),
179 }
180 }
181}