rdp/model/
link.rs

1extern crate native_tls;
2
3use model::error::{RdpResult, Error, RdpError, RdpErrorKind};
4use std::io::{Cursor, Read, Write};
5use self::native_tls::{TlsConnector, TlsStream, Certificate};
6use model::data::{Message};
7
8/// This a wrapper to work equals
9/// for a stream and a TLS stream
10pub enum Stream<S> {
11    /// Raw stream that implement Read + Write
12    Raw(S),
13    /// TLS Stream
14    Ssl(TlsStream<S>)
15}
16
17impl<S: Read + Write> Stream<S> {
18    /// Read exactly the number of bytes present in buffer
19    ///
20    /// # Example
21    /// ```
22    /// use rdp::model::link::Stream;
23    /// use std::io::Cursor;
24    /// let mut s = Stream::Raw(Cursor::new(vec![1, 2, 3]));
25    /// let mut result = [0, 0];
26    /// s.read_exact(&mut result).unwrap();
27    /// assert_eq!(result, [1, 2])
28    /// ```
29    pub fn read_exact(&mut self, buf: &mut[u8]) -> RdpResult<()> {
30        match self {
31            Stream::Raw(e) => e.read_exact(buf)?,
32            Stream::Ssl(e) => e.read_exact(buf)?
33        };
34        Ok(())
35    }
36
37    /// Read all available buffer
38    ///
39    /// # Example
40    /// ```
41    /// use rdp::model::link::Stream;
42    /// use std::io::Cursor;
43    /// let mut s = Stream::Raw(Cursor::new(vec![1, 2, 3]));
44    /// let mut result = [0, 0, 0, 0];
45    /// s.read(&mut result).unwrap();
46    /// assert_eq!(result, [1, 2, 3, 0])
47    /// ```
48    pub fn read(&mut self, buf: &mut[u8]) -> RdpResult<usize> {
49        match self {
50            Stream::Raw(e) => Ok(e.read(buf)?),
51            Stream::Ssl(e) => Ok(e.read(buf)?)
52        }
53    }
54
55    /// Write all buffer to the stream
56    ///
57    /// # Example
58    /// ```
59    /// use rdp::model::link::Stream;
60    /// use std::io::Cursor;
61    /// let mut s = Stream::Raw(Cursor::new(vec![]));
62    /// let result = [1, 2, 3, 4];
63    /// s.write(&result).unwrap();
64    /// if let Stream::Raw(r) = s {
65    ///     assert_eq!(r.into_inner(), [1, 2, 3, 4])
66    /// }
67    /// else {
68    ///     panic!("invalid")
69    /// }
70    /// ```
71    pub fn write(&mut self, buffer: &[u8]) -> RdpResult<usize> {
72        Ok(match self {
73            Stream::Raw(e) => e.write(buffer)?,
74            Stream::Ssl(e) => e.write(buffer)?
75        })
76    }
77
78    /// Shutdown the stream
79    /// Only works when stream is a SSL stream
80    pub fn shutdown(&mut self) -> RdpResult<()> {
81        Ok(match self {
82            Stream::Ssl(e) => e.shutdown()?,
83            _ => ()
84        })
85    }
86}
87
88/// Link layer is a wrapper around TCP or SSL stream
89/// It can swicth from TCP to SSL
90pub struct Link<S> {
91    stream: Stream<S>
92}
93
94impl<S: Read + Write> Link<S> {
95    /// Create a new link layer from a Stream
96    ///
97    /// # Example
98    /// ```no_run
99    /// use rdp::model::link::{Link, Stream};
100    /// use std::io::Cursor;
101    /// use std::net::{TcpStream, SocketAddr};
102    /// let link = Link::new(Stream::Raw(Cursor::new(vec![])));
103    /// let addr = "127.0.0.1:3389".parse::<SocketAddr>().unwrap();
104    /// let link_tcp = Link::new(Stream::Raw(TcpStream::connect(&addr).unwrap()));
105    /// ```
106    pub fn new(stream: Stream<S>) -> Self {
107        Link {
108            stream
109        }
110    }
111
112    /// This method is designed to write a Message
113    /// either for TCP or SSL stream
114    ///
115    /// # Example
116    /// ```
117    /// # #[macro_use]
118    /// # extern crate rdp;
119    /// # use rdp::model::data::{Component, U32};
120    /// # use rdp::model::link::{Link, Stream};
121    /// # use std::io::Cursor;
122    /// # fn main() {
123    ///     let mut link = Link::new(Stream::Raw(Cursor::new(vec![])));
124    ///     link.write(&component![
125    ///         "foo" => U32::LE(1)
126    ///     ]).unwrap();
127    ///
128    ///     if let Stream::Raw(r) = link.get_stream() {
129    ///         assert_eq!(r.into_inner(), [1, 0, 0, 0])
130    ///     }
131    ///     else {
132    ///         panic!("invalid")
133    ///     }
134    /// # }
135    /// ```
136    pub fn write(&mut self, message: &dyn Message) -> RdpResult<()> {
137        let mut buffer = Cursor::new(Vec::new());
138        message.write(&mut buffer)?;
139        self.stream.write(buffer.into_inner().as_slice())?;
140        Ok(())
141    }
142
143    /// This function will block until the expected size will be read
144    ///
145    /// # Example
146    /// ```
147    /// use rdp::model::link::{Link, Stream};
148    /// use std::io::Cursor;
149    /// let mut link = Link::new(Stream::Raw(Cursor::new(vec![0, 1, 2])));
150    /// assert_eq!(link.read(2).unwrap(), [0, 1])
151    /// ```
152    pub fn read(&mut self, expected_size: usize) -> RdpResult<Vec<u8>> {
153        if expected_size == 0 {
154            let mut buffer = vec![0; 1500];
155            let size = self.stream.read(&mut buffer)?;
156            buffer.resize(size, 0);
157            Ok(buffer)
158        }
159        else {
160            let mut buffer = vec![0; expected_size];
161            self.stream.read_exact(&mut buffer)?;
162            Ok(buffer)
163        }
164    }
165
166    /// Start a ssl connection from a raw stream
167    ///
168    /// # Example
169    /// ```no_run
170    /// use rdp::model::link::{Link, Stream};
171    /// use std::net::{TcpStream, SocketAddr};
172    /// let addr = "127.0.0.1:3389".parse::<SocketAddr>().unwrap();
173    /// let link_tcp = Link::new(Stream::Raw(TcpStream::connect(&addr).unwrap()));
174    /// let link_ssl = link_tcp.start_ssl(false).unwrap();
175    /// ```
176    pub fn start_ssl(self, check_certificate: bool) -> RdpResult<Link<S>> {
177        let mut builder = TlsConnector::builder();
178        builder.danger_accept_invalid_certs(!check_certificate);
179        builder.use_sni(false);
180
181        let connector = builder.build()?;
182
183        if let Stream::Raw(stream) = self.stream {
184            return Ok(Link::new(Stream::Ssl(connector.connect("", stream)?)))
185        }
186        Err(Error::RdpError(RdpError::new(RdpErrorKind::NotImplemented, "start_ssl on ssl stream is forbidden")))
187    }
188
189    /// Retrive the peer certificate
190    /// Use by the NLA authentication protocol
191    /// to avoid MITM attack
192    /// # Example
193    /// ```no_run
194    /// use rdp::model::link::{Link, Stream};
195    /// use std::net::{TcpStream, SocketAddr};
196    /// let addr = "127.0.0.1:3389".parse::<SocketAddr>().unwrap();
197    /// let link_tcp = Link::new(Stream::Raw(TcpStream::connect(&addr).unwrap()));
198    /// let link_ssl = link_tcp.start_ssl(false).unwrap();
199    /// let certificate = link_ssl.get_peer_certificate().unwrap().unwrap();
200    /// ```
201    pub fn get_peer_certificate(&self) -> RdpResult<Option<Certificate>> {
202        if let Stream::Ssl(stream) = &self.stream {
203            Ok(stream.peer_certificate()?)
204        }
205        else {
206            Err(Error::RdpError(RdpError::new(RdpErrorKind::InvalidData, "get peer certificate on non ssl link is impossible")))
207        }
208    }
209
210    /// Close the stream
211    /// Only works on SSL Stream
212    pub fn shutdown(&mut self) -> RdpResult<()> {
213        self.stream.shutdown()
214    }
215
216    #[cfg(feature = "integration")]
217    pub fn get_stream(self) -> Stream<S> {
218        self.stream
219    }
220}