ra-multiplex 0.2.6

share one language server instance between multiple LSP clients to save resources
Documentation
#[cfg(target_family = "unix")]
use std::fs;
use std::pin::Pin;
use std::task::{Context, Poll};
use std::{io, net};

use anyhow::{Context as _, Result};
use pin_project_lite::pin_project;
use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
use tokio::net::{tcp, TcpListener, TcpStream};
#[cfg(target_family = "unix")]
use tokio::net::{unix, UnixListener, UnixStream};

use crate::config::Address;

pub enum SocketAddr {
    Ip(#[allow(dead_code)] net::SocketAddr),
    #[cfg(target_family = "unix")]
    Unix(#[allow(dead_code)] tokio::net::unix::SocketAddr),
}

impl From<net::SocketAddr> for SocketAddr {
    fn from(val: net::SocketAddr) -> Self {
        SocketAddr::Ip(val)
    }
}

#[cfg(target_family = "unix")]
impl From<tokio::net::unix::SocketAddr> for SocketAddr {
    fn from(val: tokio::net::unix::SocketAddr) -> Self {
        SocketAddr::Unix(val)
    }
}

#[cfg(target_family = "unix")]
pin_project! {
    #[project = OwnedReadHalfProj]
    pub enum OwnedReadHalf {
        Tcp{#[pin] tcp: tcp::OwnedReadHalf},
        Unix{#[pin] unix: unix::OwnedReadHalf},
    }
}
#[cfg(not(target_family = "unix"))]
pin_project! {
    #[project = OwnedReadHalfProj]
    pub enum OwnedReadHalf {
        Tcp{#[pin] tcp: tcp::OwnedReadHalf},
    }
}

impl AsyncRead for OwnedReadHalf {
    fn poll_read(
        self: Pin<&mut Self>,
        cx: &mut Context<'_>,
        buf: &mut ReadBuf<'_>,
    ) -> Poll<io::Result<()>> {
        match self.project() {
            OwnedReadHalfProj::Tcp { tcp } => tcp.poll_read(cx, buf),
            #[cfg(target_family = "unix")]
            OwnedReadHalfProj::Unix { unix } => unix.poll_read(cx, buf),
        }
    }
}

#[cfg(target_family = "unix")]
pin_project! {
    #[project = OwnedWriteHalfProj]
    pub enum OwnedWriteHalf {
        Tcp{#[pin] tcp: tcp::OwnedWriteHalf},
        Unix{#[pin] unix: unix::OwnedWriteHalf},
    }
}
#[cfg(not(target_family = "unix"))]
pin_project! {
    #[project = OwnedWriteHalfProj]
    pub enum OwnedWriteHalf {
        Tcp{#[pin] tcp: tcp::OwnedWriteHalf},
    }
}

impl AsyncWrite for OwnedWriteHalf {
    fn poll_write(
        self: Pin<&mut Self>,
        cx: &mut Context<'_>,
        buf: &[u8],
    ) -> Poll<Result<usize, io::Error>> {
        match self.project() {
            OwnedWriteHalfProj::Tcp { tcp } => tcp.poll_write(cx, buf),
            #[cfg(target_family = "unix")]
            OwnedWriteHalfProj::Unix { unix } => unix.poll_write(cx, buf),
        }
    }

    fn poll_write_vectored(
        self: Pin<&mut Self>,
        cx: &mut Context<'_>,
        bufs: &[io::IoSlice<'_>],
    ) -> Poll<Result<usize, io::Error>> {
        match self.project() {
            OwnedWriteHalfProj::Tcp { tcp } => tcp.poll_write_vectored(cx, bufs),
            #[cfg(target_family = "unix")]
            OwnedWriteHalfProj::Unix { unix } => unix.poll_write_vectored(cx, bufs),
        }
    }

    fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), io::Error>> {
        match self.project() {
            OwnedWriteHalfProj::Tcp { tcp } => tcp.poll_flush(cx),
            #[cfg(target_family = "unix")]
            OwnedWriteHalfProj::Unix { unix } => unix.poll_flush(cx),
        }
    }

    fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), io::Error>> {
        match self.project() {
            OwnedWriteHalfProj::Tcp { tcp } => tcp.poll_shutdown(cx),
            #[cfg(target_family = "unix")]
            OwnedWriteHalfProj::Unix { unix } => unix.poll_shutdown(cx),
        }
    }
}

#[cfg(target_family = "unix")]
pin_project! {
    #[project = StreamProj]
    pub enum Stream {
        Tcp{#[pin] tcp: TcpStream},
        Unix{#[pin] unix: UnixStream},
    }
}
#[cfg(not(target_family = "unix"))]
pin_project! {
    #[project = StreamProj]
    pub enum Stream {
        Tcp{#[pin] tcp: TcpStream},
    }
}

impl Stream {
    pub async fn connect(addr: &Address) -> Result<Stream> {
        match addr {
            Address::Tcp(ip_addr, port) => TcpStream::connect((*ip_addr, *port))
                .await
                .with_context(|| format!("connecting to tcp socket {ip_addr}:{port}"))
                .map(|tcp| Stream::Tcp { tcp }),
            #[cfg(target_family = "unix")]
            Address::Unix(path) => UnixStream::connect(path)
                .await
                .with_context(|| format!("connecting to unix socket {path:?}"))
                .map(|unix| Stream::Unix { unix }),
        }
    }

    pub fn into_split(self) -> (OwnedReadHalf, OwnedWriteHalf) {
        match self {
            Stream::Tcp { tcp } => {
                let (read, write) = tcp.into_split();
                (
                    OwnedReadHalf::Tcp { tcp: read },
                    OwnedWriteHalf::Tcp { tcp: write },
                )
            }
            #[cfg(target_family = "unix")]
            Stream::Unix { unix } => {
                let (read, write) = unix.into_split();
                (
                    OwnedReadHalf::Unix { unix: read },
                    OwnedWriteHalf::Unix { unix: write },
                )
            }
        }
    }
}

impl AsyncRead for Stream {
    fn poll_read(
        self: Pin<&mut Self>,
        cx: &mut Context<'_>,
        buf: &mut ReadBuf<'_>,
    ) -> Poll<io::Result<()>> {
        match self.project() {
            StreamProj::Tcp { tcp } => tcp.poll_read(cx, buf),
            #[cfg(target_family = "unix")]
            StreamProj::Unix { unix } => unix.poll_read(cx, buf),
        }
    }
}

impl AsyncWrite for Stream {
    fn poll_write(
        self: Pin<&mut Self>,
        cx: &mut Context<'_>,
        buf: &[u8],
    ) -> Poll<Result<usize, io::Error>> {
        match self.project() {
            StreamProj::Tcp { tcp } => tcp.poll_write(cx, buf),
            #[cfg(target_family = "unix")]
            StreamProj::Unix { unix } => unix.poll_write(cx, buf),
        }
    }

    fn poll_write_vectored(
        self: Pin<&mut Self>,
        cx: &mut Context<'_>,
        bufs: &[io::IoSlice<'_>],
    ) -> Poll<Result<usize, io::Error>> {
        match self.project() {
            StreamProj::Tcp { tcp } => tcp.poll_write_vectored(cx, bufs),
            #[cfg(target_family = "unix")]
            StreamProj::Unix { unix } => unix.poll_write_vectored(cx, bufs),
        }
    }

    fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), io::Error>> {
        match self.project() {
            StreamProj::Tcp { tcp } => tcp.poll_flush(cx),
            #[cfg(target_family = "unix")]
            StreamProj::Unix { unix } => unix.poll_flush(cx),
        }
    }

    fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), io::Error>> {
        match self.project() {
            StreamProj::Tcp { tcp } => tcp.poll_shutdown(cx),
            #[cfg(target_family = "unix")]
            StreamProj::Unix { unix } => unix.poll_shutdown(cx),
        }
    }
}

pub enum Listener {
    Tcp(TcpListener),
    #[cfg(target_family = "unix")]
    Unix(UnixListener),
}

impl Listener {
    pub async fn bind(addr: &Address) -> Result<Listener> {
        match addr {
            Address::Tcp(ip_addr, port) => TcpListener::bind((*ip_addr, *port))
                .await
                .with_context(|| format!("binding to tcp socket {ip_addr}:{port}"))
                .map(Listener::Tcp),
            #[cfg(target_family = "unix")]
            Address::Unix(path) => {
                match fs::remove_file(path) {
                    Ok(()) => (),
                    Err(e) if e.kind() == io::ErrorKind::NotFound => (),
                    Err(e) => {
                        return Err(e)
                            .with_context(|| format!("removing old unix socket file {path:?}"))
                    }
                }
                UnixListener::bind(path)
                    .with_context(|| format!("binding to unix socket {path:?}"))
                    .map(Listener::Unix)
            }
        }
    }

    pub async fn accept(&self) -> io::Result<(Stream, SocketAddr)> {
        match self {
            Listener::Tcp(tcp) => {
                let (stream, addr) = tcp.accept().await?;
                Ok((Stream::Tcp { tcp: stream }, addr.into()))
            }
            #[cfg(target_family = "unix")]
            Listener::Unix(unix) => {
                let (stream, addr) = unix.accept().await?;
                Ok((Stream::Unix { unix: stream }, addr.into()))
            }
        }
    }
}