ironrdp_pdu/
mcs.rs

1use std::borrow::Cow;
2
3use ironrdp_core::{
4    cast_length, ensure_fixed_part_size, ensure_size, invalid_field_err, other_err, read_padding,
5    unexpected_message_type_err, IntoOwned, ReadCursor, WriteCursor,
6};
7
8use crate::gcc::{ChannelDef, ClientGccBlocks, ConferenceCreateRequest, ConferenceCreateResponse};
9use crate::tpdu::{TpduCode, TpduHeader};
10use crate::tpkt::TpktHeader;
11use crate::x224::{user_data_size, X224Pdu};
12use crate::{impl_x224_pdu_borrowing, impl_x224_pdu_pod, per, DecodeResult, EncodeResult, PduError};
13
14// T.125 MCS is defined in:
15//
16// http://www.itu.int/rec/T-REC-T.125-199802-I/
17// ITU-T T.125 Multipoint Communication Service Protocol Specification
18//
19// Connect-Initial ::= [APPLICATION 101] IMPLICIT SEQUENCE
20// {
21//     callingDomainSelector	OCTET_STRING,
22//     calledDomainSelector		OCTET_STRING,
23//     upwardFlag			    BOOLEAN,
24//     targetParameters		    DomainParameters,
25//     minimumParameters		DomainParameters,
26//     maximumParameters		DomainParameters,
27//     userData			        OCTET_STRING
28// }
29//
30// DomainParameters ::= SEQUENCE
31// {
32//     maxChannelIds		INTEGER (0..MAX),
33//     maxUserIds			INTEGER (0..MAX),
34//     maxTokenIds			INTEGER (0..MAX),
35//     numPriorities		INTEGER (0..MAX),
36//     minThroughput		INTEGER (0..MAX),
37//     maxHeight			INTEGER (0..MAX),
38//     maxMCSPDUsize		INTEGER (0..MAX),
39//     protocolVersion		INTEGER (0..MAX)
40// }
41//
42// Connect-Response ::= [APPLICATION 102] IMPLICIT SEQUENCE
43// {
44//     result				Result,
45//     calledConnectId		INTEGER (0..MAX),
46//     domainParameters		DomainParameters,
47//     userData			    OCTET_STRING
48// }
49//
50// Result ::= ENUMERATED
51// {
52//     rt-successful			    (0),
53//     rt-domain-merging		    (1),
54//     rt-domain-not-hierarchical	(2),
55//     rt-no-such-channel		    (3),
56//     rt-no-such-domain		    (4),
57//     rt-no-such-user			    (5),
58//     rt-not-admitted			    (6),
59//     rt-other-user-id		        (7),
60//     rt-parameters-unacceptable	(8),
61//     rt-token-not-available		(9),
62//     rt-token-not-possessed		(10),
63//     rt-too-many-channels		    (11),
64//     rt-too-many-tokens		    (12),
65//     rt-too-many-users		    (13),
66//     rt-unspecified-failure		(14),
67//     rt-user-rejected		        (15)
68// }
69//
70// ErectDomainRequest ::= [APPLICATION 1] IMPLICIT SEQUENCE
71// {
72//     subHeight		INTEGER (0..MAX),
73//     subInterval		INTEGER (0..MAX)
74// }
75//
76// AttachUserRequest ::= [APPLICATION 10] IMPLICIT SEQUENCE
77// {
78// }
79//
80// AttachUserConfirm ::= [APPLICATION 11] IMPLICIT SEQUENCE
81// {
82//     result			Result,
83//     initiator		UserId OPTIONAL
84// }
85//
86// ChannelJoinRequest ::= [APPLICATION 14] IMPLICIT SEQUENCE
87// {
88//     initiator		UserId,
89//     channelId		ChannelId
90// }
91//
92// ChannelJoinConfirm ::= [APPLICATION 15] IMPLICIT SEQUENCE
93// {
94//     result		Result,
95//     initiator	UserId,
96//     requested	ChannelId,
97//     channelId	ChannelId OPTIONAL
98// }
99//
100// SendDataRequest ::= [APPLICATION 25] IMPLICIT SEQUENCE
101// {
102//     initiator		UserId,
103//     channelId		ChannelId,
104//     dataPriority		DataPriority,
105//     segmentation		Segmentation,
106//     userData			OCTET_STRING
107// }
108//
109// DataPriority ::= CHOICE
110// {
111//     top		NULL,
112//     high		NULL,
113//     medium	NULL,
114//     low		NULL,
115//     ...
116// }
117//
118// Segmentation ::= BIT_STRING
119// {
120//     begin	(0),
121//     end		(1)
122// } (SIZE(2))
123//
124// SendDataIndication ::= [APPLICATION 26] IMPLICIT SEQUENCE
125// {
126//     initiator		UserId,
127//     channelId		ChannelId,
128//     dataPriority		DataPriority,
129//     segmentation		Segmentation,
130//     userData			OCTET_STRING
131// }
132
133pub const RESULT_ENUM_LENGTH: u8 = 16;
134
135const BASE_CHANNEL_ID: u16 = 1001;
136const SEND_DATA_PDU_DATA_PRIORITY_AND_SEGMENTATION: u8 = 0x70;
137
138/// Creates a closure mapping a `PerError` to a `PduError` with field-level context.
139///
140/// Shorthand for
141/// ```rust
142/// |e| <crate::PduError as crate::PduErrorExt>::invalid_field(Self::MCS_NAME, field_name, "PER").with_source(e)
143/// ```
144macro_rules! per_field_err {
145    ($field_name:expr) => {{
146        |error| ironrdp_core::invalid_field_err_with_source(Self::MCS_NAME, $field_name, "PER", error)
147    }};
148}
149
150#[doc(hidden)]
151pub trait McsPdu<'de>: Sized {
152    const MCS_NAME: &'static str;
153
154    fn mcs_body_encode(&self, dst: &mut WriteCursor<'_>) -> EncodeResult<()>;
155
156    fn mcs_body_decode(src: &mut ReadCursor<'de>, tpdu_user_data_size: usize) -> DecodeResult<Self>;
157
158    fn mcs_size(&self) -> usize;
159
160    fn name(&self) -> &'static str {
161        Self::MCS_NAME
162    }
163}
164
165impl<'de, T> X224Pdu<'de> for T
166where
167    T: McsPdu<'de>,
168{
169    const X224_NAME: &'static str = T::MCS_NAME;
170
171    const TPDU_CODE: TpduCode = TpduCode::DATA;
172
173    fn x224_body_encode(&self, dst: &mut WriteCursor<'_>) -> EncodeResult<()> {
174        self.mcs_body_encode(dst)
175    }
176
177    fn x224_body_decode(src: &mut ReadCursor<'de>, tpkt: &TpktHeader, tpdu: &TpduHeader) -> DecodeResult<Self> {
178        let tpdu_user_data_size = user_data_size(tpkt, tpdu);
179        T::mcs_body_decode(src, tpdu_user_data_size)
180    }
181
182    fn tpdu_header_variable_part_size(&self) -> usize {
183        0
184    }
185
186    fn tpdu_user_data_size(&self) -> usize {
187        self.mcs_size()
188    }
189}
190
191#[derive(Debug, Copy, Clone, PartialEq)]
192#[repr(u8)]
193enum DomainMcsPdu {
194    ErectDomainRequest = 1,
195    DisconnectProviderUltimatum = 8,
196    AttachUserRequest = 10,
197    AttachUserConfirm = 11,
198    ChannelJoinRequest = 14,
199    ChannelJoinConfirm = 15,
200    SendDataRequest = 25,
201    SendDataIndication = 26,
202}
203
204impl DomainMcsPdu {
205    fn check_expected(self, name: &'static str, expected: DomainMcsPdu) -> DecodeResult<()> {
206        if self != expected {
207            Err(unexpected_message_type_err!(name, self.as_u8()))
208        } else {
209            Ok(())
210        }
211    }
212
213    fn from_choice(choice: u8) -> Option<Self> {
214        Self::from_u8(choice >> 2)
215    }
216
217    fn to_choice(self) -> u8 {
218        self.as_u8() << 2
219    }
220
221    fn from_u8(value: u8) -> Option<Self> {
222        match value {
223            1 => Some(Self::ErectDomainRequest),
224            8 => Some(Self::DisconnectProviderUltimatum),
225            10 => Some(Self::AttachUserRequest),
226            11 => Some(Self::AttachUserConfirm),
227            14 => Some(Self::ChannelJoinRequest),
228            15 => Some(Self::ChannelJoinConfirm),
229            25 => Some(Self::SendDataRequest),
230            26 => Some(Self::SendDataIndication),
231            _ => None,
232        }
233    }
234
235    fn as_u8(self) -> u8 {
236        self as u8
237    }
238}
239
240fn read_mcspdu_header(src: &mut ReadCursor<'_>, ctx: &'static str) -> DecodeResult<DomainMcsPdu> {
241    let choice = src.try_read_u8().map_err(|e| other_err!(ctx, source: e))?;
242
243    DomainMcsPdu::from_choice(choice)
244        .ok_or_else(|| invalid_field_err(ctx, "domain-mcspdu", "unexpected application tag for CHOICE"))
245}
246
247fn peek_mcspdu_header(src: &mut ReadCursor<'_>, ctx: &'static str) -> DecodeResult<DomainMcsPdu> {
248    let choice = src.try_peek_u8().map_err(|e| other_err!(ctx, source: e))?;
249
250    DomainMcsPdu::from_choice(choice)
251        .ok_or_else(|| invalid_field_err(ctx, "domain-mcspdu", "unexpected application tag for CHOICE"))
252}
253
254fn write_mcspdu_header(dst: &mut WriteCursor<'_>, domain_mcspdu: DomainMcsPdu, options: u8) {
255    let choice = domain_mcspdu.to_choice();
256
257    debug_assert_eq!(options & !0b11, 0);
258    debug_assert_eq!(choice & 0b11, 0);
259
260    dst.write_u8(choice | options);
261}
262
263/// The kind of the RDP header message that may carry additional data.
264#[derive(Debug, Clone, PartialEq, Eq)]
265pub enum McsMessage<'a> {
266    ErectDomainRequest(ErectDomainPdu),
267    AttachUserRequest(AttachUserRequest),
268    AttachUserConfirm(AttachUserConfirm),
269    ChannelJoinRequest(ChannelJoinRequest),
270    ChannelJoinConfirm(ChannelJoinConfirm),
271    SendDataRequest(SendDataRequest<'a>),
272    SendDataIndication(SendDataIndication<'a>),
273    DisconnectProviderUltimatum(DisconnectProviderUltimatum),
274}
275
276impl_x224_pdu_borrowing!(McsMessage<'_>, OwnedMcsMessage);
277
278impl IntoOwned for McsMessage<'_> {
279    type Owned = OwnedMcsMessage;
280
281    fn into_owned(self) -> Self::Owned {
282        match self {
283            Self::ErectDomainRequest(msg) => McsMessage::ErectDomainRequest(msg.into_owned()),
284            Self::AttachUserRequest(msg) => McsMessage::AttachUserRequest(msg.into_owned()),
285            Self::AttachUserConfirm(msg) => McsMessage::AttachUserConfirm(msg.into_owned()),
286            Self::ChannelJoinRequest(msg) => McsMessage::ChannelJoinRequest(msg.into_owned()),
287            Self::ChannelJoinConfirm(msg) => McsMessage::ChannelJoinConfirm(msg.into_owned()),
288            Self::SendDataRequest(msg) => McsMessage::SendDataRequest(msg.into_owned()),
289            Self::SendDataIndication(msg) => McsMessage::SendDataIndication(msg.into_owned()),
290            Self::DisconnectProviderUltimatum(msg) => McsMessage::DisconnectProviderUltimatum(msg.into_owned()),
291        }
292    }
293}
294
295impl<'de> McsPdu<'de> for McsMessage<'de> {
296    const MCS_NAME: &'static str = "McsMessage";
297
298    fn mcs_body_encode(&self, dst: &mut WriteCursor<'_>) -> EncodeResult<()> {
299        match self {
300            Self::ErectDomainRequest(msg) => msg.mcs_body_encode(dst),
301            Self::AttachUserRequest(msg) => msg.mcs_body_encode(dst),
302            Self::AttachUserConfirm(msg) => msg.mcs_body_encode(dst),
303            Self::ChannelJoinRequest(msg) => msg.mcs_body_encode(dst),
304            Self::ChannelJoinConfirm(msg) => msg.mcs_body_encode(dst),
305            Self::SendDataRequest(msg) => msg.mcs_body_encode(dst),
306            Self::SendDataIndication(msg) => msg.mcs_body_encode(dst),
307            Self::DisconnectProviderUltimatum(msg) => msg.mcs_body_encode(dst),
308        }
309    }
310
311    fn mcs_body_decode(src: &mut ReadCursor<'de>, tpdu_user_data_size: usize) -> DecodeResult<Self> {
312        match peek_mcspdu_header(src, Self::MCS_NAME)? {
313            DomainMcsPdu::ErectDomainRequest => Ok(McsMessage::ErectDomainRequest(ErectDomainPdu::mcs_body_decode(
314                src,
315                tpdu_user_data_size,
316            )?)),
317            DomainMcsPdu::AttachUserRequest => Ok(McsMessage::AttachUserRequest(AttachUserRequest::mcs_body_decode(
318                src,
319                tpdu_user_data_size,
320            )?)),
321            DomainMcsPdu::AttachUserConfirm => Ok(McsMessage::AttachUserConfirm(AttachUserConfirm::mcs_body_decode(
322                src,
323                tpdu_user_data_size,
324            )?)),
325            DomainMcsPdu::ChannelJoinRequest => Ok(McsMessage::ChannelJoinRequest(
326                ChannelJoinRequest::mcs_body_decode(src, tpdu_user_data_size)?,
327            )),
328            DomainMcsPdu::ChannelJoinConfirm => Ok(McsMessage::ChannelJoinConfirm(
329                ChannelJoinConfirm::mcs_body_decode(src, tpdu_user_data_size)?,
330            )),
331            DomainMcsPdu::SendDataRequest => Ok(McsMessage::SendDataRequest(SendDataRequest::mcs_body_decode(
332                src,
333                tpdu_user_data_size,
334            )?)),
335            DomainMcsPdu::SendDataIndication => Ok(McsMessage::SendDataIndication(
336                SendDataIndication::mcs_body_decode(src, tpdu_user_data_size)?,
337            )),
338            DomainMcsPdu::DisconnectProviderUltimatum => Ok(McsMessage::DisconnectProviderUltimatum(
339                DisconnectProviderUltimatum::mcs_body_decode(src, tpdu_user_data_size)?,
340            )),
341        }
342    }
343
344    fn mcs_size(&self) -> usize {
345        match self {
346            Self::ErectDomainRequest(msg) => msg.mcs_size(),
347            Self::AttachUserRequest(msg) => msg.mcs_size(),
348            Self::AttachUserConfirm(msg) => msg.mcs_size(),
349            Self::ChannelJoinRequest(msg) => msg.mcs_size(),
350            Self::ChannelJoinConfirm(msg) => msg.mcs_size(),
351            Self::SendDataRequest(msg) => msg.mcs_size(),
352            Self::SendDataIndication(msg) => msg.mcs_size(),
353            Self::DisconnectProviderUltimatum(msg) => msg.mcs_size(),
354        }
355    }
356
357    fn name(&self) -> &'static str {
358        match self {
359            Self::ErectDomainRequest(msg) => msg.name(),
360            Self::AttachUserRequest(msg) => msg.name(),
361            Self::AttachUserConfirm(msg) => msg.name(),
362            Self::ChannelJoinRequest(msg) => msg.name(),
363            Self::ChannelJoinConfirm(msg) => msg.name(),
364            Self::SendDataRequest(msg) => msg.name(),
365            Self::SendDataIndication(msg) => msg.name(),
366            Self::DisconnectProviderUltimatum(msg) => msg.name(),
367        }
368    }
369}
370
371#[derive(Debug, Clone, PartialEq, Eq)]
372pub struct ErectDomainPdu {
373    pub sub_height: u32,
374    pub sub_interval: u32,
375}
376
377impl_x224_pdu_pod!(ErectDomainPdu);
378
379impl<'de> McsPdu<'de> for ErectDomainPdu {
380    const MCS_NAME: &'static str = "ErectDomainPdu";
381
382    fn mcs_body_encode(&self, dst: &mut WriteCursor<'_>) -> EncodeResult<()> {
383        write_mcspdu_header(dst, DomainMcsPdu::ErectDomainRequest, 0);
384
385        per::write_u32(dst, self.sub_height);
386        per::write_u32(dst, self.sub_interval);
387
388        Ok(())
389    }
390
391    fn mcs_body_decode(src: &mut ReadCursor<'de>, _: usize) -> DecodeResult<Self> {
392        read_mcspdu_header(src, Self::MCS_NAME)?.check_expected(Self::MCS_NAME, DomainMcsPdu::ErectDomainRequest)?;
393
394        let sub_height = per::read_u32(src).map_err(per_field_err!("subHeight"))?;
395        let sub_interval = per::read_u32(src).map_err(per_field_err!("subInterval"))?;
396
397        Ok(Self {
398            sub_height,
399            sub_interval,
400        })
401    }
402
403    fn mcs_size(&self) -> usize {
404        per::CHOICE_SIZE + per::sizeof_u32(self.sub_height) + per::sizeof_u32(self.sub_interval)
405    }
406}
407
408#[derive(Debug, Clone, PartialEq, Eq)]
409pub struct AttachUserRequest;
410
411impl_x224_pdu_pod!(AttachUserRequest);
412
413impl<'de> McsPdu<'de> for AttachUserRequest {
414    const MCS_NAME: &'static str = "AttachUserRequest";
415
416    fn mcs_body_encode(&self, dst: &mut WriteCursor<'_>) -> EncodeResult<()> {
417        write_mcspdu_header(dst, DomainMcsPdu::AttachUserRequest, 0);
418
419        Ok(())
420    }
421
422    fn mcs_body_decode(src: &mut ReadCursor<'de>, _: usize) -> DecodeResult<Self> {
423        read_mcspdu_header(src, Self::MCS_NAME)?.check_expected(Self::MCS_NAME, DomainMcsPdu::AttachUserRequest)?;
424
425        Ok(Self)
426    }
427
428    fn mcs_size(&self) -> usize {
429        per::CHOICE_SIZE
430    }
431}
432
433#[derive(Debug, Clone, PartialEq, Eq)]
434pub struct AttachUserConfirm {
435    pub result: u8,
436    pub initiator_id: u16,
437}
438
439impl_x224_pdu_pod!(AttachUserConfirm);
440
441impl<'de> McsPdu<'de> for AttachUserConfirm {
442    const MCS_NAME: &'static str = "AttachUserConfirm";
443
444    fn mcs_body_encode(&self, dst: &mut WriteCursor<'_>) -> EncodeResult<()> {
445        write_mcspdu_header(dst, DomainMcsPdu::AttachUserConfirm, 2);
446
447        per::write_enum(dst, self.result);
448        per::write_u16(dst, self.initiator_id, BASE_CHANNEL_ID).map_err(per_field_err!("initiator"))?;
449
450        Ok(())
451    }
452
453    fn mcs_body_decode(src: &mut ReadCursor<'de>, _: usize) -> DecodeResult<Self> {
454        read_mcspdu_header(src, Self::MCS_NAME)?.check_expected(Self::MCS_NAME, DomainMcsPdu::AttachUserConfirm)?;
455
456        let result = per::read_enum(src, RESULT_ENUM_LENGTH).map_err(per_field_err!("result"))?;
457        let user_id = per::read_u16(src, BASE_CHANNEL_ID).map_err(per_field_err!("userId"))?;
458
459        Ok(Self {
460            result,
461            initiator_id: user_id,
462        })
463    }
464
465    fn mcs_size(&self) -> usize {
466        per::CHOICE_SIZE + per::ENUM_SIZE + per::U16_SIZE
467    }
468}
469
470#[derive(Debug, Clone, PartialEq, Eq)]
471pub struct ChannelJoinRequest {
472    pub initiator_id: u16,
473    pub channel_id: u16,
474}
475
476impl_x224_pdu_pod!(ChannelJoinRequest);
477
478impl<'de> McsPdu<'de> for ChannelJoinRequest {
479    const MCS_NAME: &'static str = "ChannelJoinRequest";
480
481    fn mcs_body_encode(&self, dst: &mut WriteCursor<'_>) -> EncodeResult<()> {
482        write_mcspdu_header(dst, DomainMcsPdu::ChannelJoinRequest, 0);
483
484        per::write_u16(dst, self.initiator_id, BASE_CHANNEL_ID).map_err(per_field_err!("initiator"))?;
485        per::write_u16(dst, self.channel_id, 0).map_err(per_field_err!("channelId"))?;
486
487        Ok(())
488    }
489
490    fn mcs_body_decode(src: &mut ReadCursor<'de>, _: usize) -> DecodeResult<Self> {
491        read_mcspdu_header(src, Self::MCS_NAME)?.check_expected(Self::MCS_NAME, DomainMcsPdu::ChannelJoinRequest)?;
492
493        let initiator_id = per::read_u16(src, BASE_CHANNEL_ID).map_err(per_field_err!("initiator"))?;
494        let channel_id = per::read_u16(src, 0).map_err(per_field_err!("channelID"))?;
495
496        Ok(Self {
497            initiator_id,
498            channel_id,
499        })
500    }
501
502    fn mcs_size(&self) -> usize {
503        per::CHOICE_SIZE + per::U16_SIZE * 2
504    }
505}
506
507#[derive(Debug, Clone, PartialEq, Eq)]
508pub struct ChannelJoinConfirm {
509    pub result: u8,
510    pub initiator_id: u16,
511    pub requested_channel_id: u16,
512    pub channel_id: u16,
513}
514
515impl_x224_pdu_pod!(ChannelJoinConfirm);
516
517impl<'de> McsPdu<'de> for ChannelJoinConfirm {
518    const MCS_NAME: &'static str = "ChannelJoinConfirm";
519
520    fn mcs_body_encode(&self, dst: &mut WriteCursor<'_>) -> EncodeResult<()> {
521        write_mcspdu_header(dst, DomainMcsPdu::ChannelJoinConfirm, 2);
522
523        per::write_enum(dst, self.result);
524        per::write_u16(dst, self.initiator_id, BASE_CHANNEL_ID).map_err(per_field_err!("initiator"))?;
525        per::write_u16(dst, self.requested_channel_id, 0).map_err(per_field_err!("requested"))?;
526        per::write_u16(dst, self.channel_id, 0).map_err(per_field_err!("channelId"))?;
527
528        Ok(())
529    }
530
531    fn mcs_body_decode(src: &mut ReadCursor<'de>, _: usize) -> DecodeResult<Self> {
532        read_mcspdu_header(src, Self::MCS_NAME)?.check_expected(Self::MCS_NAME, DomainMcsPdu::ChannelJoinConfirm)?;
533
534        let result = per::read_enum(src, RESULT_ENUM_LENGTH).map_err(per_field_err!("result"))?;
535        let initiator_id = per::read_u16(src, BASE_CHANNEL_ID).map_err(per_field_err!("initiator"))?;
536        let requested_channel_id = per::read_u16(src, 0).map_err(per_field_err!("requested"))?;
537        let channel_id = per::read_u16(src, 0).map_err(per_field_err!("channelId"))?;
538
539        Ok(Self {
540            result,
541            initiator_id,
542            requested_channel_id,
543            channel_id,
544        })
545    }
546
547    fn mcs_size(&self) -> usize {
548        per::CHOICE_SIZE + per::ENUM_SIZE + per::U16_SIZE * 3
549    }
550}
551
552#[derive(Debug, Clone, PartialEq, Eq)]
553pub struct SendDataRequest<'a> {
554    pub initiator_id: u16,
555    pub channel_id: u16,
556    pub user_data: Cow<'a, [u8]>,
557}
558
559impl_x224_pdu_borrowing!(SendDataRequest<'_>, OwnedSendDataRequest);
560
561impl IntoOwned for SendDataRequest<'_> {
562    type Owned = OwnedSendDataRequest;
563
564    fn into_owned(self) -> Self::Owned {
565        SendDataRequest {
566            user_data: Cow::Owned(self.user_data.into_owned()),
567            ..self
568        }
569    }
570}
571
572impl<'de> McsPdu<'de> for SendDataRequest<'de> {
573    const MCS_NAME: &'static str = "SendDataRequest";
574
575    fn mcs_body_encode(&self, dst: &mut WriteCursor<'_>) -> EncodeResult<()> {
576        write_mcspdu_header(dst, DomainMcsPdu::SendDataRequest, 0);
577
578        per::write_u16(dst, self.initiator_id, BASE_CHANNEL_ID).map_err(per_field_err!("initiator"))?;
579        per::write_u16(dst, self.channel_id, 0).map_err(per_field_err!("channelID"))?;
580
581        dst.write_u8(SEND_DATA_PDU_DATA_PRIORITY_AND_SEGMENTATION);
582
583        per::write_length(dst, cast_length!("user-data-length", self.user_data.len())?);
584        dst.write_slice(&self.user_data);
585
586        Ok(())
587    }
588
589    fn mcs_body_decode(src: &mut ReadCursor<'de>, tpdu_user_data_size: usize) -> DecodeResult<Self> {
590        let src_len_before = src.len();
591
592        read_mcspdu_header(src, Self::MCS_NAME)?.check_expected(Self::MCS_NAME, DomainMcsPdu::SendDataRequest)?;
593
594        let initiator_id = per::read_u16(src, BASE_CHANNEL_ID).map_err(per_field_err!("initiator"))?;
595        let channel_id = per::read_u16(src, 0).map_err(per_field_err!("channelId"))?;
596
597        // dataPriority + segmentation
598        ensure_size!(ctx: Self::MCS_NAME, in: src, size: 1);
599        read_padding!(src, 1);
600
601        let (length, _) = per::read_length(src).map_err(per_field_err!("userDataLength"))?;
602        let length = usize::from(length);
603
604        let src_len_after = src.len();
605
606        if length > tpdu_user_data_size.saturating_sub(src_len_before - src_len_after) {
607            return Err(invalid_field_err(
608                Self::MCS_NAME,
609                "userDataLength",
610                "inconsistent with user data size advertised in TPDU",
611            ));
612        }
613
614        ensure_size!(ctx: Self::MCS_NAME, in: src, size: length);
615        let user_data = Cow::Borrowed(src.read_slice(length));
616
617        Ok(Self {
618            initiator_id,
619            channel_id,
620            user_data,
621        })
622    }
623
624    fn mcs_size(&self) -> usize {
625        per::CHOICE_SIZE
626            + per::U16_SIZE * 2
627            + 1
628            + per::sizeof_length(u16::try_from(self.user_data.len()).unwrap_or(u16::MAX))
629            + self.user_data.len()
630    }
631}
632
633#[derive(Debug, Clone, PartialEq, Eq)]
634pub struct SendDataIndication<'a> {
635    pub initiator_id: u16,
636    pub channel_id: u16,
637    pub user_data: Cow<'a, [u8]>,
638}
639
640impl_x224_pdu_borrowing!(SendDataIndication<'_>, OwnedSendDataIndication);
641
642impl IntoOwned for SendDataIndication<'_> {
643    type Owned = OwnedSendDataIndication;
644
645    fn into_owned(self) -> Self::Owned {
646        SendDataIndication {
647            user_data: Cow::Owned(self.user_data.into_owned()),
648            ..self
649        }
650    }
651}
652
653impl<'de> McsPdu<'de> for SendDataIndication<'de> {
654    const MCS_NAME: &'static str = "SendDataIndication";
655
656    fn mcs_body_encode(&self, dst: &mut WriteCursor<'_>) -> EncodeResult<()> {
657        write_mcspdu_header(dst, DomainMcsPdu::SendDataIndication, 0);
658
659        per::write_u16(dst, self.initiator_id, BASE_CHANNEL_ID).map_err(per_field_err!("initiator"))?;
660        per::write_u16(dst, self.channel_id, 0).map_err(per_field_err!("channelId"))?;
661
662        dst.write_u8(SEND_DATA_PDU_DATA_PRIORITY_AND_SEGMENTATION);
663
664        per::write_length(dst, cast_length!("userDataLength", self.user_data.len())?);
665        dst.write_slice(&self.user_data);
666
667        Ok(())
668    }
669
670    fn mcs_body_decode(src: &mut ReadCursor<'de>, tpdu_user_data_size: usize) -> DecodeResult<Self> {
671        let src_len_before = src.len();
672
673        read_mcspdu_header(src, Self::MCS_NAME)?.check_expected(Self::MCS_NAME, DomainMcsPdu::SendDataIndication)?;
674
675        let initiator_id = per::read_u16(src, BASE_CHANNEL_ID).map_err(per_field_err!("initiator"))?;
676        let channel_id = per::read_u16(src, 0).map_err(per_field_err!("channelId"))?;
677
678        // dataPriority + segmentation
679        ensure_size!(ctx: Self::MCS_NAME, in: src, size: 1);
680        read_padding!(src, 1);
681
682        let (length, _) = per::read_length(src).map_err(per_field_err!("userDataLength"))?;
683        let length = usize::from(length);
684
685        let src_len_after = src.len();
686
687        if length > tpdu_user_data_size.saturating_sub(src_len_before - src_len_after) {
688            return Err(invalid_field_err(
689                Self::MCS_NAME,
690                "userDataLength",
691                "inconsistent with user data size advertised in TPDU",
692            ));
693        }
694
695        ensure_size!(ctx: Self::MCS_NAME, in: src, size: length);
696        let user_data = Cow::Borrowed(src.read_slice(length));
697
698        Ok(Self {
699            initiator_id,
700            channel_id,
701            user_data,
702        })
703    }
704
705    fn mcs_size(&self) -> usize {
706        per::CHOICE_SIZE
707            + per::U16_SIZE * 2
708            + 1
709            + per::sizeof_length(u16::try_from(self.user_data.len()).unwrap_or(u16::MAX))
710            + self.user_data.len()
711    }
712}
713
714/// The reason of `DisconnectProviderUltimatum`.
715#[derive(Debug, Copy, Clone, PartialEq, Eq)]
716#[repr(u8)]
717pub enum DisconnectReason {
718    DomainDisconnected = 0,
719    ProviderInitiated = 1,
720    TokenPurged = 2,
721    UserRequested = 3,
722    ChannelPurged = 4,
723}
724
725impl DisconnectReason {
726    pub fn as_u8(self) -> u8 {
727        self as u8
728    }
729
730    pub fn from_u8(value: u8) -> Option<Self> {
731        match value {
732            0 => Some(Self::DomainDisconnected),
733            1 => Some(Self::ProviderInitiated),
734            2 => Some(Self::TokenPurged),
735            3 => Some(Self::UserRequested),
736            4 => Some(Self::ChannelPurged),
737            _ => None,
738        }
739    }
740
741    pub fn description(self) -> &'static str {
742        match self {
743            Self::DomainDisconnected => "domain disconnected",
744            Self::ProviderInitiated => "server-initiated disconnect",
745            Self::TokenPurged => "token purged",
746            Self::UserRequested => "user-requested disconnect",
747            Self::ChannelPurged => "channel purged",
748        }
749    }
750}
751
752impl core::fmt::Display for DisconnectReason {
753    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
754        f.write_str(self.description())
755    }
756}
757
758#[derive(Debug, Copy, Clone, PartialEq, Eq)]
759pub struct DisconnectProviderUltimatum {
760    pub reason: DisconnectReason,
761}
762
763impl_x224_pdu_pod!(DisconnectProviderUltimatum);
764
765impl DisconnectProviderUltimatum {
766    pub const NAME: &'static str = "DisconnectProviderUltimatum";
767
768    pub const FIXED_PART_SIZE: usize = 2;
769
770    pub fn from_reason(reason: DisconnectReason) -> Self {
771        Self { reason }
772    }
773}
774
775impl<'de> McsPdu<'de> for DisconnectProviderUltimatum {
776    const MCS_NAME: &'static str = "DisconnectProviderUltimatum";
777
778    fn mcs_body_encode(&self, dst: &mut WriteCursor<'_>) -> EncodeResult<()> {
779        let domain_mcspdu = DomainMcsPdu::DisconnectProviderUltimatum.as_u8();
780        let reason = self.reason.as_u8();
781
782        let b1 = (domain_mcspdu << 2) | ((reason >> 1) & 0x03);
783        let b2 = reason << 7;
784
785        dst.write_array([b1, b2]);
786
787        Ok(())
788    }
789
790    fn mcs_body_decode(src: &mut ReadCursor<'de>, _: usize) -> DecodeResult<Self> {
791        // http://msdn.microsoft.com/en-us/library/cc240872.aspx:
792        //
793        // PER encoded (ALIGNED variant of BASIC-PER) PDU contents:
794        // 21 80
795        //
796        // 0x21:
797        // 0 - --\
798        // 0 -   |
799        // 1 -   | CHOICE: From DomainMCSPDU select disconnectProviderUltimatum (8)
800        // 0 -   | of type DisconnectProviderUltimatum
801        // 0 -   |
802        // 0 - --/
803        // 0 - --\
804        // 1 -   |
805        //       | DisconnectProviderUltimatum::reason = rn-user-requested (3)
806        // 0x80: |
807        // 1 - --/
808        // 0 - padding
809        // 0 - padding
810        // 0 - padding
811        // 0 - padding
812        // 0 - padding
813        // 0 - padding
814        // 0 - padding
815
816        ensure_fixed_part_size!(in: src);
817
818        let [b1, b2] = src.read_array();
819
820        let domain_mcspdu_choice = b1 >> 2;
821        let reason = ((b1 & 0x03) << 1) | (b2 >> 7);
822
823        DomainMcsPdu::from_u8(domain_mcspdu_choice)
824            .ok_or_else(|| invalid_field_err(Self::MCS_NAME, "domain-mcspdu", "unexpected application tag for CHOICE"))?
825            .check_expected(Self::MCS_NAME, DomainMcsPdu::DisconnectProviderUltimatum)?;
826
827        Ok(Self {
828            reason: DisconnectReason::from_u8(reason)
829                .ok_or_else(|| invalid_field_err(Self::MCS_NAME, "reason", "unknown variant"))?,
830        })
831    }
832
833    fn mcs_size(&self) -> usize {
834        2
835    }
836}
837
838#[derive(Clone, Debug, PartialEq, Eq)]
839pub struct ConnectInitial {
840    pub conference_create_request: ConferenceCreateRequest,
841    pub calling_domain_selector: Vec<u8>,
842    pub called_domain_selector: Vec<u8>,
843    pub upward_flag: bool,
844    pub target_parameters: DomainParameters,
845    pub min_parameters: DomainParameters,
846    pub max_parameters: DomainParameters,
847}
848
849impl ConnectInitial {
850    pub fn with_gcc_blocks(gcc_blocks: ClientGccBlocks) -> Self {
851        Self {
852            conference_create_request: ConferenceCreateRequest { gcc_blocks },
853            calling_domain_selector: vec![0x01],
854            called_domain_selector: vec![0x01],
855            upward_flag: true,
856            target_parameters: DomainParameters::target(),
857            min_parameters: DomainParameters::min(),
858            max_parameters: DomainParameters::max(),
859        }
860    }
861
862    pub fn channel_names(&self) -> Option<Vec<ChannelDef>> {
863        self.conference_create_request.gcc_blocks.channel_names()
864    }
865}
866
867#[derive(Clone, Debug, PartialEq, Eq)]
868pub struct ConnectResponse {
869    pub conference_create_response: ConferenceCreateResponse,
870    pub called_connect_id: u32,
871    pub domain_parameters: DomainParameters,
872}
873
874impl ConnectResponse {
875    pub fn channel_ids(&self) -> Vec<u16> {
876        self.conference_create_response.gcc_blocks.channel_ids()
877    }
878
879    pub fn global_channel_id(&self) -> u16 {
880        self.conference_create_response.gcc_blocks.global_channel_id()
881    }
882}
883
884#[derive(Clone, Debug, PartialEq, Eq)]
885pub struct DomainParameters {
886    pub max_channel_ids: u32,
887    pub max_user_ids: u32,
888    pub max_token_ids: u32,
889    pub num_priorities: u32,
890    pub min_throughput: u32,
891    pub max_height: u32,
892    pub max_mcs_pdu_size: u32,
893    pub protocol_version: u32,
894}
895
896impl DomainParameters {
897    pub fn min() -> Self {
898        Self {
899            max_channel_ids: 1,
900            max_user_ids: 1,
901            max_token_ids: 1,
902            num_priorities: 1,
903            min_throughput: 0,
904            max_height: 1,
905            max_mcs_pdu_size: 1056,
906            protocol_version: 2,
907        }
908    }
909
910    pub fn target() -> Self {
911        Self {
912            max_channel_ids: 34,
913            max_user_ids: 2,
914            max_token_ids: 0,
915            num_priorities: 1,
916            min_throughput: 0,
917            max_height: 1,
918            max_mcs_pdu_size: 65535,
919            protocol_version: 2,
920        }
921    }
922
923    pub fn max() -> Self {
924        Self {
925            max_channel_ids: 65535,
926            max_user_ids: 64535,
927            max_token_ids: 65535,
928            num_priorities: 1,
929            min_throughput: 0,
930            max_height: 1,
931            max_mcs_pdu_size: 65535,
932            protocol_version: 2,
933        }
934    }
935}
936
937pub use legacy::McsError;
938
939mod legacy {
940    use std::io;
941
942    use ironrdp_core::{Decode, DecodeResult, Encode};
943    use thiserror::Error;
944
945    use super::{
946        cast_length, ensure_size, ConnectInitial, ConnectResponse, DomainParameters, PduError, ReadCursor, WriteCursor,
947        RESULT_ENUM_LENGTH,
948    };
949    use crate::gcc::{ConferenceCreateRequest, ConferenceCreateResponse, GccError};
950    use crate::{ber, EncodeResult};
951
952    // impl<'de> McsPdu<'de> for ConnectInitial {
953    //     const MCS_NAME: &'static str = "DisconnectProviderUltimatum";
954
955    //     fn mcs_body_encode(&self, dst: &mut WriteCursor<'_>) -> Result<()> {
956    //         todo!()
957    //     }
958
959    //     fn mcs_body_decode(src: &mut ReadCursor<'de>, tpdu_user_data_size: usize) -> Result<Self> {
960    //         todo!()
961    //     }
962
963    //     fn mcs_size(&self) -> usize {
964    //         todo!()
965    //     }
966    // }
967
968    const MCS_TYPE_CONNECT_INITIAL: u8 = 0x65;
969    const MCS_TYPE_CONNECT_RESPONSE: u8 = 0x66;
970
971    impl ConnectInitial {
972        const NAME: &'static str = "ConnectInitial";
973
974        fn fields_buffer_ber_length(&self) -> usize {
975            ber::sizeof_octet_string(self.calling_domain_selector.len() as u16)
976                + ber::sizeof_octet_string(self.called_domain_selector.len() as u16)
977                + ber::SIZEOF_BOOL
978                + (self.target_parameters.size() + self.min_parameters.size() + self.max_parameters.size())
979                + ber::sizeof_octet_string(self.conference_create_request.size() as u16)
980        }
981    }
982
983    impl Encode for ConnectInitial {
984        fn encode(&self, dst: &mut WriteCursor<'_>) -> EncodeResult<()> {
985            ensure_size!(in: dst, size: self.size());
986
987            ber::write_application_tag(dst, MCS_TYPE_CONNECT_INITIAL, self.fields_buffer_ber_length() as u16)?;
988            ber::write_octet_string(dst, self.calling_domain_selector.as_ref())?;
989            ber::write_octet_string(dst, self.called_domain_selector.as_ref())?;
990            ber::write_bool(dst, self.upward_flag)?;
991            self.target_parameters.encode(dst)?;
992            self.min_parameters.encode(dst)?;
993            self.max_parameters.encode(dst)?;
994            ber::write_octet_string_tag(dst, cast_length!("len", self.conference_create_request.size())?)?;
995            self.conference_create_request.encode(dst)?;
996
997            Ok(())
998        }
999
1000        fn name(&self) -> &'static str {
1001            Self::NAME
1002        }
1003
1004        fn size(&self) -> usize {
1005            let fields_buffer_ber_length = self.fields_buffer_ber_length();
1006
1007            fields_buffer_ber_length
1008                + ber::sizeof_application_tag(MCS_TYPE_CONNECT_INITIAL, fields_buffer_ber_length as u16)
1009        }
1010    }
1011
1012    impl<'de> Decode<'de> for ConnectInitial {
1013        fn decode(src: &mut ReadCursor<'de>) -> DecodeResult<Self> {
1014            ber::read_application_tag(src, MCS_TYPE_CONNECT_INITIAL)?;
1015            let calling_domain_selector = ber::read_octet_string(src)?;
1016            let called_domain_selector = ber::read_octet_string(src)?;
1017            let upward_flag = ber::read_bool(src)?;
1018            let target_parameters = DomainParameters::decode(src)?;
1019            let min_parameters = DomainParameters::decode(src)?;
1020            let max_parameters = DomainParameters::decode(src)?;
1021            let _user_data_buffer_length = ber::read_octet_string_tag(src)?;
1022            let conference_create_request = ConferenceCreateRequest::decode(src)?;
1023
1024            Ok(Self {
1025                conference_create_request,
1026                calling_domain_selector,
1027                called_domain_selector,
1028                upward_flag,
1029                target_parameters,
1030                min_parameters,
1031                max_parameters,
1032            })
1033        }
1034    }
1035
1036    impl ConnectResponse {
1037        const NAME: &'static str = "ConnectResponse";
1038
1039        fn fields_buffer_ber_length(&self) -> usize {
1040            ber::SIZEOF_ENUMERATED
1041                + ber::sizeof_integer(self.called_connect_id)
1042                + self.domain_parameters.size()
1043                + ber::sizeof_octet_string(self.conference_create_response.size() as u16)
1044        }
1045    }
1046
1047    impl Encode for ConnectResponse {
1048        fn encode(&self, dst: &mut WriteCursor<'_>) -> EncodeResult<()> {
1049            ensure_size!(in: dst, size: self.size());
1050
1051            ber::write_application_tag(dst, MCS_TYPE_CONNECT_RESPONSE, self.fields_buffer_ber_length() as u16)?;
1052            ber::write_enumerated(dst, 0)?;
1053            ber::write_integer(dst, self.called_connect_id)?;
1054            self.domain_parameters.encode(dst)?;
1055            ber::write_octet_string_tag(dst, cast_length!("len", self.conference_create_response.size())?)?;
1056            self.conference_create_response.encode(dst)?;
1057
1058            Ok(())
1059        }
1060
1061        fn name(&self) -> &'static str {
1062            Self::NAME
1063        }
1064
1065        fn size(&self) -> usize {
1066            let fields_buffer_ber_length = self.fields_buffer_ber_length();
1067
1068            fields_buffer_ber_length
1069                + ber::sizeof_application_tag(MCS_TYPE_CONNECT_RESPONSE, fields_buffer_ber_length as u16)
1070        }
1071    }
1072
1073    impl<'de> Decode<'de> for ConnectResponse {
1074        fn decode(src: &mut ReadCursor<'de>) -> DecodeResult<Self> {
1075            ber::read_application_tag(src, MCS_TYPE_CONNECT_RESPONSE)?;
1076            ber::read_enumerated(src, RESULT_ENUM_LENGTH)?;
1077            let called_connect_id = ber::read_integer(src)? as u32;
1078            let domain_parameters = DomainParameters::decode(src)?;
1079            let _user_data_buffer_length = ber::read_octet_string_tag(src)?;
1080            let conference_create_response = ConferenceCreateResponse::decode(src)?;
1081
1082            Ok(Self {
1083                called_connect_id,
1084                domain_parameters,
1085                conference_create_response,
1086            })
1087        }
1088    }
1089
1090    impl DomainParameters {
1091        const NAME: &'static str = "DomainParameters";
1092
1093        fn fields_buffer_ber_length(&self) -> usize {
1094            ber::sizeof_integer(self.max_channel_ids)
1095                + ber::sizeof_integer(self.max_user_ids)
1096                + ber::sizeof_integer(self.max_token_ids)
1097                + ber::sizeof_integer(self.num_priorities)
1098                + ber::sizeof_integer(self.min_throughput)
1099                + ber::sizeof_integer(self.max_height)
1100                + ber::sizeof_integer(self.max_mcs_pdu_size)
1101                + ber::sizeof_integer(self.protocol_version)
1102        }
1103    }
1104
1105    impl Encode for DomainParameters {
1106        fn encode(&self, dst: &mut WriteCursor<'_>) -> EncodeResult<()> {
1107            ensure_size!(in: dst, size: self.size());
1108
1109            ber::write_sequence_tag(dst, cast_length!("seqTagLen", self.fields_buffer_ber_length())?)?;
1110            ber::write_integer(dst, self.max_channel_ids)?;
1111            ber::write_integer(dst, self.max_user_ids)?;
1112            ber::write_integer(dst, self.max_token_ids)?;
1113            ber::write_integer(dst, self.num_priorities)?;
1114            ber::write_integer(dst, self.min_throughput)?;
1115            ber::write_integer(dst, self.max_height)?;
1116            ber::write_integer(dst, self.max_mcs_pdu_size)?;
1117            ber::write_integer(dst, self.protocol_version)?;
1118
1119            Ok(())
1120        }
1121
1122        fn name(&self) -> &'static str {
1123            Self::NAME
1124        }
1125
1126        fn size(&self) -> usize {
1127            let fields_buffer_ber_length = self.fields_buffer_ber_length();
1128
1129            // FIXME: maybe size should return PduResult...
1130            fields_buffer_ber_length + ber::sizeof_sequence_tag(fields_buffer_ber_length as u16)
1131        }
1132    }
1133
1134    impl<'de> Decode<'de> for DomainParameters {
1135        fn decode(src: &mut ReadCursor<'de>) -> DecodeResult<Self> {
1136            ber::read_sequence_tag(src)?;
1137            let max_channel_ids = ber::read_integer(src)? as u32;
1138            let max_user_ids = ber::read_integer(src)? as u32;
1139            let max_token_ids = ber::read_integer(src)? as u32;
1140            let num_priorities = ber::read_integer(src)? as u32;
1141            let min_throughput = ber::read_integer(src)? as u32;
1142            let max_height = ber::read_integer(src)? as u32;
1143            let max_mcs_pdu_size = ber::read_integer(src)? as u32;
1144            let protocol_version = ber::read_integer(src)? as u32;
1145
1146            Ok(Self {
1147                max_channel_ids,
1148                max_user_ids,
1149                max_token_ids,
1150                num_priorities,
1151                min_throughput,
1152                max_height,
1153                max_mcs_pdu_size,
1154                protocol_version,
1155            })
1156        }
1157    }
1158
1159    #[derive(Debug, Error)]
1160    pub enum McsError {
1161        #[error("IO error")]
1162        IOError(#[from] io::Error),
1163        #[error("GCC block error")]
1164        GccError(#[from] GccError),
1165        #[error("invalid disconnect provider ultimatum")]
1166        InvalidDisconnectProviderUltimatum,
1167        #[error("invalid domain MCS PDU")]
1168        InvalidDomainMcsPdu,
1169        #[error("invalid MCS Connection Sequence PDU")]
1170        InvalidPdu(String),
1171        #[error("invalid invalid MCS channel id")]
1172        UnexpectedChannelId(String),
1173        #[error("PDU error: {0}")]
1174        Pdu(PduError),
1175    }
1176
1177    impl From<PduError> for McsError {
1178        fn from(e: PduError) -> Self {
1179            Self::Pdu(e)
1180        }
1181    }
1182
1183    impl From<McsError> for io::Error {
1184        fn from(e: McsError) -> io::Error {
1185            io::Error::other(format!("MCS Connection Sequence error: {e}"))
1186        }
1187    }
1188}