1use core::x224;
2use core::tpkt;
3use model::error::{RdpResult, Error, RdpError, RdpErrorKind};
4use core::gcc::{KeyboardLayout, client_core_data, ClientData, ServerData, client_security_data, client_network_data, block_header, write_conference_create_request, MessageType, read_conference_create_response, Version};
5use model::data::{Trame, to_vec, Message, DataType, U16};
6use nla::asn1::{Sequence, ImplicitTag, OctetString, Enumerate, ASN1Type, Integer, to_der, from_ber};
7use yasna::{Tag};
8use std::io::{Write, Read, BufRead, Cursor};
9use core::per;
10use std::collections::HashMap;
11
12#[allow(dead_code)]
13#[repr(u8)]
14enum DomainMCSPDU {
15 ErectDomainRequest = 1,
16 DisconnectProviderUltimatum = 8,
17 AttachUserRequest = 10,
18 AttachUserConfirm = 11,
19 ChannelJoinRequest = 14,
20 ChannelJoinConfirm = 15,
21 SendDataRequest = 25,
22 SendDataIndication = 26
23}
24
25fn domain_parameters(max_channel_ids: u32, maw_user_ids: u32, max_token_ids: u32,
28 num_priorities: u32, min_thoughput: u32, max_height: u32,
29 max_mcs_pdu_size: u32, protocol_version: u32) -> Sequence {
30 sequence![
31 "maxChannelIds" => max_channel_ids,
32 "maxUserIds" => maw_user_ids,
33 "maxTokenIds" => max_token_ids,
34 "numPriorities" => num_priorities,
35 "minThoughput" => min_thoughput,
36 "maxHeight" => max_height,
37 "maxMCSPDUsize" => max_mcs_pdu_size,
38 "protocolVersion" => protocol_version
39 ]
40}
41
42fn connect_initial(user_data: Option<OctetString>) -> ImplicitTag<Sequence> {
47 ImplicitTag::new(Tag::application(101), sequence![
48 "callingDomainSelector" => vec![1 as u8] as OctetString,
49 "calledDomainSelector" => vec![1 as u8] as OctetString,
50 "upwardFlag" => true,
51 "targetParameters" => domain_parameters(34, 2, 0, 1, 0, 1, 0xffff, 2),
52 "minimumParameters" => domain_parameters(1, 1, 1, 1, 0, 1, 0x420, 2),
53 "maximumParameters" => domain_parameters(0xffff, 0xfc17, 0xffff, 1, 0, 1, 0xffff, 2),
54 "userData" => user_data.unwrap_or(Vec::new())
55 ])
56}
57
58fn connect_response(user_data: Option<OctetString>) -> ImplicitTag<Sequence> {
60 ImplicitTag::new(Tag::application(102),
61sequence![
62 "result" => 0 as Enumerate,
63 "calledConnectId" => 0 as Integer,
64 "domainParameters" => domain_parameters(22, 3, 0, 1, 0, 1,0xfff8, 2),
65 "userData" => user_data.unwrap_or(Vec::new())
66 ])
67}
68
69fn mcs_pdu_header(pdu: Option<DomainMCSPDU>, options: Option<u8>) -> u8 {
71 (pdu.unwrap_or(DomainMCSPDU::AttachUserConfirm) as u8) << 2 | options.unwrap_or(0)
72}
73
74fn read_attach_user_confirm(buffer: &mut dyn Read) -> RdpResult<u16> {
78 let mut confirm = trame![0 as u8, Vec::<u8>::new()];
79 confirm.read(buffer)?;
80 if cast!(DataType::U8, confirm[0])? >> 2 != mcs_pdu_header(Some(DomainMCSPDU::AttachUserConfirm), None) >> 2 {
81 return Err(Error::RdpError(RdpError::new(RdpErrorKind::InvalidData, "MCS: unexpected header on recv_attach_user_confirm")));
82 }
83
84 let mut request = Cursor::new(cast!(DataType::Slice, confirm[1])?);
85 if per::read_enumerates(&mut request)? != 0 {
86 return Err(Error::RdpError(RdpError::new(RdpErrorKind::RejectedByServer, "MCS: recv_attach_user_confirm user rejected by server")));
87 }
88 Ok(per::read_integer_16(1001, &mut request)?)
89}
90
91fn attach_user_request() -> u8 {
96 mcs_pdu_header(Some(DomainMCSPDU::AttachUserRequest), None)
97}
98
99
100fn erect_domain_request() -> RdpResult<Trame> {
102 let mut result = Cursor::new(vec![]);
103 per::write_integer(0, &mut result)?;
104 per::write_integer(0, &mut result)?;
105 Ok(trame![
106 mcs_pdu_header(Some(DomainMCSPDU::ErectDomainRequest), None),
107 result.into_inner()
108 ])
109}
110
111fn channel_join_request(user_id: Option<u16>, channel_id: Option<u16>) -> RdpResult<Trame> {
120 Ok(trame![
121 mcs_pdu_header(Some(DomainMCSPDU::ChannelJoinRequest), None),
122 U16::BE(user_id.unwrap_or(1001) - 1001),
123 U16::BE(channel_id.unwrap_or(0))
124 ])
125}
126
127fn read_channel_join_confirm(user_id: u16, channel_id: u16, buffer: &mut dyn Read) -> RdpResult<bool> {
136 let mut confirm = trame![0 as u8, Vec::<u8>::new()];
137 confirm.read(buffer)?;
138 if cast!(DataType::U8, confirm[0])? >> 2 != mcs_pdu_header(Some(DomainMCSPDU::ChannelJoinConfirm), None) >> 2 {
139 return Err(Error::RdpError(RdpError::new(RdpErrorKind::InvalidData, "MCS: unexpected header on read_channel_join_confirm")));
140 }
141
142 let mut request = Cursor::new(cast!(DataType::Slice, confirm[1])?);
143 let confirm = per::read_enumerates(&mut request)?;
144 let confirm_user_id = per::read_integer_16(1001, &mut request)?;
145 let confirm_channel_id = per::read_integer_16(0, &mut request)?;
146
147 if user_id != confirm_user_id {
148 return Err(Error::RdpError(RdpError::new(RdpErrorKind::InvalidData, "MCS: read_channel_join_confirm invalid user id")));
149 }
150
151 if channel_id != confirm_channel_id {
152 return Err(Error::RdpError(RdpError::new(RdpErrorKind::InvalidData, "MCS: read_channel_join_confirm invalid channel_id")));
153 }
154
155 Ok(confirm == 0)
156}
157
158pub struct Client<S> {
160 x224: x224::Client<S>,
162 server_data: Option<ServerData>,
164 user_id: Option<u16>,
166 channel_ids : HashMap<String, u16>
168}
169
170impl<S: Read + Write> Client<S> {
171 pub fn new(x224: x224::Client<S>) -> Self {
172 Client {
173 server_data: None,
174 x224,
175 user_id: None,
176 channel_ids: HashMap::new()
177 }
178 }
179
180 fn write_connect_initial(&mut self, screen_width: u16, screen_height: u16, keyboard_layout: KeyboardLayout, client_name: String) -> RdpResult<()> {
184 let client_core_data = client_core_data(Some(ClientData {
185 width: screen_width,
186 height: screen_height,
187 layout: keyboard_layout,
188 server_selected_protocol: self.x224.get_selected_protocols() as u32,
189 rdp_version: Version::RdpVersion5plus,
190 name: client_name
191 }));
192 let client_security_data = client_security_data();
193 let client_network_data = client_network_data(trame![]);
194 let user_data = to_vec(&trame![
195 trame![block_header(Some(MessageType::CsCore), Some(client_core_data.length() as u16)), client_core_data],
196 trame![block_header(Some(MessageType::CsSecurity), Some(client_security_data.length() as u16)), client_security_data],
197 trame![block_header(Some(MessageType::CsNet), Some(client_network_data.length() as u16)), client_network_data]
198 ]);
199 let conference = write_conference_create_request(&user_data)?;
200 self.x224.write(to_der(&connect_initial(Some(conference))))
201 }
202
203 fn read_connect_response(&mut self) -> RdpResult<()> {
205 let mut connect_response = connect_response(None);
207 let mut payload = try_let!(tpkt::Payload::Raw, self.x224.read()?)?;
208 from_ber(&mut connect_response, payload.fill_buf()?)?;
209
210 let cc_response = cast!(ASN1Type::OctetString, connect_response.inner["userData"])?;
213 self.server_data = Some(read_conference_create_response(&mut Cursor::new(cc_response))?);
214 Ok(())
215 }
216
217 pub fn connect(&mut self, client_name: String, screen_width: u16, screen_height: u16, keyboard_layout: KeyboardLayout) -> RdpResult<()> {
227 self.write_connect_initial(screen_width, screen_height, keyboard_layout, client_name)?;
228 self.read_connect_response()?;
229 self.x224.write(erect_domain_request()?)?;
230 self.x224.write(attach_user_request())?;
231
232 self.user_id = Some(read_attach_user_confirm(&mut try_let!(tpkt::Payload::Raw, self.x224.read()?)?)?);
233
234 self.channel_ids.insert("global".to_string(), 1003);
236 self.channel_ids.insert("user".to_string(), self.user_id.unwrap());
237
238 for channel_id in self.channel_ids.values() {
241 self.x224.write(channel_join_request(self.user_id, Some(*channel_id))?)?;
242 if !read_channel_join_confirm(self.user_id.unwrap(), *channel_id, &mut try_let!(tpkt::Payload::Raw, self.x224.read()?)?)? {
243 println!("Server reject channel id {:?}", channel_id);
244 }
245 }
246
247 Ok(())
248 }
249
250 pub fn write<T: 'static>(&mut self, channel_name: &String, message: T) -> RdpResult<()>
261 where T: Message {
262 self.x224.write(trame![
263 mcs_pdu_header(Some(DomainMCSPDU::SendDataRequest), None),
264 U16::BE(self.user_id.unwrap() - 1001),
265 U16::BE(self.channel_ids[channel_name]),
266 0x70 as u8,
267 per::write_length(message.length() as u16)?,
268 message
269 ])
270 }
271
272 pub fn read(&mut self) -> RdpResult<(String, tpkt::Payload)> {
287 let message = self.x224.read()?;
288 match message {
289 tpkt::Payload::Raw(mut payload) => {
290 let mut header = mcs_pdu_header(None, None);
291 header.read(&mut payload)?;
292 if header >> 2 == DomainMCSPDU::DisconnectProviderUltimatum as u8 {
293 return Err(Error::RdpError(RdpError::new(RdpErrorKind::Disconnect, "MCS: Disconnect Provider Ultimatum")));
294 }
295
296 if header >> 2 != DomainMCSPDU::SendDataIndication as u8 {
297 return Err(Error::RdpError(RdpError::new(RdpErrorKind::InvalidData, "MCS: Invalid opcode")));
298 }
299
300 per::read_integer_16(1001, &mut payload)?;
302
303 let channel_id = per::read_integer_16(0, &mut payload)?;
304 let channel = self.channel_ids.iter().find(|x| *x.1 == channel_id).ok_or(Error::RdpError(RdpError::new(RdpErrorKind::Unknown, "MCS: unknown channel")))?;
305
306 per::read_enumerates(&mut payload)?;
307 per::read_length(&mut payload)?;
308
309 Ok((channel.0.clone(), tpkt::Payload::Raw(payload)))
310 },
311 tpkt::Payload::FastPath(sec_flag, payload) => {
312 Ok(("global".to_string(), tpkt::Payload::FastPath(sec_flag, payload)))
314 }
315 }
316
317 }
318
319 pub fn shutdown(&mut self) -> RdpResult<()> {
321 self.x224.write(trame![
322 mcs_pdu_header(Some(DomainMCSPDU::DisconnectProviderUltimatum), Some(1)),
323 per::write_enumerates(0x80)?,
324 b"\x00\x00\x00\x00\x00\x00".to_vec()
325 ])?;
326 self.x224.shutdown()
327 }
328
329 pub fn is_rdp_version_5_plus(&self) -> bool {
332 self.server_data.as_ref().unwrap().rdp_version == Version::RdpVersion5plus
333 }
334
335 pub fn get_user_id(&self) -> u16 {
337 self.user_id.unwrap()
338 }
339
340 pub fn get_global_channel_id(&self) -> u16 {
342 self.channel_ids["global"]
343 }
344}
345
346#[cfg(test)]
347mod test {
348 use super::*;
349
350 #[test]
352 fn test_read_attach_user_confirm() {
353 assert_eq!(read_attach_user_confirm(&mut Cursor::new(vec![46, 0, 0, 3])).unwrap(), 1004)
354 }
355
356 #[test]
358 fn test_attach_user_request() {
359 assert_eq!(attach_user_request(), 40)
360 }
361
362 #[test]
364 fn test_erect_domain_request() {
365 assert_eq!(to_vec(&erect_domain_request().unwrap()), [4, 1, 0, 1, 0])
366 }
367
368 #[test]
370 fn test_channel_join_request() {
371 assert_eq!(to_vec(&channel_join_request(None, None).unwrap()), [56, 0, 0, 0, 0])
372 }
373
374 #[test]
376 fn test_domain_parameters() {
377 let result = to_der(&domain_parameters(1,2,3,4, 5, 6, 7, 8));
378 assert_eq!(result, vec![48, 24, 2, 1, 1, 2, 1, 2, 2, 1, 3, 2, 1, 4, 2, 1, 5, 2, 1, 6, 2, 1, 7, 2, 1, 8])
379 }
380
381 #[test]
383 fn test_connect_initial() {
384 let result = to_der(&connect_initial(Some(vec![1, 2, 3])));
385 assert_eq!(result, vec![127, 101, 103, 4, 1, 1, 4, 1, 1, 1, 1, 255, 48, 26, 2, 1, 34, 2, 1, 2, 2, 1, 0, 2, 1, 1, 2, 1, 0, 2, 1, 1, 2, 3, 0, 255, 255, 2, 1, 2, 48, 25, 2, 1, 1, 2, 1, 1, 2, 1, 1, 2, 1, 1, 2, 1, 0, 2, 1, 1, 2, 2, 4, 32, 2, 1, 2, 48, 32, 2, 3, 0, 255, 255, 2, 3, 0, 252, 23, 2, 3, 0, 255, 255, 2, 1, 1, 2, 1, 0, 2, 1, 1, 2, 3, 0, 255, 255, 2, 1, 2, 4, 3, 1, 2, 3])
386 }
387
388 #[test]
390 fn test_connect_response() {
391 let result = to_der(&connect_response(Some(vec![1, 2, 3])));
392 assert_eq!(result, vec![127, 102, 39, 10, 1, 0, 2, 1, 0, 48, 26, 2, 1, 22, 2, 1, 3, 2, 1, 0, 2, 1, 1, 2, 1, 0, 2, 1, 1, 2, 3, 0, 255, 248, 2, 1, 2, 4, 3, 1, 2, 3])
393 }
394}