mtcp_rs/
stream.rs

1/*
2 * mtcp - TcpListener/TcpStream *with* timeout/cancellation support
3 * This is free and unencumbered software released into the public domain.
4 */
5use std::io::{Read, Write, Result as IoResult, ErrorKind};
6use std::net::{SocketAddr, Shutdown};
7use std::num::NonZeroUsize;
8use std::ops::{Deref, DerefMut};
9use std::rc::Rc;
10use std::time::{Duration};
11
12use mio::{Token, Interest};
13use mio::net::TcpStream as MioTcpStream;
14
15use log::warn;
16use spare_buffer::SpareBuffer;
17
18use crate::utilities::Timeout;
19use crate::{TcpConnection, TcpManager, TcpError};
20use crate::manager::TcpPollContext;
21
22/// A TCP stream between a local and a remote socket, akin to
23/// [`std::net::TcpStream`](std::net::TcpStream)
24///
25/// All I/O operations provided by `mtcp_rs::TcpStream` are "blocking", but –
26/// unlike the `std::net` implementation – proper ***timeout*** and
27/// ***cancellation*** support is available. The `mtcp_rs::TcpStream` is tied
28/// to an [`mtcp_rs::TcpManager`](crate::TcpManager) instance.
29/// 
30/// The TCP stream is created by [`connect()`](TcpStream::connect())ing to a
31/// remote host, or directly [`from()`](TcpStream::from()) an existing
32/// [`mtcp_rs::TcpConnection`](crate::TcpConnection).
33/// 
34/// If the `timeout` parameter was set to `Some(Duration)` and if the I/O
35/// operation does **not** complete before the specified timeout period
36/// expires, then the pending I/O operation will be aborted and fail with an
37/// [`TcpError::TimedOut`](crate::TcpError::TimedOut) error.
38/// 
39/// Functions like [`Read::read()`](std::io::Read::read()) and
40/// [`Write::write()`](std::io::Write::write()), which do **not** have an
41/// explicit `timeout` parameter, *implicitly* use the timeouts that have been
42/// set up via the
43/// [`set_default_timeouts()`](TcpStream::set_default_timeouts()) function.
44/// Initially, these timeouts are disabled.
45#[derive(Debug)]
46pub struct TcpStream {
47    stream: MioTcpStream,
48    token: Token,
49    timeouts: (Option<Duration>, Option<Duration>),
50    manager: Rc<TcpManager>,
51}
52
53impl TcpStream {
54    /// Initialize a new `TcpStream` from an existing `TcpConnection` instance.
55    /// 
56    /// `TcpConnection` instances are usually obtained by
57    /// [`accept()`](crate::TcpListener::accept)ing incoming TCP connections
58    /// via a bound `TcpListener`.
59    /// 
60    /// The new `TcpStream` is tied to the specified `TcpManager` instance.
61    pub fn from(manager: &Rc<TcpManager>, connection: TcpConnection) -> IoResult<Self> {
62        let mut stream = connection.stream();
63        let manager = manager.clone();
64        let token = Self::register(&manager.context(), &mut stream)?;
65
66        Ok(Self {
67            stream,
68            token,
69            timeouts: (None, None),
70            manager,
71        })
72    }
73
74    /// Set up the *default* timeouts, to be used by functions like
75    /// [`Read::read()`](std::io::Read::read()) and
76    /// [`Write::write()`](std::io::Write::write()).
77    pub fn set_default_timeouts(&mut self, timeout_rd: Option<Duration>, timeout_wr: Option<Duration>) {
78        self.timeouts = (timeout_rd, timeout_wr);
79    }
80
81    /// Get the *peer* socket address of this TCP stream.
82    pub fn peer_addr(&self) -> Option<SocketAddr> {
83        self.stream.peer_addr().ok()
84    }
85
86    /// Get the *local* socket address of this TCP stream.
87    pub fn local_addr(&self) -> Option<SocketAddr> {
88        self.stream.local_addr().ok()
89    }
90
91    /// Shuts down the read, write, or both halves of this TCP stream.
92    pub fn shutdown(&self, how: Shutdown) -> IoResult<()> {
93        self.stream.shutdown(how)
94    }
95
96    fn register<T>(context: &T, stream: &mut MioTcpStream) -> IoResult<Token>
97    where
98        T: Deref<Target=TcpPollContext>
99    {
100        let token = context.token();
101        context.registry().register(stream, token, Interest::READABLE | Interest::WRITABLE)?;
102        Ok(token)
103    }
104
105    fn deregister<T>(context: &T, stream: &mut MioTcpStream)
106    where
107        T: Deref<Target=TcpPollContext>
108    {
109        if let Err(error) = context.registry().deregister(stream) {
110            warn!("Failed to de-register: {:?}", error);
111        }
112    }
113
114    // ~~~~~~~~~~~~~~~~~~~~~~~
115    // Connect functions
116    // ~~~~~~~~~~~~~~~~~~~~~~~
117
118    /// Opens a new TCP connection to the remote host at the specified address.
119    /// 
120    /// An optional ***timeout*** can be specified, after which the operation
121    /// is going to fail, if the connection could **not** be established yet.
122    /// 
123    /// The new `TcpStream` is tied to the specified `TcpManager` instance.
124    pub fn connect(manager: &Rc<TcpManager>, addr: SocketAddr, timeout: Option<Duration>) -> Result<Self, TcpError> {
125        if manager.cancelled() {
126            return Err(TcpError::Cancelled);
127        }
128
129        let mut stream = MioTcpStream::connect(addr)?;
130        let manager = manager.clone();
131        let token = Self::init_connection(&manager, &mut stream, timeout)?;
132
133        Ok(Self {
134            stream,
135            token,
136            timeouts: (None, None),
137            manager,
138        })
139    }
140
141    fn init_connection(manager: &Rc<TcpManager>, stream: &mut MioTcpStream, timeout: Option<Duration>) -> Result<Token, TcpError> {
142        let mut context = manager.context_mut();
143        let token = Self::register(&context, stream)?;
144
145        match Self::await_connected(manager, &mut context, stream, token, timeout) {
146            Ok(_) => Ok(token),
147            Err(error) => {
148                Self::deregister(&context, stream);
149                Err(error)
150            },
151        }
152    }
153
154    fn await_connected<T>(manager: &Rc<TcpManager>, context: &mut T, stream: &mut MioTcpStream, token: Token, timeout: Option<Duration>) -> Result<(), TcpError>
155    where
156        T: DerefMut<Target=TcpPollContext>
157    {
158        let timeout = Timeout::start(timeout);
159
160        loop {
161            let remaining = timeout.remaining_time();
162            match context.poll(remaining) {
163                Ok(events) => {
164                    for _event in events.iter().filter(|event| (event.token() == token)) {
165                        match Self::event_conn(stream) {
166                            Ok(true) => return Ok(()),
167                            Ok(_) => (),
168                            Err(error) => return Err(error.into()),
169                        }
170                    }
171                },
172                Err(error) => return Err(error.into()),
173            }
174            if manager.cancelled() {
175                return Err(TcpError::Cancelled);
176            }
177            if remaining.map(|time| time.is_zero()).unwrap_or(false) {
178                return Err(TcpError::TimedOut);
179            }
180        }
181    }
182
183    fn event_conn(stream: &mut MioTcpStream) -> IoResult<bool> {
184        loop {
185            if let Some(err) = stream.take_error()? {
186                return Err(err);
187            }
188            match stream.peer_addr() {
189                Ok(_addr) => return Ok(true),
190                Err(error) => match error.kind() {
191                    ErrorKind::Interrupted => (),
192                    ErrorKind::NotConnected => return Ok(false),
193                    _ => return Err(error),
194                },
195            }
196        }
197    }
198
199    // ~~~~~~~~~~~~~~~~~~~~~~~
200    // Read functions
201    // ~~~~~~~~~~~~~~~~~~~~~~~
202
203    /// Read the next "chunk" of incoming data from the TCP stream into the
204    /// specified destination buffer.
205    /// 
206    /// This function attempts to read a maximum of `buffer.len()` bytes, but
207    /// *fewer* bytes may actually be read! Specifically, the function waits
208    /// until *some* data become available for reading, or the end of the
209    /// stream (or an error) is encountered. It then reads as many bytes as are
210    /// available and returns immediately. The function does **not** wait any
211    /// longer, even if the `buffer` is **not** filled completely.
212    /// 
213    /// An optional ***timeout*** can be specified, after which the operation
214    /// is going to fail, if still **no** data is available for reading.
215    /// 
216    /// Returns the number of bytes that have been pulled from the stream into
217    /// the buffer, which is less than or equal to `buffer.len()`. A ***zero***
218    /// return value indicates the end of the stream. Otherwise, more data may
219    /// become available for reading soon!
220    pub fn read_timeout(&mut self, buffer: &mut [u8], timeout: Option<Duration>) -> Result<usize, TcpError> {
221        if self.manager.cancelled() {
222            return Err(TcpError::Cancelled);
223        }
224
225        let timeout = Timeout::start(timeout);
226
227        match Self::event_read(&mut self.stream, buffer) {
228            Ok(Some(len)) => return Ok(len),
229            Ok(_) => (),
230            Err(error) => return Err(error.into()),
231        }
232
233        let mut context = self.manager.context_mut();
234
235        loop {
236            let remaining = timeout.remaining_time();
237            match context.poll(remaining) {
238                Ok(events) => {
239                    for _event in events.iter().filter(|event| (event.token() == self.token) && event.is_readable()) {
240                        match Self::event_read(&mut self.stream, buffer) {
241                            Ok(Some(len)) => return Ok(len),
242                            Ok(_) => (),
243                            Err(error) => return Err(error.into()),
244                        }
245                    }
246                },
247                Err(error) => return Err(error.into()),
248            }
249            if self.manager.cancelled() {
250                return Err(TcpError::Cancelled);
251            }
252            if remaining.map(|time| time.is_zero()).unwrap_or(false) {
253                return Err(TcpError::TimedOut);
254            }
255        }
256    }
257
258    /// Read **all** incoming data from the TCP stream into the specified
259    /// destination buffer.
260    /// 
261    /// This function keeps on [reading](Self::read_timeout) from the stream,
262    /// until the input data has been read *completely*, as defined by the
263    /// `fn_complete` closure, or an error is encountered. All input data is
264    /// appended to the given `buffer`, extending the buffer as needed. The
265    /// `fn_complete` closure is invoked every time that a new "chunk" of input
266    /// was received. Unless the closure returned `true`, the function waits
267    /// for more input. If the end of the stream is encountered while the data
268    /// still is incomplete, the function fails.
269    /// 
270    /// The closure `fn_complete` takes a single parameter, a reference to the
271    /// current buffer, which contains *all* data that has been read so far.
272    /// That closure shall return `true` if and only if the data in the buffer
273    /// is considered "complete".
274    /// 
275    /// An optional ***timeout*** can be specified, after which the operation
276    /// is going to fail, if the data still is **not** complete.
277    ///
278    /// The optional ***chunk size*** specifies the maximum amount of data that
279    /// can be [read](Self::read_timeout) at once.
280    /// 
281    /// An optional ***maximum length*** can be specified. If the total size
282    /// exceeds this limit *before* the data is complete, the function fails.
283    pub fn read_all_timeout<F>(&mut self, buffer: &mut Vec<u8>, timeout: Option<Duration>, chunk_size: Option<NonZeroUsize>, maximum_length: Option<NonZeroUsize>, fn_complete: F) -> Result<(), TcpError>
284    where
285        F: Fn(&[u8]) -> bool,
286    {
287        let chunk_size = chunk_size.unwrap_or_else(|| NonZeroUsize::new(4096).unwrap());
288        if maximum_length.map_or(false, |value| value < chunk_size) {
289            panic!("maximum_length must be greater than or equal to chunk_size!")
290        }
291
292        let mut buffer = SpareBuffer::from(buffer, maximum_length);
293
294        loop {
295            let spare = buffer.allocate_spare(chunk_size);
296            match self.read_timeout(spare, timeout) {
297                Ok(0) => return Err(TcpError::Incomplete),
298                Ok(count) => {
299                    buffer.commit(count).map_err(|_err| TcpError::TooBig)?;
300                    match fn_complete(buffer.data()) {
301                        true => return Ok(()),
302                        false => {},
303                    }
304                },
305                Err(error) => return Err(error),
306            };
307        }
308    }
309
310    fn event_read(stream: &mut MioTcpStream, buffer: &mut [u8]) -> IoResult<Option<usize>> {
311        loop {
312            match stream.read(buffer) {
313                Ok(count) => return Ok(Some(count)),
314                Err(error) => match error.kind() {
315                    ErrorKind::Interrupted => (),
316                    ErrorKind::WouldBlock => return Ok(None),
317                    _ => return Err(error),
318                },
319            }
320        }
321    }
322
323    // ~~~~~~~~~~~~~~~~~~~~~~~
324    // Write functions
325    // ~~~~~~~~~~~~~~~~~~~~~~~
326
327    /// Write the next "chunk" of outgoing data from the specified source
328    /// buffer to the TCP stream.
329    /// 
330    /// This function attempts to write a maximum of `buffer.len()` bytes, but
331    /// *fewer* bytes may actually be written! Specifically, the function waits
332    /// until *some* data can be written, the stream is closed by the peer, or
333    /// an error is encountered. It then writes as many bytes as possible to
334    /// the stream. The function does **not** wait any longer, even if **not**
335    /// all data in `buffer` could be written yet.
336    /// 
337    /// An optional ***timeout*** can be specified, after which the operation
338    /// is going to fail, if still **no** data could be written.
339    /// 
340    /// Returns the number of bytes that have been pushed from the buffer into
341    /// the stream, which is less than or equal to `buffer.len()`. A ***zero***
342    /// return value indicates that the stream was closed. Otherwise, it may be
343    /// possible to write more data soon!
344    pub fn write_timeout(&mut self, buffer: &[u8], timeout: Option<Duration>) -> Result<usize, TcpError> {
345        if self.manager.cancelled() {
346            return Err(TcpError::Cancelled);
347        }
348
349        let timeout = Timeout::start(timeout);
350
351        match Self::event_write(&mut self.stream, buffer) {
352            Ok(Some(len)) => return Ok(len),
353            Ok(_) => (),
354            Err(error) => return Err(error.into()),
355        }
356
357        let mut context = self.manager.context_mut();
358
359        loop {
360            let remaining = timeout.remaining_time();
361            match context.poll(remaining) {
362                Ok(events) => {
363                    for _event in events.iter().filter(|event| (event.token() == self.token) && event.is_writable()) {
364                        match Self::event_write(&mut self.stream, buffer) {
365                            Ok(Some(len)) => return Ok(len),
366                            Ok(_) => (),
367                            Err(error) => return Err(error.into()),
368                        }
369                    }
370                },
371                Err(error) => return Err(error.into()),
372            }
373            if self.manager.cancelled() {
374                return Err(TcpError::Cancelled);
375            }
376            if remaining.map(|time| time.is_zero()).unwrap_or(false) {
377                return Err(TcpError::TimedOut);
378            }
379        }
380    }
381
382    /// Write **all** outgoing data from the specified source buffer to the TCP
383    /// stream.
384    /// 
385    /// This function keeps on [writing](Self::write_timeout) to the stream,
386    /// until the output data has been written *completely*, the peer closes
387    /// the stream, or an error is encountered. If the stream is closed
388    /// *before* all data could be written, the function fails.
389    /// 
390    /// An optional ***timeout*** can be specified, after which the operation
391    /// is going to fail, if the data still was **not** written completely.
392    pub fn write_all_timeout(&mut self, mut buffer: &[u8], timeout: Option<Duration>) -> Result<(), TcpError> {
393        loop {
394            match self.write_timeout(buffer, timeout) {
395                Ok(0) => return Err(TcpError::Incomplete),
396                Ok(count) => {
397                    buffer = &buffer[count..];
398                    if buffer.is_empty() { return Ok(()); }
399                },
400                Err(error) => return Err(error),
401            };
402        }
403    }
404
405    fn event_write(stream: &mut MioTcpStream, buffer: &[u8]) -> IoResult<Option<usize>> {
406        loop {
407            match stream.write(buffer) {
408                Ok(count) => return Ok(Some(count)),
409                Err(error) => match error.kind() {
410                    ErrorKind::Interrupted => (),
411                    ErrorKind::WouldBlock => return Ok(None),
412                    _ => return Err(error),
413                },
414            }
415        }
416    }
417}
418
419impl Read for TcpStream {
420    fn read(&mut self, buf: &mut [u8]) -> IoResult<usize> {
421        into_io_result(self.read_timeout(buf, self.timeouts.0))
422    }
423}
424
425impl Write for TcpStream {
426    fn write(&mut self, buf: &[u8]) -> IoResult<usize> {
427        into_io_result(self.write_timeout(buf, self.timeouts.1))
428    }
429
430    fn flush(&mut self) -> IoResult<()> {
431        self.stream.flush()
432    }
433}
434
435impl Drop for TcpStream {
436    fn drop(&mut self) {
437        let context = self.manager.context();
438        Self::deregister(&context, &mut self.stream);
439    }
440}
441
442fn into_io_result<T>(result: Result<T, TcpError>) -> IoResult<T> {
443    match result {
444        Ok(value) => Ok(value),
445        Err(error) => Err(error.into()),
446    }
447}