Skip to main content

picodata_plugin/transport/
stream.rs

1//! Stream API for plugin connections.
2//!
3//! This module provides `PicoStream` which can be either a plain TCP stream
4//! or a TLS-encrypted stream, both using Tarantool's cooperative I/O.
5
6use openssl::error::ErrorStack;
7use openssl::ssl::{self, SslStream};
8use std::io::{self, Read, Write};
9use std::os::fd::AsRawFd;
10use std::time::Duration;
11use tarantool::coio::CoIOStream;
12use thiserror::Error;
13
14////////////////////////////////////////////////////////////////////////////////
15// Errors
16////////////////////////////////////////////////////////////////////////////////
17
18/// Error that can occur during TLS handshake.
19#[derive(Error, Debug)]
20pub enum TlsHandshakeError {
21    /// Error during TLS setup.
22    #[error("setup failure: {0}")]
23    SetupFailure(ErrorStack),
24
25    /// Error during handshake.
26    #[error("handshake error: {0}")]
27    Failure(ssl::Error),
28}
29
30/// Error that can occur when accepting a connection.
31#[derive(Error, Debug)]
32pub enum PicoStreamError {
33    /// Configuration error.
34    #[error("configuration error: {0}")]
35    Config(String),
36
37    /// IO error.
38    #[error("io error: {0}")]
39    Io(#[from] io::Error),
40
41    /// TLS handshake error.
42    #[error("tls error: {0}")]
43    Tls(#[from] TlsHandshakeError),
44}
45
46////////////////////////////////////////////////////////////////////////////////
47// PicoStream
48////////////////////////////////////////////////////////////////////////////////
49
50/// A stream that can be either plain or TLS-encrypted.
51///
52/// Uses Tarantool's cooperative I/O stream for non-blocking operations.
53pub struct PicoStream {
54    inner: PicoStreamImpl,
55}
56
57enum PicoStreamImpl {
58    Plain(CoIOStream),
59    Tls(SslStream<CoIOStream>),
60}
61
62impl PicoStream {
63    /// Creates a new plain stream.
64    pub fn plain(stream: CoIOStream) -> Self {
65        Self {
66            inner: PicoStreamImpl::Plain(stream),
67        }
68    }
69
70    /// Creates a new TLS-encrypted stream.
71    pub fn tls(stream: SslStream<CoIOStream>) -> Self {
72        Self {
73            inner: PicoStreamImpl::Tls(stream),
74        }
75    }
76
77    /// Returns true if this stream is TLS-encrypted.
78    pub fn is_tls(&self) -> bool {
79        matches!(self.inner, PicoStreamImpl::Tls(_))
80    }
81
82    /// Sets or clears the `TCP_NODELAY` option on the underlying socket.
83    ///
84    /// When enabled, disables Nagle's algorithm for lower latency at the cost
85    /// of potentially higher network overhead for small writes.
86    pub fn set_nodelay(&self, nodelay: bool) -> io::Result<()> {
87        use std::os::fd::BorrowedFd;
88        let fd = self.as_inner().as_raw_fd();
89        // SAFETY: stream contains a valid descriptor
90        let fd = unsafe { BorrowedFd::borrow_raw(fd) };
91        socket2::SockRef::from(&fd).set_nodelay(nodelay)
92    }
93
94    /// Returns a reference to the underlying `CoIOStream`.
95    fn as_inner(&self) -> &CoIOStream {
96        match &self.inner {
97            PicoStreamImpl::Plain(s) => s,
98            PicoStreamImpl::Tls(s) => s.get_ref(),
99        }
100    }
101
102    /// Reads data from the stream with a timeout.
103    ///
104    /// This method provides timeout functionality similar to CoIOStream::read_with_timeout.
105    /// For plain connections, it delegates to the underlying CoIOStream.
106    /// For TLS connections, implements timeout using Tarantool's fiber API.
107    ///
108    /// # Parameters
109    ///
110    /// - `buf`: Buffer to read data into
111    /// - `timeout`: Optional timeout duration
112    ///
113    /// # Returns
114    ///
115    /// - `Ok(n)`: Number of bytes read
116    /// - `Err(e)`: IO error (including TimedOut if timeout expires)
117    pub fn read_with_timeout(
118        &mut self,
119        buf: &mut [u8],
120        timeout: Option<Duration>,
121    ) -> io::Result<usize> {
122        match &mut self.inner {
123            PicoStreamImpl::Plain(s) => s.read_with_timeout(buf, timeout),
124            PicoStreamImpl::Tls(s) => {
125                if let Some(timeout_duration) = timeout {
126                    read_tls_with_timeout(s, buf, timeout_duration)
127                } else {
128                    s.read(buf)
129                }
130            }
131        }
132    }
133}
134
135/// Helper function to read from TLS stream with timeout using fiber API.
136fn read_tls_with_timeout(
137    ssl_stream: &mut SslStream<CoIOStream>,
138    buf: &mut [u8],
139    timeout: Duration,
140) -> io::Result<usize> {
141    use tarantool::ffi::tarantool as ffi;
142
143    match ssl_stream.read(buf) {
144        Ok(n) => return Ok(n),
145        Err(e) if e.kind() != io::ErrorKind::WouldBlock => return Err(e),
146        _ => {}
147    }
148
149    let fd = ssl_stream.get_ref().as_raw_fd();
150    let timeout_secs = timeout.as_secs_f64();
151
152    match unsafe { ffi::coio_wait(fd, ffi::CoIOFlags::READ.bits(), timeout_secs) } {
153        0 => Err(io::Error::new(io::ErrorKind::TimedOut, "read timeout")),
154        _ => ssl_stream.read(buf),
155    }
156}
157
158impl Read for PicoStream {
159    fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
160        match &mut self.inner {
161            PicoStreamImpl::Plain(s) => s.read(buf),
162            PicoStreamImpl::Tls(s) => s.read(buf),
163        }
164    }
165}
166
167impl Write for PicoStream {
168    fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
169        match &mut self.inner {
170            PicoStreamImpl::Plain(s) => s.write(buf),
171            PicoStreamImpl::Tls(s) => s.write(buf),
172        }
173    }
174
175    fn flush(&mut self) -> io::Result<()> {
176        match &mut self.inner {
177            PicoStreamImpl::Plain(s) => s.flush(),
178            PicoStreamImpl::Tls(s) => s.flush(),
179        }
180    }
181}