ironrdp_acceptor/
channel_connection.rs

1use std::collections::HashSet;
2
3use ironrdp_connector::{
4    reason_err, ConnectorError, ConnectorErrorExt as _, ConnectorResult, Sequence, State, Written,
5};
6use ironrdp_core::WriteBuf;
7use ironrdp_pdu::mcs;
8use ironrdp_pdu::x224::X224;
9use tracing::debug;
10
11#[derive(Debug)]
12pub struct ChannelConnectionSequence {
13    state: ChannelConnectionState,
14    user_channel_id: u16,
15    channel_ids: Option<HashSet<u16>>,
16}
17
18#[derive(Default, Debug)]
19pub enum ChannelConnectionState {
20    #[default]
21    Consumed,
22
23    WaitErectDomainRequest,
24    WaitAttachUserRequest,
25    SendAttachUserConfirm,
26    WaitChannelJoinRequest {
27        remaining: HashSet<u16>,
28    },
29    SendChannelJoinConfirm {
30        remaining: HashSet<u16>,
31        channel_id: u16,
32    },
33    AllJoined,
34}
35
36impl State for ChannelConnectionState {
37    fn name(&self) -> &'static str {
38        match self {
39            Self::Consumed => "Consumed",
40            Self::WaitErectDomainRequest => "WaitErectDomainRequest",
41            Self::WaitAttachUserRequest => "WaitAttachUserRequest",
42            Self::SendAttachUserConfirm => "SendAttachUserConfirm",
43            Self::WaitChannelJoinRequest { .. } => "WaitChannelJoinRequest",
44            Self::SendChannelJoinConfirm { .. } => "SendChannelJoinConfirm",
45            Self::AllJoined { .. } => "AllJoined",
46        }
47    }
48
49    fn is_terminal(&self) -> bool {
50        matches!(self, Self::AllJoined { .. })
51    }
52
53    fn as_any(&self) -> &dyn core::any::Any {
54        self
55    }
56}
57
58impl Sequence for ChannelConnectionSequence {
59    fn next_pdu_hint(&self) -> Option<&dyn ironrdp_pdu::PduHint> {
60        match &self.state {
61            ChannelConnectionState::Consumed => None,
62            ChannelConnectionState::WaitErectDomainRequest => Some(&ironrdp_pdu::X224_HINT),
63            ChannelConnectionState::WaitAttachUserRequest => Some(&ironrdp_pdu::X224_HINT),
64            ChannelConnectionState::SendAttachUserConfirm => None,
65            ChannelConnectionState::WaitChannelJoinRequest { .. } => Some(&ironrdp_pdu::X224_HINT),
66            ChannelConnectionState::SendChannelJoinConfirm { .. } => None,
67            ChannelConnectionState::AllJoined => None,
68        }
69    }
70
71    fn state(&self) -> &dyn State {
72        &self.state
73    }
74
75    fn step(&mut self, input: &[u8], output: &mut WriteBuf) -> ConnectorResult<Written> {
76        let (written, next_state) = match core::mem::take(&mut self.state) {
77            ChannelConnectionState::WaitErectDomainRequest => {
78                let erect_domain_request = ironrdp_core::decode::<X224<mcs::ErectDomainPdu>>(input)
79                    .map_err(ConnectorError::decode)
80                    .map(|p| p.0)?;
81
82                debug!(message = ?erect_domain_request, "Received");
83
84                (Written::Nothing, ChannelConnectionState::WaitAttachUserRequest)
85            }
86
87            ChannelConnectionState::WaitAttachUserRequest => {
88                let attach_user_request = ironrdp_core::decode::<X224<mcs::AttachUserRequest>>(input)
89                    .map_err(ConnectorError::decode)
90                    .map(|p| p.0)?;
91
92                debug!(message = ?attach_user_request, "Received");
93
94                (Written::Nothing, ChannelConnectionState::SendAttachUserConfirm)
95            }
96
97            ChannelConnectionState::SendAttachUserConfirm => {
98                let attach_user_confirm = mcs::AttachUserConfirm {
99                    result: 0,
100                    initiator_id: self.user_channel_id,
101                };
102
103                debug!(message = ?attach_user_confirm, "Send");
104
105                let written =
106                    ironrdp_core::encode_buf(&X224(attach_user_confirm), output).map_err(ConnectorError::encode)?;
107
108                let next_state = match self.channel_ids.take() {
109                    Some(channel_ids) => ChannelConnectionState::WaitChannelJoinRequest { remaining: channel_ids },
110                    None => ChannelConnectionState::AllJoined,
111                };
112
113                (Written::from_size(written)?, next_state)
114            }
115
116            ChannelConnectionState::WaitChannelJoinRequest { mut remaining } => {
117                let channel_request = ironrdp_core::decode::<X224<mcs::ChannelJoinRequest>>(input)
118                    .map_err(ConnectorError::decode)
119                    .map(|p| p.0)?;
120
121                debug!(message = ?channel_request, "Received");
122
123                let is_expected = remaining.remove(&channel_request.channel_id);
124
125                if !is_expected {
126                    return Err(reason_err!(
127                        "ChannelJoinConfirm",
128                        "unexpected channel_id in MCS Channel Join Request: got {}, expected one of: {:?}",
129                        channel_request.channel_id,
130                        remaining,
131                    ));
132                }
133
134                (
135                    Written::Nothing,
136                    ChannelConnectionState::SendChannelJoinConfirm {
137                        remaining,
138                        channel_id: channel_request.channel_id,
139                    },
140                )
141            }
142
143            ChannelConnectionState::SendChannelJoinConfirm { remaining, channel_id } => {
144                let channel_confirm = mcs::ChannelJoinConfirm {
145                    result: 0,
146                    initiator_id: self.user_channel_id,
147                    requested_channel_id: channel_id,
148                    channel_id,
149                };
150
151                debug!(message = ?channel_confirm, "Send");
152
153                let written =
154                    ironrdp_core::encode_buf(&X224(channel_confirm), output).map_err(ConnectorError::encode)?;
155
156                let next_state = if remaining.is_empty() {
157                    ChannelConnectionState::AllJoined
158                } else {
159                    ChannelConnectionState::WaitChannelJoinRequest { remaining }
160                };
161
162                (Written::from_size(written)?, next_state)
163            }
164
165            _ => unreachable!(),
166        };
167
168        self.state = next_state;
169        Ok(written)
170    }
171}
172
173impl ChannelConnectionSequence {
174    pub fn new(user_channel_id: u16, io_channel_id: u16, other_channels: Vec<u16>) -> Self {
175        Self {
176            state: ChannelConnectionState::WaitErectDomainRequest,
177            user_channel_id,
178            channel_ids: Some(
179                vec![user_channel_id, io_channel_id]
180                    .into_iter()
181                    .chain(other_channels)
182                    .collect(),
183            ),
184        }
185    }
186
187    pub fn skip_channel_join(user_channel_id: u16) -> Self {
188        Self {
189            state: ChannelConnectionState::WaitErectDomainRequest,
190            user_channel_id,
191            channel_ids: None,
192        }
193    }
194
195    pub fn is_done(&self) -> bool {
196        self.state.is_terminal()
197    }
198}