rdp/core/
mcs.rs

1use core::x224;
2use core::tpkt;
3use model::error::{RdpResult, Error, RdpError, RdpErrorKind};
4use core::gcc::{KeyboardLayout, client_core_data, ClientData, ServerData, client_security_data, client_network_data, block_header, write_conference_create_request, MessageType, read_conference_create_response, Version};
5use model::data::{Trame, to_vec, Message, DataType, U16};
6use nla::asn1::{Sequence, ImplicitTag, OctetString, Enumerate, ASN1Type, Integer, to_der, from_ber};
7use yasna::{Tag};
8use std::io::{Write, Read, BufRead, Cursor};
9use core::per;
10use std::collections::HashMap;
11
12#[allow(dead_code)]
13#[repr(u8)]
14enum DomainMCSPDU {
15    ErectDomainRequest = 1,
16    DisconnectProviderUltimatum = 8,
17    AttachUserRequest = 10,
18    AttachUserConfirm = 11,
19    ChannelJoinRequest = 14,
20    ChannelJoinConfirm = 15,
21    SendDataRequest = 25,
22    SendDataIndication = 26
23}
24
25/// ASN1 structure use by mcs layer
26/// to inform on conference capability
27fn domain_parameters(max_channel_ids: u32, maw_user_ids: u32, max_token_ids: u32,
28                     num_priorities: u32, min_thoughput: u32, max_height: u32,
29                     max_mcs_pdu_size: u32, protocol_version: u32) -> Sequence {
30    sequence![
31        "maxChannelIds" => max_channel_ids,
32        "maxUserIds" => maw_user_ids,
33        "maxTokenIds" => max_token_ids,
34        "numPriorities" => num_priorities,
35        "minThoughput" => min_thoughput,
36        "maxHeight" => max_height,
37        "maxMCSPDUsize" => max_mcs_pdu_size,
38        "protocolVersion" => protocol_version
39    ]
40}
41
42/// First MCS payload send from client to server
43/// Payload send from client to server
44///
45/// http://www.itu.int/rec/T-REC-T.125-199802-I/en page 25
46fn connect_initial(user_data: Option<OctetString>) -> ImplicitTag<Sequence> {
47    ImplicitTag::new(Tag::application(101), sequence![
48        "callingDomainSelector" => vec![1 as u8] as OctetString,
49        "calledDomainSelector" => vec![1 as u8] as OctetString,
50        "upwardFlag" => true,
51        "targetParameters" => domain_parameters(34, 2, 0, 1, 0, 1, 0xffff, 2),
52        "minimumParameters" => domain_parameters(1, 1, 1, 1, 0, 1, 0x420, 2),
53        "maximumParameters" => domain_parameters(0xffff, 0xfc17, 0xffff, 1, 0, 1, 0xffff, 2),
54        "userData" => user_data.unwrap_or(Vec::new())
55    ])
56}
57
58/// Server response with channel capacity
59fn connect_response(user_data: Option<OctetString>) -> ImplicitTag<Sequence> {
60    ImplicitTag::new(Tag::application(102),
61sequence![
62        "result" => 0 as Enumerate,
63        "calledConnectId" => 0 as Integer,
64        "domainParameters" => domain_parameters(22, 3, 0, 1, 0, 1,0xfff8, 2),
65        "userData" => user_data.unwrap_or(Vec::new())
66    ])
67}
68
69/// Create a basic MCS PDU header
70fn mcs_pdu_header(pdu: Option<DomainMCSPDU>, options: Option<u8>) -> u8 {
71    (pdu.unwrap_or(DomainMCSPDU::AttachUserConfirm) as u8) << 2 | options.unwrap_or(0)
72}
73
74/// Read attach user confirm
75/// Client -- attach_user_request -> Server
76/// Client <- attach_user_confirm -- Server
77fn read_attach_user_confirm(buffer: &mut dyn Read) -> RdpResult<u16> {
78    let mut confirm = trame![0 as u8, Vec::<u8>::new()];
79    confirm.read(buffer)?;
80    if cast!(DataType::U8, confirm[0])? >> 2 != mcs_pdu_header(Some(DomainMCSPDU::AttachUserConfirm), None) >> 2 {
81        return Err(Error::RdpError(RdpError::new(RdpErrorKind::InvalidData, "MCS: unexpected header on recv_attach_user_confirm")));
82    }
83
84    let mut request = Cursor::new(cast!(DataType::Slice, confirm[1])?);
85    if per::read_enumerates(&mut request)? != 0 {
86        return Err(Error::RdpError(RdpError::new(RdpErrorKind::RejectedByServer, "MCS: recv_attach_user_confirm user rejected by server")));
87    }
88    Ok(per::read_integer_16(1001, &mut request)?)
89}
90
91/// Create a session for the current user
92///
93/// Client -- attach_user_request -> Server
94/// Client <- attach_user_confirm -- Server
95fn attach_user_request() -> u8 {
96    mcs_pdu_header(Some(DomainMCSPDU::AttachUserRequest), None)
97}
98
99
100/// Create a new domain for MCS layer
101fn erect_domain_request() -> RdpResult<Trame> {
102    let mut result = Cursor::new(vec![]);
103    per::write_integer(0, &mut result)?;
104    per::write_integer(0, &mut result)?;
105    Ok(trame![
106        mcs_pdu_header(Some(DomainMCSPDU::ErectDomainRequest), None),
107        result.into_inner()
108    ])
109}
110
111/// Ask to join a new channel
112/// /// The MCS will negotiate each channel
113/// channel join confirm is sent by server
114/// to validate or not the channel requested
115/// by the client
116///
117/// Client -- channel_join_request -> Server
118/// Client <- channel_join_confirm -- Server
119fn channel_join_request(user_id: Option<u16>, channel_id: Option<u16>) -> RdpResult<Trame> {
120    Ok(trame![
121        mcs_pdu_header(Some(DomainMCSPDU::ChannelJoinRequest), None),
122        U16::BE(user_id.unwrap_or(1001) - 1001),
123        U16::BE(channel_id.unwrap_or(0))
124    ])
125}
126
127/// Read channel join confirm
128/// The MCS will negotiate each channel
129/// channel join confirm is sent by server
130/// to validate or not the channel requested
131/// by the client
132///
133/// Client -- channel_join_request -> Server
134/// Client <- channel_join_confirm -- Server
135fn read_channel_join_confirm(user_id: u16, channel_id: u16, buffer: &mut dyn Read) -> RdpResult<bool> {
136    let mut confirm = trame![0 as u8, Vec::<u8>::new()];
137    confirm.read(buffer)?;
138    if cast!(DataType::U8, confirm[0])? >> 2 != mcs_pdu_header(Some(DomainMCSPDU::ChannelJoinConfirm), None) >> 2 {
139        return Err(Error::RdpError(RdpError::new(RdpErrorKind::InvalidData, "MCS: unexpected header on read_channel_join_confirm")));
140    }
141
142    let mut request = Cursor::new(cast!(DataType::Slice, confirm[1])?);
143    let confirm = per::read_enumerates(&mut request)?;
144    let confirm_user_id = per::read_integer_16(1001, &mut request)?;
145    let confirm_channel_id = per::read_integer_16(0, &mut request)?;
146
147    if user_id != confirm_user_id {
148        return Err(Error::RdpError(RdpError::new(RdpErrorKind::InvalidData, "MCS: read_channel_join_confirm invalid user id")));
149    }
150
151    if channel_id != confirm_channel_id {
152        return Err(Error::RdpError(RdpError::new(RdpErrorKind::InvalidData, "MCS: read_channel_join_confirm invalid channel_id")));
153    }
154
155    Ok(confirm == 0)
156}
157
158/// MCS client channel
159pub struct Client<S> {
160    /// X224 transport layer
161    x224: x224::Client<S>,
162    /// Server data send during connection step
163    server_data: Option<ServerData>,
164    /// User id session negotiated by the MCS
165    user_id: Option<u16>,
166    /// Map that translate channel name to channel id
167    channel_ids : HashMap<String, u16>
168}
169
170impl<S: Read + Write> Client<S> {
171    pub fn new(x224: x224::Client<S>) -> Self {
172        Client {
173            server_data: None,
174            x224,
175            user_id: None,
176            channel_ids: HashMap::new()
177        }
178    }
179
180    /// Write connection initial payload
181    /// This payload include a lot of
182    /// client specific config parameters
183    fn write_connect_initial(&mut self, screen_width: u16, screen_height: u16, keyboard_layout: KeyboardLayout, client_name: String) -> RdpResult<()> {
184        let client_core_data = client_core_data(Some(ClientData {
185            width: screen_width,
186            height: screen_height,
187            layout: keyboard_layout,
188            server_selected_protocol: self.x224.get_selected_protocols() as u32,
189            rdp_version: Version::RdpVersion5plus,
190            name: client_name
191        }));
192        let client_security_data = client_security_data();
193        let client_network_data = client_network_data(trame![]);
194        let user_data = to_vec(&trame![
195            trame![block_header(Some(MessageType::CsCore), Some(client_core_data.length() as u16)), client_core_data],
196            trame![block_header(Some(MessageType::CsSecurity), Some(client_security_data.length() as u16)), client_security_data],
197            trame![block_header(Some(MessageType::CsNet), Some(client_network_data.length() as u16)), client_network_data]
198        ]);
199        let conference = write_conference_create_request(&user_data)?;
200        self.x224.write(to_der(&connect_initial(Some(conference))))
201    }
202
203    /// Read a connect response comming from server to client
204    fn read_connect_response(&mut self) -> RdpResult<()> {
205        // Now read response from the server
206        let mut connect_response = connect_response(None);
207        let mut payload = try_let!(tpkt::Payload::Raw, self.x224.read()?)?;
208        from_ber(&mut connect_response, payload.fill_buf()?)?;
209
210        // Get server data
211        // Read conference create response
212        let cc_response = cast!(ASN1Type::OctetString, connect_response.inner["userData"])?;
213        self.server_data = Some(read_conference_create_response(&mut Cursor::new(cc_response))?);
214        Ok(())
215    }
216
217    /// Connect the MCS channel
218    /// Ask connection for each channel requested
219    /// and confirmed by server
220    ///
221    /// # Example
222    /// ```rust, ignore
223    /// let mut mcs = mcs::Client(x224);
224    /// mcs.connect(800, 600, KeyboardLayout::French).unwrap()
225    /// ```
226    pub fn connect(&mut self, client_name: String, screen_width: u16, screen_height: u16, keyboard_layout: KeyboardLayout) -> RdpResult<()> {
227        self.write_connect_initial(screen_width, screen_height, keyboard_layout, client_name)?;
228        self.read_connect_response()?;
229        self.x224.write(erect_domain_request()?)?;
230        self.x224.write(attach_user_request())?;
231
232        self.user_id = Some(read_attach_user_confirm(&mut try_let!(tpkt::Payload::Raw, self.x224.read()?)?)?);
233
234        // Add static channel
235        self.channel_ids.insert("global".to_string(), 1003);
236        self.channel_ids.insert("user".to_string(), self.user_id.unwrap());
237
238        // Create list of requested channels
239        // Actually only the two static main channel are requested
240        for channel_id in self.channel_ids.values() {
241            self.x224.write(channel_join_request(self.user_id, Some(*channel_id))?)?;
242            if !read_channel_join_confirm(self.user_id.unwrap(), *channel_id, &mut try_let!(tpkt::Payload::Raw, self.x224.read()?)?)? {
243                println!("Server reject channel id {:?}", channel_id);
244            }
245        }
246
247        Ok(())
248    }
249
250    /// Send a message to a connected channel
251    /// MCS stand for multi channel
252    /// Write function write a message to specific channel
253    ///
254    /// # Example
255    /// ```rust, ignore
256    /// let mut mcs = mcs::Client(x224);
257    /// mcs.connect(800, 600, KeyboardLayout::French).unwrap();
258    /// mcs.write("global".to_string(), trame![U16::LE(0)])
259    /// ```
260    pub fn write<T: 'static>(&mut self, channel_name: &String, message: T) -> RdpResult<()>
261    where T: Message {
262        self.x224.write(trame![
263            mcs_pdu_header(Some(DomainMCSPDU::SendDataRequest), None),
264            U16::BE(self.user_id.unwrap() - 1001),
265            U16::BE(self.channel_ids[channel_name]),
266            0x70 as u8,
267            per::write_length(message.length() as u16)?,
268            message
269        ])
270    }
271
272    /// Receive a message for a specific channel
273    /// Actually by design you can't ask for a specific channel
274    /// the caller need to handle all channels
275    ///
276    /// # Example
277    /// ```rust, ignore
278    /// let mut mcs = mcs::Client(x224);
279    /// mcs.connect(800, 600, KeyboardLayout::French).unwrap();
280    /// let (channel_name, payload) = mcs.read().unwrap();
281    /// match channel_name.as_str() {
282    ///     "global" => println!("main channel");
283    ///     ...
284    /// }
285    /// ```
286    pub fn read(&mut self) -> RdpResult<(String, tpkt::Payload)> {
287        let message = self.x224.read()?;
288        match message {
289            tpkt::Payload::Raw(mut payload) => {
290                 let mut header = mcs_pdu_header(None, None);
291                header.read(&mut payload)?;
292                if header >> 2 == DomainMCSPDU::DisconnectProviderUltimatum as u8 {
293                    return Err(Error::RdpError(RdpError::new(RdpErrorKind::Disconnect, "MCS: Disconnect Provider Ultimatum")));
294                }
295
296                if header >> 2 != DomainMCSPDU::SendDataIndication as u8 {
297                    return Err(Error::RdpError(RdpError::new(RdpErrorKind::InvalidData, "MCS: Invalid opcode")));
298                }
299
300                // Server user id
301                per::read_integer_16(1001, &mut payload)?;
302
303                let channel_id = per::read_integer_16(0, &mut payload)?;
304                let channel = self.channel_ids.iter().find(|x| *x.1 == channel_id).ok_or(Error::RdpError(RdpError::new(RdpErrorKind::Unknown, "MCS: unknown channel")))?;
305
306                per::read_enumerates(&mut payload)?;
307                per::read_length(&mut payload)?;
308
309                Ok((channel.0.clone(), tpkt::Payload::Raw(payload)))
310            },
311            tpkt::Payload::FastPath(sec_flag, payload) => {
312                // fastpath packet are dedicated to global channel
313                Ok(("global".to_string(), tpkt::Payload::FastPath(sec_flag, payload)))
314            }
315        }
316
317    }
318
319    /// Send a close event to server
320    pub fn shutdown(&mut self) -> RdpResult<()> {
321        self.x224.write(trame![
322            mcs_pdu_header(Some(DomainMCSPDU::DisconnectProviderUltimatum), Some(1)),
323            per::write_enumerates(0x80)?,
324            b"\x00\x00\x00\x00\x00\x00".to_vec()
325        ])?;
326        self.x224.shutdown()
327    }
328
329    /// This function check if the client
330    /// version protocol choose is 5+
331    pub fn is_rdp_version_5_plus(&self) -> bool {
332        self.server_data.as_ref().unwrap().rdp_version == Version::RdpVersion5plus
333    }
334
335    /// Getter of the user id negotiated during connection steps
336    pub fn get_user_id(&self) -> u16 {
337        self.user_id.unwrap()
338    }
339
340    /// Getter of the global channel id
341    pub fn get_global_channel_id(&self) -> u16 {
342        self.channel_ids["global"]
343    }
344}
345
346#[cfg(test)]
347mod test {
348    use super::*;
349
350    /// Test of read read_attach_user_confirm
351    #[test]
352    fn test_read_attach_user_confirm() {
353        assert_eq!(read_attach_user_confirm(&mut Cursor::new(vec![46, 0, 0, 3])).unwrap(), 1004)
354    }
355
356    /// Attach user request payload
357    #[test]
358    fn test_attach_user_request() {
359        assert_eq!(attach_user_request(), 40)
360    }
361
362    /// Test of the new domain request
363    #[test]
364    fn test_erect_domain_request() {
365        assert_eq!(to_vec(&erect_domain_request().unwrap()), [4, 1, 0, 1, 0])
366    }
367
368    /// Test format of the channel join request
369    #[test]
370    fn test_channel_join_request() {
371         assert_eq!(to_vec(&channel_join_request(None, None).unwrap()), [56, 0, 0, 0, 0])
372    }
373
374    /// Test domain parameters format
375    #[test]
376    fn test_domain_parameters() {
377        let result = to_der(&domain_parameters(1,2,3,4, 5, 6, 7, 8));
378        assert_eq!(result, vec![48, 24, 2, 1, 1, 2, 1, 2, 2, 1, 3, 2, 1, 4, 2, 1, 5, 2, 1, 6, 2, 1, 7, 2, 1, 8])
379    }
380
381    /// Test connect initial
382    #[test]
383    fn test_connect_initial() {
384        let result = to_der(&connect_initial(Some(vec![1, 2, 3])));
385        assert_eq!(result, vec![127, 101, 103, 4, 1, 1, 4, 1, 1, 1, 1, 255, 48, 26, 2, 1, 34, 2, 1, 2, 2, 1, 0, 2, 1, 1, 2, 1, 0, 2, 1, 1, 2, 3, 0, 255, 255, 2, 1, 2, 48, 25, 2, 1, 1, 2, 1, 1, 2, 1, 1, 2, 1, 1, 2, 1, 0, 2, 1, 1, 2, 2, 4, 32, 2, 1, 2, 48, 32, 2, 3, 0, 255, 255, 2, 3, 0, 252, 23, 2, 3, 0, 255, 255, 2, 1, 1, 2, 1, 0, 2, 1, 1, 2, 3, 0, 255, 255, 2, 1, 2, 4, 3, 1, 2, 3])
386    }
387
388    /// Test connect response
389    #[test]
390    fn test_connect_response() {
391        let result = to_der(&connect_response(Some(vec![1, 2, 3])));
392        assert_eq!(result, vec![127, 102, 39, 10, 1, 0, 2, 1, 0, 48, 26, 2, 1, 22, 2, 1, 3, 2, 1, 0, 2, 1, 1, 2, 1, 0, 2, 1, 1, 2, 3, 0, 255, 248, 2, 1, 2, 4, 3, 1, 2, 3])
393    }
394}