1use core::net::SocketAddr;
2use std::rc::Rc;
3use std::sync::Arc;
4
5use anyhow::{anyhow, bail, Context as _, Result};
6use ironrdp_acceptor::{Acceptor, AcceptorResult, BeginResult, DesktopSize};
7use ironrdp_async::Framed;
8use ironrdp_cliprdr::backend::ClipboardMessage;
9use ironrdp_cliprdr::CliprdrServer;
10use ironrdp_core::{decode, encode_vec, impl_as_any};
11use ironrdp_displaycontrol::pdu::DisplayControlMonitorLayout;
12use ironrdp_displaycontrol::server::{DisplayControlHandler, DisplayControlServer};
13use ironrdp_pdu::input::fast_path::{FastPathInput, FastPathInputEvent};
14use ironrdp_pdu::input::InputEventPdu;
15use ironrdp_pdu::mcs::{SendDataIndication, SendDataRequest};
16use ironrdp_pdu::rdp::capability_sets::{BitmapCodecs, CapabilitySet, CmdFlags, CodecProperty, GeneralExtraFlags};
17pub use ironrdp_pdu::rdp::client_info::Credentials;
18use ironrdp_pdu::rdp::headers::{ServerDeactivateAll, ShareControlPdu};
19use ironrdp_pdu::x224::X224;
20use ironrdp_pdu::{decode_err, mcs, nego, rdp, Action, PduResult};
21use ironrdp_svc::{server_encode_svc_messages, StaticChannelId, StaticChannelSet, SvcProcessor};
22use ironrdp_tokio::{split_tokio_framed, unsplit_tokio_framed, FramedRead, FramedWrite, TokioFramed};
23use rdpsnd::server::{RdpsndServer, RdpsndServerMessage};
24use tokio::io::{AsyncRead, AsyncWrite};
25use tokio::net::{TcpListener, TcpStream};
26use tokio::sync::{mpsc, oneshot, Mutex};
27use tokio::task;
28use tokio_rustls::TlsAcceptor;
29use tracing::{debug, error, trace, warn};
30use {ironrdp_dvc as dvc, ironrdp_rdpsnd as rdpsnd};
31
32use crate::clipboard::CliprdrServerFactory;
33use crate::display::{DisplayUpdate, RdpServerDisplay};
34use crate::encoder::{UpdateEncoder, UpdateEncoderCodecs};
35use crate::handler::RdpServerInputHandler;
36use crate::{builder, capabilities, SoundServerFactory};
37
38#[derive(Clone)]
39pub struct RdpServerOptions {
40 pub addr: SocketAddr,
41 pub security: RdpServerSecurity,
42 pub codecs: BitmapCodecs,
43}
44
45impl RdpServerOptions {
46 fn has_image_remote_fx(&self) -> bool {
47 self.codecs
48 .0
49 .iter()
50 .any(|codec| matches!(codec.property, CodecProperty::ImageRemoteFx(_)))
51 }
52
53 fn has_remote_fx(&self) -> bool {
54 self.codecs
55 .0
56 .iter()
57 .any(|codec| matches!(codec.property, CodecProperty::RemoteFx(_)))
58 }
59
60 #[cfg(feature = "qoi")]
61 fn has_qoi(&self) -> bool {
62 self.codecs
63 .0
64 .iter()
65 .any(|codec| matches!(codec.property, CodecProperty::Qoi))
66 }
67
68 #[cfg(feature = "qoiz")]
69 fn has_qoiz(&self) -> bool {
70 self.codecs
71 .0
72 .iter()
73 .any(|codec| matches!(codec.property, CodecProperty::QoiZ))
74 }
75}
76
77#[derive(Clone)]
78pub enum RdpServerSecurity {
79 None,
80 Tls(TlsAcceptor),
81 Hybrid((TlsAcceptor, Vec<u8>)),
83}
84
85impl RdpServerSecurity {
86 pub fn flag(&self) -> nego::SecurityProtocol {
87 match self {
88 RdpServerSecurity::None => nego::SecurityProtocol::empty(),
89 RdpServerSecurity::Tls(_) => nego::SecurityProtocol::SSL,
90 RdpServerSecurity::Hybrid(_) => nego::SecurityProtocol::HYBRID | nego::SecurityProtocol::HYBRID_EX,
91 }
92 }
93}
94
95struct AInputHandler {
96 handler: Arc<Mutex<Box<dyn RdpServerInputHandler>>>,
97}
98
99impl_as_any!(AInputHandler);
100
101impl dvc::DvcProcessor for AInputHandler {
102 fn channel_name(&self) -> &str {
103 ironrdp_ainput::CHANNEL_NAME
104 }
105
106 fn start(&mut self, _channel_id: u32) -> PduResult<Vec<dvc::DvcMessage>> {
107 use ironrdp_ainput::{ServerPdu, VersionPdu};
108
109 let pdu = ServerPdu::Version(VersionPdu::default());
110
111 Ok(vec![Box::new(pdu)])
112 }
113
114 fn close(&mut self, _channel_id: u32) {}
115
116 fn process(&mut self, _channel_id: u32, payload: &[u8]) -> PduResult<Vec<dvc::DvcMessage>> {
117 use ironrdp_ainput::ClientPdu;
118
119 match decode(payload).map_err(|e| decode_err!(e))? {
120 ClientPdu::Mouse(pdu) => {
121 let handler = Arc::clone(&self.handler);
122 task::spawn_blocking(move || {
123 handler.blocking_lock().mouse(pdu.into());
124 });
125 }
126 }
127
128 Ok(Vec::new())
129 }
130}
131
132impl dvc::DvcServerProcessor for AInputHandler {}
133
134struct DisplayControlBackend {
135 display: Arc<Mutex<Box<dyn RdpServerDisplay>>>,
136}
137
138impl DisplayControlBackend {
139 fn new(display: Arc<Mutex<Box<dyn RdpServerDisplay>>>) -> Self {
140 Self { display }
141 }
142}
143
144impl DisplayControlHandler for DisplayControlBackend {
145 fn monitor_layout(&self, layout: DisplayControlMonitorLayout) {
146 let display = Arc::clone(&self.display);
147 task::spawn_blocking(move || display.blocking_lock().request_layout(layout));
148 }
149}
150
151pub struct RdpServer {
213 opts: RdpServerOptions,
214 handler: Arc<Mutex<Box<dyn RdpServerInputHandler>>>,
216 display: Arc<Mutex<Box<dyn RdpServerDisplay>>>,
217 static_channels: StaticChannelSet,
218 sound_factory: Option<Box<dyn SoundServerFactory>>,
219 cliprdr_factory: Option<Box<dyn CliprdrServerFactory>>,
220 ev_sender: mpsc::UnboundedSender<ServerEvent>,
221 ev_receiver: Arc<Mutex<mpsc::UnboundedReceiver<ServerEvent>>>,
222 creds: Option<Credentials>,
223 local_addr: Option<SocketAddr>,
224}
225
226#[derive(Debug)]
227pub enum ServerEvent {
228 Quit(String),
229 Clipboard(ClipboardMessage),
230 Rdpsnd(RdpsndServerMessage),
231 SetCredentials(Credentials),
232 GetLocalAddr(oneshot::Sender<Option<SocketAddr>>),
233}
234
235pub trait ServerEventSender {
236 fn set_sender(&mut self, sender: mpsc::UnboundedSender<ServerEvent>);
237}
238
239impl ServerEvent {
240 pub fn create_channel() -> (mpsc::UnboundedSender<Self>, mpsc::UnboundedReceiver<Self>) {
241 mpsc::unbounded_channel()
242 }
243}
244
245#[derive(Debug, PartialEq)]
246enum RunState {
247 Continue,
248 Disconnect,
249 DeactivationReactivation { desktop_size: DesktopSize },
250}
251
252impl RdpServer {
253 pub fn new(
254 opts: RdpServerOptions,
255 handler: Box<dyn RdpServerInputHandler>,
256 display: Box<dyn RdpServerDisplay>,
257 mut sound_factory: Option<Box<dyn SoundServerFactory>>,
258 mut cliprdr_factory: Option<Box<dyn CliprdrServerFactory>>,
259 ) -> Self {
260 let (ev_sender, ev_receiver) = ServerEvent::create_channel();
261 if let Some(cliprdr) = cliprdr_factory.as_mut() {
262 cliprdr.set_sender(ev_sender.clone());
263 }
264 if let Some(snd) = sound_factory.as_mut() {
265 snd.set_sender(ev_sender.clone());
266 }
267 Self {
268 opts,
269 handler: Arc::new(Mutex::new(handler)),
270 display: Arc::new(Mutex::new(display)),
271 static_channels: StaticChannelSet::new(),
272 sound_factory,
273 cliprdr_factory,
274 ev_sender,
275 ev_receiver: Arc::new(Mutex::new(ev_receiver)),
276 creds: None,
277 local_addr: None,
278 }
279 }
280
281 pub fn builder() -> builder::RdpServerBuilder<builder::WantsAddr> {
282 builder::RdpServerBuilder::new()
283 }
284
285 pub fn event_sender(&self) -> &mpsc::UnboundedSender<ServerEvent> {
286 &self.ev_sender
287 }
288
289 fn attach_channels(&mut self, acceptor: &mut Acceptor) {
290 if let Some(cliprdr_factory) = self.cliprdr_factory.as_deref() {
291 let backend = cliprdr_factory.build_cliprdr_backend();
292
293 let cliprdr = CliprdrServer::new(backend);
294
295 acceptor.attach_static_channel(cliprdr);
296 }
297
298 if let Some(factory) = self.sound_factory.as_deref() {
299 let backend = factory.build_backend();
300
301 acceptor.attach_static_channel(RdpsndServer::new(backend));
302 }
303
304 let dcs_backend = DisplayControlBackend::new(Arc::clone(&self.display));
305 let dvc = dvc::DrdynvcServer::new()
306 .with_dynamic_channel(AInputHandler {
307 handler: Arc::clone(&self.handler),
308 })
309 .with_dynamic_channel(DisplayControlServer::new(Box::new(dcs_backend)));
310 acceptor.attach_static_channel(dvc);
311 }
312
313 pub async fn run_connection(&mut self, stream: TcpStream) -> Result<()> {
314 let framed = TokioFramed::new(stream);
315
316 let size = self.display.lock().await.size().await;
317 let capabilities = capabilities::capabilities(&self.opts, size);
318 let mut acceptor = Acceptor::new(self.opts.security.flag(), size, capabilities, self.creds.clone());
319
320 self.attach_channels(&mut acceptor);
321
322 let res = ironrdp_acceptor::accept_begin(framed, &mut acceptor)
323 .await
324 .context("accept_begin failed")?;
325
326 match res {
327 BeginResult::ShouldUpgrade(stream) => {
328 let tls_acceptor = match &self.opts.security {
329 RdpServerSecurity::Tls(acceptor) => acceptor,
330 RdpServerSecurity::Hybrid((acceptor, _)) => acceptor,
331 RdpServerSecurity::None => unreachable!(),
332 };
333 let accept = match tls_acceptor.accept(stream).await {
334 Ok(accept) => accept,
335 Err(e) => {
336 warn!("Failed to TLS accept: {}", e);
337 return Ok(());
338 }
339 };
340 let mut framed = TokioFramed::new(accept);
341
342 acceptor.mark_security_upgrade_as_done();
343
344 if let RdpServerSecurity::Hybrid((_, pub_key)) = &self.opts.security {
345 let client_name = framed.get_inner().0.get_ref().0.peer_addr()?.to_string();
348
349 ironrdp_acceptor::accept_credssp(
350 &mut framed,
351 &mut acceptor,
352 client_name.into(),
353 pub_key.clone(),
354 None,
355 None,
356 )
357 .await?;
358 }
359
360 self.accept_finalize(framed, acceptor).await?;
361 }
362
363 BeginResult::Continue(framed) => {
364 self.accept_finalize(framed, acceptor).await?;
365 }
366 };
367
368 Ok(())
369 }
370
371 pub async fn run(&mut self) -> Result<()> {
372 let listener = TcpListener::bind(self.opts.addr).await?;
373 let local_addr = listener.local_addr()?;
374
375 debug!("Listening for connections on {local_addr}");
376 self.local_addr = Some(local_addr);
377
378 loop {
379 let ev_receiver = Arc::clone(&self.ev_receiver);
380 let mut ev_receiver = ev_receiver.lock().await;
381 tokio::select! {
382 Some(event) = ev_receiver.recv() => {
383 match event {
384 ServerEvent::Quit(reason) => {
385 debug!("Got quit event {reason}");
386 break;
387 }
388 ServerEvent::GetLocalAddr(tx) => {
389 let _ = tx.send(self.local_addr);
390 }
391 ServerEvent::SetCredentials(creds) => {
392 self.set_credentials(Some(creds));
393 }
394 ev => {
395 debug!("Unexpected event {:?}", ev);
396 }
397 }
398 },
399 Ok((stream, peer)) = listener.accept() => {
400 debug!(?peer, "Received connection");
401 drop(ev_receiver);
402 if let Err(error) = self.run_connection(stream).await {
403 error!(?error, "Connection error");
404 }
405 self.static_channels = StaticChannelSet::new();
406 }
407 else => break,
408 }
409 }
410
411 Ok(())
412 }
413
414 pub fn get_svc_processor<T: SvcProcessor + 'static>(&mut self) -> Option<&mut T> {
415 self.static_channels
416 .get_by_type_mut::<T>()
417 .and_then(|svc| svc.channel_processor_downcast_mut())
418 }
419
420 pub fn get_channel_id_by_type<T: SvcProcessor + 'static>(&self) -> Option<StaticChannelId> {
421 self.static_channels.get_channel_id_by_type::<T>()
422 }
423
424 async fn dispatch_pdu(
425 &mut self,
426 action: Action,
427 bytes: bytes::BytesMut,
428 writer: &mut impl FramedWrite,
429 io_channel_id: u16,
430 user_channel_id: u16,
431 ) -> Result<RunState> {
432 match action {
433 Action::FastPath => {
434 let input = decode(&bytes)?;
435 self.handle_fastpath(input).await;
436 }
437
438 Action::X224 => {
439 if self
440 .handle_x224(writer, io_channel_id, user_channel_id, &bytes)
441 .await
442 .context("X224 input error")?
443 {
444 debug!("Got disconnect request");
445 return Ok(RunState::Disconnect);
446 }
447 }
448 }
449
450 Ok(RunState::Continue)
451 }
452
453 async fn dispatch_display_update(
454 update: DisplayUpdate,
455 writer: &mut impl FramedWrite,
456 user_channel_id: u16,
457 io_channel_id: u16,
458 buffer: &mut Vec<u8>,
459 mut encoder: UpdateEncoder,
460 ) -> Result<(RunState, UpdateEncoder)> {
461 if let DisplayUpdate::Resize(desktop_size) = update {
462 debug!(?desktop_size, "Display resize");
463 encoder.set_desktop_size(desktop_size);
464 deactivate_all(io_channel_id, user_channel_id, writer).await?;
465 return Ok((RunState::DeactivationReactivation { desktop_size }, encoder));
466 }
467
468 let mut encoder_iter = encoder.update(update);
469 loop {
470 let Some(fragmenter) = encoder_iter.next().await else {
471 break;
472 };
473
474 let mut fragmenter = fragmenter.context("error while encoding")?;
475 if fragmenter.size_hint() > buffer.len() {
476 buffer.resize(fragmenter.size_hint(), 0);
477 }
478
479 while let Some(len) = fragmenter.next(buffer) {
480 writer
481 .write_all(&buffer[..len])
482 .await
483 .context("failed to write display update")?;
484 }
485 }
486
487 Ok((RunState::Continue, encoder))
488 }
489
490 async fn dispatch_server_events(
491 &mut self,
492 events: &mut Vec<ServerEvent>,
493 writer: &mut impl FramedWrite,
494 user_channel_id: u16,
495 ) -> Result<RunState> {
496 let mut wave_limit = 4;
500 for event in events.drain(..) {
501 trace!(?event, "Dispatching");
502 match event {
503 ServerEvent::Quit(reason) => {
504 debug!("Got quit event: {reason}");
505 return Ok(RunState::Disconnect);
506 }
507 ServerEvent::GetLocalAddr(tx) => {
508 let _ = tx.send(self.local_addr);
509 }
510 ServerEvent::SetCredentials(creds) => {
511 self.set_credentials(Some(creds));
512 }
513 ServerEvent::Rdpsnd(s) => {
514 let Some(rdpsnd) = self.get_svc_processor::<RdpsndServer>() else {
515 warn!("No rdpsnd channel, dropping event");
516 continue;
517 };
518 let msgs = match s {
519 RdpsndServerMessage::Wave(data, ts) => {
520 if wave_limit == 0 {
521 debug!("Dropping wave");
522 continue;
523 }
524 wave_limit -= 1;
525 rdpsnd.wave(data, ts)
526 }
527 RdpsndServerMessage::SetVolume { left, right } => rdpsnd.set_volume(left, right),
528 RdpsndServerMessage::Close => rdpsnd.close(),
529 RdpsndServerMessage::Error(error) => {
530 error!(?error, "Handling rdpsnd event");
531 continue;
532 }
533 }
534 .context("failed to send rdpsnd event")?;
535 let channel_id = self
536 .get_channel_id_by_type::<RdpsndServer>()
537 .ok_or_else(|| anyhow!("SVC channel not found"))?;
538 let data = server_encode_svc_messages(msgs.into(), channel_id, user_channel_id)?;
539 writer.write_all(&data).await?;
540 }
541 ServerEvent::Clipboard(c) => {
542 let Some(cliprdr) = self.get_svc_processor::<CliprdrServer>() else {
543 warn!("No clipboard channel, dropping event");
544 continue;
545 };
546 let msgs = match c {
547 ClipboardMessage::SendInitiateCopy(formats) => cliprdr.initiate_copy(&formats),
548 ClipboardMessage::SendFormatData(data) => cliprdr.submit_format_data(data),
549 ClipboardMessage::SendInitiatePaste(format) => cliprdr.initiate_paste(format),
550 ClipboardMessage::Error(error) => {
551 error!(?error, "Handling clipboard event");
552 continue;
553 }
554 }
555 .context("failed to send clipboard event")?;
556 let channel_id = self
557 .get_channel_id_by_type::<CliprdrServer>()
558 .ok_or_else(|| anyhow!("SVC channel not found"))?;
559 let data = server_encode_svc_messages(msgs.into(), channel_id, user_channel_id)?;
560 writer.write_all(&data).await?;
561 }
562 }
563 }
564
565 Ok(RunState::Continue)
566 }
567
568 async fn client_loop<R, W>(
569 &mut self,
570 reader: &mut Framed<R>,
571 writer: &mut Framed<W>,
572 io_channel_id: u16,
573 user_channel_id: u16,
574 mut encoder: UpdateEncoder,
575 ) -> Result<RunState>
576 where
577 R: FramedRead,
578 W: FramedWrite,
579 {
580 debug!("Starting client loop");
581 let mut display_updates = self.display.lock().await.updates().await?;
582 let mut writer = SharedWriter::new(writer);
583 let mut display_writer = writer.clone();
584 let mut event_writer = writer.clone();
585 let ev_receiver = Arc::clone(&self.ev_receiver);
586 let s = Rc::new(Mutex::new(self));
587
588 let this = Rc::clone(&s);
589 let dispatch_pdu = async move {
590 loop {
591 let (action, bytes) = reader.read_pdu().await?;
592 let mut this = this.lock().await;
593 match this
594 .dispatch_pdu(action, bytes, &mut writer, io_channel_id, user_channel_id)
595 .await?
596 {
597 RunState::Continue => continue,
598 state => break Ok(state),
599 }
600 }
601 };
602
603 let dispatch_display = async move {
604 let mut buffer = vec![0u8; 4096];
605
606 loop {
607 match display_updates.next_update().await {
608 Ok(Some(update)) => {
609 match Self::dispatch_display_update(
610 update,
611 &mut display_writer,
612 user_channel_id,
613 io_channel_id,
614 &mut buffer,
615 encoder,
616 )
617 .await?
618 {
619 (RunState::Continue, enc) => {
620 encoder = enc;
621 continue;
622 }
623 (state, _) => {
624 break Ok(state);
625 }
626 }
627 }
628 Ok(None) => {
629 break Ok(RunState::Disconnect);
630 }
631 Err(error) => {
632 warn!(error = format!("{error:#}"), "next_updated failed");
633 }
634 }
635 }
636 };
637
638 let this = Rc::clone(&s);
639 let mut ev_receiver = ev_receiver.lock().await;
640 let dispatch_events = async move {
641 let mut events = Vec::with_capacity(100);
642 loop {
643 let nevents = ev_receiver.recv_many(&mut events, 100).await;
644 if nevents == 0 {
645 debug!("No sever events.. stopping");
646 break Ok(RunState::Disconnect);
647 }
648 while let Ok(ev) = ev_receiver.try_recv() {
649 events.push(ev);
650 }
651 let mut this = this.lock().await;
652 match this
653 .dispatch_server_events(&mut events, &mut event_writer, user_channel_id)
654 .await?
655 {
656 RunState::Continue => continue,
657 state => break Ok(state),
658 }
659 }
660 };
661
662 let state = tokio::select!(
663 state = dispatch_pdu => state,
664 state = dispatch_display => state,
665 state = dispatch_events => state,
666 );
667
668 debug!("End of client loop: {state:?}");
669 state
670 }
671
672 async fn client_accepted<R, W>(
673 &mut self,
674 reader: &mut Framed<R>,
675 writer: &mut Framed<W>,
676 result: AcceptorResult,
677 ) -> Result<RunState>
678 where
679 R: FramedRead,
680 W: FramedWrite,
681 {
682 debug!("Client accepted");
683
684 if !result.input_events.is_empty() {
685 debug!("Handling input event backlog from acceptor sequence");
686 self.handle_input_backlog(
687 writer,
688 result.io_channel_id,
689 result.user_channel_id,
690 result.input_events,
691 )
692 .await?;
693 }
694
695 self.static_channels = result.static_channels;
696 if !result.reactivation {
697 for (_type_id, channel, channel_id) in self.static_channels.iter_mut() {
698 debug!(?channel, ?channel_id, "Start");
699 let Some(channel_id) = channel_id else {
700 continue;
701 };
702 let svc_responses = channel.start()?;
703 let response = server_encode_svc_messages(svc_responses, channel_id, result.user_channel_id)?;
704 writer.write_all(&response).await?;
705 }
706 }
707
708 let mut update_codecs = UpdateEncoderCodecs::new();
709 let mut surface_flags = CmdFlags::empty();
710 for c in result.capabilities {
711 match c {
712 CapabilitySet::General(c) => {
713 let fastpath = c.extra_flags.contains(GeneralExtraFlags::FASTPATH_OUTPUT_SUPPORTED);
714 if !fastpath {
715 bail!("Fastpath output not supported!");
716 }
717 }
718 CapabilitySet::Bitmap(b) => {
719 if !b.desktop_resize_flag {
720 debug!("Desktop resize is not supported by the client");
721 continue;
722 }
723
724 let client_size = DesktopSize {
725 width: b.desktop_width,
726 height: b.desktop_height,
727 };
728 let display_size = self.display.lock().await.size().await;
729
730 if client_size.width < display_size.width || client_size.height < display_size.height {
733 warn!(
735 "Client size doesn't fit the server size: {:?} < {:?}",
736 client_size, display_size
737 );
738 }
739 }
740 CapabilitySet::SurfaceCommands(c) => {
741 surface_flags = c.flags;
742 }
743 CapabilitySet::BitmapCodecs(BitmapCodecs(codecs)) => {
744 for codec in codecs {
745 match codec.property {
746 CodecProperty::RemoteFx(rdp::capability_sets::RemoteFxContainer::ClientContainer(c))
756 if self.opts.has_remote_fx() =>
757 {
758 for caps in c.caps_data.0 .0 {
759 update_codecs.set_remotefx(Some((caps.entropy_bits, codec.id)));
760 }
761 }
762 CodecProperty::ImageRemoteFx(rdp::capability_sets::RemoteFxContainer::ClientContainer(
763 c,
764 )) if self.opts.has_image_remote_fx() => {
765 for caps in c.caps_data.0 .0 {
766 update_codecs.set_remotefx(Some((caps.entropy_bits, codec.id)));
767 }
768 }
769 CodecProperty::NsCodec(_) => (),
770 #[cfg(feature = "qoi")]
771 CodecProperty::Qoi if self.opts.has_qoi() => {
772 update_codecs.set_qoi(Some(codec.id));
773 }
774 #[cfg(feature = "qoiz")]
775 CodecProperty::QoiZ if self.opts.has_qoiz() => {
776 update_codecs.set_qoiz(Some(codec.id));
777 }
778 _ => (),
779 }
780 }
781 }
782 _ => {}
783 }
784 }
785
786 let desktop_size = self.display.lock().await.size().await;
787 let encoder = UpdateEncoder::new(desktop_size, surface_flags, update_codecs)
788 .context("failed to initialize update encoder")?;
789
790 let state = self
791 .client_loop(reader, writer, result.io_channel_id, result.user_channel_id, encoder)
792 .await
793 .context("client loop failure")?;
794
795 Ok(state)
796 }
797
798 async fn handle_input_backlog(
799 &mut self,
800 writer: &mut impl FramedWrite,
801 io_channel_id: u16,
802 user_channel_id: u16,
803 frames: Vec<Vec<u8>>,
804 ) -> Result<()> {
805 for frame in frames {
806 match Action::from_fp_output_header(frame[0]) {
807 Ok(Action::FastPath) => {
808 let input = decode(&frame)?;
809 self.handle_fastpath(input).await;
810 }
811
812 Ok(Action::X224) => {
813 let _ = self.handle_x224(writer, io_channel_id, user_channel_id, &frame).await;
814 }
815
816 Err(_) => unreachable!(),
819 }
820 }
821
822 Ok(())
823 }
824
825 async fn handle_fastpath(&mut self, input: FastPathInput) {
826 for event in input.0 {
827 let mut handler = self.handler.lock().await;
828 match event {
829 FastPathInputEvent::KeyboardEvent(flags, key) => {
830 handler.keyboard((key, flags).into());
831 }
832
833 FastPathInputEvent::UnicodeKeyboardEvent(flags, key) => {
834 handler.keyboard((key, flags).into());
835 }
836
837 FastPathInputEvent::SyncEvent(flags) => {
838 handler.keyboard(flags.into());
839 }
840
841 FastPathInputEvent::MouseEvent(mouse) => {
842 handler.mouse(mouse.into());
843 }
844
845 FastPathInputEvent::MouseEventEx(mouse) => {
846 handler.mouse(mouse.into());
847 }
848
849 FastPathInputEvent::MouseEventRel(mouse) => {
850 handler.mouse(mouse.into());
851 }
852
853 FastPathInputEvent::QoeEvent(quality) => {
854 warn!("Received QoE: {}", quality);
855 }
856 }
857 }
858 }
859
860 async fn handle_io_channel_data(&mut self, data: SendDataRequest<'_>) -> Result<bool> {
861 let control: rdp::headers::ShareControlHeader = decode(data.user_data.as_ref())?;
862
863 match control.share_control_pdu {
864 ShareControlPdu::Data(header) => match header.share_data_pdu {
865 rdp::headers::ShareDataPdu::Input(pdu) => {
866 self.handle_input_event(pdu).await;
867 }
868
869 rdp::headers::ShareDataPdu::ShutdownRequest => {
870 return Ok(true);
871 }
872
873 unexpected => {
874 warn!(?unexpected, "Unexpected share data pdu");
875 }
876 },
877
878 unexpected => {
879 warn!(?unexpected, "Unexpected share control");
880 }
881 }
882
883 Ok(false)
884 }
885
886 async fn handle_x224(
887 &mut self,
888 writer: &mut impl FramedWrite,
889 io_channel_id: u16,
890 user_channel_id: u16,
891 frame: &[u8],
892 ) -> Result<bool> {
893 let message = decode::<X224<mcs::McsMessage<'_>>>(frame)?;
894 match message.0 {
895 mcs::McsMessage::SendDataRequest(data) => {
896 debug!(?data, "McsMessage::SendDataRequest");
897 if data.channel_id == io_channel_id {
898 return self.handle_io_channel_data(data).await;
899 }
900
901 if let Some(svc) = self.static_channels.get_by_channel_id_mut(data.channel_id) {
902 let response_pdus = svc.process(&data.user_data)?;
903 let response = server_encode_svc_messages(response_pdus, data.channel_id, user_channel_id)?;
904 writer.write_all(&response).await?;
905 } else {
906 warn!(channel_id = data.channel_id, "Unexpected channel received: ID",);
907 }
908 }
909
910 mcs::McsMessage::DisconnectProviderUltimatum(disconnect) => {
911 if disconnect.reason == mcs::DisconnectReason::UserRequested {
912 return Ok(true);
913 }
914 }
915
916 _ => {
917 warn!(name = ironrdp_core::name(&message), "Unexpected mcs message");
918 }
919 }
920
921 Ok(false)
922 }
923
924 async fn handle_input_event(&mut self, input: InputEventPdu) {
925 for event in input.0 {
926 let mut handler = self.handler.lock().await;
927 match event {
928 ironrdp_pdu::input::InputEvent::ScanCode(key) => {
929 handler.keyboard((key.key_code, key.flags).into());
930 }
931
932 ironrdp_pdu::input::InputEvent::Unicode(key) => {
933 handler.keyboard((key.unicode_code, key.flags).into());
934 }
935
936 ironrdp_pdu::input::InputEvent::Sync(sync) => {
937 handler.keyboard(sync.flags.into());
938 }
939
940 ironrdp_pdu::input::InputEvent::Mouse(mouse) => {
941 handler.mouse(mouse.into());
942 }
943
944 ironrdp_pdu::input::InputEvent::MouseX(mouse) => {
945 handler.mouse(mouse.into());
946 }
947
948 ironrdp_pdu::input::InputEvent::MouseRel(mouse) => {
949 handler.mouse(mouse.into());
950 }
951
952 ironrdp_pdu::input::InputEvent::Unused(_) => {}
953 }
954 }
955 }
956
957 async fn accept_finalize<S>(&mut self, mut framed: TokioFramed<S>, mut acceptor: Acceptor) -> Result<()>
958 where
959 S: AsyncRead + AsyncWrite + Sync + Send + Unpin,
960 {
961 loop {
962 let (new_framed, result) = ironrdp_acceptor::accept_finalize(framed, &mut acceptor)
963 .await
964 .context("failed to accept client during finalize")?;
965
966 let (mut reader, mut writer) = split_tokio_framed(new_framed);
967
968 match self.client_accepted(&mut reader, &mut writer, result).await? {
969 RunState::Continue => {
970 unreachable!();
971 }
972 RunState::DeactivationReactivation { desktop_size } => {
973 acceptor = Acceptor::new_deactivation_reactivation(
978 acceptor,
979 core::mem::take(&mut self.static_channels),
980 desktop_size,
981 )?;
982 framed = unsplit_tokio_framed(reader, writer);
983 continue;
984 }
985 RunState::Disconnect => break,
986 }
987 }
988
989 Ok(())
990 }
991
992 pub fn set_credentials(&mut self, creds: Option<Credentials>) {
993 debug!(?creds, "Changing credentials");
994 self.creds = creds
995 }
996}
997
998async fn deactivate_all(
999 io_channel_id: u16,
1000 user_channel_id: u16,
1001 writer: &mut impl FramedWrite,
1002) -> Result<(), anyhow::Error> {
1003 let pdu = ShareControlPdu::ServerDeactivateAll(ServerDeactivateAll);
1004 let pdu = rdp::headers::ShareControlHeader {
1005 share_id: 0,
1006 pdu_source: io_channel_id,
1007 share_control_pdu: pdu,
1008 };
1009 let user_data = encode_vec(&pdu)?.into();
1010 let pdu = SendDataIndication {
1011 initiator_id: user_channel_id,
1012 channel_id: io_channel_id,
1013 user_data,
1014 };
1015 let msg = encode_vec(&X224(pdu))?;
1016 writer.write_all(&msg).await?;
1017 Ok(())
1018}
1019
1020struct SharedWriter<'w, W: FramedWrite> {
1021 writer: Rc<Mutex<&'w mut W>>,
1022}
1023
1024impl<W: FramedWrite> Clone for SharedWriter<'_, W> {
1025 fn clone(&self) -> Self {
1026 Self {
1027 writer: Rc::clone(&self.writer),
1028 }
1029 }
1030}
1031
1032impl<W> FramedWrite for SharedWriter<'_, W>
1033where
1034 W: FramedWrite,
1035{
1036 type WriteAllFut<'write>
1037 = core::pin::Pin<Box<dyn core::future::Future<Output = std::io::Result<()>> + 'write>>
1038 where
1039 Self: 'write;
1040
1041 fn write_all<'a>(&'a mut self, buf: &'a [u8]) -> Self::WriteAllFut<'a> {
1042 Box::pin(async {
1043 let mut writer = self.writer.lock().await;
1044
1045 writer.write_all(buf).await?;
1046 Ok(())
1047 })
1048 }
1049}
1050
1051impl<'a, W: FramedWrite> SharedWriter<'a, W> {
1052 fn new(writer: &'a mut W) -> Self {
1053 Self {
1054 writer: Rc::new(Mutex::new(writer)),
1055 }
1056 }
1057}