use std::io::IoSlice;
use x11rb_protocol::x11_utils::{ReplyFDsRequest, ReplyRequest, VoidRequest};
use crate::cookie::{Cookie, CookieWithFds, VoidCookie};
use crate::errors::{ConnectionError, ParseError, ReplyError, ReplyOrIdError};
use crate::protocol::xproto::Setup;
use crate::protocol::Event;
use crate::utils::RawFdContainer;
use crate::x11_utils::{ExtensionInformation, TryParse, TryParseFd, X11Error};
pub use x11rb_protocol::{DiscardMode, RawEventAndSeqNumber, SequenceNumber};
mod impls;
pub type BufWithFds<B> = (B, Vec<RawFdContainer>);
pub type EventAndSeqNumber = (Event, SequenceNumber);
#[derive(Debug)]
pub enum ReplyOrError<R, E = R>
where
R: std::fmt::Debug,
E: AsRef<[u8]> + std::fmt::Debug,
{
Reply(R),
Error(E),
}
pub trait RequestConnection {
type Buf: AsRef<[u8]> + std::fmt::Debug + Send + Sync + 'static;
fn send_request_with_reply<R>(
&self,
bufs: &[IoSlice<'_>],
fds: Vec<RawFdContainer>,
) -> Result<Cookie<'_, Self, R>, ConnectionError>
where
R: TryParse;
fn send_trait_request_with_reply<R>(
&self,
request: R,
) -> Result<Cookie<'_, Self, <R as ReplyRequest>::Reply>, ConnectionError>
where
R: ReplyRequest,
{
let opcode = match R::EXTENSION_NAME {
None => 0,
Some(extension) => {
self.extension_information(extension)?
.ok_or(ConnectionError::UnsupportedExtension)?
.major_opcode
}
};
let (buf, fds) = request.serialize(opcode);
self.send_request_with_reply(&[IoSlice::new(&buf)], fds)
}
fn send_request_with_reply_with_fds<R>(
&self,
bufs: &[IoSlice<'_>],
fds: Vec<RawFdContainer>,
) -> Result<CookieWithFds<'_, Self, R>, ConnectionError>
where
R: TryParseFd;
fn send_trait_request_with_reply_with_fds<R>(
&self,
request: R,
) -> Result<CookieWithFds<'_, Self, R::Reply>, ConnectionError>
where
R: ReplyFDsRequest,
{
let opcode = match R::EXTENSION_NAME {
None => 0,
Some(extension) => {
self.extension_information(extension)?
.ok_or(ConnectionError::UnsupportedExtension)?
.major_opcode
}
};
let (buf, fds) = request.serialize(opcode);
self.send_request_with_reply_with_fds(&[IoSlice::new(&buf)], fds)
}
fn send_request_without_reply(
&self,
bufs: &[IoSlice<'_>],
fds: Vec<RawFdContainer>,
) -> Result<VoidCookie<'_, Self>, ConnectionError>;
fn send_trait_request_without_reply<R>(
&self,
request: R,
) -> Result<VoidCookie<'_, Self>, ConnectionError>
where
R: VoidRequest,
{
let opcode = match R::EXTENSION_NAME {
None => 0,
Some(extension) => {
self.extension_information(extension)?
.ok_or(ConnectionError::UnsupportedExtension)?
.major_opcode
}
};
let (buf, fds) = request.serialize(opcode);
self.send_request_without_reply(&[IoSlice::new(&buf)], fds)
}
fn discard_reply(&self, sequence: SequenceNumber, kind: RequestKind, mode: DiscardMode);
fn prefetch_extension_information(
&self,
extension_name: &'static str,
) -> Result<(), ConnectionError>;
fn extension_information(
&self,
extension_name: &'static str,
) -> Result<Option<ExtensionInformation>, ConnectionError>;
fn wait_for_reply_or_error(&self, sequence: SequenceNumber) -> Result<Self::Buf, ReplyError> {
match self.wait_for_reply_or_raw_error(sequence)? {
ReplyOrError::Reply(reply) => Ok(reply),
ReplyOrError::Error(error) => {
Err(ReplyError::X11Error(self.parse_error(error.as_ref())?))
}
}
}
fn wait_for_reply_or_raw_error(
&self,
sequence: SequenceNumber,
) -> Result<ReplyOrError<Self::Buf>, ConnectionError>;
fn wait_for_reply(
&self,
sequence: SequenceNumber,
) -> Result<Option<Self::Buf>, ConnectionError>;
fn wait_for_reply_with_fds(
&self,
sequence: SequenceNumber,
) -> Result<BufWithFds<Self::Buf>, ReplyError> {
match self.wait_for_reply_with_fds_raw(sequence)? {
ReplyOrError::Reply(reply) => Ok(reply),
ReplyOrError::Error(error) => {
Err(ReplyError::X11Error(self.parse_error(error.as_ref())?))
}
}
}
fn wait_for_reply_with_fds_raw(
&self,
sequence: SequenceNumber,
) -> Result<ReplyOrError<BufWithFds<Self::Buf>, Self::Buf>, ConnectionError>;
fn check_for_error(&self, sequence: SequenceNumber) -> Result<(), ReplyError> {
match self.check_for_raw_error(sequence)? {
Some(err) => Err(self.parse_error(err.as_ref())?.into()),
None => Ok(()),
}
}
fn check_for_raw_error(
&self,
sequence: SequenceNumber,
) -> Result<Option<Self::Buf>, ConnectionError>;
fn prefetch_maximum_request_bytes(&self);
fn maximum_request_bytes(&self) -> usize;
fn parse_error(&self, error: &[u8]) -> Result<X11Error, ParseError>;
fn parse_event(&self, event: &[u8]) -> Result<Event, ParseError>;
}
pub trait Connection: RequestConnection {
fn wait_for_event(&self) -> Result<Event, ConnectionError> {
Ok(self.wait_for_event_with_sequence()?.0)
}
fn wait_for_raw_event(&self) -> Result<Self::Buf, ConnectionError> {
Ok(self.wait_for_raw_event_with_sequence()?.0)
}
fn wait_for_event_with_sequence(&self) -> Result<EventAndSeqNumber, ConnectionError> {
let (event, seq) = self.wait_for_raw_event_with_sequence()?;
let event = self.parse_event(event.as_ref())?;
Ok((event, seq))
}
fn wait_for_raw_event_with_sequence(
&self,
) -> Result<RawEventAndSeqNumber<Self::Buf>, ConnectionError>;
fn poll_for_event(&self) -> Result<Option<Event>, ConnectionError> {
Ok(self.poll_for_event_with_sequence()?.map(|r| r.0))
}
fn poll_for_raw_event(&self) -> Result<Option<Self::Buf>, ConnectionError> {
Ok(self.poll_for_raw_event_with_sequence()?.map(|r| r.0))
}
fn poll_for_event_with_sequence(&self) -> Result<Option<EventAndSeqNumber>, ConnectionError> {
Ok(match self.poll_for_raw_event_with_sequence()? {
Some((event, seq)) => Some((self.parse_event(event.as_ref())?, seq)),
None => None,
})
}
fn poll_for_raw_event_with_sequence(
&self,
) -> Result<Option<RawEventAndSeqNumber<Self::Buf>>, ConnectionError>;
fn flush(&self) -> Result<(), ConnectionError>;
fn setup(&self) -> &Setup;
fn generate_id(&self) -> Result<u32, ReplyOrIdError>;
}
#[derive(Debug, Copy, Clone, PartialEq, Eq)]
pub enum RequestKind {
IsVoid,
HasResponse,
}
pub fn compute_length_field<'b>(
conn: &impl RequestConnection,
request_buffers: &'b [IoSlice<'b>],
storage: &'b mut (Vec<IoSlice<'b>>, [u8; 8]),
) -> Result<&'b [IoSlice<'b>], ConnectionError> {
let length: usize = request_buffers.iter().map(|buf| buf.len()).sum();
assert_eq!(
length % 4,
0,
"The length of X11 requests must be a multiple of 4, got {length}"
);
let wire_length = length / 4;
let first_buf = &request_buffers[0];
if let Ok(wire_length) = u16::try_from(wire_length) {
let length_field = u16::from_ne_bytes([first_buf[2], first_buf[3]]);
assert_eq!(
wire_length, length_field,
"Length field contains incorrect value"
);
return Ok(request_buffers);
}
if length > conn.maximum_request_bytes() {
return Err(ConnectionError::MaximumRequestLengthExceeded);
}
let wire_length: u32 = wire_length
.checked_add(1)
.ok_or(ConnectionError::MaximumRequestLengthExceeded)?
.try_into()
.expect("X11 request larger than 2^34 bytes?!?");
let wire_length = wire_length.to_ne_bytes();
storage.1.copy_from_slice(&[
first_buf[0],
first_buf[1],
0,
0,
wire_length[0],
wire_length[1],
wire_length[2],
wire_length[3],
]);
storage.0.push(IoSlice::new(&storage.1));
storage.0.push(IoSlice::new(&first_buf[4..]));
storage.0.extend(
request_buffers[1..]
.iter()
.map(std::ops::Deref::deref)
.map(IoSlice::new),
);
Ok(&storage.0[..])
}