#![cfg(feature = "async-std-support")]
cfg_std_unix! {
use std::os::unix::io::AsRawFd;
}
cfg_std_windows! {
use std::os::windows::io::AsRawSocket;
}
use core::task::{Context, Poll};
use crate::{
connection::Connection,
display::{
AsyncDisplay, AsyncStatus, BasicDisplay, CanBeAsyncDisplay, DisplayBase, DisplayConnection,
Interest, RawReply, RawRequest,
},
Error, NameConnection, Result,
};
use alloc::{string::ToString, sync::Arc, vec, vec::Vec};
use async_io::Async;
use core::future::Future;
use tracing::Instrument;
use x11rb_protocol::{
connect::Connect,
parse_display,
protocol::{xproto::Setup, Event},
xauth,
};
cfg_std_unix! {
#[doc(hidden)]
pub trait Source: AsRawFd {}
impl<T: AsRawFd> Source for T {}
}
cfg_std_windows! {
#[doc(hidden)]
pub trait Source: AsRawSocket {}
impl<T: AsRawSocket> Source for T {}
}
impl<D: CanBeAsyncDisplay + Source> AsyncDisplay for Async<D> {
fn poll_for_interest(
&mut self,
interest: Interest,
callback: &mut dyn FnMut(&mut dyn AsyncDisplay, &mut Context<'_>) -> Result<()>,
ctx: &mut Context<'_>,
) -> Poll<Result<()>> {
let span = tracing::trace_span!(
"async_std_support::poll_for_interest",
interest = ?interest
);
let _enter = span.enter();
match poll_ready(self, interest, ctx) {
Poll::Ready(Ok(())) => {}
poll => return poll,
}
match callback(self, ctx) {
Err(e) if e.would_block() => {
ctx.waker().wake_by_ref();
Poll::Pending
}
poll => Poll::Ready(poll),
}
}
}
impl<'lt, D: DisplayBase + Source> AsyncDisplay for &'lt Async<D>
where
&'lt D: CanBeAsyncDisplay,
{
fn poll_for_interest(
&mut self,
interest: Interest,
callback: &mut dyn FnMut(&mut dyn AsyncDisplay, &mut Context<'_>) -> Result<()>,
ctx: &mut Context<'_>,
) -> Poll<Result<()>> {
let span = tracing::trace_span!(
"async_std_support::poll_for_interest",
interest = ?interest
);
let _enter = span.enter();
match poll_ready(self, interest, ctx) {
Poll::Ready(Ok(())) => {}
poll => return poll,
}
match callback(self, ctx) {
Err(e) if e.would_block() => {
ctx.waker().wake_by_ref();
Poll::Pending
}
poll => Poll::Ready(poll),
}
}
}
fn poll_ready<D>(a: &Async<D>, interest: Interest, ctx: &mut Context<'_>) -> Poll<Result<()>> {
tracing::trace!("polling for interest in {:?}", interest);
let res = match interest {
Interest::Readable => a.poll_readable(ctx),
Interest::Writable => a.poll_writable(ctx),
};
tracing::trace!(is_ready = res.is_ready(), "polled for readiness");
res.map_err(Error::io)
}
pub fn connect(name: Option<&str>) -> impl Future<Output = Result<Async<DisplayConnection>>> {
let name = name.map(ToString::to_string);
async move {
let dpy = parse_display::parse_display(name.as_deref())
.ok_or_else(|| Error::couldnt_parse_display(name.is_none()))?;
let screen = dpy.screen;
let display_num = dpy.display;
let conn =
NameConnection::from_parsed_display_async(&dpy, name.is_none(), |name| async move {
let registered = Async::new(name).map_err(Error::io)?;
registered.writable().await.map_err(Error::io)?;
let name = registered.into_inner().map_err(Error::io)?;
if let Some(err) = name.take_error() {
Err(err)
} else {
Ok(name)
}
})
.await?;
let (family, address) = conn.get_address()?;
let (name, data) = blocking::unblock(move || {
match xauth::get_auth(family, &address, display_num).map_err(Error::io) {
Err(e) => Err(e),
Ok(Some(auth)) => Ok(auth),
Ok(None) => {
tracing::warn!("No Xauth found for display {}", display_num);
Ok((vec![], vec![]))
}
}
})
.await?;
establish_connect(conn.into(), screen as usize, name, data).await
}
}
pub fn establish_connect<Conn: Source + Connection>(
conn: Conn,
default_screen: usize,
auth_name: Vec<u8>,
auth_data: Vec<u8>,
) -> impl Future<Output = Result<Async<BasicDisplay<Conn>>>> {
let span = tracing::info_span!("establish_connect");
async move {
let (mut connect, setup_request) = Connect::with_authorization(auth_name, auth_data);
let mut registered = Async::new(conn).map_err(Error::io)?;
let mut written = 0;
while written < setup_request.len() {
write_with_mut(&mut registered, |conn| {
let n = conn.send_slice(&setup_request[written..])?;
written += n;
Ok(())
})
.await?;
}
write_with_mut(&mut registered, Connection::flush).await?;
loop {
let adv =
read_with_mut(&mut registered, |conn| conn.recv_slice(connect.buffer())).await?;
if connect.advance(adv) {
break;
}
}
let setup = connect.into_setup().map_err(Error::make_connect_error)?;
let dpy = BasicDisplay::with_connection(
registered.into_inner().map_err(Error::io)?,
setup,
default_screen,
)?;
Async::new(dpy).map_err(Error::io)
}
.instrument(span)
}
async fn write_with_mut<D, R: Default>(
a: &mut Async<D>,
mut f: impl FnMut(&mut D) -> Result<R>,
) -> Result<R> {
let mut res: Result<()> = Ok(());
let io_res = a
.write_with_mut(|conn| match f(conn) {
Ok(r) => Ok(r),
Err(e) => match e.into_io_error() {
Ok(e) => Err(e),
Err(e) => {
res = Err(e);
Ok(Default::default())
}
},
})
.await;
res.and(io_res.map_err(Error::io))
}
async fn read_with_mut<D, R: Default>(
a: &mut Async<D>,
mut f: impl FnMut(&mut D) -> Result<R>,
) -> Result<R> {
let mut res: Result<()> = Ok(());
let io_res = a
.read_with_mut(|conn| match f(conn) {
Ok(r) => Ok(r),
Err(e) => match e.into_io_error() {
Ok(e) => Err(e),
Err(e) => {
res = Err(e);
Ok(Default::default())
}
},
})
.await;
res.and(io_res.map_err(Error::io))
}
impl<D: DisplayBase> DisplayBase for Async<D> {
fn setup(&self) -> &Arc<Setup> {
self.get_ref().setup()
}
fn default_screen_index(&self) -> usize {
self.get_ref().default_screen_index()
}
fn poll_for_event(&mut self) -> Result<Option<Event>> {
self.get_mut().poll_for_event()
}
fn poll_for_reply_raw(&mut self, seq: u64) -> Result<Option<RawReply>> {
self.get_mut().poll_for_reply_raw(seq)
}
}
impl<'lt, D: DisplayBase> DisplayBase for &'lt Async<D>
where
&'lt D: DisplayBase,
{
fn setup(&self) -> &Arc<Setup> {
self.get_ref().setup()
}
fn default_screen_index(&self) -> usize {
self.get_ref().default_screen_index()
}
fn poll_for_event(&mut self) -> Result<Option<Event>> {
self.get_ref().poll_for_event()
}
fn poll_for_reply_raw(&mut self, seq: u64) -> Result<Option<RawReply>> {
self.get_ref().poll_for_reply_raw(seq)
}
}
impl<D: CanBeAsyncDisplay> CanBeAsyncDisplay for Async<D> {
fn format_request(
&mut self,
req: &mut RawRequest<'_, '_>,
ctx: &mut Context<'_>,
) -> Result<AsyncStatus<u64>> {
self.get_mut().format_request(req, ctx)
}
fn try_send_request_raw(
&mut self,
req: &mut RawRequest<'_, '_>,
ctx: &mut Context<'_>,
) -> Result<AsyncStatus<()>> {
self.get_mut().try_send_request_raw(req, ctx)
}
fn try_wait_for_reply_raw(
&mut self,
seq: u64,
ctx: &mut Context<'_>,
) -> Result<AsyncStatus<RawReply>> {
self.get_mut().try_wait_for_reply_raw(seq, ctx)
}
fn try_wait_for_event(&mut self, ctx: &mut Context<'_>) -> Result<AsyncStatus<Event>> {
self.get_mut().try_wait_for_event(ctx)
}
fn try_flush(&mut self, ctx: &mut Context<'_>) -> Result<AsyncStatus<()>> {
self.get_mut().try_flush(ctx)
}
fn try_generate_xid(&mut self, ctx: &mut Context<'_>) -> Result<AsyncStatus<u32>> {
self.get_mut().try_generate_xid(ctx)
}
fn try_maximum_request_length(&mut self, ctx: &mut Context<'_>) -> Result<AsyncStatus<usize>> {
self.get_mut().try_maximum_request_length(ctx)
}
fn try_check_for_error(&mut self, seq: u64, ctx: &mut Context<'_>) -> Result<AsyncStatus<()>> {
self.get_mut().try_check_for_error(seq, ctx)
}
}
impl<'lt, D: DisplayBase> CanBeAsyncDisplay for &'lt Async<D>
where
&'lt D: CanBeAsyncDisplay,
{
fn format_request(
&mut self,
req: &mut RawRequest<'_, '_>,
ctx: &mut Context<'_>,
) -> Result<AsyncStatus<u64>> {
self.get_ref().format_request(req, ctx)
}
fn try_send_request_raw(
&mut self,
req: &mut RawRequest<'_, '_>,
ctx: &mut Context<'_>,
) -> Result<AsyncStatus<()>> {
self.get_ref().try_send_request_raw(req, ctx)
}
fn try_wait_for_event(&mut self, ctx: &mut Context<'_>) -> Result<AsyncStatus<Event>> {
self.get_ref().try_wait_for_event(ctx)
}
fn try_maximum_request_length(&mut self, ctx: &mut Context<'_>) -> Result<AsyncStatus<usize>> {
self.get_ref().try_maximum_request_length(ctx)
}
fn try_wait_for_reply_raw(
&mut self,
seq: u64,
ctx: &mut Context<'_>,
) -> Result<AsyncStatus<RawReply>> {
self.get_ref().try_wait_for_reply_raw(seq, ctx)
}
fn try_generate_xid(&mut self, ctx: &mut Context<'_>) -> Result<AsyncStatus<u32>> {
self.get_ref().try_generate_xid(ctx)
}
fn try_flush(&mut self, ctx: &mut Context<'_>) -> Result<AsyncStatus<()>> {
self.get_ref().try_flush(ctx)
}
fn try_check_for_error(&mut self, seq: u64, ctx: &mut Context<'_>) -> Result<AsyncStatus<()>> {
self.get_ref().try_check_for_error(seq, ctx)
}
}