1#![cfg_attr(doc, doc = include_str!("../README.md"))]
2#![doc(html_logo_url = "https://cdnweb.devolutions.net/images/projects/devolutions/logos/devolutions-icon-shadow.svg")]
3
4extern crate alloc;
7
8use alloc::collections::BTreeMap;
9use core::any::TypeId;
10use core::fmt;
11use core::marker::PhantomData;
12use std::borrow::Cow;
13
14use bitflags::bitflags;
15use ironrdp_core::{
16 assert_obj_safe, decode_cursor, encode_buf, AsAny, DecodeResult, Encode, EncodeResult, ReadCursor, WriteBuf,
17 WriteCursor,
18};
19use ironrdp_pdu::gcc::{ChannelDef, ChannelName, ChannelOptions};
20use ironrdp_pdu::rdp::vc::ChannelControlFlags;
21use ironrdp_pdu::x224::X224;
22use ironrdp_pdu::{decode_err, mcs, PduResult};
23
24#[rustfmt::skip] pub use ironrdp_pdu as pdu;
27
28#[doc(hidden)]
30#[deprecated(since = "0.1.0", note = "use ironrdp-core")]
31#[rustfmt::skip] pub use ironrdp_core::impl_as_any;
33
34pub type StaticChannelId = u16;
39
40pub struct SvcProcessorMessages<P: SvcProcessor> {
43 messages: Vec<SvcMessage>,
44 _channel: PhantomData<P>,
45}
46
47impl<P: SvcProcessor> SvcProcessorMessages<P> {
48 pub fn new(messages: Vec<SvcMessage>) -> Self {
49 Self {
50 messages,
51 _channel: PhantomData,
52 }
53 }
54}
55
56impl<P: SvcProcessor> From<Vec<SvcMessage>> for SvcProcessorMessages<P> {
57 fn from(messages: Vec<SvcMessage>) -> Self {
58 Self::new(messages)
59 }
60}
61
62impl<P: SvcProcessor> From<SvcProcessorMessages<P>> for Vec<SvcMessage> {
63 fn from(request: SvcProcessorMessages<P>) -> Self {
64 request.messages
65 }
66}
67
68pub trait SvcEncode: Encode + Send {}
72
73impl SvcEncode for Vec<u8> {}
76
77pub struct SvcMessage {
81 pdu: Box<dyn SvcEncode>,
82 flags: ChannelFlags,
83}
84
85impl fmt::Debug for SvcMessage {
86 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
87 f.debug_struct("SvcMessage")
88 .field("pdu", &self.pdu.name())
89 .field("flags", &self.flags)
90 .finish()
91 }
92}
93
94impl SvcMessage {
95 #[must_use]
97 pub fn with_flags(mut self, flags: ChannelFlags) -> Self {
98 self.flags |= flags;
99 self
100 }
101}
102
103impl<T> From<T> for SvcMessage
104where
105 T: SvcEncode + 'static,
106{
107 fn from(pdu: T) -> Self {
108 Self {
109 pdu: Box::new(pdu),
110 flags: ChannelFlags::empty(),
111 }
112 }
113}
114
115#[derive(Debug, Clone, Copy, PartialEq, Eq)]
117pub enum CompressionCondition {
118 Never,
120 WhenRdpDataIsCompressed,
122 Always,
124}
125
126#[derive(Debug)]
128pub struct StaticVirtualChannel {
129 channel_processor: Box<dyn SvcProcessor>,
130 chunk_processor: ChunkProcessor,
131}
132
133impl StaticVirtualChannel {
134 pub fn new<T: SvcProcessor + 'static>(channel_processor: T) -> Self {
135 Self {
136 channel_processor: Box::new(channel_processor),
137 chunk_processor: ChunkProcessor::new(),
138 }
139 }
140
141 pub fn channel_name(&self) -> ChannelName {
142 self.channel_processor.channel_name()
143 }
144
145 pub fn compression_condition(&self) -> CompressionCondition {
146 self.channel_processor.compression_condition()
147 }
148
149 pub fn start(&mut self) -> PduResult<Vec<SvcMessage>> {
150 self.channel_processor.start()
151 }
152
153 pub fn process(&mut self, payload: &[u8]) -> PduResult<Vec<SvcMessage>> {
156 if let Some(payload) = self.dechunkify(payload).map_err(|e| decode_err!(e))? {
157 return self.channel_processor.process(&payload);
158 }
159
160 Ok(Vec::new())
161 }
162
163 pub fn chunkify(messages: Vec<SvcMessage>) -> EncodeResult<Vec<WriteBuf>> {
164 ChunkProcessor::chunkify(messages, CHANNEL_CHUNK_LENGTH)
165 }
166
167 pub fn channel_processor_downcast_ref<T: SvcProcessor + 'static>(&self) -> Option<&T> {
168 self.channel_processor.as_any().downcast_ref()
169 }
170
171 pub fn channel_processor_downcast_mut<T: SvcProcessor + 'static>(&mut self) -> Option<&mut T> {
172 self.channel_processor.as_any_mut().downcast_mut()
173 }
174
175 fn dechunkify(&mut self, payload: &[u8]) -> DecodeResult<Option<Vec<u8>>> {
176 self.chunk_processor.dechunkify(payload)
177 }
178}
179
180fn encode_svc_messages(
181 messages: Vec<SvcMessage>,
182 channel_id: u16,
183 initiator_id: u16,
184 client: bool,
185) -> EncodeResult<Vec<u8>> {
186 let mut fully_encoded_responses = WriteBuf::new(); let chunks = StaticVirtualChannel::chunkify(messages)?;
190
191 if client {
200 for chunk in chunks {
201 let pdu = mcs::SendDataRequest {
202 initiator_id,
203 channel_id,
204 user_data: Cow::Borrowed(chunk.filled()),
205 };
206 encode_buf(&X224(pdu), &mut fully_encoded_responses)?;
207 }
208 } else {
209 for chunk in chunks {
210 let pdu = mcs::SendDataIndication {
211 initiator_id,
212 channel_id,
213 user_data: Cow::Borrowed(chunk.filled()),
214 };
215 encode_buf(&X224(pdu), &mut fully_encoded_responses)?;
216 }
217 }
218
219 Ok(fully_encoded_responses.into_inner())
220}
221
222pub fn client_encode_svc_messages(
229 messages: Vec<SvcMessage>,
230 channel_id: u16,
231 initiator_id: u16,
232) -> EncodeResult<Vec<u8>> {
233 encode_svc_messages(messages, channel_id, initiator_id, true)
234}
235
236pub fn server_encode_svc_messages(
243 messages: Vec<SvcMessage>,
244 channel_id: u16,
245 initiator_id: u16,
246) -> EncodeResult<Vec<u8>> {
247 encode_svc_messages(messages, channel_id, initiator_id, false)
248}
249
250pub trait SvcProcessor: AsAny + fmt::Debug + Send {
257 fn channel_name(&self) -> ChannelName;
259
260 fn compression_condition(&self) -> CompressionCondition {
262 CompressionCondition::Never
263 }
264
265 fn start(&mut self) -> PduResult<Vec<SvcMessage>> {
269 Ok(Vec::new())
270 }
271
272 fn process(&mut self, payload: &[u8]) -> PduResult<Vec<SvcMessage>>;
277}
278
279assert_obj_safe!(SvcProcessor);
280
281pub trait SvcClientProcessor: SvcProcessor {}
282
283assert_obj_safe!(SvcClientProcessor);
284
285pub trait SvcServerProcessor: SvcProcessor {}
286
287assert_obj_safe!(SvcServerProcessor);
288
289#[derive(Debug)]
291struct ChunkProcessor {
292 chunked_pdu: Vec<u8>,
295}
296
297impl ChunkProcessor {
298 fn new() -> Self {
299 Self {
300 chunked_pdu: Vec::new(),
301 }
302 }
303
304 fn chunkify(messages: Vec<SvcMessage>, max_chunk_len: usize) -> EncodeResult<Vec<WriteBuf>> {
308 let mut results = Vec::new();
309 for message in messages {
310 results.extend(Self::chunkify_one(message, max_chunk_len)?);
311 }
312 Ok(results)
313 }
314
315 fn dechunkify(&mut self, payload: &[u8]) -> DecodeResult<Option<Vec<u8>>> {
321 let mut cursor = ReadCursor::new(payload);
322 let last = Self::process_header(&mut cursor)?;
323
324 self.chunked_pdu.extend_from_slice(cursor.remaining());
326
327 if last {
329 return Ok(Some(core::mem::take(&mut self.chunked_pdu)));
331 }
332
333 Ok(None)
335 }
336
337 fn process_header(payload: &mut ReadCursor<'_>) -> DecodeResult<bool> {
339 let channel_header: ironrdp_pdu::rdp::vc::ChannelPduHeader = decode_cursor(payload)?;
340
341 Ok(channel_header.flags.contains(ChannelControlFlags::FLAG_LAST))
342 }
343
344 fn chunkify_one(message: SvcMessage, max_chunk_len: usize) -> EncodeResult<Vec<WriteBuf>> {
353 let mut encoded_pdu = WriteBuf::new(); encode_buf(message.pdu.as_ref(), &mut encoded_pdu)?;
355
356 let mut chunks = Vec::new();
357
358 let total_len = encoded_pdu.filled_len();
359 let mut chunk_start_index: usize = 0;
360 let mut chunk_end_index = core::cmp::min(total_len, max_chunk_len);
361 loop {
362 let mut chunk = WriteBuf::new();
367
368 let first = chunk_start_index == 0;
370 let last = chunk_end_index == total_len;
371
372 let header = {
374 let mut flags = ChannelFlags::empty();
375 if first {
376 flags |= ChannelFlags::FIRST;
377 }
378 if last {
379 flags |= ChannelFlags::LAST;
380 }
381
382 flags |= message.flags;
383
384 ChannelPduHeader {
385 length: ironrdp_core::cast_int!(ChannelPduHeader::NAME, "length", total_len)?,
386 flags,
387 }
388 };
389
390 encode_buf(&header, &mut chunk)?;
392 chunk.write_slice(&encoded_pdu[chunk_start_index..chunk_end_index]);
394 chunks.push(chunk);
396
397 if last {
399 break;
400 }
401
402 chunk_start_index = chunk_end_index;
404 chunk_end_index = core::cmp::min(total_len, chunk_end_index.saturating_add(max_chunk_len));
405 }
406
407 Ok(chunks)
408 }
409}
410
411impl Default for ChunkProcessor {
412 fn default() -> Self {
413 Self::new()
414 }
415}
416
417pub fn make_channel_options(channel: &StaticVirtualChannel) -> ChannelOptions {
419 match channel.compression_condition() {
420 CompressionCondition::Never => ChannelOptions::empty(),
421 CompressionCondition::WhenRdpDataIsCompressed => ChannelOptions::COMPRESS_RDP,
422 CompressionCondition::Always => ChannelOptions::COMPRESS,
423 }
424}
425
426pub fn make_channel_definition(channel: &StaticVirtualChannel) -> ChannelDef {
428 let name = channel.channel_name();
429 let options = make_channel_options(channel);
430 ChannelDef { name, options }
431}
432
433#[derive(Debug)]
447pub struct StaticChannelSet {
448 channels: BTreeMap<TypeId, StaticVirtualChannel>,
449 to_channel_id: BTreeMap<TypeId, StaticChannelId>,
450 to_type_id: BTreeMap<StaticChannelId, TypeId>,
451}
452
453impl StaticChannelSet {
454 #[inline]
455 pub fn new() -> Self {
456 Self {
457 channels: BTreeMap::new(),
458 to_channel_id: BTreeMap::new(),
459 to_type_id: BTreeMap::new(),
460 }
461 }
462
463 pub fn insert<T: SvcProcessor + 'static>(&mut self, val: T) -> Option<StaticVirtualChannel> {
467 self.channels.insert(TypeId::of::<T>(), StaticVirtualChannel::new(val))
468 }
469
470 pub fn get_by_type_id(&self, type_id: TypeId) -> Option<&StaticVirtualChannel> {
472 self.channels.get(&type_id)
473 }
474
475 pub fn get_by_type_id_mut(&mut self, type_id: TypeId) -> Option<&mut StaticVirtualChannel> {
477 self.channels.get_mut(&type_id)
478 }
479
480 pub fn get_by_type<T: SvcProcessor + 'static>(&self) -> Option<&StaticVirtualChannel> {
482 self.get_by_type_id(TypeId::of::<T>())
483 }
484
485 pub fn get_by_type_mut<T: SvcProcessor + 'static>(&mut self) -> Option<&mut StaticVirtualChannel> {
487 self.get_by_type_id_mut(TypeId::of::<T>())
488 }
489
490 pub fn get_by_channel_name(&self, name: &ChannelName) -> Option<(TypeId, &StaticVirtualChannel)> {
492 self.iter().find(|(_, x)| x.channel_processor.channel_name() == *name)
493 }
494
495 pub fn get_by_channel_id(&self, channel_id: StaticChannelId) -> Option<&StaticVirtualChannel> {
497 self.get_type_id_by_channel_id(channel_id)
498 .and_then(|type_id| self.get_by_type_id(type_id))
499 }
500
501 pub fn get_by_channel_id_mut(&mut self, channel_id: StaticChannelId) -> Option<&mut StaticVirtualChannel> {
503 self.get_type_id_by_channel_id(channel_id)
504 .and_then(|type_id| self.get_by_type_id_mut(type_id))
505 }
506
507 pub fn remove_by_type_id(&mut self, type_id: TypeId) -> Option<StaticVirtualChannel> {
511 let svc = self.channels.remove(&type_id);
512 if let Some(channel_id) = self.to_channel_id.remove(&type_id) {
513 self.to_type_id.remove(&channel_id);
514 }
515 svc
516 }
517
518 pub fn remove_by_type<T: SvcProcessor + 'static>(&mut self) -> Option<StaticVirtualChannel> {
522 let type_id = TypeId::of::<T>();
523 self.remove_by_type_id(type_id)
524 }
525
526 pub fn attach_channel_id(&mut self, type_id: TypeId, channel_id: StaticChannelId) -> Option<StaticChannelId> {
530 self.to_type_id.insert(channel_id, type_id);
531 self.to_channel_id.insert(type_id, channel_id)
532 }
533
534 pub fn get_channel_id_by_type_id(&self, type_id: TypeId) -> Option<StaticChannelId> {
536 self.to_channel_id.get(&type_id).copied()
537 }
538
539 pub fn get_channel_id_by_type<T: SvcProcessor + 'static>(&self) -> Option<StaticChannelId> {
541 self.get_channel_id_by_type_id(TypeId::of::<T>())
542 }
543
544 pub fn get_type_id_by_channel_id(&self, channel_id: StaticChannelId) -> Option<TypeId> {
546 self.to_type_id.get(&channel_id).copied()
547 }
548
549 pub fn detach_channel_id(&mut self, type_id: TypeId) -> Option<StaticChannelId> {
551 if let Some(channel_id) = self.to_channel_id.remove(&type_id) {
552 self.to_type_id.remove(&channel_id);
553 Some(channel_id)
554 } else {
555 None
556 }
557 }
558
559 #[inline]
560 pub fn iter(&self) -> impl Iterator<Item = (TypeId, &StaticVirtualChannel)> {
561 self.channels.iter().map(|(type_id, svc)| (*type_id, svc))
562 }
563
564 #[inline]
565 pub fn iter_mut(&mut self) -> impl Iterator<Item = (TypeId, &mut StaticVirtualChannel, Option<StaticChannelId>)> {
566 let to_channel_id = self.to_channel_id.clone();
567 self.channels
568 .iter_mut()
569 .map(move |(type_id, svc)| (*type_id, svc, to_channel_id.get(type_id).copied()))
570 }
571
572 #[inline]
573 pub fn values(&self) -> impl Iterator<Item = &StaticVirtualChannel> {
574 self.channels.values()
575 }
576
577 #[inline]
578 pub fn type_ids(&self) -> impl Iterator<Item = TypeId> + '_ {
579 self.channels.keys().copied()
580 }
581
582 #[inline]
583 pub fn channel_ids(&self) -> impl Iterator<Item = StaticChannelId> + '_ {
584 self.to_channel_id.values().copied()
585 }
586
587 #[inline]
588 pub fn clear(&mut self) {
589 self.channels.clear();
590 self.to_channel_id.clear();
591 self.to_type_id.clear();
592 }
593}
594
595impl Default for StaticChannelSet {
596 fn default() -> Self {
597 Self::new()
598 }
599}
600
601pub const CHANNEL_CHUNK_LENGTH: usize = 1600;
611
612bitflags! {
613 #[derive(Debug, PartialEq, Copy, Clone)]
617 pub struct ChannelFlags: u32 {
618 const FIRST = 0x0000_0001;
620 const LAST = 0x0000_0002;
622 const SHOW_PROTOCOL = 0x0000_0010;
624 const SUSPEND = 0x0000_0020;
626 const RESUME = 0x0000_0040;
628 const SHADOW_PERSISTENT = 0x0000_0080;
630 const COMPRESSED = 0x0020_0000;
632 const AT_FRONT = 0x0040_0000;
634 const FLUSHED = 0x0080_0000;
636 }
637}
638
639#[derive(Debug)]
648struct ChannelPduHeader {
649 length: u32,
655 flags: ChannelFlags,
656}
657
658impl ChannelPduHeader {
659 const NAME: &'static str = "CHANNEL_PDU_HEADER";
660
661 const FIXED_PART_SIZE: usize = 4 + 4 ;
662}
663
664impl Encode for ChannelPduHeader {
665 fn encode(&self, dst: &mut WriteCursor<'_>) -> EncodeResult<()> {
666 dst.write_u32(self.length);
667 dst.write_u32(self.flags.bits());
668 Ok(())
669 }
670
671 fn name(&self) -> &'static str {
672 Self::NAME
673 }
674
675 fn size(&self) -> usize {
676 Self::FIXED_PART_SIZE
677 }
678}