ironrdp_server/
server.rs

1use core::net::SocketAddr;
2use std::rc::Rc;
3use std::sync::Arc;
4
5use anyhow::{anyhow, bail, Context as _, Result};
6use ironrdp_acceptor::{Acceptor, AcceptorResult, BeginResult, DesktopSize};
7use ironrdp_async::Framed;
8use ironrdp_cliprdr::backend::ClipboardMessage;
9use ironrdp_cliprdr::CliprdrServer;
10use ironrdp_core::{decode, encode_vec, impl_as_any};
11use ironrdp_displaycontrol::pdu::DisplayControlMonitorLayout;
12use ironrdp_displaycontrol::server::{DisplayControlHandler, DisplayControlServer};
13use ironrdp_pdu::input::fast_path::{FastPathInput, FastPathInputEvent};
14use ironrdp_pdu::input::InputEventPdu;
15use ironrdp_pdu::mcs::{SendDataIndication, SendDataRequest};
16use ironrdp_pdu::rdp::capability_sets::{BitmapCodecs, CapabilitySet, CmdFlags, CodecProperty, GeneralExtraFlags};
17pub use ironrdp_pdu::rdp::client_info::Credentials;
18use ironrdp_pdu::rdp::headers::{ServerDeactivateAll, ShareControlPdu};
19use ironrdp_pdu::x224::X224;
20use ironrdp_pdu::{decode_err, mcs, nego, rdp, Action, PduResult};
21use ironrdp_svc::{server_encode_svc_messages, StaticChannelId, StaticChannelSet, SvcProcessor};
22use ironrdp_tokio::{split_tokio_framed, unsplit_tokio_framed, FramedRead, FramedWrite, TokioFramed};
23use rdpsnd::server::{RdpsndServer, RdpsndServerMessage};
24use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt as _};
25use tokio::net::{TcpListener, TcpStream};
26use tokio::sync::{mpsc, oneshot, Mutex};
27use tokio::task;
28use tokio_rustls::TlsAcceptor;
29use tracing::{debug, error, trace, warn};
30use {ironrdp_dvc as dvc, ironrdp_rdpsnd as rdpsnd};
31
32use crate::clipboard::CliprdrServerFactory;
33use crate::display::{DisplayUpdate, RdpServerDisplay};
34use crate::encoder::{UpdateEncoder, UpdateEncoderCodecs};
35use crate::handler::RdpServerInputHandler;
36use crate::{builder, capabilities, SoundServerFactory};
37
38#[derive(Clone)]
39pub struct RdpServerOptions {
40    pub addr: SocketAddr,
41    pub security: RdpServerSecurity,
42    pub codecs: BitmapCodecs,
43}
44
45impl RdpServerOptions {
46    fn has_image_remote_fx(&self) -> bool {
47        self.codecs
48            .0
49            .iter()
50            .any(|codec| matches!(codec.property, CodecProperty::ImageRemoteFx(_)))
51    }
52
53    fn has_remote_fx(&self) -> bool {
54        self.codecs
55            .0
56            .iter()
57            .any(|codec| matches!(codec.property, CodecProperty::RemoteFx(_)))
58    }
59
60    #[cfg(feature = "qoi")]
61    fn has_qoi(&self) -> bool {
62        self.codecs
63            .0
64            .iter()
65            .any(|codec| matches!(codec.property, CodecProperty::Qoi))
66    }
67
68    #[cfg(feature = "qoiz")]
69    fn has_qoiz(&self) -> bool {
70        self.codecs
71            .0
72            .iter()
73            .any(|codec| matches!(codec.property, CodecProperty::QoiZ))
74    }
75}
76
77#[derive(Clone)]
78pub enum RdpServerSecurity {
79    None,
80    Tls(TlsAcceptor),
81    /// Used for both hybrid + hybrid-ex.
82    Hybrid((TlsAcceptor, Vec<u8>)),
83}
84
85impl RdpServerSecurity {
86    pub fn flag(&self) -> nego::SecurityProtocol {
87        match self {
88            RdpServerSecurity::None => nego::SecurityProtocol::empty(),
89            RdpServerSecurity::Tls(_) => nego::SecurityProtocol::SSL,
90            RdpServerSecurity::Hybrid(_) => nego::SecurityProtocol::HYBRID | nego::SecurityProtocol::HYBRID_EX,
91        }
92    }
93}
94
95struct AInputHandler {
96    handler: Arc<Mutex<Box<dyn RdpServerInputHandler>>>,
97}
98
99impl_as_any!(AInputHandler);
100
101impl dvc::DvcProcessor for AInputHandler {
102    fn channel_name(&self) -> &str {
103        ironrdp_ainput::CHANNEL_NAME
104    }
105
106    fn start(&mut self, _channel_id: u32) -> PduResult<Vec<dvc::DvcMessage>> {
107        use ironrdp_ainput::{ServerPdu, VersionPdu};
108
109        let pdu = ServerPdu::Version(VersionPdu::default());
110
111        Ok(vec![Box::new(pdu)])
112    }
113
114    fn close(&mut self, _channel_id: u32) {}
115
116    fn process(&mut self, _channel_id: u32, payload: &[u8]) -> PduResult<Vec<dvc::DvcMessage>> {
117        use ironrdp_ainput::ClientPdu;
118
119        match decode(payload).map_err(|e| decode_err!(e))? {
120            ClientPdu::Mouse(pdu) => {
121                let handler = Arc::clone(&self.handler);
122                task::spawn_blocking(move || {
123                    handler.blocking_lock().mouse(pdu.into());
124                });
125            }
126        }
127
128        Ok(Vec::new())
129    }
130}
131
132impl dvc::DvcServerProcessor for AInputHandler {}
133
134struct DisplayControlBackend {
135    display: Arc<Mutex<Box<dyn RdpServerDisplay>>>,
136}
137
138impl DisplayControlBackend {
139    fn new(display: Arc<Mutex<Box<dyn RdpServerDisplay>>>) -> Self {
140        Self { display }
141    }
142}
143
144impl DisplayControlHandler for DisplayControlBackend {
145    fn monitor_layout(&self, layout: DisplayControlMonitorLayout) {
146        let display = Arc::clone(&self.display);
147        task::spawn_blocking(move || display.blocking_lock().request_layout(layout));
148    }
149}
150
151/// RDP Server
152///
153/// A server is created to listen for connections.
154/// After the connection sequence is finalized using the provided security mechanism, the server can:
155///  - receive display updates from a [`RdpServerDisplay`] and forward them to the client
156///  - receive input events from a client and forward them to an [`RdpServerInputHandler`]
157///
158/// # Example
159///
160/// ```
161/// use ironrdp_server::{RdpServer, RdpServerInputHandler, RdpServerDisplay, RdpServerDisplayUpdates};
162///
163///# use anyhow::Result;
164///# use ironrdp_server::{DisplayUpdate, DesktopSize, KeyboardEvent, MouseEvent};
165///# use tokio_rustls::TlsAcceptor;
166///# struct NoopInputHandler;
167///# impl RdpServerInputHandler for NoopInputHandler {
168///#     fn keyboard(&mut self, _: KeyboardEvent) {}
169///#     fn mouse(&mut self, _: MouseEvent) {}
170///# }
171///# struct NoopDisplay;
172///# #[async_trait::async_trait]
173///# impl RdpServerDisplay for NoopDisplay {
174///#     async fn size(&mut self) -> DesktopSize {
175///#         todo!()
176///#     }
177///#     async fn updates(&mut self) -> Result<Box<dyn RdpServerDisplayUpdates>> {
178///#         todo!()
179///#     }
180///# }
181///# async fn stub() -> Result<()> {
182/// fn make_tls_acceptor() -> TlsAcceptor {
183///    /* snip */
184///#    todo!()
185/// }
186///
187/// fn make_input_handler() -> impl RdpServerInputHandler {
188///    /* snip */
189///#    NoopInputHandler
190/// }
191///
192/// fn make_display_handler() -> impl RdpServerDisplay {
193///    /* snip */
194///#    NoopDisplay
195/// }
196///
197/// let tls_acceptor = make_tls_acceptor();
198/// let input_handler = make_input_handler();
199/// let display_handler = make_display_handler();
200///
201/// let mut server = RdpServer::builder()
202///     .with_addr(([127, 0, 0, 1], 3389))
203///     .with_tls(tls_acceptor)
204///     .with_input_handler(input_handler)
205///     .with_display_handler(display_handler)
206///     .build();
207///
208/// server.run().await;
209/// Ok(())
210///# }
211/// ```
212pub struct RdpServer {
213    opts: RdpServerOptions,
214    // FIXME: replace with a channel and poll/process the handler?
215    handler: Arc<Mutex<Box<dyn RdpServerInputHandler>>>,
216    display: Arc<Mutex<Box<dyn RdpServerDisplay>>>,
217    static_channels: StaticChannelSet,
218    sound_factory: Option<Box<dyn SoundServerFactory>>,
219    cliprdr_factory: Option<Box<dyn CliprdrServerFactory>>,
220    ev_sender: mpsc::UnboundedSender<ServerEvent>,
221    ev_receiver: Arc<Mutex<mpsc::UnboundedReceiver<ServerEvent>>>,
222    creds: Option<Credentials>,
223    local_addr: Option<SocketAddr>,
224}
225
226#[derive(Debug)]
227pub enum ServerEvent {
228    Quit(String),
229    Clipboard(ClipboardMessage),
230    Rdpsnd(RdpsndServerMessage),
231    SetCredentials(Credentials),
232    GetLocalAddr(oneshot::Sender<Option<SocketAddr>>),
233}
234
235pub trait ServerEventSender {
236    fn set_sender(&mut self, sender: mpsc::UnboundedSender<ServerEvent>);
237}
238
239impl ServerEvent {
240    pub fn create_channel() -> (mpsc::UnboundedSender<Self>, mpsc::UnboundedReceiver<Self>) {
241        mpsc::unbounded_channel()
242    }
243}
244
245#[derive(Debug, PartialEq)]
246enum RunState {
247    Continue,
248    Disconnect,
249    DeactivationReactivation { desktop_size: DesktopSize },
250}
251
252impl RdpServer {
253    pub fn new(
254        opts: RdpServerOptions,
255        handler: Box<dyn RdpServerInputHandler>,
256        display: Box<dyn RdpServerDisplay>,
257        mut sound_factory: Option<Box<dyn SoundServerFactory>>,
258        mut cliprdr_factory: Option<Box<dyn CliprdrServerFactory>>,
259    ) -> Self {
260        let (ev_sender, ev_receiver) = ServerEvent::create_channel();
261        if let Some(cliprdr) = cliprdr_factory.as_mut() {
262            cliprdr.set_sender(ev_sender.clone());
263        }
264        if let Some(snd) = sound_factory.as_mut() {
265            snd.set_sender(ev_sender.clone());
266        }
267        Self {
268            opts,
269            handler: Arc::new(Mutex::new(handler)),
270            display: Arc::new(Mutex::new(display)),
271            static_channels: StaticChannelSet::new(),
272            sound_factory,
273            cliprdr_factory,
274            ev_sender,
275            ev_receiver: Arc::new(Mutex::new(ev_receiver)),
276            creds: None,
277            local_addr: None,
278        }
279    }
280
281    pub fn builder() -> builder::RdpServerBuilder<builder::WantsAddr> {
282        builder::RdpServerBuilder::new()
283    }
284
285    pub fn event_sender(&self) -> &mpsc::UnboundedSender<ServerEvent> {
286        &self.ev_sender
287    }
288
289    fn attach_channels(&mut self, acceptor: &mut Acceptor) {
290        if let Some(cliprdr_factory) = self.cliprdr_factory.as_deref() {
291            let backend = cliprdr_factory.build_cliprdr_backend();
292
293            let cliprdr = CliprdrServer::new(backend);
294
295            acceptor.attach_static_channel(cliprdr);
296        }
297
298        if let Some(factory) = self.sound_factory.as_deref() {
299            let backend = factory.build_backend();
300
301            acceptor.attach_static_channel(RdpsndServer::new(backend));
302        }
303
304        let dcs_backend = DisplayControlBackend::new(Arc::clone(&self.display));
305        let dvc = dvc::DrdynvcServer::new()
306            .with_dynamic_channel(AInputHandler {
307                handler: Arc::clone(&self.handler),
308            })
309            .with_dynamic_channel(DisplayControlServer::new(Box::new(dcs_backend)));
310        acceptor.attach_static_channel(dvc);
311    }
312
313    pub async fn run_connection(&mut self, stream: TcpStream) -> Result<()> {
314        let framed = TokioFramed::new(stream);
315
316        let size = self.display.lock().await.size().await;
317        let capabilities = capabilities::capabilities(&self.opts, size);
318        let mut acceptor = Acceptor::new(self.opts.security.flag(), size, capabilities, self.creds.clone());
319
320        self.attach_channels(&mut acceptor);
321
322        let res = ironrdp_acceptor::accept_begin(framed, &mut acceptor)
323            .await
324            .context("accept_begin failed")?;
325
326        match res {
327            BeginResult::ShouldUpgrade(stream) => {
328                let tls_acceptor = match &self.opts.security {
329                    RdpServerSecurity::Tls(acceptor) => acceptor,
330                    RdpServerSecurity::Hybrid((acceptor, _)) => acceptor,
331                    RdpServerSecurity::None => unreachable!(),
332                };
333                let accept = match tls_acceptor.accept(stream).await {
334                    Ok(accept) => accept,
335                    Err(e) => {
336                        warn!("Failed to TLS accept: {}", e);
337                        return Ok(());
338                    }
339                };
340                let mut framed = TokioFramed::new(accept);
341
342                acceptor.mark_security_upgrade_as_done();
343
344                if let RdpServerSecurity::Hybrid((_, pub_key)) = &self.opts.security {
345                    // how to get the client name?
346                    // doesn't seem to matter yet
347                    let client_name = framed.get_inner().0.get_ref().0.peer_addr()?.to_string();
348
349                    ironrdp_acceptor::accept_credssp(
350                        &mut framed,
351                        &mut acceptor,
352                        &mut ironrdp_tokio::reqwest::ReqwestNetworkClient::new(),
353                        client_name.into(),
354                        pub_key.clone(),
355                        None,
356                    )
357                    .await?;
358                }
359
360                let framed = self.accept_finalize(framed, acceptor).await?;
361                debug!("Shutting down TLS connection");
362                let (mut tls_stream, _) = framed.into_inner();
363                if let Err(e) = tls_stream.shutdown().await {
364                    debug!(?e, "TLS shutdown error");
365                }
366            }
367
368            BeginResult::Continue(framed) => {
369                self.accept_finalize(framed, acceptor).await?;
370            }
371        };
372
373        Ok(())
374    }
375
376    pub async fn run(&mut self) -> Result<()> {
377        let listener = TcpListener::bind(self.opts.addr).await?;
378        let local_addr = listener.local_addr()?;
379
380        debug!("Listening for connections on {local_addr}");
381        self.local_addr = Some(local_addr);
382
383        loop {
384            let ev_receiver = Arc::clone(&self.ev_receiver);
385            let mut ev_receiver = ev_receiver.lock().await;
386            tokio::select! {
387                Some(event) = ev_receiver.recv() => {
388                    match event {
389                        ServerEvent::Quit(reason) => {
390                            debug!("Got quit event {reason}");
391                            break;
392                        }
393                        ServerEvent::GetLocalAddr(tx) => {
394                            let _ = tx.send(self.local_addr);
395                        }
396                        ServerEvent::SetCredentials(creds) => {
397                            self.set_credentials(Some(creds));
398                        }
399                        ev => {
400                            debug!("Unexpected event {:?}", ev);
401                        }
402                    }
403                },
404                Ok((stream, peer)) = listener.accept() => {
405                    debug!(?peer, "Received connection");
406                    drop(ev_receiver);
407                    if let Err(error) = self.run_connection(stream).await {
408                        error!(?error, "Connection error");
409                    }
410                    self.static_channels = StaticChannelSet::new();
411                }
412                else => break,
413            }
414        }
415
416        Ok(())
417    }
418
419    pub fn get_svc_processor<T: SvcProcessor + 'static>(&mut self) -> Option<&mut T> {
420        self.static_channels
421            .get_by_type_mut::<T>()
422            .and_then(|svc| svc.channel_processor_downcast_mut())
423    }
424
425    pub fn get_channel_id_by_type<T: SvcProcessor + 'static>(&self) -> Option<StaticChannelId> {
426        self.static_channels.get_channel_id_by_type::<T>()
427    }
428
429    async fn dispatch_pdu(
430        &mut self,
431        action: Action,
432        bytes: bytes::BytesMut,
433        writer: &mut impl FramedWrite,
434        io_channel_id: u16,
435        user_channel_id: u16,
436    ) -> Result<RunState> {
437        match action {
438            Action::FastPath => {
439                let input = decode(&bytes)?;
440                self.handle_fastpath(input).await;
441            }
442
443            Action::X224 => {
444                if self
445                    .handle_x224(writer, io_channel_id, user_channel_id, &bytes)
446                    .await
447                    .context("X224 input error")?
448                {
449                    debug!("Got disconnect request");
450                    return Ok(RunState::Disconnect);
451                }
452            }
453        }
454
455        Ok(RunState::Continue)
456    }
457
458    async fn dispatch_display_update(
459        update: DisplayUpdate,
460        writer: &mut impl FramedWrite,
461        user_channel_id: u16,
462        io_channel_id: u16,
463        buffer: &mut Vec<u8>,
464        mut encoder: UpdateEncoder,
465    ) -> Result<(RunState, UpdateEncoder)> {
466        if let DisplayUpdate::Resize(desktop_size) = update {
467            debug!(?desktop_size, "Display resize");
468            encoder.set_desktop_size(desktop_size);
469            deactivate_all(io_channel_id, user_channel_id, writer).await?;
470            return Ok((RunState::DeactivationReactivation { desktop_size }, encoder));
471        }
472
473        let mut encoder_iter = encoder.update(update);
474        loop {
475            let Some(fragmenter) = encoder_iter.next().await else {
476                break;
477            };
478
479            let mut fragmenter = fragmenter.context("error while encoding")?;
480            if fragmenter.size_hint() > buffer.len() {
481                buffer.resize(fragmenter.size_hint(), 0);
482            }
483
484            while let Some(len) = fragmenter.next(buffer) {
485                writer
486                    .write_all(&buffer[..len])
487                    .await
488                    .context("failed to write display update")?;
489            }
490        }
491
492        Ok((RunState::Continue, encoder))
493    }
494
495    async fn dispatch_server_events(
496        &mut self,
497        events: &mut Vec<ServerEvent>,
498        writer: &mut impl FramedWrite,
499        user_channel_id: u16,
500    ) -> Result<RunState> {
501        // Avoid wave message queuing up and causing extra delays.
502        // This is a naive solution, better solutions should compute the actual delay, add IO priority, encode audio, use UDP etc.
503        // 4 frames should roughly corresponds to hundreds of ms in regular setups.
504        let mut wave_limit = 4;
505        for event in events.drain(..) {
506            trace!(?event, "Dispatching");
507            match event {
508                ServerEvent::Quit(reason) => {
509                    debug!("Got quit event: {reason}");
510                    return Ok(RunState::Disconnect);
511                }
512                ServerEvent::GetLocalAddr(tx) => {
513                    let _ = tx.send(self.local_addr);
514                }
515                ServerEvent::SetCredentials(creds) => {
516                    self.set_credentials(Some(creds));
517                }
518                ServerEvent::Rdpsnd(s) => {
519                    let Some(rdpsnd) = self.get_svc_processor::<RdpsndServer>() else {
520                        warn!("No rdpsnd channel, dropping event");
521                        continue;
522                    };
523                    let msgs = match s {
524                        RdpsndServerMessage::Wave(data, ts) => {
525                            if wave_limit == 0 {
526                                debug!("Dropping wave");
527                                continue;
528                            }
529                            wave_limit -= 1;
530                            rdpsnd.wave(data, ts)
531                        }
532                        RdpsndServerMessage::SetVolume { left, right } => rdpsnd.set_volume(left, right),
533                        RdpsndServerMessage::Close => rdpsnd.close(),
534                        RdpsndServerMessage::Error(error) => {
535                            error!(?error, "Handling rdpsnd event");
536                            continue;
537                        }
538                    }
539                    .context("failed to send rdpsnd event")?;
540                    let channel_id = self
541                        .get_channel_id_by_type::<RdpsndServer>()
542                        .ok_or_else(|| anyhow!("SVC channel not found"))?;
543                    let data = server_encode_svc_messages(msgs.into(), channel_id, user_channel_id)?;
544                    writer.write_all(&data).await?;
545                }
546                ServerEvent::Clipboard(c) => {
547                    let Some(cliprdr) = self.get_svc_processor::<CliprdrServer>() else {
548                        warn!("No clipboard channel, dropping event");
549                        continue;
550                    };
551                    let msgs = match c {
552                        ClipboardMessage::SendInitiateCopy(formats) => cliprdr.initiate_copy(&formats),
553                        ClipboardMessage::SendFormatData(data) => cliprdr.submit_format_data(data),
554                        ClipboardMessage::SendInitiatePaste(format) => cliprdr.initiate_paste(format),
555                        ClipboardMessage::Error(error) => {
556                            error!(?error, "Handling clipboard event");
557                            continue;
558                        }
559                    }
560                    .context("failed to send clipboard event")?;
561                    let channel_id = self
562                        .get_channel_id_by_type::<CliprdrServer>()
563                        .ok_or_else(|| anyhow!("SVC channel not found"))?;
564                    let data = server_encode_svc_messages(msgs.into(), channel_id, user_channel_id)?;
565                    writer.write_all(&data).await?;
566                }
567            }
568        }
569
570        Ok(RunState::Continue)
571    }
572
573    async fn client_loop<R, W>(
574        &mut self,
575        reader: &mut Framed<R>,
576        writer: &mut Framed<W>,
577        io_channel_id: u16,
578        user_channel_id: u16,
579        mut encoder: UpdateEncoder,
580    ) -> Result<RunState>
581    where
582        R: FramedRead,
583        W: FramedWrite,
584    {
585        debug!("Starting client loop");
586        let mut display_updates = self.display.lock().await.updates().await?;
587        let mut writer = SharedWriter::new(writer);
588        let mut display_writer = writer.clone();
589        let mut event_writer = writer.clone();
590        let ev_receiver = Arc::clone(&self.ev_receiver);
591        let s = Rc::new(Mutex::new(self));
592
593        let this = Rc::clone(&s);
594        let dispatch_pdu = async move {
595            loop {
596                let (action, bytes) = reader.read_pdu().await?;
597                let mut this = this.lock().await;
598                match this
599                    .dispatch_pdu(action, bytes, &mut writer, io_channel_id, user_channel_id)
600                    .await?
601                {
602                    RunState::Continue => continue,
603                    state => break Ok(state),
604                }
605            }
606        };
607
608        let dispatch_display = async move {
609            let mut buffer = vec![0u8; 4096];
610
611            loop {
612                match display_updates.next_update().await {
613                    Ok(Some(update)) => {
614                        match Self::dispatch_display_update(
615                            update,
616                            &mut display_writer,
617                            user_channel_id,
618                            io_channel_id,
619                            &mut buffer,
620                            encoder,
621                        )
622                        .await?
623                        {
624                            (RunState::Continue, enc) => {
625                                encoder = enc;
626                                continue;
627                            }
628                            (state, _) => {
629                                break Ok(state);
630                            }
631                        }
632                    }
633                    Ok(None) => {
634                        break Ok(RunState::Disconnect);
635                    }
636                    Err(error) => {
637                        warn!(error = format!("{error:#}"), "next_updated failed");
638                    }
639                }
640            }
641        };
642
643        let this = Rc::clone(&s);
644        let mut ev_receiver = ev_receiver.lock().await;
645        let dispatch_events = async move {
646            let mut events = Vec::with_capacity(100);
647            loop {
648                let nevents = ev_receiver.recv_many(&mut events, 100).await;
649                if nevents == 0 {
650                    debug!("No sever events.. stopping");
651                    break Ok(RunState::Disconnect);
652                }
653                while let Ok(ev) = ev_receiver.try_recv() {
654                    events.push(ev);
655                }
656                let mut this = this.lock().await;
657                match this
658                    .dispatch_server_events(&mut events, &mut event_writer, user_channel_id)
659                    .await?
660                {
661                    RunState::Continue => continue,
662                    state => break Ok(state),
663                }
664            }
665        };
666
667        let state = tokio::select!(
668            state = dispatch_pdu => state,
669            state = dispatch_display => state,
670            state = dispatch_events => state,
671        );
672
673        debug!("End of client loop: {state:?}");
674        state
675    }
676
677    async fn client_accepted<R, W>(
678        &mut self,
679        reader: &mut Framed<R>,
680        writer: &mut Framed<W>,
681        result: AcceptorResult,
682    ) -> Result<RunState>
683    where
684        R: FramedRead,
685        W: FramedWrite,
686    {
687        debug!("Client accepted");
688
689        if !result.input_events.is_empty() {
690            debug!("Handling input event backlog from acceptor sequence");
691            self.handle_input_backlog(
692                writer,
693                result.io_channel_id,
694                result.user_channel_id,
695                result.input_events,
696            )
697            .await?;
698        }
699
700        self.static_channels = result.static_channels;
701        if !result.reactivation {
702            for (_type_id, channel, channel_id) in self.static_channels.iter_mut() {
703                debug!(?channel, ?channel_id, "Start");
704                let Some(channel_id) = channel_id else {
705                    continue;
706                };
707                let svc_responses = channel.start()?;
708                let response = server_encode_svc_messages(svc_responses, channel_id, result.user_channel_id)?;
709                writer.write_all(&response).await?;
710            }
711        }
712
713        let mut update_codecs = UpdateEncoderCodecs::new();
714        let mut surface_flags = CmdFlags::empty();
715        for c in result.capabilities {
716            match c {
717                CapabilitySet::General(c) => {
718                    let fastpath = c.extra_flags.contains(GeneralExtraFlags::FASTPATH_OUTPUT_SUPPORTED);
719                    if !fastpath {
720                        bail!("Fastpath output not supported!");
721                    }
722                }
723                CapabilitySet::Bitmap(b) => {
724                    if !b.desktop_resize_flag {
725                        debug!("Desktop resize is not supported by the client");
726                        continue;
727                    }
728
729                    let client_size = DesktopSize {
730                        width: b.desktop_width,
731                        height: b.desktop_height,
732                    };
733                    let display_size = self.display.lock().await.size().await;
734
735                    // It's problematic when the client didn't resize, as we send bitmap updates that don't fit.
736                    // The client will likely drop the connection.
737                    if client_size.width < display_size.width || client_size.height < display_size.height {
738                        // TODO: we may have different behaviour instead, such as clipping or scaling?
739                        warn!(
740                            "Client size doesn't fit the server size: {:?} < {:?}",
741                            client_size, display_size
742                        );
743                    }
744                }
745                CapabilitySet::SurfaceCommands(c) => {
746                    surface_flags = c.flags;
747                }
748                CapabilitySet::BitmapCodecs(BitmapCodecs(codecs)) => {
749                    for codec in codecs {
750                        match codec.property {
751                            // FIXME: The encoder operates in image mode only.
752                            //
753                            // See [MS-RDPRFX] 3.1.1.1 "State Machine" for
754                            // implementation of the video mode. which allows to
755                            // skip sending Header for each image.
756                            //
757                            // We should distinguish parameters for both modes,
758                            // and somehow choose the "best", instead of picking
759                            // the last parsed here.
760                            CodecProperty::RemoteFx(rdp::capability_sets::RemoteFxContainer::ClientContainer(c))
761                                if self.opts.has_remote_fx() =>
762                            {
763                                for caps in c.caps_data.0 .0 {
764                                    update_codecs.set_remotefx(Some((caps.entropy_bits, codec.id)));
765                                }
766                            }
767                            CodecProperty::ImageRemoteFx(rdp::capability_sets::RemoteFxContainer::ClientContainer(
768                                c,
769                            )) if self.opts.has_image_remote_fx() => {
770                                for caps in c.caps_data.0 .0 {
771                                    update_codecs.set_remotefx(Some((caps.entropy_bits, codec.id)));
772                                }
773                            }
774                            CodecProperty::NsCodec(_) => (),
775                            #[cfg(feature = "qoi")]
776                            CodecProperty::Qoi if self.opts.has_qoi() => {
777                                update_codecs.set_qoi(Some(codec.id));
778                            }
779                            #[cfg(feature = "qoiz")]
780                            CodecProperty::QoiZ if self.opts.has_qoiz() => {
781                                update_codecs.set_qoiz(Some(codec.id));
782                            }
783                            _ => (),
784                        }
785                    }
786                }
787                _ => {}
788            }
789        }
790
791        let desktop_size = self.display.lock().await.size().await;
792        let encoder = UpdateEncoder::new(desktop_size, surface_flags, update_codecs)
793            .context("failed to initialize update encoder")?;
794
795        let state = self
796            .client_loop(reader, writer, result.io_channel_id, result.user_channel_id, encoder)
797            .await
798            .context("client loop failure")?;
799
800        Ok(state)
801    }
802
803    async fn handle_input_backlog(
804        &mut self,
805        writer: &mut impl FramedWrite,
806        io_channel_id: u16,
807        user_channel_id: u16,
808        frames: Vec<Vec<u8>>,
809    ) -> Result<()> {
810        for frame in frames {
811            match Action::from_fp_output_header(frame[0]) {
812                Ok(Action::FastPath) => {
813                    let input = decode(&frame)?;
814                    self.handle_fastpath(input).await;
815                }
816
817                Ok(Action::X224) => {
818                    let _ = self.handle_x224(writer, io_channel_id, user_channel_id, &frame).await;
819                }
820
821                // the frame here is always valid, because otherwise it would
822                // have failed during the acceptor loop
823                Err(_) => unreachable!(),
824            }
825        }
826
827        Ok(())
828    }
829
830    async fn handle_fastpath(&mut self, input: FastPathInput) {
831        for event in input.input_events().iter().copied() {
832            let mut handler = self.handler.lock().await;
833            match event {
834                FastPathInputEvent::KeyboardEvent(flags, key) => {
835                    handler.keyboard((key, flags).into());
836                }
837
838                FastPathInputEvent::UnicodeKeyboardEvent(flags, key) => {
839                    handler.keyboard((key, flags).into());
840                }
841
842                FastPathInputEvent::SyncEvent(flags) => {
843                    handler.keyboard(flags.into());
844                }
845
846                FastPathInputEvent::MouseEvent(mouse) => {
847                    handler.mouse(mouse.into());
848                }
849
850                FastPathInputEvent::MouseEventEx(mouse) => {
851                    handler.mouse(mouse.into());
852                }
853
854                FastPathInputEvent::MouseEventRel(mouse) => {
855                    handler.mouse(mouse.into());
856                }
857
858                FastPathInputEvent::QoeEvent(quality) => {
859                    warn!("Received QoE: {}", quality);
860                }
861            }
862        }
863    }
864
865    async fn handle_io_channel_data(&mut self, data: SendDataRequest<'_>) -> Result<bool> {
866        let control: rdp::headers::ShareControlHeader = decode(data.user_data.as_ref())?;
867
868        match control.share_control_pdu {
869            ShareControlPdu::Data(header) => match header.share_data_pdu {
870                rdp::headers::ShareDataPdu::Input(pdu) => {
871                    self.handle_input_event(pdu).await;
872                }
873
874                rdp::headers::ShareDataPdu::ShutdownRequest => {
875                    return Ok(true);
876                }
877
878                unexpected => {
879                    warn!(?unexpected, "Unexpected share data pdu");
880                }
881            },
882
883            unexpected => {
884                warn!(?unexpected, "Unexpected share control");
885            }
886        }
887
888        Ok(false)
889    }
890
891    async fn handle_x224(
892        &mut self,
893        writer: &mut impl FramedWrite,
894        io_channel_id: u16,
895        user_channel_id: u16,
896        frame: &[u8],
897    ) -> Result<bool> {
898        let message = decode::<X224<mcs::McsMessage<'_>>>(frame)?;
899        match message.0 {
900            mcs::McsMessage::SendDataRequest(data) => {
901                debug!(?data, "McsMessage::SendDataRequest");
902                if data.channel_id == io_channel_id {
903                    return self.handle_io_channel_data(data).await;
904                }
905
906                if let Some(svc) = self.static_channels.get_by_channel_id_mut(data.channel_id) {
907                    let response_pdus = svc.process(&data.user_data)?;
908                    let response = server_encode_svc_messages(response_pdus, data.channel_id, user_channel_id)?;
909                    writer.write_all(&response).await?;
910                } else {
911                    warn!(channel_id = data.channel_id, "Unexpected channel received: ID",);
912                }
913            }
914
915            mcs::McsMessage::DisconnectProviderUltimatum(disconnect) => {
916                if disconnect.reason == mcs::DisconnectReason::UserRequested {
917                    return Ok(true);
918                }
919            }
920
921            _ => {
922                warn!(name = ironrdp_core::name(&message), "Unexpected mcs message");
923            }
924        }
925
926        Ok(false)
927    }
928
929    async fn handle_input_event(&mut self, input: InputEventPdu) {
930        for event in input.0 {
931            let mut handler = self.handler.lock().await;
932            match event {
933                ironrdp_pdu::input::InputEvent::ScanCode(key) => {
934                    handler.keyboard((key.key_code, key.flags).into());
935                }
936
937                ironrdp_pdu::input::InputEvent::Unicode(key) => {
938                    handler.keyboard((key.unicode_code, key.flags).into());
939                }
940
941                ironrdp_pdu::input::InputEvent::Sync(sync) => {
942                    handler.keyboard(sync.flags.into());
943                }
944
945                ironrdp_pdu::input::InputEvent::Mouse(mouse) => {
946                    handler.mouse(mouse.into());
947                }
948
949                ironrdp_pdu::input::InputEvent::MouseX(mouse) => {
950                    handler.mouse(mouse.into());
951                }
952
953                ironrdp_pdu::input::InputEvent::MouseRel(mouse) => {
954                    handler.mouse(mouse.into());
955                }
956
957                ironrdp_pdu::input::InputEvent::Unused(_) => {}
958            }
959        }
960    }
961
962    async fn accept_finalize<S>(&mut self, mut framed: TokioFramed<S>, mut acceptor: Acceptor) -> Result<TokioFramed<S>>
963    where
964        S: AsyncRead + AsyncWrite + Sync + Send + Unpin,
965    {
966        loop {
967            let (new_framed, result) = ironrdp_acceptor::accept_finalize(framed, &mut acceptor)
968                .await
969                .context("failed to accept client during finalize")?;
970
971            let (mut reader, mut writer) = split_tokio_framed(new_framed);
972
973            match self.client_accepted(&mut reader, &mut writer, result).await? {
974                RunState::Continue => {
975                    unreachable!();
976                }
977                RunState::DeactivationReactivation { desktop_size } => {
978                    // No description of such behavior was found in the
979                    // specification, but apparently, we must keep the channel
980                    // state as they were during reactivation. This fixes
981                    // various state issues during client resize.
982                    acceptor = Acceptor::new_deactivation_reactivation(
983                        acceptor,
984                        core::mem::take(&mut self.static_channels),
985                        desktop_size,
986                    )?;
987                    framed = unsplit_tokio_framed(reader, writer);
988                    continue;
989                }
990                RunState::Disconnect => {
991                    let final_framed = unsplit_tokio_framed(reader, writer);
992                    return Ok(final_framed);
993                }
994            }
995        }
996    }
997
998    pub fn set_credentials(&mut self, creds: Option<Credentials>) {
999        debug!(?creds, "Changing credentials");
1000        self.creds = creds
1001    }
1002}
1003
1004async fn deactivate_all(
1005    io_channel_id: u16,
1006    user_channel_id: u16,
1007    writer: &mut impl FramedWrite,
1008) -> Result<(), anyhow::Error> {
1009    let pdu = ShareControlPdu::ServerDeactivateAll(ServerDeactivateAll);
1010    let pdu = rdp::headers::ShareControlHeader {
1011        share_id: 0,
1012        pdu_source: io_channel_id,
1013        share_control_pdu: pdu,
1014    };
1015    let user_data = encode_vec(&pdu)?.into();
1016    let pdu = SendDataIndication {
1017        initiator_id: user_channel_id,
1018        channel_id: io_channel_id,
1019        user_data,
1020    };
1021    let msg = encode_vec(&X224(pdu))?;
1022    writer.write_all(&msg).await?;
1023    Ok(())
1024}
1025
1026struct SharedWriter<'w, W: FramedWrite> {
1027    writer: Rc<Mutex<&'w mut W>>,
1028}
1029
1030impl<W: FramedWrite> Clone for SharedWriter<'_, W> {
1031    fn clone(&self) -> Self {
1032        Self {
1033            writer: Rc::clone(&self.writer),
1034        }
1035    }
1036}
1037
1038impl<W> FramedWrite for SharedWriter<'_, W>
1039where
1040    W: FramedWrite,
1041{
1042    type WriteAllFut<'write>
1043        = core::pin::Pin<Box<dyn core::future::Future<Output = std::io::Result<()>> + 'write>>
1044    where
1045        Self: 'write;
1046
1047    fn write_all<'a>(&'a mut self, buf: &'a [u8]) -> Self::WriteAllFut<'a> {
1048        Box::pin(async {
1049            let mut writer = self.writer.lock().await;
1050
1051            writer.write_all(buf).await?;
1052            Ok(())
1053        })
1054    }
1055}
1056
1057impl<'a, W: FramedWrite> SharedWriter<'a, W> {
1058    fn new(writer: &'a mut W) -> Self {
1059        Self {
1060            writer: Rc::new(Mutex::new(writer)),
1061        }
1062    }
1063}