ironrdp_acceptor/
channel_connection.rs1use std::collections::HashSet;
2
3use ironrdp_connector::{
4 reason_err, ConnectorError, ConnectorErrorExt as _, ConnectorResult, Sequence, State, Written,
5};
6use ironrdp_core::WriteBuf;
7use ironrdp_pdu::mcs;
8use ironrdp_pdu::x224::X224;
9use tracing::debug;
10
11#[derive(Debug)]
12pub struct ChannelConnectionSequence {
13 state: ChannelConnectionState,
14 user_channel_id: u16,
15 channel_ids: Option<HashSet<u16>>,
16}
17
18#[derive(Default, Debug)]
19pub enum ChannelConnectionState {
20 #[default]
21 Consumed,
22
23 WaitErectDomainRequest,
24 WaitAttachUserRequest,
25 SendAttachUserConfirm,
26 WaitChannelJoinRequest {
27 remaining: HashSet<u16>,
28 },
29 SendChannelJoinConfirm {
30 remaining: HashSet<u16>,
31 channel_id: u16,
32 },
33 AllJoined,
34}
35
36impl State for ChannelConnectionState {
37 fn name(&self) -> &'static str {
38 match self {
39 Self::Consumed => "Consumed",
40 Self::WaitErectDomainRequest => "WaitErectDomainRequest",
41 Self::WaitAttachUserRequest => "WaitAttachUserRequest",
42 Self::SendAttachUserConfirm => "SendAttachUserConfirm",
43 Self::WaitChannelJoinRequest { .. } => "WaitChannelJoinRequest",
44 Self::SendChannelJoinConfirm { .. } => "SendChannelJoinConfirm",
45 Self::AllJoined { .. } => "AllJoined",
46 }
47 }
48
49 fn is_terminal(&self) -> bool {
50 matches!(self, Self::AllJoined { .. })
51 }
52
53 fn as_any(&self) -> &dyn core::any::Any {
54 self
55 }
56}
57
58impl Sequence for ChannelConnectionSequence {
59 fn next_pdu_hint(&self) -> Option<&dyn ironrdp_pdu::PduHint> {
60 match &self.state {
61 ChannelConnectionState::Consumed => None,
62 ChannelConnectionState::WaitErectDomainRequest => Some(&ironrdp_pdu::X224_HINT),
63 ChannelConnectionState::WaitAttachUserRequest => Some(&ironrdp_pdu::X224_HINT),
64 ChannelConnectionState::SendAttachUserConfirm => None,
65 ChannelConnectionState::WaitChannelJoinRequest { .. } => Some(&ironrdp_pdu::X224_HINT),
66 ChannelConnectionState::SendChannelJoinConfirm { .. } => None,
67 ChannelConnectionState::AllJoined => None,
68 }
69 }
70
71 fn state(&self) -> &dyn State {
72 &self.state
73 }
74
75 fn step(&mut self, input: &[u8], output: &mut WriteBuf) -> ConnectorResult<Written> {
76 let (written, next_state) = match core::mem::take(&mut self.state) {
77 ChannelConnectionState::WaitErectDomainRequest => {
78 let erect_domain_request = ironrdp_core::decode::<X224<mcs::ErectDomainPdu>>(input)
79 .map_err(ConnectorError::decode)
80 .map(|p| p.0)?;
81
82 debug!(message = ?erect_domain_request, "Received");
83
84 (Written::Nothing, ChannelConnectionState::WaitAttachUserRequest)
85 }
86
87 ChannelConnectionState::WaitAttachUserRequest => {
88 let attach_user_request = ironrdp_core::decode::<X224<mcs::AttachUserRequest>>(input)
89 .map_err(ConnectorError::decode)
90 .map(|p| p.0)?;
91
92 debug!(message = ?attach_user_request, "Received");
93
94 (Written::Nothing, ChannelConnectionState::SendAttachUserConfirm)
95 }
96
97 ChannelConnectionState::SendAttachUserConfirm => {
98 let attach_user_confirm = mcs::AttachUserConfirm {
99 result: 0,
100 initiator_id: self.user_channel_id,
101 };
102
103 debug!(message = ?attach_user_confirm, "Send");
104
105 let written =
106 ironrdp_core::encode_buf(&X224(attach_user_confirm), output).map_err(ConnectorError::encode)?;
107
108 let next_state = match self.channel_ids.take() {
109 Some(channel_ids) => ChannelConnectionState::WaitChannelJoinRequest { remaining: channel_ids },
110 None => ChannelConnectionState::AllJoined,
111 };
112
113 (Written::from_size(written)?, next_state)
114 }
115
116 ChannelConnectionState::WaitChannelJoinRequest { mut remaining } => {
117 let channel_request = ironrdp_core::decode::<X224<mcs::ChannelJoinRequest>>(input)
118 .map_err(ConnectorError::decode)
119 .map(|p| p.0)?;
120
121 debug!(message = ?channel_request, "Received");
122
123 let is_expected = remaining.remove(&channel_request.channel_id);
124
125 if !is_expected {
126 return Err(reason_err!(
127 "ChannelJoinConfirm",
128 "unexpected channel_id in MCS Channel Join Request: got {}, expected one of: {:?}",
129 channel_request.channel_id,
130 remaining,
131 ));
132 }
133
134 (
135 Written::Nothing,
136 ChannelConnectionState::SendChannelJoinConfirm {
137 remaining,
138 channel_id: channel_request.channel_id,
139 },
140 )
141 }
142
143 ChannelConnectionState::SendChannelJoinConfirm { remaining, channel_id } => {
144 let channel_confirm = mcs::ChannelJoinConfirm {
145 result: 0,
146 initiator_id: self.user_channel_id,
147 requested_channel_id: channel_id,
148 channel_id,
149 };
150
151 debug!(message = ?channel_confirm, "Send");
152
153 let written =
154 ironrdp_core::encode_buf(&X224(channel_confirm), output).map_err(ConnectorError::encode)?;
155
156 let next_state = if remaining.is_empty() {
157 ChannelConnectionState::AllJoined
158 } else {
159 ChannelConnectionState::WaitChannelJoinRequest { remaining }
160 };
161
162 (Written::from_size(written)?, next_state)
163 }
164
165 _ => unreachable!(),
166 };
167
168 self.state = next_state;
169 Ok(written)
170 }
171}
172
173impl ChannelConnectionSequence {
174 pub fn new(user_channel_id: u16, io_channel_id: u16, other_channels: Vec<u16>) -> Self {
175 Self {
176 state: ChannelConnectionState::WaitErectDomainRequest,
177 user_channel_id,
178 channel_ids: Some(
179 vec![user_channel_id, io_channel_id]
180 .into_iter()
181 .chain(other_channels)
182 .collect(),
183 ),
184 }
185 }
186
187 pub fn skip_channel_join(user_channel_id: u16) -> Self {
188 Self {
189 state: ChannelConnectionState::WaitErectDomainRequest,
190 user_channel_id,
191 channel_ids: None,
192 }
193 }
194
195 pub fn is_done(&self) -> bool {
196 self.state.is_terminal()
197 }
198}