use std::net::SocketAddr;
use std::rc::Rc;
use std::sync::Arc;
use anyhow::{anyhow, bail, Context, Result};
use ironrdp_acceptor::{self, Acceptor, AcceptorResult, BeginResult, DesktopSize};
use ironrdp_async::{bytes, Framed};
use ironrdp_cliprdr::backend::ClipboardMessage;
use ironrdp_cliprdr::CliprdrServer;
use ironrdp_core::{decode, encode_vec, impl_as_any};
use ironrdp_displaycontrol::pdu::DisplayControlMonitorLayout;
use ironrdp_displaycontrol::server::{DisplayControlHandler, DisplayControlServer};
use ironrdp_pdu::input::fast_path::{FastPathInput, FastPathInputEvent};
use ironrdp_pdu::input::InputEventPdu;
use ironrdp_pdu::mcs::{SendDataIndication, SendDataRequest};
use ironrdp_pdu::rdp::capability_sets::{BitmapCodecs, CapabilitySet, CmdFlags, GeneralExtraFlags};
pub use ironrdp_pdu::rdp::client_info::Credentials;
use ironrdp_pdu::rdp::headers::{ServerDeactivateAll, ShareControlPdu};
use ironrdp_pdu::x224::X224;
use ironrdp_pdu::{self, decode_err, mcs, nego, rdp, Action, PduResult};
use ironrdp_svc::{server_encode_svc_messages, StaticChannelId, StaticChannelSet, SvcProcessor};
use ironrdp_tokio::{split_tokio_framed, unsplit_tokio_framed, FramedRead, FramedWrite, TokioFramed};
use rdpsnd::server::{RdpsndServer, RdpsndServerMessage};
use tokio::io::{AsyncRead, AsyncWrite};
use tokio::net::{TcpListener, TcpStream};
use tokio::sync::{mpsc, oneshot, Mutex};
use tokio::task;
use tokio_rustls::TlsAcceptor;
use {ironrdp_dvc as dvc, ironrdp_rdpsnd as rdpsnd};
use crate::clipboard::CliprdrServerFactory;
use crate::display::{DisplayUpdate, RdpServerDisplay};
use crate::encoder::{UpdateEncoder, UpdateEncoderCodecs};
use crate::handler::RdpServerInputHandler;
use crate::{builder, capabilities, SoundServerFactory};
#[derive(Clone)]
pub struct RdpServerOptions {
pub addr: SocketAddr,
pub security: RdpServerSecurity,
pub with_remote_fx: bool,
}
#[derive(Clone)]
pub enum RdpServerSecurity {
None,
Tls(TlsAcceptor),
Hybrid((TlsAcceptor, Vec<u8>)),
}
impl RdpServerSecurity {
pub fn flag(&self) -> nego::SecurityProtocol {
match self {
RdpServerSecurity::None => nego::SecurityProtocol::empty(),
RdpServerSecurity::Tls(_) => nego::SecurityProtocol::SSL,
RdpServerSecurity::Hybrid(_) => nego::SecurityProtocol::HYBRID | nego::SecurityProtocol::HYBRID_EX,
}
}
}
struct AInputHandler {
handler: Arc<Mutex<Box<dyn RdpServerInputHandler>>>,
}
impl_as_any!(AInputHandler);
impl dvc::DvcProcessor for AInputHandler {
fn channel_name(&self) -> &str {
ironrdp_ainput::CHANNEL_NAME
}
fn start(&mut self, _channel_id: u32) -> PduResult<Vec<dvc::DvcMessage>> {
use ironrdp_ainput::{ServerPdu, VersionPdu};
let pdu = ServerPdu::Version(VersionPdu::default());
Ok(vec![Box::new(pdu)])
}
fn close(&mut self, _channel_id: u32) {}
fn process(&mut self, _channel_id: u32, payload: &[u8]) -> PduResult<Vec<dvc::DvcMessage>> {
use ironrdp_ainput::ClientPdu;
match decode(payload).map_err(|e| decode_err!(e))? {
ClientPdu::Mouse(pdu) => {
let handler = Arc::clone(&self.handler);
task::spawn_blocking(move || {
handler.blocking_lock().mouse(pdu.into());
});
}
}
Ok(Vec::new())
}
}
impl dvc::DvcServerProcessor for AInputHandler {}
struct DisplayControlBackend {
display: Arc<Mutex<Box<dyn RdpServerDisplay>>>,
}
impl DisplayControlBackend {
fn new(display: Arc<Mutex<Box<dyn RdpServerDisplay>>>) -> Self {
Self { display }
}
}
impl DisplayControlHandler for DisplayControlBackend {
fn monitor_layout(&self, layout: DisplayControlMonitorLayout) {
let display = Arc::clone(&self.display);
task::spawn_blocking(move || display.blocking_lock().request_layout(layout));
}
}
pub struct RdpServer {
opts: RdpServerOptions,
handler: Arc<Mutex<Box<dyn RdpServerInputHandler>>>,
display: Arc<Mutex<Box<dyn RdpServerDisplay>>>,
static_channels: StaticChannelSet,
sound_factory: Option<Box<dyn SoundServerFactory>>,
cliprdr_factory: Option<Box<dyn CliprdrServerFactory>>,
ev_sender: mpsc::UnboundedSender<ServerEvent>,
ev_receiver: Arc<Mutex<mpsc::UnboundedReceiver<ServerEvent>>>,
creds: Option<Credentials>,
local_addr: Option<SocketAddr>,
}
#[derive(Debug)]
pub enum ServerEvent {
Quit(String),
Clipboard(ClipboardMessage),
Rdpsnd(RdpsndServerMessage),
SetCredentials(Credentials),
GetLocalAddr(oneshot::Sender<Option<SocketAddr>>),
}
pub trait ServerEventSender {
fn set_sender(&mut self, sender: mpsc::UnboundedSender<ServerEvent>);
}
impl ServerEvent {
pub fn create_channel() -> (mpsc::UnboundedSender<Self>, mpsc::UnboundedReceiver<Self>) {
mpsc::unbounded_channel()
}
}
#[derive(Debug, PartialEq)]
enum RunState {
Continue,
Disconnect,
DeactivationReactivation { desktop_size: DesktopSize },
}
impl RdpServer {
pub fn new(
opts: RdpServerOptions,
handler: Box<dyn RdpServerInputHandler>,
display: Box<dyn RdpServerDisplay>,
mut sound_factory: Option<Box<dyn SoundServerFactory>>,
mut cliprdr_factory: Option<Box<dyn CliprdrServerFactory>>,
) -> Self {
let (ev_sender, ev_receiver) = ServerEvent::create_channel();
if let Some(cliprdr) = cliprdr_factory.as_mut() {
cliprdr.set_sender(ev_sender.clone());
}
if let Some(snd) = sound_factory.as_mut() {
snd.set_sender(ev_sender.clone());
}
Self {
opts,
handler: Arc::new(Mutex::new(handler)),
display: Arc::new(Mutex::new(display)),
static_channels: StaticChannelSet::new(),
sound_factory,
cliprdr_factory,
ev_sender,
ev_receiver: Arc::new(Mutex::new(ev_receiver)),
creds: None,
local_addr: None,
}
}
pub fn builder() -> builder::RdpServerBuilder<builder::WantsAddr> {
builder::RdpServerBuilder::new()
}
pub fn event_sender(&self) -> &mpsc::UnboundedSender<ServerEvent> {
&self.ev_sender
}
fn attach_channels(&mut self, acceptor: &mut Acceptor) {
if let Some(cliprdr_factory) = self.cliprdr_factory.as_deref() {
let backend = cliprdr_factory.build_cliprdr_backend();
let cliprdr = CliprdrServer::new(backend);
acceptor.attach_static_channel(cliprdr);
}
if let Some(factory) = self.sound_factory.as_deref() {
let backend = factory.build_backend();
acceptor.attach_static_channel(RdpsndServer::new(backend));
}
let dcs_backend = DisplayControlBackend::new(Arc::clone(&self.display));
let dvc = dvc::DrdynvcServer::new()
.with_dynamic_channel(AInputHandler {
handler: Arc::clone(&self.handler),
})
.with_dynamic_channel(DisplayControlServer::new(Box::new(dcs_backend)));
acceptor.attach_static_channel(dvc);
}
pub async fn run_connection(&mut self, stream: TcpStream) -> Result<()> {
let framed = TokioFramed::new(stream);
let size = self.display.lock().await.size().await;
let capabilities = capabilities::capabilities(&self.opts, size);
let mut acceptor = Acceptor::new(self.opts.security.flag(), size, capabilities, self.creds.clone());
self.attach_channels(&mut acceptor);
let res = ironrdp_acceptor::accept_begin(framed, &mut acceptor)
.await
.context("accept_begin failed")?;
match res {
BeginResult::ShouldUpgrade(stream) => {
let tls_acceptor = match &self.opts.security {
RdpServerSecurity::Tls(acceptor) => acceptor,
RdpServerSecurity::Hybrid((acceptor, _)) => acceptor,
RdpServerSecurity::None => unreachable!(),
};
let accept = match tls_acceptor.accept(stream).await {
Ok(accept) => accept,
Err(e) => {
warn!("Failed to TLS accept: {}", e);
return Ok(());
}
};
let mut framed = TokioFramed::new(accept);
acceptor.mark_security_upgrade_as_done();
if let RdpServerSecurity::Hybrid((_, pub_key)) = &self.opts.security {
let client_name = framed.get_inner().0.get_ref().0.peer_addr()?.to_string();
ironrdp_acceptor::accept_credssp(
&mut framed,
&mut acceptor,
client_name.into(),
pub_key.clone(),
None,
)
.await?;
}
self.accept_finalize(framed, acceptor).await?;
}
BeginResult::Continue(framed) => {
self.accept_finalize(framed, acceptor).await?;
}
};
Ok(())
}
pub async fn run(&mut self) -> Result<()> {
let listener = TcpListener::bind(self.opts.addr).await?;
let local_addr = listener.local_addr()?;
debug!("Listening for connections on {local_addr}");
self.local_addr = Some(local_addr);
loop {
let ev_receiver = Arc::clone(&self.ev_receiver);
let mut ev_receiver = ev_receiver.lock().await;
tokio::select! {
Some(event) = ev_receiver.recv() => {
match event {
ServerEvent::Quit(reason) => {
debug!("Got quit event {reason}");
break;
}
ServerEvent::GetLocalAddr(tx) => {
let _ = tx.send(self.local_addr);
}
ServerEvent::SetCredentials(creds) => {
self.set_credentials(Some(creds));
}
ev => {
debug!("Unexpected event {:?}", ev);
}
}
},
Ok((stream, peer)) = listener.accept() => {
debug!(?peer, "Received connection");
drop(ev_receiver);
if let Err(error) = self.run_connection(stream).await {
error!(?error, "Connection error");
}
self.static_channels = StaticChannelSet::new();
}
else => break,
}
}
Ok(())
}
pub fn get_svc_processor<T: SvcProcessor + 'static>(&mut self) -> Option<&mut T> {
self.static_channels
.get_by_type_mut::<T>()
.and_then(|svc| svc.channel_processor_downcast_mut())
}
pub fn get_channel_id_by_type<T: SvcProcessor + 'static>(&self) -> Option<StaticChannelId> {
self.static_channels.get_channel_id_by_type::<T>()
}
async fn dispatch_pdu(
&mut self,
action: Action,
bytes: bytes::BytesMut,
writer: &mut impl FramedWrite,
io_channel_id: u16,
user_channel_id: u16,
) -> Result<RunState> {
match action {
Action::FastPath => {
let input = decode(&bytes)?;
self.handle_fastpath(input).await;
}
Action::X224 => {
if self
.handle_x224(writer, io_channel_id, user_channel_id, &bytes)
.await
.context("X224 input error")?
{
debug!("Got disconnect request");
return Ok(RunState::Disconnect);
}
}
}
Ok(RunState::Continue)
}
async fn dispatch_display_update(
update: DisplayUpdate,
writer: &mut impl FramedWrite,
user_channel_id: u16,
io_channel_id: u16,
buffer: &mut Vec<u8>,
mut encoder: UpdateEncoder,
) -> Result<(RunState, UpdateEncoder)> {
if let DisplayUpdate::Resize(desktop_size) = update {
debug!(?desktop_size, "Display resize");
encoder.set_desktop_size(desktop_size);
deactivate_all(io_channel_id, user_channel_id, writer).await?;
return Ok((RunState::DeactivationReactivation { desktop_size }, encoder));
}
let mut encoder_iter = encoder.update(update);
loop {
let Some(fragmenter) = encoder_iter.next().await else {
break;
};
let mut fragmenter = fragmenter.context("error while encoding")?;
if fragmenter.size_hint() > buffer.len() {
buffer.resize(fragmenter.size_hint(), 0);
}
while let Some(len) = fragmenter.next(buffer) {
writer
.write_all(&buffer[..len])
.await
.context("failed to write display update")?;
}
}
Ok((RunState::Continue, encoder))
}
async fn dispatch_server_events(
&mut self,
events: &mut Vec<ServerEvent>,
writer: &mut impl FramedWrite,
user_channel_id: u16,
) -> Result<RunState> {
let mut wave_limit = 4;
for event in events.drain(..) {
trace!(?event, "Dispatching");
match event {
ServerEvent::Quit(reason) => {
debug!("Got quit event: {reason}");
return Ok(RunState::Disconnect);
}
ServerEvent::GetLocalAddr(tx) => {
let _ = tx.send(self.local_addr);
}
ServerEvent::SetCredentials(creds) => {
self.set_credentials(Some(creds));
}
ServerEvent::Rdpsnd(s) => {
let Some(rdpsnd) = self.get_svc_processor::<RdpsndServer>() else {
warn!("No rdpsnd channel, dropping event");
continue;
};
let msgs = match s {
RdpsndServerMessage::Wave(data, ts) => {
if wave_limit == 0 {
debug!("Dropping wave");
continue;
}
wave_limit -= 1;
rdpsnd.wave(data, ts)
}
RdpsndServerMessage::SetVolume { left, right } => rdpsnd.set_volume(left, right),
RdpsndServerMessage::Close => rdpsnd.close(),
RdpsndServerMessage::Error(error) => {
error!(?error, "Handling rdpsnd event");
continue;
}
}
.context("failed to send rdpsnd event")?;
let channel_id = self
.get_channel_id_by_type::<RdpsndServer>()
.ok_or_else(|| anyhow!("SVC channel not found"))?;
let data = server_encode_svc_messages(msgs.into(), channel_id, user_channel_id)?;
writer.write_all(&data).await?;
}
ServerEvent::Clipboard(c) => {
let Some(cliprdr) = self.get_svc_processor::<CliprdrServer>() else {
warn!("No clipboard channel, dropping event");
continue;
};
let msgs = match c {
ClipboardMessage::SendInitiateCopy(formats) => cliprdr.initiate_copy(&formats),
ClipboardMessage::SendFormatData(data) => cliprdr.submit_format_data(data),
ClipboardMessage::SendInitiatePaste(format) => cliprdr.initiate_paste(format),
ClipboardMessage::Error(error) => {
error!(?error, "Handling clipboard event");
continue;
}
}
.context("failed to send clipboard event")?;
let channel_id = self
.get_channel_id_by_type::<CliprdrServer>()
.ok_or_else(|| anyhow!("SVC channel not found"))?;
let data = server_encode_svc_messages(msgs.into(), channel_id, user_channel_id)?;
writer.write_all(&data).await?;
}
}
}
Ok(RunState::Continue)
}
async fn client_loop<R, W>(
&mut self,
reader: &mut Framed<R>,
writer: &mut Framed<W>,
io_channel_id: u16,
user_channel_id: u16,
mut encoder: UpdateEncoder,
) -> Result<RunState>
where
R: FramedRead,
W: FramedWrite,
{
debug!("Starting client loop");
let mut display_updates = self.display.lock().await.updates().await?;
let mut writer = SharedWriter::new(writer);
let mut display_writer = writer.clone();
let mut event_writer = writer.clone();
let ev_receiver = Arc::clone(&self.ev_receiver);
let s = Rc::new(Mutex::new(self));
let this = Rc::clone(&s);
let dispatch_pdu = async move {
loop {
let (action, bytes) = reader.read_pdu().await?;
let mut this = this.lock().await;
match this
.dispatch_pdu(action, bytes, &mut writer, io_channel_id, user_channel_id)
.await?
{
RunState::Continue => continue,
state => break Ok(state),
}
}
};
let dispatch_display = async move {
let mut buffer = vec![0u8; 4096];
loop {
if let Some(update) = display_updates.next_update().await {
match Self::dispatch_display_update(
update,
&mut display_writer,
user_channel_id,
io_channel_id,
&mut buffer,
encoder,
)
.await?
{
(RunState::Continue, enc) => {
encoder = enc;
continue;
}
(state, _) => {
break Ok(state);
}
}
} else {
break Ok(RunState::Disconnect);
}
}
};
let this = Rc::clone(&s);
let mut ev_receiver = ev_receiver.lock().await;
let dispatch_events = async move {
let mut events = Vec::with_capacity(100);
loop {
let nevents = ev_receiver.recv_many(&mut events, 100).await;
if nevents == 0 {
debug!("No sever events.. stopping");
break Ok(RunState::Disconnect);
}
while let Ok(ev) = ev_receiver.try_recv() {
events.push(ev);
}
let mut this = this.lock().await;
match this
.dispatch_server_events(&mut events, &mut event_writer, user_channel_id)
.await?
{
RunState::Continue => continue,
state => break Ok(state),
}
}
};
let state = tokio::select!(
state = dispatch_pdu => state,
state = dispatch_display => state,
state = dispatch_events => state,
);
debug!("End of client loop: {state:?}");
state
}
async fn client_accepted<R, W>(
&mut self,
reader: &mut Framed<R>,
writer: &mut Framed<W>,
result: AcceptorResult,
) -> Result<RunState>
where
R: FramedRead,
W: FramedWrite,
{
debug!("Client accepted");
if !result.input_events.is_empty() {
debug!("Handling input event backlog from acceptor sequence");
self.handle_input_backlog(
writer,
result.io_channel_id,
result.user_channel_id,
result.input_events,
)
.await?;
}
self.static_channels = result.static_channels;
if !result.reactivation {
for (_type_id, channel, channel_id) in self.static_channels.iter_mut() {
debug!(?channel, ?channel_id, "Start");
let Some(channel_id) = channel_id else {
continue;
};
let svc_responses = channel.start()?;
let response = server_encode_svc_messages(svc_responses, channel_id, result.user_channel_id)?;
writer.write_all(&response).await?;
}
}
let mut update_codecs = UpdateEncoderCodecs::new();
let mut surface_flags = CmdFlags::empty();
for c in result.capabilities {
match c {
CapabilitySet::General(c) => {
let fastpath = c.extra_flags.contains(GeneralExtraFlags::FASTPATH_OUTPUT_SUPPORTED);
if !fastpath {
bail!("Fastpath output not supported!");
}
}
CapabilitySet::Bitmap(b) => {
if !b.desktop_resize_flag {
debug!("Desktop resize is not supported by the client");
continue;
}
let client_size = DesktopSize {
width: b.desktop_width,
height: b.desktop_height,
};
let display_size = self.display.lock().await.size().await;
if client_size.width < display_size.width || client_size.height < display_size.height {
warn!(
"Client size doesn't fit the server size: {:?} < {:?}",
client_size, display_size
);
}
}
CapabilitySet::SurfaceCommands(c) => {
surface_flags = c.flags;
}
CapabilitySet::BitmapCodecs(BitmapCodecs(codecs)) => {
for codec in codecs {
match codec.property {
rdp::capability_sets::CodecProperty::RemoteFx(
rdp::capability_sets::RemoteFxContainer::ClientContainer(c),
) if self.opts.with_remote_fx => {
for caps in c.caps_data.0 .0 {
update_codecs.set_remotefx(Some((caps.entropy_bits, codec.id)));
}
}
rdp::capability_sets::CodecProperty::ImageRemoteFx(
rdp::capability_sets::RemoteFxContainer::ClientContainer(c),
) if self.opts.with_remote_fx => {
for caps in c.caps_data.0 .0 {
update_codecs.set_remotefx(Some((caps.entropy_bits, codec.id)));
}
}
rdp::capability_sets::CodecProperty::NsCodec(_) => (),
_ => (),
}
}
}
_ => {}
}
}
let desktop_size = self.display.lock().await.size().await;
let encoder = UpdateEncoder::new(desktop_size, surface_flags, update_codecs);
let state = self
.client_loop(reader, writer, result.io_channel_id, result.user_channel_id, encoder)
.await
.context("client loop failure")?;
Ok(state)
}
async fn handle_input_backlog(
&mut self,
writer: &mut impl FramedWrite,
io_channel_id: u16,
user_channel_id: u16,
frames: Vec<Vec<u8>>,
) -> Result<()> {
for frame in frames {
match Action::from_fp_output_header(frame[0]) {
Ok(Action::FastPath) => {
let input = decode(&frame)?;
self.handle_fastpath(input).await;
}
Ok(Action::X224) => {
let _ = self.handle_x224(writer, io_channel_id, user_channel_id, &frame).await;
}
Err(_) => unreachable!(),
}
}
Ok(())
}
async fn handle_fastpath(&mut self, input: FastPathInput) {
for event in input.0 {
let mut handler = self.handler.lock().await;
match event {
FastPathInputEvent::KeyboardEvent(flags, key) => {
handler.keyboard((key, flags).into());
}
FastPathInputEvent::UnicodeKeyboardEvent(flags, key) => {
handler.keyboard((key, flags).into());
}
FastPathInputEvent::SyncEvent(flags) => {
handler.keyboard(flags.into());
}
FastPathInputEvent::MouseEvent(mouse) => {
handler.mouse(mouse.into());
}
FastPathInputEvent::MouseEventEx(mouse) => {
handler.mouse(mouse.into());
}
FastPathInputEvent::MouseEventRel(mouse) => {
handler.mouse(mouse.into());
}
FastPathInputEvent::QoeEvent(quality) => {
warn!("Received QoE: {}", quality);
}
}
}
}
async fn handle_io_channel_data(&mut self, data: SendDataRequest<'_>) -> Result<bool> {
let control: rdp::headers::ShareControlHeader = decode(data.user_data.as_ref())?;
match control.share_control_pdu {
ShareControlPdu::Data(header) => match header.share_data_pdu {
rdp::headers::ShareDataPdu::Input(pdu) => {
self.handle_input_event(pdu).await;
}
rdp::headers::ShareDataPdu::ShutdownRequest => {
return Ok(true);
}
unexpected => {
warn!(?unexpected, "Unexpected share data pdu");
}
},
unexpected => {
warn!(?unexpected, "Unexpected share control");
}
}
Ok(false)
}
async fn handle_x224(
&mut self,
writer: &mut impl FramedWrite,
io_channel_id: u16,
user_channel_id: u16,
frame: &[u8],
) -> Result<bool> {
let message = decode::<X224<mcs::McsMessage<'_>>>(frame)?;
match message.0 {
mcs::McsMessage::SendDataRequest(data) => {
debug!(?data, "McsMessage::SendDataRequest");
if data.channel_id == io_channel_id {
return self.handle_io_channel_data(data).await;
}
if let Some(svc) = self.static_channels.get_by_channel_id_mut(data.channel_id) {
let response_pdus = svc.process(&data.user_data)?;
let response = server_encode_svc_messages(response_pdus, data.channel_id, user_channel_id)?;
writer.write_all(&response).await?;
} else {
warn!(channel_id = data.channel_id, "Unexpected channel received: ID",);
}
}
mcs::McsMessage::DisconnectProviderUltimatum(disconnect) => {
if disconnect.reason == mcs::DisconnectReason::UserRequested {
return Ok(true);
}
}
_ => {
warn!(name = ironrdp_core::name(&message), "Unexpected mcs message");
}
}
Ok(false)
}
async fn handle_input_event(&mut self, input: InputEventPdu) {
for event in input.0 {
let mut handler = self.handler.lock().await;
match event {
ironrdp_pdu::input::InputEvent::ScanCode(key) => {
handler.keyboard((key.key_code, key.flags).into());
}
ironrdp_pdu::input::InputEvent::Unicode(key) => {
handler.keyboard((key.unicode_code, key.flags).into());
}
ironrdp_pdu::input::InputEvent::Sync(sync) => {
handler.keyboard(sync.flags.into());
}
ironrdp_pdu::input::InputEvent::Mouse(mouse) => {
handler.mouse(mouse.into());
}
ironrdp_pdu::input::InputEvent::MouseX(mouse) => {
handler.mouse(mouse.into());
}
ironrdp_pdu::input::InputEvent::MouseRel(mouse) => {
handler.mouse(mouse.into());
}
ironrdp_pdu::input::InputEvent::Unused(_) => {}
}
}
}
async fn accept_finalize<S>(&mut self, mut framed: TokioFramed<S>, mut acceptor: Acceptor) -> Result<()>
where
S: AsyncRead + AsyncWrite + Sync + Send + Unpin,
{
loop {
let (new_framed, result) = ironrdp_acceptor::accept_finalize(framed, &mut acceptor)
.await
.context("failed to accept client during finalize")?;
let (mut reader, mut writer) = split_tokio_framed(new_framed);
match self.client_accepted(&mut reader, &mut writer, result).await? {
RunState::Continue => {
unreachable!();
}
RunState::DeactivationReactivation { desktop_size } => {
acceptor = Acceptor::new_deactivation_reactivation(
acceptor,
core::mem::take(&mut self.static_channels),
desktop_size,
);
framed = unsplit_tokio_framed(reader, writer);
continue;
}
RunState::Disconnect => break,
}
}
Ok(())
}
pub fn set_credentials(&mut self, creds: Option<Credentials>) {
debug!(?creds, "Changing credentials");
self.creds = creds
}
}
async fn deactivate_all(
io_channel_id: u16,
user_channel_id: u16,
writer: &mut impl FramedWrite,
) -> Result<(), anyhow::Error> {
let pdu = ShareControlPdu::ServerDeactivateAll(ServerDeactivateAll);
let pdu = rdp::headers::ShareControlHeader {
share_id: 0,
pdu_source: io_channel_id,
share_control_pdu: pdu,
};
let user_data = encode_vec(&pdu)?.into();
let pdu = SendDataIndication {
initiator_id: user_channel_id,
channel_id: io_channel_id,
user_data,
};
let msg = encode_vec(&X224(pdu))?;
writer.write_all(&msg).await?;
Ok(())
}
struct SharedWriter<'w, W: FramedWrite> {
writer: Rc<Mutex<&'w mut W>>,
}
impl<W: FramedWrite> Clone for SharedWriter<'_, W> {
fn clone(&self) -> Self {
Self {
writer: Rc::clone(&self.writer),
}
}
}
impl<W> FramedWrite for SharedWriter<'_, W>
where
W: FramedWrite,
{
type WriteAllFut<'write>
= core::pin::Pin<Box<dyn core::future::Future<Output = std::io::Result<()>> + 'write>>
where
Self: 'write;
fn write_all<'a>(&'a mut self, buf: &'a [u8]) -> Self::WriteAllFut<'a> {
Box::pin(async {
let mut writer = self.writer.lock().await;
writer.write_all(buf).await?;
Ok(())
})
}
}
impl<'a, W: FramedWrite> SharedWriter<'a, W> {
fn new(writer: &'a mut W) -> Self {
Self {
writer: Rc::new(Mutex::new(writer)),
}
}
}