mtcp_rs/stream.rs
1/*
2 * mtcp - TcpListener/TcpStream *with* timeout/cancellation support
3 * This is free and unencumbered software released into the public domain.
4 */
5use std::io::{Read, Write, Result as IoResult, ErrorKind};
6use std::net::{SocketAddr, Shutdown};
7use std::num::NonZeroUsize;
8use std::ops::{Deref, DerefMut};
9use std::rc::Rc;
10use std::time::{Duration};
11
12use mio::{Token, Interest};
13use mio::net::TcpStream as MioTcpStream;
14
15use log::warn;
16use spare_buffer::SpareBuffer;
17
18use crate::utilities::Timeout;
19use crate::{TcpConnection, TcpManager, TcpError};
20use crate::manager::TcpPollContext;
21
22/// A TCP stream between a local and a remote socket, akin to
23/// [`std::net::TcpStream`](std::net::TcpStream)
24///
25/// All I/O operations provided by `mtcp_rs::TcpStream` are "blocking", but –
26/// unlike the `std::net` implementation – proper ***timeout*** and
27/// ***cancellation*** support is available. The `mtcp_rs::TcpStream` is tied
28/// to an [`mtcp_rs::TcpManager`](crate::TcpManager) instance.
29///
30/// The TCP stream is created by [`connect()`](TcpStream::connect())ing to a
31/// remote host, or directly [`from()`](TcpStream::from()) an existing
32/// [`mtcp_rs::TcpConnection`](crate::TcpConnection).
33///
34/// If the `timeout` parameter was set to `Some(Duration)` and if the I/O
35/// operation does **not** complete before the specified timeout period
36/// expires, then the pending I/O operation will be aborted and fail with an
37/// [`TcpError::TimedOut`](crate::TcpError::TimedOut) error.
38///
39/// Functions like [`Read::read()`](std::io::Read::read()) and
40/// [`Write::write()`](std::io::Write::write()), which do **not** have an
41/// explicit `timeout` parameter, *implicitly* use the timeouts that have been
42/// set up via the
43/// [`set_default_timeouts()`](TcpStream::set_default_timeouts()) function.
44/// Initially, these timeouts are disabled.
45#[derive(Debug)]
46pub struct TcpStream {
47 stream: MioTcpStream,
48 token: Token,
49 timeouts: (Option<Duration>, Option<Duration>),
50 manager: Rc<TcpManager>,
51}
52
53impl TcpStream {
54 /// Initialize a new `TcpStream` from an existing `TcpConnection` instance.
55 ///
56 /// `TcpConnection` instances are usually obtained by
57 /// [`accept()`](crate::TcpListener::accept)ing incoming TCP connections
58 /// via a bound `TcpListener`.
59 ///
60 /// The new `TcpStream` is tied to the specified `TcpManager` instance.
61 pub fn from(manager: &Rc<TcpManager>, connection: TcpConnection) -> IoResult<Self> {
62 let mut stream = connection.stream();
63 let manager = manager.clone();
64 let token = Self::register(&manager.context(), &mut stream)?;
65
66 Ok(Self {
67 stream,
68 token,
69 timeouts: (None, None),
70 manager,
71 })
72 }
73
74 /// Set up the *default* timeouts, to be used by functions like
75 /// [`Read::read()`](std::io::Read::read()) and
76 /// [`Write::write()`](std::io::Write::write()).
77 pub fn set_default_timeouts(&mut self, timeout_rd: Option<Duration>, timeout_wr: Option<Duration>) {
78 self.timeouts = (timeout_rd, timeout_wr);
79 }
80
81 /// Get the *peer* socket address of this TCP stream.
82 pub fn peer_addr(&self) -> Option<SocketAddr> {
83 self.stream.peer_addr().ok()
84 }
85
86 /// Get the *local* socket address of this TCP stream.
87 pub fn local_addr(&self) -> Option<SocketAddr> {
88 self.stream.local_addr().ok()
89 }
90
91 /// Shuts down the read, write, or both halves of this TCP stream.
92 pub fn shutdown(&self, how: Shutdown) -> IoResult<()> {
93 self.stream.shutdown(how)
94 }
95
96 fn register<T>(context: &T, stream: &mut MioTcpStream) -> IoResult<Token>
97 where
98 T: Deref<Target=TcpPollContext>
99 {
100 let token = context.token();
101 context.registry().register(stream, token, Interest::READABLE | Interest::WRITABLE)?;
102 Ok(token)
103 }
104
105 fn deregister<T>(context: &T, stream: &mut MioTcpStream)
106 where
107 T: Deref<Target=TcpPollContext>
108 {
109 if let Err(error) = context.registry().deregister(stream) {
110 warn!("Failed to de-register: {:?}", error);
111 }
112 }
113
114 // ~~~~~~~~~~~~~~~~~~~~~~~
115 // Connect functions
116 // ~~~~~~~~~~~~~~~~~~~~~~~
117
118 /// Opens a new TCP connection to the remote host at the specified address.
119 ///
120 /// An optional ***timeout*** can be specified, after which the operation
121 /// is going to fail, if the connection could **not** be established yet.
122 ///
123 /// The new `TcpStream` is tied to the specified `TcpManager` instance.
124 pub fn connect(manager: &Rc<TcpManager>, addr: SocketAddr, timeout: Option<Duration>) -> Result<Self, TcpError> {
125 if manager.cancelled() {
126 return Err(TcpError::Cancelled);
127 }
128
129 let mut stream = MioTcpStream::connect(addr)?;
130 let manager = manager.clone();
131 let token = Self::init_connection(&manager, &mut stream, timeout)?;
132
133 Ok(Self {
134 stream,
135 token,
136 timeouts: (None, None),
137 manager,
138 })
139 }
140
141 fn init_connection(manager: &Rc<TcpManager>, stream: &mut MioTcpStream, timeout: Option<Duration>) -> Result<Token, TcpError> {
142 let mut context = manager.context_mut();
143 let token = Self::register(&context, stream)?;
144
145 match Self::await_connected(manager, &mut context, stream, token, timeout) {
146 Ok(_) => Ok(token),
147 Err(error) => {
148 Self::deregister(&context, stream);
149 Err(error)
150 },
151 }
152 }
153
154 fn await_connected<T>(manager: &Rc<TcpManager>, context: &mut T, stream: &mut MioTcpStream, token: Token, timeout: Option<Duration>) -> Result<(), TcpError>
155 where
156 T: DerefMut<Target=TcpPollContext>
157 {
158 let timeout = Timeout::start(timeout);
159
160 loop {
161 let remaining = timeout.remaining_time();
162 match context.poll(remaining) {
163 Ok(events) => {
164 for _event in events.iter().filter(|event| (event.token() == token)) {
165 match Self::event_conn(stream) {
166 Ok(true) => return Ok(()),
167 Ok(_) => (),
168 Err(error) => return Err(error.into()),
169 }
170 }
171 },
172 Err(error) => return Err(error.into()),
173 }
174 if manager.cancelled() {
175 return Err(TcpError::Cancelled);
176 }
177 if remaining.map(|time| time.is_zero()).unwrap_or(false) {
178 return Err(TcpError::TimedOut);
179 }
180 }
181 }
182
183 fn event_conn(stream: &mut MioTcpStream) -> IoResult<bool> {
184 loop {
185 if let Some(err) = stream.take_error()? {
186 return Err(err);
187 }
188 match stream.peer_addr() {
189 Ok(_addr) => return Ok(true),
190 Err(error) => match error.kind() {
191 ErrorKind::Interrupted => (),
192 ErrorKind::NotConnected => return Ok(false),
193 _ => return Err(error),
194 },
195 }
196 }
197 }
198
199 // ~~~~~~~~~~~~~~~~~~~~~~~
200 // Read functions
201 // ~~~~~~~~~~~~~~~~~~~~~~~
202
203 /// Read the next "chunk" of incoming data from the TCP stream into the
204 /// specified destination buffer.
205 ///
206 /// This function attempts to read a maximum of `buffer.len()` bytes, but
207 /// *fewer* bytes may actually be read! Specifically, the function waits
208 /// until *some* data become available for reading, or the end of the
209 /// stream (or an error) is encountered. It then reads as many bytes as are
210 /// available and returns immediately. The function does **not** wait any
211 /// longer, even if the `buffer` is **not** filled completely.
212 ///
213 /// An optional ***timeout*** can be specified, after which the operation
214 /// is going to fail, if still **no** data is available for reading.
215 ///
216 /// Returns the number of bytes that have been pulled from the stream into
217 /// the buffer, which is less than or equal to `buffer.len()`. A ***zero***
218 /// return value indicates the end of the stream. Otherwise, more data may
219 /// become available for reading soon!
220 pub fn read_timeout(&mut self, buffer: &mut [u8], timeout: Option<Duration>) -> Result<usize, TcpError> {
221 if self.manager.cancelled() {
222 return Err(TcpError::Cancelled);
223 }
224
225 let timeout = Timeout::start(timeout);
226
227 match Self::event_read(&mut self.stream, buffer) {
228 Ok(Some(len)) => return Ok(len),
229 Ok(_) => (),
230 Err(error) => return Err(error.into()),
231 }
232
233 let mut context = self.manager.context_mut();
234
235 loop {
236 let remaining = timeout.remaining_time();
237 match context.poll(remaining) {
238 Ok(events) => {
239 for _event in events.iter().filter(|event| (event.token() == self.token) && event.is_readable()) {
240 match Self::event_read(&mut self.stream, buffer) {
241 Ok(Some(len)) => return Ok(len),
242 Ok(_) => (),
243 Err(error) => return Err(error.into()),
244 }
245 }
246 },
247 Err(error) => return Err(error.into()),
248 }
249 if self.manager.cancelled() {
250 return Err(TcpError::Cancelled);
251 }
252 if remaining.map(|time| time.is_zero()).unwrap_or(false) {
253 return Err(TcpError::TimedOut);
254 }
255 }
256 }
257
258 /// Read **all** incoming data from the TCP stream into the specified
259 /// destination buffer.
260 ///
261 /// This function keeps on [reading](Self::read_timeout) from the stream,
262 /// until the input data has been read *completely*, as defined by the
263 /// `fn_complete` closure, or an error is encountered. All input data is
264 /// appended to the given `buffer`, extending the buffer as needed. The
265 /// `fn_complete` closure is invoked every time that a new "chunk" of input
266 /// was received. Unless the closure returned `true`, the function waits
267 /// for more input. If the end of the stream is encountered while the data
268 /// still is incomplete, the function fails.
269 ///
270 /// The closure `fn_complete` takes a single parameter, a reference to the
271 /// current buffer, which contains *all* data that has been read so far.
272 /// That closure shall return `true` if and only if the data in the buffer
273 /// is considered "complete".
274 ///
275 /// An optional ***timeout*** can be specified, after which the operation
276 /// is going to fail, if the data still is **not** complete.
277 ///
278 /// The optional ***chunk size*** specifies the maximum amount of data that
279 /// can be [read](Self::read_timeout) at once.
280 ///
281 /// An optional ***maximum length*** can be specified. If the total size
282 /// exceeds this limit *before* the data is complete, the function fails.
283 pub fn read_all_timeout<F>(&mut self, buffer: &mut Vec<u8>, timeout: Option<Duration>, chunk_size: Option<NonZeroUsize>, maximum_length: Option<NonZeroUsize>, fn_complete: F) -> Result<(), TcpError>
284 where
285 F: Fn(&[u8]) -> bool,
286 {
287 let chunk_size = chunk_size.unwrap_or_else(|| NonZeroUsize::new(4096).unwrap());
288 if maximum_length.map_or(false, |value| value < chunk_size) {
289 panic!("maximum_length must be greater than or equal to chunk_size!")
290 }
291
292 let mut buffer = SpareBuffer::from(buffer, maximum_length);
293
294 loop {
295 let spare = buffer.allocate_spare(chunk_size);
296 match self.read_timeout(spare, timeout) {
297 Ok(0) => return Err(TcpError::Incomplete),
298 Ok(count) => {
299 buffer.commit(count).map_err(|_err| TcpError::TooBig)?;
300 match fn_complete(buffer.data()) {
301 true => return Ok(()),
302 false => {},
303 }
304 },
305 Err(error) => return Err(error),
306 };
307 }
308 }
309
310 fn event_read(stream: &mut MioTcpStream, buffer: &mut [u8]) -> IoResult<Option<usize>> {
311 loop {
312 match stream.read(buffer) {
313 Ok(count) => return Ok(Some(count)),
314 Err(error) => match error.kind() {
315 ErrorKind::Interrupted => (),
316 ErrorKind::WouldBlock => return Ok(None),
317 _ => return Err(error),
318 },
319 }
320 }
321 }
322
323 // ~~~~~~~~~~~~~~~~~~~~~~~
324 // Write functions
325 // ~~~~~~~~~~~~~~~~~~~~~~~
326
327 /// Write the next "chunk" of outgoing data from the specified source
328 /// buffer to the TCP stream.
329 ///
330 /// This function attempts to write a maximum of `buffer.len()` bytes, but
331 /// *fewer* bytes may actually be written! Specifically, the function waits
332 /// until *some* data can be written, the stream is closed by the peer, or
333 /// an error is encountered. It then writes as many bytes as possible to
334 /// the stream. The function does **not** wait any longer, even if **not**
335 /// all data in `buffer` could be written yet.
336 ///
337 /// An optional ***timeout*** can be specified, after which the operation
338 /// is going to fail, if still **no** data could be written.
339 ///
340 /// Returns the number of bytes that have been pushed from the buffer into
341 /// the stream, which is less than or equal to `buffer.len()`. A ***zero***
342 /// return value indicates that the stream was closed. Otherwise, it may be
343 /// possible to write more data soon!
344 pub fn write_timeout(&mut self, buffer: &[u8], timeout: Option<Duration>) -> Result<usize, TcpError> {
345 if self.manager.cancelled() {
346 return Err(TcpError::Cancelled);
347 }
348
349 let timeout = Timeout::start(timeout);
350
351 match Self::event_write(&mut self.stream, buffer) {
352 Ok(Some(len)) => return Ok(len),
353 Ok(_) => (),
354 Err(error) => return Err(error.into()),
355 }
356
357 let mut context = self.manager.context_mut();
358
359 loop {
360 let remaining = timeout.remaining_time();
361 match context.poll(remaining) {
362 Ok(events) => {
363 for _event in events.iter().filter(|event| (event.token() == self.token) && event.is_writable()) {
364 match Self::event_write(&mut self.stream, buffer) {
365 Ok(Some(len)) => return Ok(len),
366 Ok(_) => (),
367 Err(error) => return Err(error.into()),
368 }
369 }
370 },
371 Err(error) => return Err(error.into()),
372 }
373 if self.manager.cancelled() {
374 return Err(TcpError::Cancelled);
375 }
376 if remaining.map(|time| time.is_zero()).unwrap_or(false) {
377 return Err(TcpError::TimedOut);
378 }
379 }
380 }
381
382 /// Write **all** outgoing data from the specified source buffer to the TCP
383 /// stream.
384 ///
385 /// This function keeps on [writing](Self::write_timeout) to the stream,
386 /// until the output data has been written *completely*, the peer closes
387 /// the stream, or an error is encountered. If the stream is closed
388 /// *before* all data could be written, the function fails.
389 ///
390 /// An optional ***timeout*** can be specified, after which the operation
391 /// is going to fail, if the data still was **not** written completely.
392 pub fn write_all_timeout(&mut self, mut buffer: &[u8], timeout: Option<Duration>) -> Result<(), TcpError> {
393 loop {
394 match self.write_timeout(buffer, timeout) {
395 Ok(0) => return Err(TcpError::Incomplete),
396 Ok(count) => {
397 buffer = &buffer[count..];
398 if buffer.is_empty() { return Ok(()); }
399 },
400 Err(error) => return Err(error),
401 };
402 }
403 }
404
405 fn event_write(stream: &mut MioTcpStream, buffer: &[u8]) -> IoResult<Option<usize>> {
406 loop {
407 match stream.write(buffer) {
408 Ok(count) => return Ok(Some(count)),
409 Err(error) => match error.kind() {
410 ErrorKind::Interrupted => (),
411 ErrorKind::WouldBlock => return Ok(None),
412 _ => return Err(error),
413 },
414 }
415 }
416 }
417}
418
419impl Read for TcpStream {
420 fn read(&mut self, buf: &mut [u8]) -> IoResult<usize> {
421 into_io_result(self.read_timeout(buf, self.timeouts.0))
422 }
423}
424
425impl Write for TcpStream {
426 fn write(&mut self, buf: &[u8]) -> IoResult<usize> {
427 into_io_result(self.write_timeout(buf, self.timeouts.1))
428 }
429
430 fn flush(&mut self) -> IoResult<()> {
431 self.stream.flush()
432 }
433}
434
435impl Drop for TcpStream {
436 fn drop(&mut self) {
437 let context = self.manager.context();
438 Self::deregister(&context, &mut self.stream);
439 }
440}
441
442fn into_io_result<T>(result: Result<T, TcpError>) -> IoResult<T> {
443 match result {
444 Ok(value) => Ok(value),
445 Err(error) => Err(error.into()),
446 }
447}