rdp/model/link.rs
1extern crate native_tls;
2
3use model::error::{RdpResult, Error, RdpError, RdpErrorKind};
4use std::io::{Cursor, Read, Write};
5use self::native_tls::{TlsConnector, TlsStream, Certificate};
6use model::data::{Message};
7
8/// This a wrapper to work equals
9/// for a stream and a TLS stream
10pub enum Stream<S> {
11 /// Raw stream that implement Read + Write
12 Raw(S),
13 /// TLS Stream
14 Ssl(TlsStream<S>)
15}
16
17impl<S: Read + Write> Stream<S> {
18 /// Read exactly the number of bytes present in buffer
19 ///
20 /// # Example
21 /// ```
22 /// use rdp::model::link::Stream;
23 /// use std::io::Cursor;
24 /// let mut s = Stream::Raw(Cursor::new(vec![1, 2, 3]));
25 /// let mut result = [0, 0];
26 /// s.read_exact(&mut result).unwrap();
27 /// assert_eq!(result, [1, 2])
28 /// ```
29 pub fn read_exact(&mut self, buf: &mut[u8]) -> RdpResult<()> {
30 match self {
31 Stream::Raw(e) => e.read_exact(buf)?,
32 Stream::Ssl(e) => e.read_exact(buf)?
33 };
34 Ok(())
35 }
36
37 /// Read all available buffer
38 ///
39 /// # Example
40 /// ```
41 /// use rdp::model::link::Stream;
42 /// use std::io::Cursor;
43 /// let mut s = Stream::Raw(Cursor::new(vec![1, 2, 3]));
44 /// let mut result = [0, 0, 0, 0];
45 /// s.read(&mut result).unwrap();
46 /// assert_eq!(result, [1, 2, 3, 0])
47 /// ```
48 pub fn read(&mut self, buf: &mut[u8]) -> RdpResult<usize> {
49 match self {
50 Stream::Raw(e) => Ok(e.read(buf)?),
51 Stream::Ssl(e) => Ok(e.read(buf)?)
52 }
53 }
54
55 /// Write all buffer to the stream
56 ///
57 /// # Example
58 /// ```
59 /// use rdp::model::link::Stream;
60 /// use std::io::Cursor;
61 /// let mut s = Stream::Raw(Cursor::new(vec![]));
62 /// let result = [1, 2, 3, 4];
63 /// s.write(&result).unwrap();
64 /// if let Stream::Raw(r) = s {
65 /// assert_eq!(r.into_inner(), [1, 2, 3, 4])
66 /// }
67 /// else {
68 /// panic!("invalid")
69 /// }
70 /// ```
71 pub fn write(&mut self, buffer: &[u8]) -> RdpResult<usize> {
72 Ok(match self {
73 Stream::Raw(e) => e.write(buffer)?,
74 Stream::Ssl(e) => e.write(buffer)?
75 })
76 }
77
78 /// Shutdown the stream
79 /// Only works when stream is a SSL stream
80 pub fn shutdown(&mut self) -> RdpResult<()> {
81 Ok(match self {
82 Stream::Ssl(e) => e.shutdown()?,
83 _ => ()
84 })
85 }
86}
87
88/// Link layer is a wrapper around TCP or SSL stream
89/// It can swicth from TCP to SSL
90pub struct Link<S> {
91 stream: Stream<S>
92}
93
94impl<S: Read + Write> Link<S> {
95 /// Create a new link layer from a Stream
96 ///
97 /// # Example
98 /// ```no_run
99 /// use rdp::model::link::{Link, Stream};
100 /// use std::io::Cursor;
101 /// use std::net::{TcpStream, SocketAddr};
102 /// let link = Link::new(Stream::Raw(Cursor::new(vec![])));
103 /// let addr = "127.0.0.1:3389".parse::<SocketAddr>().unwrap();
104 /// let link_tcp = Link::new(Stream::Raw(TcpStream::connect(&addr).unwrap()));
105 /// ```
106 pub fn new(stream: Stream<S>) -> Self {
107 Link {
108 stream
109 }
110 }
111
112 /// This method is designed to write a Message
113 /// either for TCP or SSL stream
114 ///
115 /// # Example
116 /// ```
117 /// # #[macro_use]
118 /// # extern crate rdp;
119 /// # use rdp::model::data::{Component, U32};
120 /// # use rdp::model::link::{Link, Stream};
121 /// # use std::io::Cursor;
122 /// # fn main() {
123 /// let mut link = Link::new(Stream::Raw(Cursor::new(vec![])));
124 /// link.write(&component![
125 /// "foo" => U32::LE(1)
126 /// ]).unwrap();
127 ///
128 /// if let Stream::Raw(r) = link.get_stream() {
129 /// assert_eq!(r.into_inner(), [1, 0, 0, 0])
130 /// }
131 /// else {
132 /// panic!("invalid")
133 /// }
134 /// # }
135 /// ```
136 pub fn write(&mut self, message: &dyn Message) -> RdpResult<()> {
137 let mut buffer = Cursor::new(Vec::new());
138 message.write(&mut buffer)?;
139 self.stream.write(buffer.into_inner().as_slice())?;
140 Ok(())
141 }
142
143 /// This function will block until the expected size will be read
144 ///
145 /// # Example
146 /// ```
147 /// use rdp::model::link::{Link, Stream};
148 /// use std::io::Cursor;
149 /// let mut link = Link::new(Stream::Raw(Cursor::new(vec![0, 1, 2])));
150 /// assert_eq!(link.read(2).unwrap(), [0, 1])
151 /// ```
152 pub fn read(&mut self, expected_size: usize) -> RdpResult<Vec<u8>> {
153 if expected_size == 0 {
154 let mut buffer = vec![0; 1500];
155 let size = self.stream.read(&mut buffer)?;
156 buffer.resize(size, 0);
157 Ok(buffer)
158 }
159 else {
160 let mut buffer = vec![0; expected_size];
161 self.stream.read_exact(&mut buffer)?;
162 Ok(buffer)
163 }
164 }
165
166 /// Start a ssl connection from a raw stream
167 ///
168 /// # Example
169 /// ```no_run
170 /// use rdp::model::link::{Link, Stream};
171 /// use std::net::{TcpStream, SocketAddr};
172 /// let addr = "127.0.0.1:3389".parse::<SocketAddr>().unwrap();
173 /// let link_tcp = Link::new(Stream::Raw(TcpStream::connect(&addr).unwrap()));
174 /// let link_ssl = link_tcp.start_ssl(false).unwrap();
175 /// ```
176 pub fn start_ssl(self, check_certificate: bool) -> RdpResult<Link<S>> {
177 let mut builder = TlsConnector::builder();
178 builder.danger_accept_invalid_certs(!check_certificate);
179 builder.use_sni(false);
180
181 let connector = builder.build()?;
182
183 if let Stream::Raw(stream) = self.stream {
184 return Ok(Link::new(Stream::Ssl(connector.connect("", stream)?)))
185 }
186 Err(Error::RdpError(RdpError::new(RdpErrorKind::NotImplemented, "start_ssl on ssl stream is forbidden")))
187 }
188
189 /// Retrive the peer certificate
190 /// Use by the NLA authentication protocol
191 /// to avoid MITM attack
192 /// # Example
193 /// ```no_run
194 /// use rdp::model::link::{Link, Stream};
195 /// use std::net::{TcpStream, SocketAddr};
196 /// let addr = "127.0.0.1:3389".parse::<SocketAddr>().unwrap();
197 /// let link_tcp = Link::new(Stream::Raw(TcpStream::connect(&addr).unwrap()));
198 /// let link_ssl = link_tcp.start_ssl(false).unwrap();
199 /// let certificate = link_ssl.get_peer_certificate().unwrap().unwrap();
200 /// ```
201 pub fn get_peer_certificate(&self) -> RdpResult<Option<Certificate>> {
202 if let Stream::Ssl(stream) = &self.stream {
203 Ok(stream.peer_certificate()?)
204 }
205 else {
206 Err(Error::RdpError(RdpError::new(RdpErrorKind::InvalidData, "get peer certificate on non ssl link is impossible")))
207 }
208 }
209
210 /// Close the stream
211 /// Only works on SSL Stream
212 pub fn shutdown(&mut self) -> RdpResult<()> {
213 self.stream.shutdown()
214 }
215
216 #[cfg(feature = "integration")]
217 pub fn get_stream(self) -> Stream<S> {
218 self.stream
219 }
220}