ironrdp_server/
server.rs

1use core::net::SocketAddr;
2use std::rc::Rc;
3use std::sync::Arc;
4
5use anyhow::{anyhow, bail, Context as _, Result};
6use ironrdp_acceptor::{Acceptor, AcceptorResult, BeginResult, DesktopSize};
7use ironrdp_async::Framed;
8use ironrdp_cliprdr::backend::ClipboardMessage;
9use ironrdp_cliprdr::CliprdrServer;
10use ironrdp_core::{decode, encode_vec, impl_as_any};
11use ironrdp_displaycontrol::pdu::DisplayControlMonitorLayout;
12use ironrdp_displaycontrol::server::{DisplayControlHandler, DisplayControlServer};
13use ironrdp_pdu::input::fast_path::{FastPathInput, FastPathInputEvent};
14use ironrdp_pdu::input::InputEventPdu;
15use ironrdp_pdu::mcs::{SendDataIndication, SendDataRequest};
16use ironrdp_pdu::rdp::capability_sets::{BitmapCodecs, CapabilitySet, CmdFlags, CodecProperty, GeneralExtraFlags};
17pub use ironrdp_pdu::rdp::client_info::Credentials;
18use ironrdp_pdu::rdp::headers::{ServerDeactivateAll, ShareControlPdu};
19use ironrdp_pdu::x224::X224;
20use ironrdp_pdu::{decode_err, mcs, nego, rdp, Action, PduResult};
21use ironrdp_svc::{server_encode_svc_messages, StaticChannelId, StaticChannelSet, SvcProcessor};
22use ironrdp_tokio::{split_tokio_framed, unsplit_tokio_framed, FramedRead, FramedWrite, TokioFramed};
23use rdpsnd::server::{RdpsndServer, RdpsndServerMessage};
24use tokio::io::{AsyncRead, AsyncWrite};
25use tokio::net::{TcpListener, TcpStream};
26use tokio::sync::{mpsc, oneshot, Mutex};
27use tokio::task;
28use tokio_rustls::TlsAcceptor;
29use tracing::{debug, error, trace, warn};
30use {ironrdp_dvc as dvc, ironrdp_rdpsnd as rdpsnd};
31
32use crate::clipboard::CliprdrServerFactory;
33use crate::display::{DisplayUpdate, RdpServerDisplay};
34use crate::encoder::{UpdateEncoder, UpdateEncoderCodecs};
35use crate::handler::RdpServerInputHandler;
36use crate::{builder, capabilities, SoundServerFactory};
37
38#[derive(Clone)]
39pub struct RdpServerOptions {
40    pub addr: SocketAddr,
41    pub security: RdpServerSecurity,
42    pub codecs: BitmapCodecs,
43}
44
45impl RdpServerOptions {
46    fn has_image_remote_fx(&self) -> bool {
47        self.codecs
48            .0
49            .iter()
50            .any(|codec| matches!(codec.property, CodecProperty::ImageRemoteFx(_)))
51    }
52
53    fn has_remote_fx(&self) -> bool {
54        self.codecs
55            .0
56            .iter()
57            .any(|codec| matches!(codec.property, CodecProperty::RemoteFx(_)))
58    }
59
60    #[cfg(feature = "qoi")]
61    fn has_qoi(&self) -> bool {
62        self.codecs
63            .0
64            .iter()
65            .any(|codec| matches!(codec.property, CodecProperty::Qoi))
66    }
67
68    #[cfg(feature = "qoiz")]
69    fn has_qoiz(&self) -> bool {
70        self.codecs
71            .0
72            .iter()
73            .any(|codec| matches!(codec.property, CodecProperty::QoiZ))
74    }
75}
76
77#[derive(Clone)]
78pub enum RdpServerSecurity {
79    None,
80    Tls(TlsAcceptor),
81    /// Used for both hybrid + hybrid-ex.
82    Hybrid((TlsAcceptor, Vec<u8>)),
83}
84
85impl RdpServerSecurity {
86    pub fn flag(&self) -> nego::SecurityProtocol {
87        match self {
88            RdpServerSecurity::None => nego::SecurityProtocol::empty(),
89            RdpServerSecurity::Tls(_) => nego::SecurityProtocol::SSL,
90            RdpServerSecurity::Hybrid(_) => nego::SecurityProtocol::HYBRID | nego::SecurityProtocol::HYBRID_EX,
91        }
92    }
93}
94
95struct AInputHandler {
96    handler: Arc<Mutex<Box<dyn RdpServerInputHandler>>>,
97}
98
99impl_as_any!(AInputHandler);
100
101impl dvc::DvcProcessor for AInputHandler {
102    fn channel_name(&self) -> &str {
103        ironrdp_ainput::CHANNEL_NAME
104    }
105
106    fn start(&mut self, _channel_id: u32) -> PduResult<Vec<dvc::DvcMessage>> {
107        use ironrdp_ainput::{ServerPdu, VersionPdu};
108
109        let pdu = ServerPdu::Version(VersionPdu::default());
110
111        Ok(vec![Box::new(pdu)])
112    }
113
114    fn close(&mut self, _channel_id: u32) {}
115
116    fn process(&mut self, _channel_id: u32, payload: &[u8]) -> PduResult<Vec<dvc::DvcMessage>> {
117        use ironrdp_ainput::ClientPdu;
118
119        match decode(payload).map_err(|e| decode_err!(e))? {
120            ClientPdu::Mouse(pdu) => {
121                let handler = Arc::clone(&self.handler);
122                task::spawn_blocking(move || {
123                    handler.blocking_lock().mouse(pdu.into());
124                });
125            }
126        }
127
128        Ok(Vec::new())
129    }
130}
131
132impl dvc::DvcServerProcessor for AInputHandler {}
133
134struct DisplayControlBackend {
135    display: Arc<Mutex<Box<dyn RdpServerDisplay>>>,
136}
137
138impl DisplayControlBackend {
139    fn new(display: Arc<Mutex<Box<dyn RdpServerDisplay>>>) -> Self {
140        Self { display }
141    }
142}
143
144impl DisplayControlHandler for DisplayControlBackend {
145    fn monitor_layout(&self, layout: DisplayControlMonitorLayout) {
146        let display = Arc::clone(&self.display);
147        task::spawn_blocking(move || display.blocking_lock().request_layout(layout));
148    }
149}
150
151/// RDP Server
152///
153/// A server is created to listen for connections.
154/// After the connection sequence is finalized using the provided security mechanism, the server can:
155///  - receive display updates from a [`RdpServerDisplay`] and forward them to the client
156///  - receive input events from a client and forward them to an [`RdpServerInputHandler`]
157///
158/// # Example
159///
160/// ```
161/// use ironrdp_server::{RdpServer, RdpServerInputHandler, RdpServerDisplay, RdpServerDisplayUpdates};
162///
163///# use anyhow::Result;
164///# use ironrdp_server::{DisplayUpdate, DesktopSize, KeyboardEvent, MouseEvent};
165///# use tokio_rustls::TlsAcceptor;
166///# struct NoopInputHandler;
167///# impl RdpServerInputHandler for NoopInputHandler {
168///#     fn keyboard(&mut self, _: KeyboardEvent) {}
169///#     fn mouse(&mut self, _: MouseEvent) {}
170///# }
171///# struct NoopDisplay;
172///# #[async_trait::async_trait]
173///# impl RdpServerDisplay for NoopDisplay {
174///#     async fn size(&mut self) -> DesktopSize {
175///#         todo!()
176///#     }
177///#     async fn updates(&mut self) -> Result<Box<dyn RdpServerDisplayUpdates>> {
178///#         todo!()
179///#     }
180///# }
181///# async fn stub() -> Result<()> {
182/// fn make_tls_acceptor() -> TlsAcceptor {
183///    /* snip */
184///#    todo!()
185/// }
186///
187/// fn make_input_handler() -> impl RdpServerInputHandler {
188///    /* snip */
189///#    NoopInputHandler
190/// }
191///
192/// fn make_display_handler() -> impl RdpServerDisplay {
193///    /* snip */
194///#    NoopDisplay
195/// }
196///
197/// let tls_acceptor = make_tls_acceptor();
198/// let input_handler = make_input_handler();
199/// let display_handler = make_display_handler();
200///
201/// let mut server = RdpServer::builder()
202///     .with_addr(([127, 0, 0, 1], 3389))
203///     .with_tls(tls_acceptor)
204///     .with_input_handler(input_handler)
205///     .with_display_handler(display_handler)
206///     .build();
207///
208/// server.run().await;
209/// Ok(())
210///# }
211/// ```
212pub struct RdpServer {
213    opts: RdpServerOptions,
214    // FIXME: replace with a channel and poll/process the handler?
215    handler: Arc<Mutex<Box<dyn RdpServerInputHandler>>>,
216    display: Arc<Mutex<Box<dyn RdpServerDisplay>>>,
217    static_channels: StaticChannelSet,
218    sound_factory: Option<Box<dyn SoundServerFactory>>,
219    cliprdr_factory: Option<Box<dyn CliprdrServerFactory>>,
220    ev_sender: mpsc::UnboundedSender<ServerEvent>,
221    ev_receiver: Arc<Mutex<mpsc::UnboundedReceiver<ServerEvent>>>,
222    creds: Option<Credentials>,
223    local_addr: Option<SocketAddr>,
224}
225
226#[derive(Debug)]
227pub enum ServerEvent {
228    Quit(String),
229    Clipboard(ClipboardMessage),
230    Rdpsnd(RdpsndServerMessage),
231    SetCredentials(Credentials),
232    GetLocalAddr(oneshot::Sender<Option<SocketAddr>>),
233}
234
235pub trait ServerEventSender {
236    fn set_sender(&mut self, sender: mpsc::UnboundedSender<ServerEvent>);
237}
238
239impl ServerEvent {
240    pub fn create_channel() -> (mpsc::UnboundedSender<Self>, mpsc::UnboundedReceiver<Self>) {
241        mpsc::unbounded_channel()
242    }
243}
244
245#[derive(Debug, PartialEq)]
246enum RunState {
247    Continue,
248    Disconnect,
249    DeactivationReactivation { desktop_size: DesktopSize },
250}
251
252impl RdpServer {
253    pub fn new(
254        opts: RdpServerOptions,
255        handler: Box<dyn RdpServerInputHandler>,
256        display: Box<dyn RdpServerDisplay>,
257        mut sound_factory: Option<Box<dyn SoundServerFactory>>,
258        mut cliprdr_factory: Option<Box<dyn CliprdrServerFactory>>,
259    ) -> Self {
260        let (ev_sender, ev_receiver) = ServerEvent::create_channel();
261        if let Some(cliprdr) = cliprdr_factory.as_mut() {
262            cliprdr.set_sender(ev_sender.clone());
263        }
264        if let Some(snd) = sound_factory.as_mut() {
265            snd.set_sender(ev_sender.clone());
266        }
267        Self {
268            opts,
269            handler: Arc::new(Mutex::new(handler)),
270            display: Arc::new(Mutex::new(display)),
271            static_channels: StaticChannelSet::new(),
272            sound_factory,
273            cliprdr_factory,
274            ev_sender,
275            ev_receiver: Arc::new(Mutex::new(ev_receiver)),
276            creds: None,
277            local_addr: None,
278        }
279    }
280
281    pub fn builder() -> builder::RdpServerBuilder<builder::WantsAddr> {
282        builder::RdpServerBuilder::new()
283    }
284
285    pub fn event_sender(&self) -> &mpsc::UnboundedSender<ServerEvent> {
286        &self.ev_sender
287    }
288
289    fn attach_channels(&mut self, acceptor: &mut Acceptor) {
290        if let Some(cliprdr_factory) = self.cliprdr_factory.as_deref() {
291            let backend = cliprdr_factory.build_cliprdr_backend();
292
293            let cliprdr = CliprdrServer::new(backend);
294
295            acceptor.attach_static_channel(cliprdr);
296        }
297
298        if let Some(factory) = self.sound_factory.as_deref() {
299            let backend = factory.build_backend();
300
301            acceptor.attach_static_channel(RdpsndServer::new(backend));
302        }
303
304        let dcs_backend = DisplayControlBackend::new(Arc::clone(&self.display));
305        let dvc = dvc::DrdynvcServer::new()
306            .with_dynamic_channel(AInputHandler {
307                handler: Arc::clone(&self.handler),
308            })
309            .with_dynamic_channel(DisplayControlServer::new(Box::new(dcs_backend)));
310        acceptor.attach_static_channel(dvc);
311    }
312
313    pub async fn run_connection(&mut self, stream: TcpStream) -> Result<()> {
314        let framed = TokioFramed::new(stream);
315
316        let size = self.display.lock().await.size().await;
317        let capabilities = capabilities::capabilities(&self.opts, size);
318        let mut acceptor = Acceptor::new(self.opts.security.flag(), size, capabilities, self.creds.clone());
319
320        self.attach_channels(&mut acceptor);
321
322        let res = ironrdp_acceptor::accept_begin(framed, &mut acceptor)
323            .await
324            .context("accept_begin failed")?;
325
326        match res {
327            BeginResult::ShouldUpgrade(stream) => {
328                let tls_acceptor = match &self.opts.security {
329                    RdpServerSecurity::Tls(acceptor) => acceptor,
330                    RdpServerSecurity::Hybrid((acceptor, _)) => acceptor,
331                    RdpServerSecurity::None => unreachable!(),
332                };
333                let accept = match tls_acceptor.accept(stream).await {
334                    Ok(accept) => accept,
335                    Err(e) => {
336                        warn!("Failed to TLS accept: {}", e);
337                        return Ok(());
338                    }
339                };
340                let mut framed = TokioFramed::new(accept);
341
342                acceptor.mark_security_upgrade_as_done();
343
344                if let RdpServerSecurity::Hybrid((_, pub_key)) = &self.opts.security {
345                    // how to get the client name?
346                    // doesn't seem to matter yet
347                    let client_name = framed.get_inner().0.get_ref().0.peer_addr()?.to_string();
348
349                    ironrdp_acceptor::accept_credssp(
350                        &mut framed,
351                        &mut acceptor,
352                        client_name.into(),
353                        pub_key.clone(),
354                        None,
355                        None,
356                    )
357                    .await?;
358                }
359
360                self.accept_finalize(framed, acceptor).await?;
361            }
362
363            BeginResult::Continue(framed) => {
364                self.accept_finalize(framed, acceptor).await?;
365            }
366        };
367
368        Ok(())
369    }
370
371    pub async fn run(&mut self) -> Result<()> {
372        let listener = TcpListener::bind(self.opts.addr).await?;
373        let local_addr = listener.local_addr()?;
374
375        debug!("Listening for connections on {local_addr}");
376        self.local_addr = Some(local_addr);
377
378        loop {
379            let ev_receiver = Arc::clone(&self.ev_receiver);
380            let mut ev_receiver = ev_receiver.lock().await;
381            tokio::select! {
382                Some(event) = ev_receiver.recv() => {
383                    match event {
384                        ServerEvent::Quit(reason) => {
385                            debug!("Got quit event {reason}");
386                            break;
387                        }
388                        ServerEvent::GetLocalAddr(tx) => {
389                            let _ = tx.send(self.local_addr);
390                        }
391                        ServerEvent::SetCredentials(creds) => {
392                            self.set_credentials(Some(creds));
393                        }
394                        ev => {
395                            debug!("Unexpected event {:?}", ev);
396                        }
397                    }
398                },
399                Ok((stream, peer)) = listener.accept() => {
400                    debug!(?peer, "Received connection");
401                    drop(ev_receiver);
402                    if let Err(error) = self.run_connection(stream).await {
403                        error!(?error, "Connection error");
404                    }
405                    self.static_channels = StaticChannelSet::new();
406                }
407                else => break,
408            }
409        }
410
411        Ok(())
412    }
413
414    pub fn get_svc_processor<T: SvcProcessor + 'static>(&mut self) -> Option<&mut T> {
415        self.static_channels
416            .get_by_type_mut::<T>()
417            .and_then(|svc| svc.channel_processor_downcast_mut())
418    }
419
420    pub fn get_channel_id_by_type<T: SvcProcessor + 'static>(&self) -> Option<StaticChannelId> {
421        self.static_channels.get_channel_id_by_type::<T>()
422    }
423
424    async fn dispatch_pdu(
425        &mut self,
426        action: Action,
427        bytes: bytes::BytesMut,
428        writer: &mut impl FramedWrite,
429        io_channel_id: u16,
430        user_channel_id: u16,
431    ) -> Result<RunState> {
432        match action {
433            Action::FastPath => {
434                let input = decode(&bytes)?;
435                self.handle_fastpath(input).await;
436            }
437
438            Action::X224 => {
439                if self
440                    .handle_x224(writer, io_channel_id, user_channel_id, &bytes)
441                    .await
442                    .context("X224 input error")?
443                {
444                    debug!("Got disconnect request");
445                    return Ok(RunState::Disconnect);
446                }
447            }
448        }
449
450        Ok(RunState::Continue)
451    }
452
453    async fn dispatch_display_update(
454        update: DisplayUpdate,
455        writer: &mut impl FramedWrite,
456        user_channel_id: u16,
457        io_channel_id: u16,
458        buffer: &mut Vec<u8>,
459        mut encoder: UpdateEncoder,
460    ) -> Result<(RunState, UpdateEncoder)> {
461        if let DisplayUpdate::Resize(desktop_size) = update {
462            debug!(?desktop_size, "Display resize");
463            encoder.set_desktop_size(desktop_size);
464            deactivate_all(io_channel_id, user_channel_id, writer).await?;
465            return Ok((RunState::DeactivationReactivation { desktop_size }, encoder));
466        }
467
468        let mut encoder_iter = encoder.update(update);
469        loop {
470            let Some(fragmenter) = encoder_iter.next().await else {
471                break;
472            };
473
474            let mut fragmenter = fragmenter.context("error while encoding")?;
475            if fragmenter.size_hint() > buffer.len() {
476                buffer.resize(fragmenter.size_hint(), 0);
477            }
478
479            while let Some(len) = fragmenter.next(buffer) {
480                writer
481                    .write_all(&buffer[..len])
482                    .await
483                    .context("failed to write display update")?;
484            }
485        }
486
487        Ok((RunState::Continue, encoder))
488    }
489
490    async fn dispatch_server_events(
491        &mut self,
492        events: &mut Vec<ServerEvent>,
493        writer: &mut impl FramedWrite,
494        user_channel_id: u16,
495    ) -> Result<RunState> {
496        // Avoid wave message queuing up and causing extra delays.
497        // This is a naive solution, better solutions should compute the actual delay, add IO priority, encode audio, use UDP etc.
498        // 4 frames should roughly corresponds to hundreds of ms in regular setups.
499        let mut wave_limit = 4;
500        for event in events.drain(..) {
501            trace!(?event, "Dispatching");
502            match event {
503                ServerEvent::Quit(reason) => {
504                    debug!("Got quit event: {reason}");
505                    return Ok(RunState::Disconnect);
506                }
507                ServerEvent::GetLocalAddr(tx) => {
508                    let _ = tx.send(self.local_addr);
509                }
510                ServerEvent::SetCredentials(creds) => {
511                    self.set_credentials(Some(creds));
512                }
513                ServerEvent::Rdpsnd(s) => {
514                    let Some(rdpsnd) = self.get_svc_processor::<RdpsndServer>() else {
515                        warn!("No rdpsnd channel, dropping event");
516                        continue;
517                    };
518                    let msgs = match s {
519                        RdpsndServerMessage::Wave(data, ts) => {
520                            if wave_limit == 0 {
521                                debug!("Dropping wave");
522                                continue;
523                            }
524                            wave_limit -= 1;
525                            rdpsnd.wave(data, ts)
526                        }
527                        RdpsndServerMessage::SetVolume { left, right } => rdpsnd.set_volume(left, right),
528                        RdpsndServerMessage::Close => rdpsnd.close(),
529                        RdpsndServerMessage::Error(error) => {
530                            error!(?error, "Handling rdpsnd event");
531                            continue;
532                        }
533                    }
534                    .context("failed to send rdpsnd event")?;
535                    let channel_id = self
536                        .get_channel_id_by_type::<RdpsndServer>()
537                        .ok_or_else(|| anyhow!("SVC channel not found"))?;
538                    let data = server_encode_svc_messages(msgs.into(), channel_id, user_channel_id)?;
539                    writer.write_all(&data).await?;
540                }
541                ServerEvent::Clipboard(c) => {
542                    let Some(cliprdr) = self.get_svc_processor::<CliprdrServer>() else {
543                        warn!("No clipboard channel, dropping event");
544                        continue;
545                    };
546                    let msgs = match c {
547                        ClipboardMessage::SendInitiateCopy(formats) => cliprdr.initiate_copy(&formats),
548                        ClipboardMessage::SendFormatData(data) => cliprdr.submit_format_data(data),
549                        ClipboardMessage::SendInitiatePaste(format) => cliprdr.initiate_paste(format),
550                        ClipboardMessage::Error(error) => {
551                            error!(?error, "Handling clipboard event");
552                            continue;
553                        }
554                    }
555                    .context("failed to send clipboard event")?;
556                    let channel_id = self
557                        .get_channel_id_by_type::<CliprdrServer>()
558                        .ok_or_else(|| anyhow!("SVC channel not found"))?;
559                    let data = server_encode_svc_messages(msgs.into(), channel_id, user_channel_id)?;
560                    writer.write_all(&data).await?;
561                }
562            }
563        }
564
565        Ok(RunState::Continue)
566    }
567
568    async fn client_loop<R, W>(
569        &mut self,
570        reader: &mut Framed<R>,
571        writer: &mut Framed<W>,
572        io_channel_id: u16,
573        user_channel_id: u16,
574        mut encoder: UpdateEncoder,
575    ) -> Result<RunState>
576    where
577        R: FramedRead,
578        W: FramedWrite,
579    {
580        debug!("Starting client loop");
581        let mut display_updates = self.display.lock().await.updates().await?;
582        let mut writer = SharedWriter::new(writer);
583        let mut display_writer = writer.clone();
584        let mut event_writer = writer.clone();
585        let ev_receiver = Arc::clone(&self.ev_receiver);
586        let s = Rc::new(Mutex::new(self));
587
588        let this = Rc::clone(&s);
589        let dispatch_pdu = async move {
590            loop {
591                let (action, bytes) = reader.read_pdu().await?;
592                let mut this = this.lock().await;
593                match this
594                    .dispatch_pdu(action, bytes, &mut writer, io_channel_id, user_channel_id)
595                    .await?
596                {
597                    RunState::Continue => continue,
598                    state => break Ok(state),
599                }
600            }
601        };
602
603        let dispatch_display = async move {
604            let mut buffer = vec![0u8; 4096];
605
606            loop {
607                match display_updates.next_update().await {
608                    Ok(Some(update)) => {
609                        match Self::dispatch_display_update(
610                            update,
611                            &mut display_writer,
612                            user_channel_id,
613                            io_channel_id,
614                            &mut buffer,
615                            encoder,
616                        )
617                        .await?
618                        {
619                            (RunState::Continue, enc) => {
620                                encoder = enc;
621                                continue;
622                            }
623                            (state, _) => {
624                                break Ok(state);
625                            }
626                        }
627                    }
628                    Ok(None) => {
629                        break Ok(RunState::Disconnect);
630                    }
631                    Err(error) => {
632                        warn!(error = format!("{error:#}"), "next_updated failed");
633                    }
634                }
635            }
636        };
637
638        let this = Rc::clone(&s);
639        let mut ev_receiver = ev_receiver.lock().await;
640        let dispatch_events = async move {
641            let mut events = Vec::with_capacity(100);
642            loop {
643                let nevents = ev_receiver.recv_many(&mut events, 100).await;
644                if nevents == 0 {
645                    debug!("No sever events.. stopping");
646                    break Ok(RunState::Disconnect);
647                }
648                while let Ok(ev) = ev_receiver.try_recv() {
649                    events.push(ev);
650                }
651                let mut this = this.lock().await;
652                match this
653                    .dispatch_server_events(&mut events, &mut event_writer, user_channel_id)
654                    .await?
655                {
656                    RunState::Continue => continue,
657                    state => break Ok(state),
658                }
659            }
660        };
661
662        let state = tokio::select!(
663            state = dispatch_pdu => state,
664            state = dispatch_display => state,
665            state = dispatch_events => state,
666        );
667
668        debug!("End of client loop: {state:?}");
669        state
670    }
671
672    async fn client_accepted<R, W>(
673        &mut self,
674        reader: &mut Framed<R>,
675        writer: &mut Framed<W>,
676        result: AcceptorResult,
677    ) -> Result<RunState>
678    where
679        R: FramedRead,
680        W: FramedWrite,
681    {
682        debug!("Client accepted");
683
684        if !result.input_events.is_empty() {
685            debug!("Handling input event backlog from acceptor sequence");
686            self.handle_input_backlog(
687                writer,
688                result.io_channel_id,
689                result.user_channel_id,
690                result.input_events,
691            )
692            .await?;
693        }
694
695        self.static_channels = result.static_channels;
696        if !result.reactivation {
697            for (_type_id, channel, channel_id) in self.static_channels.iter_mut() {
698                debug!(?channel, ?channel_id, "Start");
699                let Some(channel_id) = channel_id else {
700                    continue;
701                };
702                let svc_responses = channel.start()?;
703                let response = server_encode_svc_messages(svc_responses, channel_id, result.user_channel_id)?;
704                writer.write_all(&response).await?;
705            }
706        }
707
708        let mut update_codecs = UpdateEncoderCodecs::new();
709        let mut surface_flags = CmdFlags::empty();
710        for c in result.capabilities {
711            match c {
712                CapabilitySet::General(c) => {
713                    let fastpath = c.extra_flags.contains(GeneralExtraFlags::FASTPATH_OUTPUT_SUPPORTED);
714                    if !fastpath {
715                        bail!("Fastpath output not supported!");
716                    }
717                }
718                CapabilitySet::Bitmap(b) => {
719                    if !b.desktop_resize_flag {
720                        debug!("Desktop resize is not supported by the client");
721                        continue;
722                    }
723
724                    let client_size = DesktopSize {
725                        width: b.desktop_width,
726                        height: b.desktop_height,
727                    };
728                    let display_size = self.display.lock().await.size().await;
729
730                    // It's problematic when the client didn't resize, as we send bitmap updates that don't fit.
731                    // The client will likely drop the connection.
732                    if client_size.width < display_size.width || client_size.height < display_size.height {
733                        // TODO: we may have different behaviour instead, such as clipping or scaling?
734                        warn!(
735                            "Client size doesn't fit the server size: {:?} < {:?}",
736                            client_size, display_size
737                        );
738                    }
739                }
740                CapabilitySet::SurfaceCommands(c) => {
741                    surface_flags = c.flags;
742                }
743                CapabilitySet::BitmapCodecs(BitmapCodecs(codecs)) => {
744                    for codec in codecs {
745                        match codec.property {
746                            // FIXME: The encoder operates in image mode only.
747                            //
748                            // See [MS-RDPRFX] 3.1.1.1 "State Machine" for
749                            // implementation of the video mode. which allows to
750                            // skip sending Header for each image.
751                            //
752                            // We should distinguish parameters for both modes,
753                            // and somehow choose the "best", instead of picking
754                            // the last parsed here.
755                            CodecProperty::RemoteFx(rdp::capability_sets::RemoteFxContainer::ClientContainer(c))
756                                if self.opts.has_remote_fx() =>
757                            {
758                                for caps in c.caps_data.0 .0 {
759                                    update_codecs.set_remotefx(Some((caps.entropy_bits, codec.id)));
760                                }
761                            }
762                            CodecProperty::ImageRemoteFx(rdp::capability_sets::RemoteFxContainer::ClientContainer(
763                                c,
764                            )) if self.opts.has_image_remote_fx() => {
765                                for caps in c.caps_data.0 .0 {
766                                    update_codecs.set_remotefx(Some((caps.entropy_bits, codec.id)));
767                                }
768                            }
769                            CodecProperty::NsCodec(_) => (),
770                            #[cfg(feature = "qoi")]
771                            CodecProperty::Qoi if self.opts.has_qoi() => {
772                                update_codecs.set_qoi(Some(codec.id));
773                            }
774                            #[cfg(feature = "qoiz")]
775                            CodecProperty::QoiZ if self.opts.has_qoiz() => {
776                                update_codecs.set_qoiz(Some(codec.id));
777                            }
778                            _ => (),
779                        }
780                    }
781                }
782                _ => {}
783            }
784        }
785
786        let desktop_size = self.display.lock().await.size().await;
787        let encoder = UpdateEncoder::new(desktop_size, surface_flags, update_codecs)
788            .context("failed to initialize update encoder")?;
789
790        let state = self
791            .client_loop(reader, writer, result.io_channel_id, result.user_channel_id, encoder)
792            .await
793            .context("client loop failure")?;
794
795        Ok(state)
796    }
797
798    async fn handle_input_backlog(
799        &mut self,
800        writer: &mut impl FramedWrite,
801        io_channel_id: u16,
802        user_channel_id: u16,
803        frames: Vec<Vec<u8>>,
804    ) -> Result<()> {
805        for frame in frames {
806            match Action::from_fp_output_header(frame[0]) {
807                Ok(Action::FastPath) => {
808                    let input = decode(&frame)?;
809                    self.handle_fastpath(input).await;
810                }
811
812                Ok(Action::X224) => {
813                    let _ = self.handle_x224(writer, io_channel_id, user_channel_id, &frame).await;
814                }
815
816                // the frame here is always valid, because otherwise it would
817                // have failed during the acceptor loop
818                Err(_) => unreachable!(),
819            }
820        }
821
822        Ok(())
823    }
824
825    async fn handle_fastpath(&mut self, input: FastPathInput) {
826        for event in input.0 {
827            let mut handler = self.handler.lock().await;
828            match event {
829                FastPathInputEvent::KeyboardEvent(flags, key) => {
830                    handler.keyboard((key, flags).into());
831                }
832
833                FastPathInputEvent::UnicodeKeyboardEvent(flags, key) => {
834                    handler.keyboard((key, flags).into());
835                }
836
837                FastPathInputEvent::SyncEvent(flags) => {
838                    handler.keyboard(flags.into());
839                }
840
841                FastPathInputEvent::MouseEvent(mouse) => {
842                    handler.mouse(mouse.into());
843                }
844
845                FastPathInputEvent::MouseEventEx(mouse) => {
846                    handler.mouse(mouse.into());
847                }
848
849                FastPathInputEvent::MouseEventRel(mouse) => {
850                    handler.mouse(mouse.into());
851                }
852
853                FastPathInputEvent::QoeEvent(quality) => {
854                    warn!("Received QoE: {}", quality);
855                }
856            }
857        }
858    }
859
860    async fn handle_io_channel_data(&mut self, data: SendDataRequest<'_>) -> Result<bool> {
861        let control: rdp::headers::ShareControlHeader = decode(data.user_data.as_ref())?;
862
863        match control.share_control_pdu {
864            ShareControlPdu::Data(header) => match header.share_data_pdu {
865                rdp::headers::ShareDataPdu::Input(pdu) => {
866                    self.handle_input_event(pdu).await;
867                }
868
869                rdp::headers::ShareDataPdu::ShutdownRequest => {
870                    return Ok(true);
871                }
872
873                unexpected => {
874                    warn!(?unexpected, "Unexpected share data pdu");
875                }
876            },
877
878            unexpected => {
879                warn!(?unexpected, "Unexpected share control");
880            }
881        }
882
883        Ok(false)
884    }
885
886    async fn handle_x224(
887        &mut self,
888        writer: &mut impl FramedWrite,
889        io_channel_id: u16,
890        user_channel_id: u16,
891        frame: &[u8],
892    ) -> Result<bool> {
893        let message = decode::<X224<mcs::McsMessage<'_>>>(frame)?;
894        match message.0 {
895            mcs::McsMessage::SendDataRequest(data) => {
896                debug!(?data, "McsMessage::SendDataRequest");
897                if data.channel_id == io_channel_id {
898                    return self.handle_io_channel_data(data).await;
899                }
900
901                if let Some(svc) = self.static_channels.get_by_channel_id_mut(data.channel_id) {
902                    let response_pdus = svc.process(&data.user_data)?;
903                    let response = server_encode_svc_messages(response_pdus, data.channel_id, user_channel_id)?;
904                    writer.write_all(&response).await?;
905                } else {
906                    warn!(channel_id = data.channel_id, "Unexpected channel received: ID",);
907                }
908            }
909
910            mcs::McsMessage::DisconnectProviderUltimatum(disconnect) => {
911                if disconnect.reason == mcs::DisconnectReason::UserRequested {
912                    return Ok(true);
913                }
914            }
915
916            _ => {
917                warn!(name = ironrdp_core::name(&message), "Unexpected mcs message");
918            }
919        }
920
921        Ok(false)
922    }
923
924    async fn handle_input_event(&mut self, input: InputEventPdu) {
925        for event in input.0 {
926            let mut handler = self.handler.lock().await;
927            match event {
928                ironrdp_pdu::input::InputEvent::ScanCode(key) => {
929                    handler.keyboard((key.key_code, key.flags).into());
930                }
931
932                ironrdp_pdu::input::InputEvent::Unicode(key) => {
933                    handler.keyboard((key.unicode_code, key.flags).into());
934                }
935
936                ironrdp_pdu::input::InputEvent::Sync(sync) => {
937                    handler.keyboard(sync.flags.into());
938                }
939
940                ironrdp_pdu::input::InputEvent::Mouse(mouse) => {
941                    handler.mouse(mouse.into());
942                }
943
944                ironrdp_pdu::input::InputEvent::MouseX(mouse) => {
945                    handler.mouse(mouse.into());
946                }
947
948                ironrdp_pdu::input::InputEvent::MouseRel(mouse) => {
949                    handler.mouse(mouse.into());
950                }
951
952                ironrdp_pdu::input::InputEvent::Unused(_) => {}
953            }
954        }
955    }
956
957    async fn accept_finalize<S>(&mut self, mut framed: TokioFramed<S>, mut acceptor: Acceptor) -> Result<()>
958    where
959        S: AsyncRead + AsyncWrite + Sync + Send + Unpin,
960    {
961        loop {
962            let (new_framed, result) = ironrdp_acceptor::accept_finalize(framed, &mut acceptor)
963                .await
964                .context("failed to accept client during finalize")?;
965
966            let (mut reader, mut writer) = split_tokio_framed(new_framed);
967
968            match self.client_accepted(&mut reader, &mut writer, result).await? {
969                RunState::Continue => {
970                    unreachable!();
971                }
972                RunState::DeactivationReactivation { desktop_size } => {
973                    // No description of such behavior was found in the
974                    // specification, but apparently, we must keep the channel
975                    // state as they were during reactivation. This fixes
976                    // various state issues during client resize.
977                    acceptor = Acceptor::new_deactivation_reactivation(
978                        acceptor,
979                        core::mem::take(&mut self.static_channels),
980                        desktop_size,
981                    )?;
982                    framed = unsplit_tokio_framed(reader, writer);
983                    continue;
984                }
985                RunState::Disconnect => break,
986            }
987        }
988
989        Ok(())
990    }
991
992    pub fn set_credentials(&mut self, creds: Option<Credentials>) {
993        debug!(?creds, "Changing credentials");
994        self.creds = creds
995    }
996}
997
998async fn deactivate_all(
999    io_channel_id: u16,
1000    user_channel_id: u16,
1001    writer: &mut impl FramedWrite,
1002) -> Result<(), anyhow::Error> {
1003    let pdu = ShareControlPdu::ServerDeactivateAll(ServerDeactivateAll);
1004    let pdu = rdp::headers::ShareControlHeader {
1005        share_id: 0,
1006        pdu_source: io_channel_id,
1007        share_control_pdu: pdu,
1008    };
1009    let user_data = encode_vec(&pdu)?.into();
1010    let pdu = SendDataIndication {
1011        initiator_id: user_channel_id,
1012        channel_id: io_channel_id,
1013        user_data,
1014    };
1015    let msg = encode_vec(&X224(pdu))?;
1016    writer.write_all(&msg).await?;
1017    Ok(())
1018}
1019
1020struct SharedWriter<'w, W: FramedWrite> {
1021    writer: Rc<Mutex<&'w mut W>>,
1022}
1023
1024impl<W: FramedWrite> Clone for SharedWriter<'_, W> {
1025    fn clone(&self) -> Self {
1026        Self {
1027            writer: Rc::clone(&self.writer),
1028        }
1029    }
1030}
1031
1032impl<W> FramedWrite for SharedWriter<'_, W>
1033where
1034    W: FramedWrite,
1035{
1036    type WriteAllFut<'write>
1037        = core::pin::Pin<Box<dyn core::future::Future<Output = std::io::Result<()>> + 'write>>
1038    where
1039        Self: 'write;
1040
1041    fn write_all<'a>(&'a mut self, buf: &'a [u8]) -> Self::WriteAllFut<'a> {
1042        Box::pin(async {
1043            let mut writer = self.writer.lock().await;
1044
1045            writer.write_all(buf).await?;
1046            Ok(())
1047        })
1048    }
1049}
1050
1051impl<'a, W: FramedWrite> SharedWriter<'a, W> {
1052    fn new(writer: &'a mut W) -> Self {
1053        Self {
1054            writer: Rc::new(Mutex::new(writer)),
1055        }
1056    }
1057}