cold_io/
managed_stream.rs

1// Copyright 2021 Vladislav Melnik
2// SPDX-License-Identifier: MIT
3
4use std::{
5    io::{self, Read, Write},
6    rc::{Rc, Weak},
7    cell::{RefCell, RefMut},
8    net::Shutdown,
9};
10use mio::{Token, Interest, net::TcpStream};
11use super::{
12    marked_stream::MarkedStream,
13    proposal::{ReadOnce, WriteOnce, IoResult},
14};
15
16pub struct ManagedStream {
17    inner: Rc<RefCell<MarkedStream>>,
18    token: Token,
19}
20
21impl ManagedStream {
22    pub fn new(stream: TcpStream, token: Token) -> Self {
23        ManagedStream {
24            inner: Rc::new(RefCell::new(MarkedStream {
25                stream,
26                reader: false,
27                reader_discarded: false,
28                reader_used: false,
29                writer: false,
30                writer_discarded: false,
31                writer_used: false,
32            })),
33            token,
34        }
35    }
36
37    pub fn write_once(&self) -> Option<TcpWriteOnce> {
38        let mut s = self.inner.borrow_mut();
39        if !s.writer && !s.writer_discarded {
40            s.writer = true;
41            Some(TcpWriteOnce(Rc::downgrade(&self.inner)))
42        } else {
43            None
44        }
45    }
46
47    pub fn read_once(&self) -> Option<TcpReadOnce> {
48        let mut s = self.inner.borrow_mut();
49        if !s.reader && !s.reader_discarded {
50            s.reader = true;
51            Some(TcpReadOnce(Rc::downgrade(&self.inner)))
52        } else {
53            None
54        }
55    }
56
57    pub fn discard(self) -> io::Result<()> {
58        let mut s = self.inner.borrow_mut();
59        s.reader_discarded = true;
60        s.writer_discarded = true;
61        s.as_mut().shutdown(Shutdown::Both)
62    }
63
64    pub fn borrow_mut(&self) -> RefMut<MarkedStream> {
65        self.inner.as_ref().borrow_mut()
66    }
67
68    pub fn token(&self) -> Token {
69        self.token
70    }
71
72    pub fn closed(&self) -> bool {
73        let s = self.inner.borrow();
74        s.reader_discarded && s.writer_discarded
75    }
76
77    pub fn set_read_closed(&self) {
78        self.borrow_mut().reader_discarded = true;
79    }
80
81    pub fn set_write_closed(&self) {
82        self.borrow_mut().writer_discarded = true;
83    }
84
85    pub fn interests(&self) -> Option<Interest> {
86        let s = self.inner.borrow();
87        let read = !s.reader && !s.reader_discarded;
88        let write = !s.writer && !s.writer_discarded;
89        match (read, write) {
90            (true, true) => Some(Interest::READABLE | Interest::WRITABLE),
91            (true, false) => Some(Interest::READABLE),
92            (false, true) => Some(Interest::WRITABLE),
93            (false, false) => None,
94        }
95    }
96}
97
98pub struct TcpWriteOnce(Weak<RefCell<MarkedStream>>);
99
100impl WriteOnce for TcpWriteOnce {
101    fn write(self, data: &[u8]) -> IoResult {
102        if let Some(s) = self.0.upgrade() {
103            let mut s = s.borrow_mut();
104            let will_close = s.writer_discarded;
105            s.writer_used = true;
106            match s.as_mut().write(data) {
107                Ok(length) => IoResult::Done { length, will_close },
108                Err(error) => {
109                    log::error!("io error: {}", error);
110                    match error.kind() {
111                        io::ErrorKind::NotConnected => IoResult::Closed,
112                        io::ErrorKind::WouldBlock => IoResult::Done {
113                            length: 0,
114                            will_close,
115                        },
116                        _ => IoResult::Closed,
117                    }
118                },
119            }
120        } else {
121            IoResult::Closed
122        }
123    }
124}
125
126impl Drop for TcpWriteOnce {
127    fn drop(&mut self) {
128        if let Some(s) = self.0.upgrade() {
129            let mut s = s.borrow_mut();
130            s.writer_discarded = !s.writer_used;
131            s.writer_used = false;
132            s.writer = false;
133            if let Err(error) = s.as_mut().shutdown(Shutdown::Write) {
134                // it is expected the socket is not connected,
135                // don't report this case
136                if !matches!(error.kind(), io::ErrorKind::NotConnected) {
137                    log::error!("io error: {}", error);
138                }
139            }
140        }
141    }
142}
143
144#[must_use = "discard it if don't need"]
145pub struct TcpReadOnce(Weak<RefCell<MarkedStream>>);
146
147impl ReadOnce for TcpReadOnce {
148    fn read(self, buf: &mut [u8]) -> IoResult {
149        if let Some(s) = self.0.upgrade() {
150            let mut s = s.borrow_mut();
151            let will_close = s.reader_discarded;
152            s.reader_used = true;
153            match s.as_mut().read(buf) {
154                Ok(length) => IoResult::Done { length, will_close },
155                Err(error) => {
156                    log::error!("io error: {}", error);
157                    match error.kind() {
158                        io::ErrorKind::NotConnected => IoResult::Closed,
159                        io::ErrorKind::WouldBlock => IoResult::Done {
160                            length: 0,
161                            will_close,
162                        },
163                        _ => IoResult::Closed,
164                    }
165                },
166            }
167        } else {
168            IoResult::Closed
169        }
170    }
171}
172
173impl Drop for TcpReadOnce {
174    fn drop(&mut self) {
175        if let Some(s) = self.0.upgrade() {
176            let mut s = s.borrow_mut();
177            s.reader_discarded = !s.reader_used;
178            s.reader_used = false;
179            s.reader = false;
180            if let Err(error) = s.as_mut().shutdown(Shutdown::Read) {
181                // it is expected the socket is not connected,
182                // don't report this case
183                if !matches!(error.kind(), io::ErrorKind::NotConnected) {
184                    log::error!("io error: {}", error);
185                }
186            }
187        }
188    }
189}