1use core::net::SocketAddr;
2use std::rc::Rc;
3use std::sync::Arc;
4
5use anyhow::{anyhow, bail, Context as _, Result};
6use ironrdp_acceptor::{Acceptor, AcceptorResult, BeginResult, DesktopSize};
7use ironrdp_async::Framed;
8use ironrdp_cliprdr::backend::ClipboardMessage;
9use ironrdp_cliprdr::CliprdrServer;
10use ironrdp_core::{decode, encode_vec, impl_as_any};
11use ironrdp_displaycontrol::pdu::DisplayControlMonitorLayout;
12use ironrdp_displaycontrol::server::{DisplayControlHandler, DisplayControlServer};
13use ironrdp_pdu::input::fast_path::{FastPathInput, FastPathInputEvent};
14use ironrdp_pdu::input::InputEventPdu;
15use ironrdp_pdu::mcs::{SendDataIndication, SendDataRequest};
16use ironrdp_pdu::rdp::capability_sets::{BitmapCodecs, CapabilitySet, CmdFlags, CodecProperty, GeneralExtraFlags};
17pub use ironrdp_pdu::rdp::client_info::Credentials;
18use ironrdp_pdu::rdp::headers::{ServerDeactivateAll, ShareControlPdu};
19use ironrdp_pdu::x224::X224;
20use ironrdp_pdu::{decode_err, mcs, nego, rdp, Action, PduResult};
21use ironrdp_svc::{server_encode_svc_messages, StaticChannelId, StaticChannelSet, SvcProcessor};
22use ironrdp_tokio::{split_tokio_framed, unsplit_tokio_framed, FramedRead, FramedWrite, TokioFramed};
23use rdpsnd::server::{RdpsndServer, RdpsndServerMessage};
24use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt as _};
25use tokio::net::{TcpListener, TcpStream};
26use tokio::sync::{mpsc, oneshot, Mutex};
27use tokio::task;
28use tokio_rustls::TlsAcceptor;
29use tracing::{debug, error, trace, warn};
30use {ironrdp_dvc as dvc, ironrdp_rdpsnd as rdpsnd};
31
32use crate::clipboard::CliprdrServerFactory;
33use crate::display::{DisplayUpdate, RdpServerDisplay};
34use crate::encoder::{UpdateEncoder, UpdateEncoderCodecs};
35use crate::handler::RdpServerInputHandler;
36use crate::{builder, capabilities, SoundServerFactory};
37
38#[derive(Clone)]
39pub struct RdpServerOptions {
40 pub addr: SocketAddr,
41 pub security: RdpServerSecurity,
42 pub codecs: BitmapCodecs,
43}
44
45impl RdpServerOptions {
46 fn has_image_remote_fx(&self) -> bool {
47 self.codecs
48 .0
49 .iter()
50 .any(|codec| matches!(codec.property, CodecProperty::ImageRemoteFx(_)))
51 }
52
53 fn has_remote_fx(&self) -> bool {
54 self.codecs
55 .0
56 .iter()
57 .any(|codec| matches!(codec.property, CodecProperty::RemoteFx(_)))
58 }
59
60 #[cfg(feature = "qoi")]
61 fn has_qoi(&self) -> bool {
62 self.codecs
63 .0
64 .iter()
65 .any(|codec| matches!(codec.property, CodecProperty::Qoi))
66 }
67
68 #[cfg(feature = "qoiz")]
69 fn has_qoiz(&self) -> bool {
70 self.codecs
71 .0
72 .iter()
73 .any(|codec| matches!(codec.property, CodecProperty::QoiZ))
74 }
75}
76
77#[derive(Clone)]
78pub enum RdpServerSecurity {
79 None,
80 Tls(TlsAcceptor),
81 Hybrid((TlsAcceptor, Vec<u8>)),
83}
84
85impl RdpServerSecurity {
86 pub fn flag(&self) -> nego::SecurityProtocol {
87 match self {
88 RdpServerSecurity::None => nego::SecurityProtocol::empty(),
89 RdpServerSecurity::Tls(_) => nego::SecurityProtocol::SSL,
90 RdpServerSecurity::Hybrid(_) => nego::SecurityProtocol::HYBRID | nego::SecurityProtocol::HYBRID_EX,
91 }
92 }
93}
94
95struct AInputHandler {
96 handler: Arc<Mutex<Box<dyn RdpServerInputHandler>>>,
97}
98
99impl_as_any!(AInputHandler);
100
101impl dvc::DvcProcessor for AInputHandler {
102 fn channel_name(&self) -> &str {
103 ironrdp_ainput::CHANNEL_NAME
104 }
105
106 fn start(&mut self, _channel_id: u32) -> PduResult<Vec<dvc::DvcMessage>> {
107 use ironrdp_ainput::{ServerPdu, VersionPdu};
108
109 let pdu = ServerPdu::Version(VersionPdu::default());
110
111 Ok(vec![Box::new(pdu)])
112 }
113
114 fn close(&mut self, _channel_id: u32) {}
115
116 fn process(&mut self, _channel_id: u32, payload: &[u8]) -> PduResult<Vec<dvc::DvcMessage>> {
117 use ironrdp_ainput::ClientPdu;
118
119 match decode(payload).map_err(|e| decode_err!(e))? {
120 ClientPdu::Mouse(pdu) => {
121 let handler = Arc::clone(&self.handler);
122 task::spawn_blocking(move || {
123 handler.blocking_lock().mouse(pdu.into());
124 });
125 }
126 }
127
128 Ok(Vec::new())
129 }
130}
131
132impl dvc::DvcServerProcessor for AInputHandler {}
133
134struct DisplayControlBackend {
135 display: Arc<Mutex<Box<dyn RdpServerDisplay>>>,
136}
137
138impl DisplayControlBackend {
139 fn new(display: Arc<Mutex<Box<dyn RdpServerDisplay>>>) -> Self {
140 Self { display }
141 }
142}
143
144impl DisplayControlHandler for DisplayControlBackend {
145 fn monitor_layout(&self, layout: DisplayControlMonitorLayout) {
146 let display = Arc::clone(&self.display);
147 task::spawn_blocking(move || display.blocking_lock().request_layout(layout));
148 }
149}
150
151pub struct RdpServer {
213 opts: RdpServerOptions,
214 handler: Arc<Mutex<Box<dyn RdpServerInputHandler>>>,
216 display: Arc<Mutex<Box<dyn RdpServerDisplay>>>,
217 static_channels: StaticChannelSet,
218 sound_factory: Option<Box<dyn SoundServerFactory>>,
219 cliprdr_factory: Option<Box<dyn CliprdrServerFactory>>,
220 ev_sender: mpsc::UnboundedSender<ServerEvent>,
221 ev_receiver: Arc<Mutex<mpsc::UnboundedReceiver<ServerEvent>>>,
222 creds: Option<Credentials>,
223 local_addr: Option<SocketAddr>,
224}
225
226#[derive(Debug)]
227pub enum ServerEvent {
228 Quit(String),
229 Clipboard(ClipboardMessage),
230 Rdpsnd(RdpsndServerMessage),
231 SetCredentials(Credentials),
232 GetLocalAddr(oneshot::Sender<Option<SocketAddr>>),
233}
234
235pub trait ServerEventSender {
236 fn set_sender(&mut self, sender: mpsc::UnboundedSender<ServerEvent>);
237}
238
239impl ServerEvent {
240 pub fn create_channel() -> (mpsc::UnboundedSender<Self>, mpsc::UnboundedReceiver<Self>) {
241 mpsc::unbounded_channel()
242 }
243}
244
245#[derive(Debug, PartialEq)]
246enum RunState {
247 Continue,
248 Disconnect,
249 DeactivationReactivation { desktop_size: DesktopSize },
250}
251
252impl RdpServer {
253 pub fn new(
254 opts: RdpServerOptions,
255 handler: Box<dyn RdpServerInputHandler>,
256 display: Box<dyn RdpServerDisplay>,
257 mut sound_factory: Option<Box<dyn SoundServerFactory>>,
258 mut cliprdr_factory: Option<Box<dyn CliprdrServerFactory>>,
259 ) -> Self {
260 let (ev_sender, ev_receiver) = ServerEvent::create_channel();
261 if let Some(cliprdr) = cliprdr_factory.as_mut() {
262 cliprdr.set_sender(ev_sender.clone());
263 }
264 if let Some(snd) = sound_factory.as_mut() {
265 snd.set_sender(ev_sender.clone());
266 }
267 Self {
268 opts,
269 handler: Arc::new(Mutex::new(handler)),
270 display: Arc::new(Mutex::new(display)),
271 static_channels: StaticChannelSet::new(),
272 sound_factory,
273 cliprdr_factory,
274 ev_sender,
275 ev_receiver: Arc::new(Mutex::new(ev_receiver)),
276 creds: None,
277 local_addr: None,
278 }
279 }
280
281 pub fn builder() -> builder::RdpServerBuilder<builder::WantsAddr> {
282 builder::RdpServerBuilder::new()
283 }
284
285 pub fn event_sender(&self) -> &mpsc::UnboundedSender<ServerEvent> {
286 &self.ev_sender
287 }
288
289 fn attach_channels(&mut self, acceptor: &mut Acceptor) {
290 if let Some(cliprdr_factory) = self.cliprdr_factory.as_deref() {
291 let backend = cliprdr_factory.build_cliprdr_backend();
292
293 let cliprdr = CliprdrServer::new(backend);
294
295 acceptor.attach_static_channel(cliprdr);
296 }
297
298 if let Some(factory) = self.sound_factory.as_deref() {
299 let backend = factory.build_backend();
300
301 acceptor.attach_static_channel(RdpsndServer::new(backend));
302 }
303
304 let dcs_backend = DisplayControlBackend::new(Arc::clone(&self.display));
305 let dvc = dvc::DrdynvcServer::new()
306 .with_dynamic_channel(AInputHandler {
307 handler: Arc::clone(&self.handler),
308 })
309 .with_dynamic_channel(DisplayControlServer::new(Box::new(dcs_backend)));
310 acceptor.attach_static_channel(dvc);
311 }
312
313 pub async fn run_connection(&mut self, stream: TcpStream) -> Result<()> {
314 let framed = TokioFramed::new(stream);
315
316 let size = self.display.lock().await.size().await;
317 let capabilities = capabilities::capabilities(&self.opts, size);
318 let mut acceptor = Acceptor::new(self.opts.security.flag(), size, capabilities, self.creds.clone());
319
320 self.attach_channels(&mut acceptor);
321
322 let res = ironrdp_acceptor::accept_begin(framed, &mut acceptor)
323 .await
324 .context("accept_begin failed")?;
325
326 match res {
327 BeginResult::ShouldUpgrade(stream) => {
328 let tls_acceptor = match &self.opts.security {
329 RdpServerSecurity::Tls(acceptor) => acceptor,
330 RdpServerSecurity::Hybrid((acceptor, _)) => acceptor,
331 RdpServerSecurity::None => unreachable!(),
332 };
333 let accept = match tls_acceptor.accept(stream).await {
334 Ok(accept) => accept,
335 Err(e) => {
336 warn!("Failed to TLS accept: {}", e);
337 return Ok(());
338 }
339 };
340 let mut framed = TokioFramed::new(accept);
341
342 acceptor.mark_security_upgrade_as_done();
343
344 if let RdpServerSecurity::Hybrid((_, pub_key)) = &self.opts.security {
345 let client_name = framed.get_inner().0.get_ref().0.peer_addr()?.to_string();
348
349 ironrdp_acceptor::accept_credssp(
350 &mut framed,
351 &mut acceptor,
352 &mut ironrdp_tokio::reqwest::ReqwestNetworkClient::new(),
353 client_name.into(),
354 pub_key.clone(),
355 None,
356 )
357 .await?;
358 }
359
360 let framed = self.accept_finalize(framed, acceptor).await?;
361 debug!("Shutting down TLS connection");
362 let (mut tls_stream, _) = framed.into_inner();
363 if let Err(e) = tls_stream.shutdown().await {
364 debug!(?e, "TLS shutdown error");
365 }
366 }
367
368 BeginResult::Continue(framed) => {
369 self.accept_finalize(framed, acceptor).await?;
370 }
371 };
372
373 Ok(())
374 }
375
376 pub async fn run(&mut self) -> Result<()> {
377 let listener = TcpListener::bind(self.opts.addr).await?;
378 let local_addr = listener.local_addr()?;
379
380 debug!("Listening for connections on {local_addr}");
381 self.local_addr = Some(local_addr);
382
383 loop {
384 let ev_receiver = Arc::clone(&self.ev_receiver);
385 let mut ev_receiver = ev_receiver.lock().await;
386 tokio::select! {
387 Some(event) = ev_receiver.recv() => {
388 match event {
389 ServerEvent::Quit(reason) => {
390 debug!("Got quit event {reason}");
391 break;
392 }
393 ServerEvent::GetLocalAddr(tx) => {
394 let _ = tx.send(self.local_addr);
395 }
396 ServerEvent::SetCredentials(creds) => {
397 self.set_credentials(Some(creds));
398 }
399 ev => {
400 debug!("Unexpected event {:?}", ev);
401 }
402 }
403 },
404 Ok((stream, peer)) = listener.accept() => {
405 debug!(?peer, "Received connection");
406 drop(ev_receiver);
407 if let Err(error) = self.run_connection(stream).await {
408 error!(?error, "Connection error");
409 }
410 self.static_channels = StaticChannelSet::new();
411 }
412 else => break,
413 }
414 }
415
416 Ok(())
417 }
418
419 pub fn get_svc_processor<T: SvcProcessor + 'static>(&mut self) -> Option<&mut T> {
420 self.static_channels
421 .get_by_type_mut::<T>()
422 .and_then(|svc| svc.channel_processor_downcast_mut())
423 }
424
425 pub fn get_channel_id_by_type<T: SvcProcessor + 'static>(&self) -> Option<StaticChannelId> {
426 self.static_channels.get_channel_id_by_type::<T>()
427 }
428
429 async fn dispatch_pdu(
430 &mut self,
431 action: Action,
432 bytes: bytes::BytesMut,
433 writer: &mut impl FramedWrite,
434 io_channel_id: u16,
435 user_channel_id: u16,
436 ) -> Result<RunState> {
437 match action {
438 Action::FastPath => {
439 let input = decode(&bytes)?;
440 self.handle_fastpath(input).await;
441 }
442
443 Action::X224 => {
444 if self
445 .handle_x224(writer, io_channel_id, user_channel_id, &bytes)
446 .await
447 .context("X224 input error")?
448 {
449 debug!("Got disconnect request");
450 return Ok(RunState::Disconnect);
451 }
452 }
453 }
454
455 Ok(RunState::Continue)
456 }
457
458 async fn dispatch_display_update(
459 update: DisplayUpdate,
460 writer: &mut impl FramedWrite,
461 user_channel_id: u16,
462 io_channel_id: u16,
463 buffer: &mut Vec<u8>,
464 mut encoder: UpdateEncoder,
465 ) -> Result<(RunState, UpdateEncoder)> {
466 if let DisplayUpdate::Resize(desktop_size) = update {
467 debug!(?desktop_size, "Display resize");
468 encoder.set_desktop_size(desktop_size);
469 deactivate_all(io_channel_id, user_channel_id, writer).await?;
470 return Ok((RunState::DeactivationReactivation { desktop_size }, encoder));
471 }
472
473 let mut encoder_iter = encoder.update(update);
474 loop {
475 let Some(fragmenter) = encoder_iter.next().await else {
476 break;
477 };
478
479 let mut fragmenter = fragmenter.context("error while encoding")?;
480 if fragmenter.size_hint() > buffer.len() {
481 buffer.resize(fragmenter.size_hint(), 0);
482 }
483
484 while let Some(len) = fragmenter.next(buffer) {
485 writer
486 .write_all(&buffer[..len])
487 .await
488 .context("failed to write display update")?;
489 }
490 }
491
492 Ok((RunState::Continue, encoder))
493 }
494
495 async fn dispatch_server_events(
496 &mut self,
497 events: &mut Vec<ServerEvent>,
498 writer: &mut impl FramedWrite,
499 user_channel_id: u16,
500 ) -> Result<RunState> {
501 let mut wave_limit = 4;
505 for event in events.drain(..) {
506 trace!(?event, "Dispatching");
507 match event {
508 ServerEvent::Quit(reason) => {
509 debug!("Got quit event: {reason}");
510 return Ok(RunState::Disconnect);
511 }
512 ServerEvent::GetLocalAddr(tx) => {
513 let _ = tx.send(self.local_addr);
514 }
515 ServerEvent::SetCredentials(creds) => {
516 self.set_credentials(Some(creds));
517 }
518 ServerEvent::Rdpsnd(s) => {
519 let Some(rdpsnd) = self.get_svc_processor::<RdpsndServer>() else {
520 warn!("No rdpsnd channel, dropping event");
521 continue;
522 };
523 let msgs = match s {
524 RdpsndServerMessage::Wave(data, ts) => {
525 if wave_limit == 0 {
526 debug!("Dropping wave");
527 continue;
528 }
529 wave_limit -= 1;
530 rdpsnd.wave(data, ts)
531 }
532 RdpsndServerMessage::SetVolume { left, right } => rdpsnd.set_volume(left, right),
533 RdpsndServerMessage::Close => rdpsnd.close(),
534 RdpsndServerMessage::Error(error) => {
535 error!(?error, "Handling rdpsnd event");
536 continue;
537 }
538 }
539 .context("failed to send rdpsnd event")?;
540 let channel_id = self
541 .get_channel_id_by_type::<RdpsndServer>()
542 .ok_or_else(|| anyhow!("SVC channel not found"))?;
543 let data = server_encode_svc_messages(msgs.into(), channel_id, user_channel_id)?;
544 writer.write_all(&data).await?;
545 }
546 ServerEvent::Clipboard(c) => {
547 let Some(cliprdr) = self.get_svc_processor::<CliprdrServer>() else {
548 warn!("No clipboard channel, dropping event");
549 continue;
550 };
551 let msgs = match c {
552 ClipboardMessage::SendInitiateCopy(formats) => cliprdr.initiate_copy(&formats),
553 ClipboardMessage::SendFormatData(data) => cliprdr.submit_format_data(data),
554 ClipboardMessage::SendInitiatePaste(format) => cliprdr.initiate_paste(format),
555 ClipboardMessage::Error(error) => {
556 error!(?error, "Handling clipboard event");
557 continue;
558 }
559 }
560 .context("failed to send clipboard event")?;
561 let channel_id = self
562 .get_channel_id_by_type::<CliprdrServer>()
563 .ok_or_else(|| anyhow!("SVC channel not found"))?;
564 let data = server_encode_svc_messages(msgs.into(), channel_id, user_channel_id)?;
565 writer.write_all(&data).await?;
566 }
567 }
568 }
569
570 Ok(RunState::Continue)
571 }
572
573 async fn client_loop<R, W>(
574 &mut self,
575 reader: &mut Framed<R>,
576 writer: &mut Framed<W>,
577 io_channel_id: u16,
578 user_channel_id: u16,
579 mut encoder: UpdateEncoder,
580 ) -> Result<RunState>
581 where
582 R: FramedRead,
583 W: FramedWrite,
584 {
585 debug!("Starting client loop");
586 let mut display_updates = self.display.lock().await.updates().await?;
587 let mut writer = SharedWriter::new(writer);
588 let mut display_writer = writer.clone();
589 let mut event_writer = writer.clone();
590 let ev_receiver = Arc::clone(&self.ev_receiver);
591 let s = Rc::new(Mutex::new(self));
592
593 let this = Rc::clone(&s);
594 let dispatch_pdu = async move {
595 loop {
596 let (action, bytes) = reader.read_pdu().await?;
597 let mut this = this.lock().await;
598 match this
599 .dispatch_pdu(action, bytes, &mut writer, io_channel_id, user_channel_id)
600 .await?
601 {
602 RunState::Continue => continue,
603 state => break Ok(state),
604 }
605 }
606 };
607
608 let dispatch_display = async move {
609 let mut buffer = vec![0u8; 4096];
610
611 loop {
612 match display_updates.next_update().await {
613 Ok(Some(update)) => {
614 match Self::dispatch_display_update(
615 update,
616 &mut display_writer,
617 user_channel_id,
618 io_channel_id,
619 &mut buffer,
620 encoder,
621 )
622 .await?
623 {
624 (RunState::Continue, enc) => {
625 encoder = enc;
626 continue;
627 }
628 (state, _) => {
629 break Ok(state);
630 }
631 }
632 }
633 Ok(None) => {
634 break Ok(RunState::Disconnect);
635 }
636 Err(error) => {
637 warn!(error = format!("{error:#}"), "next_updated failed");
638 }
639 }
640 }
641 };
642
643 let this = Rc::clone(&s);
644 let mut ev_receiver = ev_receiver.lock().await;
645 let dispatch_events = async move {
646 let mut events = Vec::with_capacity(100);
647 loop {
648 let nevents = ev_receiver.recv_many(&mut events, 100).await;
649 if nevents == 0 {
650 debug!("No sever events.. stopping");
651 break Ok(RunState::Disconnect);
652 }
653 while let Ok(ev) = ev_receiver.try_recv() {
654 events.push(ev);
655 }
656 let mut this = this.lock().await;
657 match this
658 .dispatch_server_events(&mut events, &mut event_writer, user_channel_id)
659 .await?
660 {
661 RunState::Continue => continue,
662 state => break Ok(state),
663 }
664 }
665 };
666
667 let state = tokio::select!(
668 state = dispatch_pdu => state,
669 state = dispatch_display => state,
670 state = dispatch_events => state,
671 );
672
673 debug!("End of client loop: {state:?}");
674 state
675 }
676
677 async fn client_accepted<R, W>(
678 &mut self,
679 reader: &mut Framed<R>,
680 writer: &mut Framed<W>,
681 result: AcceptorResult,
682 ) -> Result<RunState>
683 where
684 R: FramedRead,
685 W: FramedWrite,
686 {
687 debug!("Client accepted");
688
689 if !result.input_events.is_empty() {
690 debug!("Handling input event backlog from acceptor sequence");
691 self.handle_input_backlog(
692 writer,
693 result.io_channel_id,
694 result.user_channel_id,
695 result.input_events,
696 )
697 .await?;
698 }
699
700 self.static_channels = result.static_channels;
701 if !result.reactivation {
702 for (_type_id, channel, channel_id) in self.static_channels.iter_mut() {
703 debug!(?channel, ?channel_id, "Start");
704 let Some(channel_id) = channel_id else {
705 continue;
706 };
707 let svc_responses = channel.start()?;
708 let response = server_encode_svc_messages(svc_responses, channel_id, result.user_channel_id)?;
709 writer.write_all(&response).await?;
710 }
711 }
712
713 let mut update_codecs = UpdateEncoderCodecs::new();
714 let mut surface_flags = CmdFlags::empty();
715 for c in result.capabilities {
716 match c {
717 CapabilitySet::General(c) => {
718 let fastpath = c.extra_flags.contains(GeneralExtraFlags::FASTPATH_OUTPUT_SUPPORTED);
719 if !fastpath {
720 bail!("Fastpath output not supported!");
721 }
722 }
723 CapabilitySet::Bitmap(b) => {
724 if !b.desktop_resize_flag {
725 debug!("Desktop resize is not supported by the client");
726 continue;
727 }
728
729 let client_size = DesktopSize {
730 width: b.desktop_width,
731 height: b.desktop_height,
732 };
733 let display_size = self.display.lock().await.size().await;
734
735 if client_size.width < display_size.width || client_size.height < display_size.height {
738 warn!(
740 "Client size doesn't fit the server size: {:?} < {:?}",
741 client_size, display_size
742 );
743 }
744 }
745 CapabilitySet::SurfaceCommands(c) => {
746 surface_flags = c.flags;
747 }
748 CapabilitySet::BitmapCodecs(BitmapCodecs(codecs)) => {
749 for codec in codecs {
750 match codec.property {
751 CodecProperty::RemoteFx(rdp::capability_sets::RemoteFxContainer::ClientContainer(c))
761 if self.opts.has_remote_fx() =>
762 {
763 for caps in c.caps_data.0 .0 {
764 update_codecs.set_remotefx(Some((caps.entropy_bits, codec.id)));
765 }
766 }
767 CodecProperty::ImageRemoteFx(rdp::capability_sets::RemoteFxContainer::ClientContainer(
768 c,
769 )) if self.opts.has_image_remote_fx() => {
770 for caps in c.caps_data.0 .0 {
771 update_codecs.set_remotefx(Some((caps.entropy_bits, codec.id)));
772 }
773 }
774 CodecProperty::NsCodec(_) => (),
775 #[cfg(feature = "qoi")]
776 CodecProperty::Qoi if self.opts.has_qoi() => {
777 update_codecs.set_qoi(Some(codec.id));
778 }
779 #[cfg(feature = "qoiz")]
780 CodecProperty::QoiZ if self.opts.has_qoiz() => {
781 update_codecs.set_qoiz(Some(codec.id));
782 }
783 _ => (),
784 }
785 }
786 }
787 _ => {}
788 }
789 }
790
791 let desktop_size = self.display.lock().await.size().await;
792 let encoder = UpdateEncoder::new(desktop_size, surface_flags, update_codecs)
793 .context("failed to initialize update encoder")?;
794
795 let state = self
796 .client_loop(reader, writer, result.io_channel_id, result.user_channel_id, encoder)
797 .await
798 .context("client loop failure")?;
799
800 Ok(state)
801 }
802
803 async fn handle_input_backlog(
804 &mut self,
805 writer: &mut impl FramedWrite,
806 io_channel_id: u16,
807 user_channel_id: u16,
808 frames: Vec<Vec<u8>>,
809 ) -> Result<()> {
810 for frame in frames {
811 match Action::from_fp_output_header(frame[0]) {
812 Ok(Action::FastPath) => {
813 let input = decode(&frame)?;
814 self.handle_fastpath(input).await;
815 }
816
817 Ok(Action::X224) => {
818 let _ = self.handle_x224(writer, io_channel_id, user_channel_id, &frame).await;
819 }
820
821 Err(_) => unreachable!(),
824 }
825 }
826
827 Ok(())
828 }
829
830 async fn handle_fastpath(&mut self, input: FastPathInput) {
831 for event in input.input_events().iter().copied() {
832 let mut handler = self.handler.lock().await;
833 match event {
834 FastPathInputEvent::KeyboardEvent(flags, key) => {
835 handler.keyboard((key, flags).into());
836 }
837
838 FastPathInputEvent::UnicodeKeyboardEvent(flags, key) => {
839 handler.keyboard((key, flags).into());
840 }
841
842 FastPathInputEvent::SyncEvent(flags) => {
843 handler.keyboard(flags.into());
844 }
845
846 FastPathInputEvent::MouseEvent(mouse) => {
847 handler.mouse(mouse.into());
848 }
849
850 FastPathInputEvent::MouseEventEx(mouse) => {
851 handler.mouse(mouse.into());
852 }
853
854 FastPathInputEvent::MouseEventRel(mouse) => {
855 handler.mouse(mouse.into());
856 }
857
858 FastPathInputEvent::QoeEvent(quality) => {
859 warn!("Received QoE: {}", quality);
860 }
861 }
862 }
863 }
864
865 async fn handle_io_channel_data(&mut self, data: SendDataRequest<'_>) -> Result<bool> {
866 let control: rdp::headers::ShareControlHeader = decode(data.user_data.as_ref())?;
867
868 match control.share_control_pdu {
869 ShareControlPdu::Data(header) => match header.share_data_pdu {
870 rdp::headers::ShareDataPdu::Input(pdu) => {
871 self.handle_input_event(pdu).await;
872 }
873
874 rdp::headers::ShareDataPdu::ShutdownRequest => {
875 return Ok(true);
876 }
877
878 unexpected => {
879 warn!(?unexpected, "Unexpected share data pdu");
880 }
881 },
882
883 unexpected => {
884 warn!(?unexpected, "Unexpected share control");
885 }
886 }
887
888 Ok(false)
889 }
890
891 async fn handle_x224(
892 &mut self,
893 writer: &mut impl FramedWrite,
894 io_channel_id: u16,
895 user_channel_id: u16,
896 frame: &[u8],
897 ) -> Result<bool> {
898 let message = decode::<X224<mcs::McsMessage<'_>>>(frame)?;
899 match message.0 {
900 mcs::McsMessage::SendDataRequest(data) => {
901 debug!(?data, "McsMessage::SendDataRequest");
902 if data.channel_id == io_channel_id {
903 return self.handle_io_channel_data(data).await;
904 }
905
906 if let Some(svc) = self.static_channels.get_by_channel_id_mut(data.channel_id) {
907 let response_pdus = svc.process(&data.user_data)?;
908 let response = server_encode_svc_messages(response_pdus, data.channel_id, user_channel_id)?;
909 writer.write_all(&response).await?;
910 } else {
911 warn!(channel_id = data.channel_id, "Unexpected channel received: ID",);
912 }
913 }
914
915 mcs::McsMessage::DisconnectProviderUltimatum(disconnect) => {
916 if disconnect.reason == mcs::DisconnectReason::UserRequested {
917 return Ok(true);
918 }
919 }
920
921 _ => {
922 warn!(name = ironrdp_core::name(&message), "Unexpected mcs message");
923 }
924 }
925
926 Ok(false)
927 }
928
929 async fn handle_input_event(&mut self, input: InputEventPdu) {
930 for event in input.0 {
931 let mut handler = self.handler.lock().await;
932 match event {
933 ironrdp_pdu::input::InputEvent::ScanCode(key) => {
934 handler.keyboard((key.key_code, key.flags).into());
935 }
936
937 ironrdp_pdu::input::InputEvent::Unicode(key) => {
938 handler.keyboard((key.unicode_code, key.flags).into());
939 }
940
941 ironrdp_pdu::input::InputEvent::Sync(sync) => {
942 handler.keyboard(sync.flags.into());
943 }
944
945 ironrdp_pdu::input::InputEvent::Mouse(mouse) => {
946 handler.mouse(mouse.into());
947 }
948
949 ironrdp_pdu::input::InputEvent::MouseX(mouse) => {
950 handler.mouse(mouse.into());
951 }
952
953 ironrdp_pdu::input::InputEvent::MouseRel(mouse) => {
954 handler.mouse(mouse.into());
955 }
956
957 ironrdp_pdu::input::InputEvent::Unused(_) => {}
958 }
959 }
960 }
961
962 async fn accept_finalize<S>(&mut self, mut framed: TokioFramed<S>, mut acceptor: Acceptor) -> Result<TokioFramed<S>>
963 where
964 S: AsyncRead + AsyncWrite + Sync + Send + Unpin,
965 {
966 loop {
967 let (new_framed, result) = ironrdp_acceptor::accept_finalize(framed, &mut acceptor)
968 .await
969 .context("failed to accept client during finalize")?;
970
971 let (mut reader, mut writer) = split_tokio_framed(new_framed);
972
973 match self.client_accepted(&mut reader, &mut writer, result).await? {
974 RunState::Continue => {
975 unreachable!();
976 }
977 RunState::DeactivationReactivation { desktop_size } => {
978 acceptor = Acceptor::new_deactivation_reactivation(
983 acceptor,
984 core::mem::take(&mut self.static_channels),
985 desktop_size,
986 )?;
987 framed = unsplit_tokio_framed(reader, writer);
988 continue;
989 }
990 RunState::Disconnect => {
991 let final_framed = unsplit_tokio_framed(reader, writer);
992 return Ok(final_framed);
993 }
994 }
995 }
996 }
997
998 pub fn set_credentials(&mut self, creds: Option<Credentials>) {
999 debug!(?creds, "Changing credentials");
1000 self.creds = creds
1001 }
1002}
1003
1004async fn deactivate_all(
1005 io_channel_id: u16,
1006 user_channel_id: u16,
1007 writer: &mut impl FramedWrite,
1008) -> Result<(), anyhow::Error> {
1009 let pdu = ShareControlPdu::ServerDeactivateAll(ServerDeactivateAll);
1010 let pdu = rdp::headers::ShareControlHeader {
1011 share_id: 0,
1012 pdu_source: io_channel_id,
1013 share_control_pdu: pdu,
1014 };
1015 let user_data = encode_vec(&pdu)?.into();
1016 let pdu = SendDataIndication {
1017 initiator_id: user_channel_id,
1018 channel_id: io_channel_id,
1019 user_data,
1020 };
1021 let msg = encode_vec(&X224(pdu))?;
1022 writer.write_all(&msg).await?;
1023 Ok(())
1024}
1025
1026struct SharedWriter<'w, W: FramedWrite> {
1027 writer: Rc<Mutex<&'w mut W>>,
1028}
1029
1030impl<W: FramedWrite> Clone for SharedWriter<'_, W> {
1031 fn clone(&self) -> Self {
1032 Self {
1033 writer: Rc::clone(&self.writer),
1034 }
1035 }
1036}
1037
1038impl<W> FramedWrite for SharedWriter<'_, W>
1039where
1040 W: FramedWrite,
1041{
1042 type WriteAllFut<'write>
1043 = core::pin::Pin<Box<dyn core::future::Future<Output = std::io::Result<()>> + 'write>>
1044 where
1045 Self: 'write;
1046
1047 fn write_all<'a>(&'a mut self, buf: &'a [u8]) -> Self::WriteAllFut<'a> {
1048 Box::pin(async {
1049 let mut writer = self.writer.lock().await;
1050
1051 writer.write_all(buf).await?;
1052 Ok(())
1053 })
1054 }
1055}
1056
1057impl<'a, W: FramedWrite> SharedWriter<'a, W> {
1058 fn new(writer: &'a mut W) -> Self {
1059 Self {
1060 writer: Rc::new(Mutex::new(writer)),
1061 }
1062 }
1063}