#![cfg(feature = "std")]
use super::Connection;
use crate::Fd;
use alloc::{borrow::Cow, format, string::String, vec::Vec};
use core::mem;
use memchr::memrchr;
use std::{env, net, path::Path};
#[cfg(feature = "async")]
use super::AsyncConnection;
#[cfg(feature = "async")]
use core::task::{Context, Poll};
#[cfg(all(feature = "async", not(feature = "tokio-support")))]
use async_io::Async;
#[cfg(all(feature = "async", not(feature = "tokio-support")))]
use std::net::ToSocketAddrs;
#[cfg(all(feature = "async", feature = "tokio-support"))]
use spinning_top::Spinlock;
#[cfg(all(feature = "async", feature = "tokio-support"))]
use tokio::net::TcpStream as TokioTcpStream;
#[cfg(all(feature = "async", feature = "tokio-support", unix))]
use tokio::net::UnixStream as TokioUnixStream;
#[cfg(test)]
use std::borrow::ToOwned;
#[cfg(unix)]
use std::os::unix::net as unet;
pub enum NameConnection {
#[doc(hidden)]
Tcp(net::TcpStream),
#[cfg(unix)]
#[doc(hidden)]
Socket(unet::UnixStream),
}
impl Connection for NameConnection {
#[inline]
fn send_packet(&mut self, bytes: &[u8], fds: &mut Vec<Fd>) -> crate::Result {
match self {
NameConnection::Tcp(t) => t.send_packet(bytes, fds),
#[cfg(unix)]
NameConnection::Socket(s) => s.send_packet(bytes, fds),
}
}
#[inline]
fn read_packet(&mut self, bytes: &mut [u8], fds: &mut Vec<Fd>) -> crate::Result {
match self {
NameConnection::Tcp(t) => t.read_packet(bytes, fds),
#[cfg(unix)]
NameConnection::Socket(s) => s.read_packet(bytes, fds),
}
}
}
impl<'a> Connection for &'a NameConnection {
#[inline]
fn send_packet(&mut self, bytes: &[u8], fds: &mut Vec<Fd>) -> crate::Result {
match self {
NameConnection::Tcp(ref t) => {
let mut t = t;
t.send_packet(bytes, fds)
}
#[cfg(unix)]
NameConnection::Socket(ref s) => {
let mut s = s;
s.send_packet(bytes, fds)
}
}
}
#[inline]
fn read_packet(&mut self, bytes: &mut [u8], fds: &mut Vec<Fd>) -> crate::Result {
match self {
NameConnection::Tcp(ref t) => {
let mut t = t;
t.read_packet(bytes, fds)
}
#[cfg(unix)]
NameConnection::Socket(ref s) => {
let mut s = s;
s.read_packet(bytes, fds)
}
}
}
}
#[cfg(feature = "async")]
pub enum AsyncNameConnection {
#[cfg(not(feature = "tokio-support"))]
#[doc(hidden)]
Tcp(Async<net::TcpStream>),
#[cfg(all(not(feature = "tokio-support"), unix))]
#[doc(hidden)]
Socket(Async<unet::UnixStream>),
#[cfg(feature = "tokio-support")]
#[doc(hidden)]
Tcp(Spinlock<TokioTcpStream>),
#[cfg(all(feature = "tokio-support", unix))]
#[doc(hidden)]
Socket(Spinlock<TokioUnixStream>),
}
#[cfg(feature = "async")]
impl AsyncConnection for AsyncNameConnection {
#[inline]
fn poll_send_packet(
&mut self,
bytes: &[u8],
fds: &mut Vec<Fd>,
cx: &mut Context<'_>,
bytes_sent: &mut usize,
) -> Poll<crate::Result> {
match self {
#[cfg(not(feature = "tokio-support"))]
AsyncNameConnection::Tcp(t) => t.poll_send_packet(bytes, fds, cx, bytes_sent),
#[cfg(all(not(feature = "tokio-support"), unix))]
AsyncNameConnection::Socket(s) => s.poll_send_packet(bytes, fds, cx, bytes_sent),
#[cfg(feature = "tokio-support")]
AsyncNameConnection::Tcp(t) => t.get_mut().poll_send_packet(bytes, fds, cx, bytes_sent),
#[cfg(all(feature = "tokio-support", unix))]
AsyncNameConnection::Socket(s) => {
s.get_mut().poll_send_packet(bytes, fds, cx, bytes_sent)
}
}
}
#[inline]
fn poll_read_packet(
&mut self,
bytes: &mut [u8],
fds: &mut Vec<Fd>,
cx: &mut Context<'_>,
bytes_read: &mut usize,
) -> Poll<crate::Result> {
match self {
#[cfg(not(feature = "tokio-support"))]
AsyncNameConnection::Tcp(t) => t.poll_read_packet(bytes, fds, cx, bytes_read),
#[cfg(all(not(feature = "tokio-support"), unix))]
AsyncNameConnection::Socket(s) => s.poll_read_packet(bytes, fds, cx, bytes_read),
#[cfg(feature = "tokio-support")]
AsyncNameConnection::Tcp(t) => t.get_mut().poll_read_packet(bytes, fds, cx, bytes_read),
#[cfg(all(feature = "tokio-support", unix))]
AsyncNameConnection::Socket(s) => {
s.get_mut().poll_read_packet(bytes, fds, cx, bytes_read)
}
}
}
}
#[cfg(feature = "async")]
impl<'a> AsyncConnection for &'a AsyncNameConnection {
#[inline]
fn poll_send_packet(
&mut self,
bytes: &[u8],
fds: &mut Vec<Fd>,
cx: &mut Context<'_>,
bytes_sent: &mut usize,
) -> Poll<crate::Result> {
match self {
#[cfg(not(feature = "tokio-support"))]
AsyncNameConnection::Tcp(ref t) => {
let mut t = t;
t.poll_send_packet(bytes, fds, cx, bytes_sent)
}
#[cfg(all(not(feature = "tokio-support"), unix))]
AsyncNameConnection::Socket(ref s) => {
let mut s = s;
s.poll_send_packet(bytes, fds, cx, bytes_sent)
}
#[cfg(feature = "tokio-support")]
AsyncNameConnection::Tcp(t) => t
.try_lock()
.expect("Tried to access tokio connection concurrently")
.poll_send_packet(bytes, fds, cx, bytes_sent),
#[cfg(all(feature = "tokio-support", unix))]
AsyncNameConnection::Socket(s) => s
.try_lock()
.expect("Tried to access tokio connection concurrently")
.poll_send_packet(bytes, fds, cx, bytes_sent),
}
}
#[inline]
fn poll_read_packet(
&mut self,
bytes: &mut [u8],
fds: &mut Vec<Fd>,
cx: &mut Context<'_>,
bytes_read: &mut usize,
) -> Poll<crate::Result> {
match self {
#[cfg(not(feature = "tokio-support"))]
AsyncNameConnection::Tcp(ref t) => {
let mut t = t;
t.poll_read_packet(bytes, fds, cx, bytes_read)
}
#[cfg(all(not(feature = "tokio-support"), unix))]
AsyncNameConnection::Socket(ref s) => {
let mut s = s;
s.poll_read_packet(bytes, fds, cx, bytes_read)
}
#[cfg(feature = "tokio-support")]
AsyncNameConnection::Tcp(t) => t
.try_lock()
.expect("Tried to access tokio connection concurrently")
.poll_send_packet(bytes, fds, cx, bytes_read),
#[cfg(all(feature = "tokio-support", unix))]
AsyncNameConnection::Socket(s) => s
.try_lock()
.expect("Tried to access tokio connection concurrently")
.poll_send_packet(bytes, fds, cx, bytes_read),
}
}
}
const X_TCP_PORT: u16 = 6000;
#[cfg(unix)]
const PART1: &str = "/tmp/.X11-unix/X";
#[derive(Debug, Copy, Clone, Eq, PartialEq)]
enum Protocol {
Unix,
Tcp,
Inet,
Inet6,
}
impl Protocol {
#[allow(clippy::unnecessary_wraps)]
#[inline]
fn from_str(s: String) -> Option<Self> {
let s = s.to_lowercase();
Some(match s.as_str() {
"unix" => Self::Unix,
"tcp" => Self::Tcp,
"inet" => Self::Inet,
"inet6" => Self::Inet6,
_ => {
#[cfg(debug_assertions)]
panic!("Unrecognized protocol: {}", s);
#[cfg(not(debug_assertions))]
return None;
}
})
}
}
#[derive(Debug, Clone, Eq, PartialEq)]
struct XConnection<'a> {
host: Option<Cow<'a, str>>,
protocol: Option<Protocol>,
display: u16,
screen: usize,
}
impl<'a> XConnection<'a> {
#[inline]
fn parse_from_socket(name: Cow<'a, str>) -> Result<Self, Cow<'a, str>> {
let (host, screen) = if Path::new(&*name).exists() {
(name, 0)
} else {
let rposn = match memrchr(b'.', name.as_bytes()) {
Some(rposn) => rposn,
None => return Err(name),
};
let screen = &name[rposn + 1..];
let screen: usize = match screen.parse() {
Ok(s) => s,
Err(_) => return Err(name),
};
if Path::new(&name[..rposn]).exists() {
(
match name {
Cow::Borrowed(s) => Cow::Borrowed(&s[..rposn]),
Cow::Owned(mut sr) => {
sr.truncate(rposn);
Cow::Owned(sr)
}
},
screen,
)
} else {
return Err(name);
}
};
Ok(XConnection {
host: Some(host),
protocol: Some(Protocol::Unix),
display: 0,
screen,
})
}
pub fn parse(name: Option<Cow<'a, str>>) -> crate::Result<XConnection> {
#[allow(unused_mut)]
let mut name = match name {
Some(name) => name,
None => Cow::Owned(
env::var("DISPLAY").map_err(|_| crate::BreadError::UnableToParseConnection)?,
),
};
#[cfg(not(test))]
let mut name = match Self::parse_from_socket(name) {
Ok(sock) => return Ok(sock),
Err(name) => name,
};
let protocol = match memrchr(b'/', name.as_bytes()) {
Some(posn) => {
let mut protocol = name.to_mut().split_off(posn + 1);
mem::swap(name.to_mut(), &mut protocol);
protocol.pop();
Protocol::from_str(protocol)
}
None => None,
};
let host = match memrchr(b':', name.as_bytes()) {
None => return Err(crate::BreadError::UnableToParseConnection),
Some(0) => None,
Some(brek) => Some(match name {
Cow::Borrowed(s) => Cow::Borrowed(&s[brek..]),
Cow::Owned(ref mut sr) => {
let mut part = sr.split_off(brek);
mem::swap(sr, &mut part);
Cow::Owned(part)
}
}),
};
let mut _dummy: String = String::new();
let mut display: String = String::with_capacity(2);
let mut screen: String = String::new();
let mut current_target: &mut String = &mut _dummy;
for c in name.chars() {
match c {
':' => {
current_target = &mut display;
}
'.' => {
current_target = &mut screen;
}
c => {
current_target.push(c);
}
}
}
let display: u16 = if display.is_empty() {
return Err(crate::BreadError::UnableToParseConnection);
} else {
display
.parse()
.map_err(|_| crate::BreadError::UnableToParseConnection)?
};
let screen = if screen.is_empty() {
0
} else {
screen
.parse()
.map_err(|_| crate::BreadError::UnableToParseConnection)?
};
Ok(XConnection {
host,
protocol,
display,
screen,
})
}
#[inline]
fn host_and_port(self) -> (Cow<'a, str>, u16) {
let XConnection { host, display, .. } = self;
let host = match host {
None => Cow::Borrowed("127.0.0.1"),
Some(host) => host,
};
let port = X_TCP_PORT + display;
(host, port)
}
#[inline]
fn open_tcp(self) -> crate::Result<NameConnection> {
let (host, port) = self.host_and_port();
let connection = net::TcpStream::connect((&*host, port))?;
Ok(NameConnection::Tcp(connection))
}
#[cfg(unix)]
#[inline]
fn socket_filename(self) -> crate::Result<Cow<'a, str>> {
let XConnection { host, .. } = self;
match host {
Some(host) => Ok(host),
None => Err(crate::BreadError::UnableToOpenConnection),
}
}
#[cfg(unix)]
#[inline]
fn open_unix(self) -> crate::Result<NameConnection> {
let fname = self.socket_filename()?;
Ok(NameConnection::Socket(unet::UnixStream::connect(&*fname)?))
}
#[allow(unused_mut)]
pub fn open(mut self) -> crate::Result<NameConnection> {
if self.protocol != Some(Protocol::Unix)
|| (self.host.is_none() || self.host.as_deref().unwrap() != "unix")
{
if let Ok(c) = self.clone().open_tcp() {
return Ok(c);
}
}
#[cfg(unix)]
{
if let Ok(u) = self.clone().open_unix() {
return Ok(u);
}
self.host = Some(Cow::Owned(format!("{}{}", PART1, self.display)));
self.open_unix()
}
#[cfg(not(unix))]
Err(crate::BreadError::UnableToOpenConnection)
}
#[cfg(feature = "async")]
#[inline]
async fn open_tcp_async(self) -> crate::Result<AsyncNameConnection> {
let (host, port) = self.host_and_port();
cfg_if::cfg_if! {
if #[cfg(feature = "tokio-support")] {
let conn = TokioTcpStream::connect(&(&*host, port)).await?;
Ok(AsyncNameConnection::Tcp(Spinlock::new(conn)))
} else {
let host = host.into_owned();
let mut last_error: Option<crate::BreadError> = None;
let addrs =
blocking::unblock(move || ToSocketAddrs::to_socket_addrs(&(&*host, port))).await?;
for addr in addrs {
let connection = match Async::<net::TcpStream>::connect(addr).await {
Ok(conn) => conn,
Err(e) => {
last_error = Some(e.into());
continue;
}
};
return Ok(AsyncNameConnection::Tcp(connection));
}
Err(last_error.unwrap_or(crate::BreadError::StaticMsg(
"Could not connect to any of the given addresses",
)))
}
}
}
#[cfg(all(feature = "async", unix))]
async fn open_unix_async(self) -> crate::Result<AsyncNameConnection> {
let fname = self.socket_filename()?;
cfg_if::cfg_if! {
if #[cfg(feature = "tokio-support")] {
let conn = TokioUnixStream::connect(&*fname).await?;
Ok(AsyncNameConnection::Socket(Spinlock::new(conn)))
} else {
let conn = Async::<unet::UnixStream>::connect(&*fname).await?;
Ok(AsyncNameConnection::Socket(conn))
}
}
}
#[allow(unused_mut)]
#[cfg(feature = "async")]
pub async fn open_async(mut self) -> crate::Result<AsyncNameConnection> {
if self.protocol != Some(Protocol::Unix)
|| (self.host.is_none() || self.host.as_deref().unwrap() != "unix")
{
if let Ok(c) = self.clone().open_tcp_async().await {
return Ok(c);
}
}
#[cfg(unix)]
{
if let Ok(u) = self.clone().open_unix_async().await {
return Ok(u);
}
self.host = Some(Cow::Owned(format!("{}{}", PART1, self.display)));
self.open_unix_async().await
}
#[cfg(not(unix))]
Err(crate::BreadError::UnableToOpenConnection)
}
}
impl NameConnection {
#[inline]
pub(crate) fn connect_internal(
name: Option<Cow<'_, str>>,
) -> crate::Result<(NameConnection, usize)> {
let connection = XConnection::parse(name)?;
let screen = connection.screen;
Ok((connection.open()?, screen))
}
}
#[cfg(feature = "async")]
impl AsyncNameConnection {
#[inline]
#[cfg(feature = "async")]
pub(crate) async fn connect_internal_async(
name: Option<Cow<'_, str>>,
) -> crate::Result<(AsyncNameConnection, usize)> {
let connection = XConnection::parse(name)?;
let screen = connection.screen;
Ok((connection.open_async().await?, screen))
}
}
#[cfg(test)]
macro_rules! borrowed_test {
($name: expr, $res: expr) => {{
let xconn = XConnection::parse(Some(Cow::Borrowed($name))).unwrap();
assert_eq!(xconn, ($res), "input: {}", $name);
let xconn = XConnection::parse(Some(Cow::Owned(($name).to_owned()))).unwrap();
assert_eq!(xconn, ($res), "input: {}", $name);
}};
}
#[test]
fn parse_basic_display() {
borrowed_test!(
":3",
XConnection {
host: None,
protocol: None,
screen: 0,
display: 3
}
);
}
#[test]
fn parse_display_and_screen() {
borrowed_test!(
":3.6",
XConnection {
host: None,
protocol: None,
screen: 6,
display: 3
}
);
}
#[test]
fn parse_display_screen_and_protocol() {
let xconn = XConnection::parse(Some(Cow::Borrowed("inet/:5"))).unwrap();
assert_eq!(
xconn,
XConnection {
host: None,
protocol: Some(Protocol::Inet),
screen: 0,
display: 5
}
);
for (protocol, res) in &[
("unix", Protocol::Unix),
("inet", Protocol::Inet),
("inet6", Protocol::Inet6),
("tcp", Protocol::Tcp),
] {
let xconn = XConnection::parse(Some(Cow::Owned(format!("{}/:9.2", protocol)))).unwrap();
assert_eq!(
xconn,
XConnection {
host: None,
protocol: Some(*res),
screen: 2,
display: 9
}
);
}
}
#[should_panic]
#[test]
fn parse_arbitrary() {
XConnection::parse(Some(Cow::Borrowed("arbitrary"))).unwrap();
}