use std::io::IoSlice;
use std::sync::{Condvar, Mutex, MutexGuard, TryLockError};
use std::time::Instant;
use crate::connection::{
compute_length_field, Connection, ReplyOrError, RequestConnection, RequestKind,
};
use crate::cookie::{Cookie, CookieWithFds, VoidCookie};
use crate::errors::DisplayParsingError;
pub use crate::errors::{ConnectError, ConnectionError, ParseError, ReplyError, ReplyOrIdError};
use crate::extension_manager::ExtensionManager;
use crate::protocol::bigreq::{ConnectionExt as _, EnableReply};
use crate::protocol::xproto::{Setup, GET_INPUT_FOCUS_REQUEST, QUERY_EXTENSION_REQUEST};
use crate::utils::RawFdContainer;
use crate::x11_utils::{ExtensionInformation, TryParse, TryParseFd};
use x11rb_protocol::connect::Connect;
use x11rb_protocol::connection::{Connection as ProtoConnection, PollReply, ReplyFdKind};
use x11rb_protocol::id_allocator::IdAllocator;
use x11rb_protocol::{xauth::get_auth, DiscardMode, RawEventAndSeqNumber, SequenceNumber};
mod packet_reader;
mod stream;
mod write_buffer;
use packet_reader::PacketReader;
pub use stream::{DefaultStream, PollMode, Stream};
use write_buffer::WriteBuffer;
type Buffer = <RustConnection as RequestConnection>::Buf;
pub type BufWithFds = crate::connection::BufWithFds<Buffer>;
#[derive(Debug)]
enum MaxRequestBytes {
Unknown,
Requested(Option<SequenceNumber>),
Known(usize),
}
#[derive(Debug)]
struct ConnectionInner {
inner: ProtoConnection,
write_buffer: WriteBuffer,
}
type MutexGuardInner<'a> = MutexGuard<'a, ConnectionInner>;
#[derive(Debug, Copy, Clone, PartialEq, Eq)]
pub(crate) enum BlockingMode {
Blocking,
NonBlocking,
}
#[derive(Debug)]
pub struct RustConnection<S: Stream = DefaultStream> {
inner: Mutex<ConnectionInner>,
stream: S,
packet_reader: Mutex<PacketReader>,
reader_condition: Condvar,
setup: Setup,
extension_manager: Mutex<ExtensionManager>,
maximum_request_bytes: Mutex<MaxRequestBytes>,
id_allocator: Mutex<IdAllocator>,
}
impl RustConnection<DefaultStream> {
pub fn connect(dpy_name: Option<&str>) -> Result<(Self, usize), ConnectError> {
let parsed_display = x11rb_protocol::parse_display::parse_display(dpy_name)?;
let screen = parsed_display.screen.into();
let mut error = None;
for addr in parsed_display.connect_instruction() {
let start = Instant::now();
match DefaultStream::connect(&addr) {
Ok((stream, (family, address))) => {
crate::trace!(
"Connected to X11 server via {:?} in {:?}",
addr,
start.elapsed()
);
let (auth_name, auth_data) = get_auth(family, &address, parsed_display.display)
.unwrap_or(None)
.unwrap_or_else(|| (Vec::new(), Vec::new()));
crate::trace!("Picked authentication via auth mechanism {:?}", auth_name);
return Ok((
Self::connect_to_stream_with_auth_info(
stream, screen, auth_name, auth_data,
)?,
screen,
));
}
Err(e) => {
crate::debug!("Failed to connect to X11 server via {:?}: {:?}", addr, e);
error = Some(e);
continue;
}
}
}
Err(match error {
Some(e) => ConnectError::IoError(e),
None => DisplayParsingError::Unknown.into(),
})
}
}
impl<S: Stream> RustConnection<S> {
pub fn connect_to_stream(stream: S, screen: usize) -> Result<Self, ConnectError> {
Self::connect_to_stream_with_auth_info(stream, screen, Vec::new(), Vec::new())
}
pub fn connect_to_stream_with_auth_info(
stream: S,
screen: usize,
auth_name: Vec<u8>,
auth_data: Vec<u8>,
) -> Result<Self, ConnectError> {
let (mut connect, setup_request) = Connect::with_authorization(auth_name, auth_data);
let mut nwritten = 0;
let mut fds = vec![];
crate::trace!(
"Writing connection setup with {} bytes",
setup_request.len()
);
while nwritten != setup_request.len() {
stream.poll(PollMode::Writable)?;
match stream.write(&setup_request[nwritten..], &mut fds) {
Ok(0) => {
return Err(std::io::Error::new(
std::io::ErrorKind::WriteZero,
"failed to write whole buffer",
)
.into())
}
Ok(n) => nwritten += n,
Err(ref e) if e.kind() == std::io::ErrorKind::WouldBlock => {}
Err(e) => return Err(e.into()),
}
}
loop {
stream.poll(PollMode::Readable)?;
crate::trace!(
"Reading connection setup with at least {} bytes remaining",
connect.buffer().len()
);
let adv = match stream.read(connect.buffer(), &mut fds) {
Ok(0) => {
return Err(std::io::Error::new(
std::io::ErrorKind::UnexpectedEof,
"failed to read whole buffer",
)
.into())
}
Ok(n) => n,
Err(ref e) if e.kind() == std::io::ErrorKind::WouldBlock => continue,
Err(e) => return Err(e.into()),
};
crate::trace!("Read {} bytes", adv);
if connect.advance(adv) {
break;
}
}
let setup = connect.into_setup()?;
if screen >= setup.roots.len() {
return Err(ConnectError::InvalidScreen);
}
Self::for_connected_stream(stream, setup)
}
pub fn for_connected_stream(stream: S, setup: Setup) -> Result<Self, ConnectError> {
let id_allocator = IdAllocator::new(setup.resource_id_base, setup.resource_id_mask)?;
Ok(RustConnection {
inner: Mutex::new(ConnectionInner {
inner: ProtoConnection::new(),
write_buffer: WriteBuffer::new(),
}),
stream,
packet_reader: Mutex::new(PacketReader::new()),
reader_condition: Condvar::new(),
setup,
extension_manager: Default::default(),
maximum_request_bytes: Mutex::new(MaxRequestBytes::Unknown),
id_allocator: Mutex::new(id_allocator),
})
}
fn send_request(
&self,
bufs: &[IoSlice<'_>],
fds: Vec<RawFdContainer>,
kind: ReplyFdKind,
) -> Result<SequenceNumber, ConnectionError> {
let _guard = crate::debug_span!("send_request").entered();
let request_info = RequestInfo {
extension_manager: &self.extension_manager,
major_opcode: bufs[0][0],
minor_opcode: bufs[0][1],
};
crate::debug!("Sending {}", request_info);
let mut storage = Default::default();
let bufs = compute_length_field(self, bufs, &mut storage)?;
let mut inner = self.inner.lock().unwrap();
loop {
let send_result = inner.inner.send_request(kind);
match send_result {
Some(seqno) => {
let _inner = self.write_all_vectored(inner, bufs, fds)?;
return Ok(seqno);
}
None => {
crate::trace!("Syncing with the X11 server since there are too many outstanding void requests");
inner = self.send_sync(inner)?;
}
}
}
}
fn send_sync<'a>(
&'a self,
mut inner: MutexGuardInner<'a>,
) -> Result<MutexGuardInner<'a>, std::io::Error> {
let length = 1u16.to_ne_bytes();
let request = [
GET_INPUT_FOCUS_REQUEST,
0,
length[0],
length[1],
];
let seqno = inner
.inner
.send_request(ReplyFdKind::ReplyWithoutFDs)
.expect("Sending a HasResponse request should not be blocked by syncs");
inner
.inner
.discard_reply(seqno, DiscardMode::DiscardReplyAndError);
let inner = self.write_all_vectored(inner, &[IoSlice::new(&request)], Vec::new())?;
Ok(inner)
}
fn write_all_vectored<'a>(
&'a self,
mut inner: MutexGuardInner<'a>,
mut bufs: &[IoSlice<'_>],
mut fds: Vec<RawFdContainer>,
) -> std::io::Result<MutexGuardInner<'a>> {
let mut partial_buf: &[u8] = &[];
while !partial_buf.is_empty() || !bufs.is_empty() {
self.stream.poll(PollMode::ReadAndWritable)?;
let write_result = if !partial_buf.is_empty() {
inner
.write_buffer
.write(&self.stream, partial_buf, &mut fds)
} else {
inner
.write_buffer
.write_vectored(&self.stream, bufs, &mut fds)
};
match write_result {
Ok(0) => {
return Err(std::io::Error::new(
std::io::ErrorKind::WriteZero,
"failed to write anything",
));
}
Ok(mut count) => {
if count >= partial_buf.len() {
count -= partial_buf.len();
partial_buf = &[];
} else {
partial_buf = &partial_buf[count..];
count = 0;
}
while count > 0 {
if count >= bufs[0].len() {
count -= bufs[0].len();
} else {
partial_buf = &bufs[0][count..];
count = 0;
}
bufs = &bufs[1..];
while bufs.first().map(|s| s.len()) == Some(0) {
bufs = &bufs[1..];
}
}
}
Err(ref e) if e.kind() == std::io::ErrorKind::WouldBlock => {
crate::trace!("Writing more data would block for now");
inner = self.read_packet_and_enqueue(inner, BlockingMode::NonBlocking)?;
}
Err(e) => return Err(e),
}
}
if !fds.is_empty() {
return Err(std::io::Error::new(
std::io::ErrorKind::Other,
"Left over FDs after sending the request",
));
}
Ok(inner)
}
fn flush_impl<'a>(
&'a self,
mut inner: MutexGuardInner<'a>,
) -> std::io::Result<MutexGuardInner<'a>> {
while inner.write_buffer.needs_flush() {
self.stream.poll(PollMode::ReadAndWritable)?;
let flush_result = inner.write_buffer.flush(&self.stream);
match flush_result {
Ok(()) => break,
Err(ref e) if e.kind() == std::io::ErrorKind::WouldBlock => {
crate::trace!("Flushing more data would block for now");
inner = self.read_packet_and_enqueue(inner, BlockingMode::NonBlocking)?;
}
Err(e) => return Err(e),
}
}
Ok(inner)
}
fn read_packet_and_enqueue<'a>(
&'a self,
mut inner: MutexGuardInner<'a>,
mode: BlockingMode,
) -> Result<MutexGuardInner<'a>, std::io::Error> {
match self.packet_reader.try_lock() {
Err(TryLockError::WouldBlock) => {
match mode {
BlockingMode::NonBlocking => {
crate::trace!("read_packet_and_enqueue in NonBlocking mode doing nothing since reader is already locked");
return Ok(inner);
}
BlockingMode::Blocking => {
crate::trace!("read_packet_and_enqueue in Blocking mode waiting for pre-existing reader");
}
}
Ok(self.reader_condition.wait(inner).unwrap())
}
Err(TryLockError::Poisoned(e)) => panic!("{}", e),
Ok(mut packet_reader) => {
let notify_on_drop = NotifyOnDrop(&self.reader_condition);
if mode == BlockingMode::Blocking {
drop(inner);
self.stream.poll(PollMode::Readable)?;
inner = self.inner.lock().unwrap();
}
let mut fds = Vec::new();
let mut packets = Vec::new();
packet_reader.try_read_packets(&self.stream, &mut packets, &mut fds)?;
drop(packet_reader);
inner.inner.enqueue_fds(fds);
packets
.into_iter()
.for_each(|packet| inner.inner.enqueue_packet(packet));
drop(notify_on_drop);
Ok(inner)
}
}
}
fn prefetch_maximum_request_bytes_impl(&self, max_bytes: &mut MutexGuard<'_, MaxRequestBytes>) {
if let MaxRequestBytes::Unknown = **max_bytes {
crate::info!("Prefetching maximum request length");
let request = self
.bigreq_enable()
.map(|cookie| cookie.into_sequence_number())
.ok();
**max_bytes = MaxRequestBytes::Requested(request);
}
}
pub fn stream(&self) -> &S {
&self.stream
}
}
impl<S: Stream> RequestConnection for RustConnection<S> {
type Buf = Vec<u8>;
fn send_request_with_reply<Reply>(
&self,
bufs: &[IoSlice<'_>],
fds: Vec<RawFdContainer>,
) -> Result<Cookie<'_, Self, Reply>, ConnectionError>
where
Reply: TryParse,
{
Ok(Cookie::new(
self,
self.send_request(bufs, fds, ReplyFdKind::ReplyWithoutFDs)?,
))
}
fn send_request_with_reply_with_fds<Reply>(
&self,
bufs: &[IoSlice<'_>],
fds: Vec<RawFdContainer>,
) -> Result<CookieWithFds<'_, Self, Reply>, ConnectionError>
where
Reply: TryParseFd,
{
Ok(CookieWithFds::new(
self,
self.send_request(bufs, fds, ReplyFdKind::ReplyWithFDs)?,
))
}
fn send_request_without_reply(
&self,
bufs: &[IoSlice<'_>],
fds: Vec<RawFdContainer>,
) -> Result<VoidCookie<'_, Self>, ConnectionError> {
Ok(VoidCookie::new(
self,
self.send_request(bufs, fds, ReplyFdKind::NoReply)?,
))
}
fn discard_reply(&self, sequence: SequenceNumber, _kind: RequestKind, mode: DiscardMode) {
crate::debug!(
"Discarding reply to request {} in mode {:?}",
sequence,
mode
);
self.inner
.lock()
.unwrap()
.inner
.discard_reply(sequence, mode);
}
fn prefetch_extension_information(
&self,
extension_name: &'static str,
) -> Result<(), ConnectionError> {
self.extension_manager
.lock()
.unwrap()
.prefetch_extension_information(self, extension_name)
}
fn extension_information(
&self,
extension_name: &'static str,
) -> Result<Option<ExtensionInformation>, ConnectionError> {
self.extension_manager
.lock()
.unwrap()
.extension_information(self, extension_name)
}
fn wait_for_reply_or_raw_error(
&self,
sequence: SequenceNumber,
) -> Result<ReplyOrError<Vec<u8>>, ConnectionError> {
match self.wait_for_reply_with_fds_raw(sequence)? {
ReplyOrError::Reply((reply, _fds)) => Ok(ReplyOrError::Reply(reply)),
ReplyOrError::Error(e) => Ok(ReplyOrError::Error(e)),
}
}
fn wait_for_reply(&self, sequence: SequenceNumber) -> Result<Option<Vec<u8>>, ConnectionError> {
let _guard = crate::debug_span!("wait_for_reply", sequence).entered();
let mut inner = self.inner.lock().unwrap();
inner = self.flush_impl(inner)?;
loop {
crate::trace!({ sequence }, "Polling for reply");
let poll_result = inner.inner.poll_for_reply(sequence);
match poll_result {
PollReply::TryAgain => {}
PollReply::NoReply => return Ok(None),
PollReply::Reply(buffer) => return Ok(Some(buffer)),
}
inner = self.read_packet_and_enqueue(inner, BlockingMode::Blocking)?;
}
}
fn check_for_raw_error(
&self,
sequence: SequenceNumber,
) -> Result<Option<Buffer>, ConnectionError> {
let _guard = crate::debug_span!("check_for_raw_error", sequence).entered();
let mut inner = self.inner.lock().unwrap();
if inner.inner.prepare_check_for_reply_or_error(sequence) {
crate::trace!("Inserting sync with the X11 server");
inner = self.send_sync(inner)?;
assert!(!inner.inner.prepare_check_for_reply_or_error(sequence));
}
inner = self.flush_impl(inner)?;
loop {
crate::trace!({ sequence }, "Polling for reply or error");
let poll_result = inner.inner.poll_check_for_reply_or_error(sequence);
match poll_result {
PollReply::TryAgain => {}
PollReply::NoReply => return Ok(None),
PollReply::Reply(buffer) => return Ok(Some(buffer)),
}
inner = self.read_packet_and_enqueue(inner, BlockingMode::Blocking)?;
}
}
fn wait_for_reply_with_fds_raw(
&self,
sequence: SequenceNumber,
) -> Result<ReplyOrError<BufWithFds, Buffer>, ConnectionError> {
let _guard = crate::debug_span!("wait_for_reply_with_fds_raw", sequence).entered();
let mut inner = self.inner.lock().unwrap();
inner = self.flush_impl(inner)?;
loop {
crate::trace!({ sequence }, "Polling for reply or error");
if let Some(reply) = inner.inner.poll_for_reply_or_error(sequence) {
if reply.0[0] == 0 {
crate::trace!("Got error");
return Ok(ReplyOrError::Error(reply.0));
} else {
crate::trace!("Got reply");
return Ok(ReplyOrError::Reply(reply));
}
}
inner = self.read_packet_and_enqueue(inner, BlockingMode::Blocking)?;
}
}
fn maximum_request_bytes(&self) -> usize {
let mut max_bytes = self.maximum_request_bytes.lock().unwrap();
self.prefetch_maximum_request_bytes_impl(&mut max_bytes);
use MaxRequestBytes::*;
let max_bytes = &mut *max_bytes;
match max_bytes {
Unknown => unreachable!("We just prefetched this"),
Requested(seqno) => {
let _guard = crate::info_span!("maximum_request_bytes").entered();
let length = seqno
.and_then(|seqno| {
Cookie::<_, EnableReply>::new(self, seqno)
.reply()
.map(|reply| reply.maximum_request_length)
.ok()
})
.unwrap_or_else(|| self.setup.maximum_request_length.into())
.try_into()
.unwrap_or(usize::MAX);
let length = length * 4;
*max_bytes = Known(length);
crate::info!("Maximum request length is {} bytes", length);
length
}
Known(length) => *length,
}
}
fn prefetch_maximum_request_bytes(&self) {
let mut max_bytes = self.maximum_request_bytes.lock().unwrap();
self.prefetch_maximum_request_bytes_impl(&mut max_bytes);
}
fn parse_error(&self, error: &[u8]) -> Result<crate::x11_utils::X11Error, ParseError> {
let ext_mgr = self.extension_manager.lock().unwrap();
crate::x11_utils::X11Error::try_parse(error, &*ext_mgr)
}
fn parse_event(&self, event: &[u8]) -> Result<crate::protocol::Event, ParseError> {
let ext_mgr = self.extension_manager.lock().unwrap();
crate::protocol::Event::parse(event, &*ext_mgr)
}
}
impl<S: Stream> Connection for RustConnection<S> {
fn wait_for_raw_event_with_sequence(
&self,
) -> Result<RawEventAndSeqNumber<Vec<u8>>, ConnectionError> {
let _guard = crate::trace_span!("wait_for_raw_event_with_sequence").entered();
let mut inner = self.inner.lock().unwrap();
loop {
if let Some(event) = inner.inner.poll_for_event_with_sequence() {
return Ok(event);
}
inner = self.read_packet_and_enqueue(inner, BlockingMode::Blocking)?;
}
}
fn poll_for_raw_event_with_sequence(
&self,
) -> Result<Option<RawEventAndSeqNumber<Vec<u8>>>, ConnectionError> {
let _guard = crate::trace_span!("poll_for_raw_event_with_sequence").entered();
let mut inner = self.inner.lock().unwrap();
if let Some(event) = inner.inner.poll_for_event_with_sequence() {
Ok(Some(event))
} else {
inner = self.read_packet_and_enqueue(inner, BlockingMode::NonBlocking)?;
Ok(inner.inner.poll_for_event_with_sequence())
}
}
fn flush(&self) -> Result<(), ConnectionError> {
let inner = self.inner.lock().unwrap();
let _inner = self.flush_impl(inner)?;
Ok(())
}
fn setup(&self) -> &Setup {
&self.setup
}
fn generate_id(&self) -> Result<u32, ReplyOrIdError> {
let mut id_allocator = self.id_allocator.lock().unwrap();
if let Some(id) = id_allocator.generate_id() {
Ok(id)
} else {
use crate::protocol::xc_misc::{self, ConnectionExt as _};
if self
.extension_information(xc_misc::X11_EXTENSION_NAME)?
.is_none()
{
crate::error!("XIDs are exhausted and XC-MISC extension is not available");
Err(ReplyOrIdError::IdsExhausted)
} else {
crate::info!("XIDs are exhausted; fetching free range via XC-MISC");
id_allocator.update_xid_range(&self.xc_misc_get_xid_range()?.reply()?)?;
id_allocator
.generate_id()
.ok_or(ReplyOrIdError::IdsExhausted)
}
}
}
}
#[derive(Debug)]
struct NotifyOnDrop<'a>(&'a Condvar);
impl Drop for NotifyOnDrop<'_> {
fn drop(&mut self) {
self.0.notify_all();
}
}
struct RequestInfo<'a> {
extension_manager: &'a Mutex<ExtensionManager>,
major_opcode: u8,
minor_opcode: u8,
}
impl std::fmt::Display for RequestInfo<'_> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
if self.major_opcode == QUERY_EXTENSION_REQUEST {
write!(f, "QueryExtension request")
} else {
let guard = self.extension_manager.lock().unwrap();
write!(
f,
"{} request",
x11rb_protocol::protocol::get_request_name(
&*guard,
self.major_opcode,
self.minor_opcode
)
)
}
}
}