vsock 0.5.0

Virtio socket support for Rust
Documentation
/*
 * Copyright 2019 fsyncd, Berlin, Germany.
 * Additional material Copyright the Rust project and it's contributors.
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

//! Virtio socket support for Rust.

use libc::{
    accept4, ioctl, sa_family_t, sockaddr, sockaddr_vm, socklen_t, suseconds_t, timeval, AF_VSOCK,
    FIONBIO, SOCK_CLOEXEC,
};
use nix::{
    ioctl_read_bad,
    sys::socket::{
        self, bind, connect, getpeername, getsockname, listen, recv, send, shutdown, socket,
        sockopt::{ReceiveTimeout, SendTimeout, SocketError},
        AddressFamily, Backlog, GetSockOpt, MsgFlags, SetSockOpt, SockFlag, SockType,
    },
};
use std::mem::size_of;
use std::net::Shutdown;
use std::os::unix::io::{AsRawFd, FromRawFd, IntoRawFd, RawFd};
use std::time::Duration;
use std::{fs::File, os::fd::OwnedFd};
use std::{
    io::{Error, ErrorKind, Read, Result, Write},
    os::fd::{AsFd, BorrowedFd},
};

pub use libc::{VMADDR_CID_ANY, VMADDR_CID_HOST, VMADDR_CID_HYPERVISOR, VMADDR_CID_LOCAL};
pub use nix::sys::socket::{SockaddrLike, VsockAddr};

fn new_socket() -> Result<OwnedFd> {
    Ok(socket(
        AddressFamily::Vsock,
        SockType::Stream,
        SockFlag::SOCK_CLOEXEC,
        None,
    )?)
}

/// An iterator that infinitely accepts connections on a VsockListener.
#[derive(Debug)]
pub struct Incoming<'a> {
    listener: &'a VsockListener,
}

impl<'a> Iterator for Incoming<'a> {
    type Item = Result<VsockStream>;

    fn next(&mut self) -> Option<Result<VsockStream>> {
        Some(self.listener.accept().map(|p| p.0))
    }
}

/// A virtio socket server, listening for connections.
#[derive(Debug)]
pub struct VsockListener {
    socket: OwnedFd,
}

impl VsockListener {
    /// Create a new VsockListener which is bound and listening on the socket address.
    pub fn bind(addr: &impl SockaddrLike) -> Result<Self> {
        if addr.family() != Some(AddressFamily::Vsock) {
            return Err(Error::new(
                ErrorKind::Other,
                "requires a virtio socket address",
            ));
        }

        let socket = new_socket()?;

        bind(socket.as_raw_fd(), addr)?;

        // rust stdlib uses a 128 connection backlog
        listen(&socket, Backlog::new(128).unwrap_or(Backlog::MAXCONN))?;

        Ok(Self { socket })
    }

    /// Create a new VsockListener with specified cid and port.
    pub fn bind_with_cid_port(cid: u32, port: u32) -> Result<VsockListener> {
        Self::bind(&VsockAddr::new(cid, port))
    }

    /// The local socket address of the listener.
    pub fn local_addr(&self) -> Result<VsockAddr> {
        Ok(getsockname(self.socket.as_raw_fd())?)
    }

    /// Create a new independently owned handle to the underlying socket.
    pub fn try_clone(&self) -> Result<Self> {
        Ok(Self {
            socket: self.socket.try_clone()?,
        })
    }

    /// Accept a new incoming connection from this listener.
    pub fn accept(&self) -> Result<(VsockStream, VsockAddr)> {
        let mut vsock_addr = sockaddr_vm {
            svm_family: AF_VSOCK as sa_family_t,
            svm_reserved1: 0,
            svm_port: 0,
            svm_cid: 0,
            svm_zero: [0u8; 4],
        };
        let mut vsock_addr_len = size_of::<sockaddr_vm>() as socklen_t;
        let socket = unsafe {
            accept4(
                self.socket.as_raw_fd(),
                &mut vsock_addr as *mut _ as *mut sockaddr,
                &mut vsock_addr_len,
                SOCK_CLOEXEC,
            )
        };
        if socket < 0 {
            Err(Error::last_os_error())
        } else {
            Ok((
                unsafe { VsockStream::from_raw_fd(socket as RawFd) },
                VsockAddr::new(vsock_addr.svm_cid, vsock_addr.svm_port),
            ))
        }
    }

    /// An iterator over the connections being received on this listener.
    pub fn incoming(&self) -> Incoming {
        Incoming { listener: self }
    }

    /// Retrieve the latest error associated with the underlying socket.
    pub fn take_error(&self) -> Result<Option<Error>> {
        let error = SocketError.get(&self.socket)?;
        Ok(if error == 0 {
            None
        } else {
            Some(Error::from_raw_os_error(error))
        })
    }

    /// Move this stream in and out of nonblocking mode.
    pub fn set_nonblocking(&self, nonblocking: bool) -> Result<()> {
        let mut nonblocking: i32 = if nonblocking { 1 } else { 0 };
        if unsafe { ioctl(self.socket.as_raw_fd(), FIONBIO, &mut nonblocking) } < 0 {
            Err(Error::last_os_error())
        } else {
            Ok(())
        }
    }
}

impl AsRawFd for VsockListener {
    fn as_raw_fd(&self) -> RawFd {
        self.socket.as_raw_fd()
    }
}

impl AsFd for VsockListener {
    fn as_fd(&self) -> BorrowedFd {
        self.socket.as_fd()
    }
}

impl FromRawFd for VsockListener {
    unsafe fn from_raw_fd(socket: RawFd) -> Self {
        Self {
            socket: OwnedFd::from_raw_fd(socket),
        }
    }
}

impl IntoRawFd for VsockListener {
    fn into_raw_fd(self) -> RawFd {
        self.socket.into_raw_fd()
    }
}

/// A virtio stream between a local and a remote socket.
#[derive(Debug)]
pub struct VsockStream {
    socket: OwnedFd,
}

impl VsockStream {
    /// Open a connection to a remote host.
    pub fn connect(addr: &impl SockaddrLike) -> Result<Self> {
        if addr.family() != Some(AddressFamily::Vsock) {
            return Err(Error::new(
                ErrorKind::Other,
                "requires a virtio socket address",
            ));
        }

        let socket = new_socket()?;
        connect(socket.as_raw_fd(), addr)?;
        Ok(Self { socket })
    }

    /// Open a connection to a remote host with specified cid and port.
    pub fn connect_with_cid_port(cid: u32, port: u32) -> Result<Self> {
        Self::connect(&VsockAddr::new(cid, port))
    }

    /// Virtio socket address of the remote peer associated with this connection.
    pub fn peer_addr(&self) -> Result<VsockAddr> {
        Ok(getpeername(self.socket.as_raw_fd())?)
    }

    /// Virtio socket address of the local address associated with this connection.
    pub fn local_addr(&self) -> Result<VsockAddr> {
        Ok(getsockname(self.socket.as_raw_fd())?)
    }

    /// Shutdown the read, write, or both halves of this connection.
    pub fn shutdown(&self, how: Shutdown) -> Result<()> {
        let how = match how {
            Shutdown::Write => socket::Shutdown::Write,
            Shutdown::Read => socket::Shutdown::Read,
            Shutdown::Both => socket::Shutdown::Both,
        };
        Ok(shutdown(self.socket.as_raw_fd(), how)?)
    }

    /// Create a new independently owned handle to the underlying socket.
    pub fn try_clone(&self) -> Result<Self> {
        Ok(Self {
            socket: self.socket.try_clone()?,
        })
    }

    /// Set the timeout on read operations.
    pub fn set_read_timeout(&self, dur: Option<Duration>) -> Result<()> {
        let timeout = Self::timeval_from_duration(dur)?.into();
        Ok(ReceiveTimeout.set(&self.socket, &timeout)?)
    }

    /// Set the timeout on write operations.
    pub fn set_write_timeout(&self, dur: Option<Duration>) -> Result<()> {
        let timeout = Self::timeval_from_duration(dur)?.into();
        Ok(SendTimeout.set(&self.socket, &timeout)?)
    }

    /// Retrieve the latest error associated with the underlying socket.
    pub fn take_error(&self) -> Result<Option<Error>> {
        let error = SocketError.get(&self.socket)?;
        Ok(if error == 0 {
            None
        } else {
            Some(Error::from_raw_os_error(error))
        })
    }

    /// Move this stream in and out of nonblocking mode.
    pub fn set_nonblocking(&self, nonblocking: bool) -> Result<()> {
        let mut nonblocking: i32 = if nonblocking { 1 } else { 0 };
        if unsafe { ioctl(self.socket.as_raw_fd(), FIONBIO, &mut nonblocking) } < 0 {
            Err(Error::last_os_error())
        } else {
            Ok(())
        }
    }

    fn timeval_from_duration(dur: Option<Duration>) -> Result<timeval> {
        match dur {
            Some(dur) => {
                if dur.as_secs() == 0 && dur.subsec_nanos() == 0 {
                    return Err(Error::new(
                        ErrorKind::InvalidInput,
                        "cannot set a zero duration timeout",
                    ));
                }

                // https://github.com/rust-lang/libc/issues/1848
                #[cfg_attr(target_env = "musl", allow(deprecated))]
                let secs = if dur.as_secs() > libc::time_t::max_value() as u64 {
                    libc::time_t::max_value()
                } else {
                    dur.as_secs() as libc::time_t
                };
                let mut timeout = timeval {
                    tv_sec: secs,
                    tv_usec: i64::from(dur.subsec_micros()) as suseconds_t,
                };
                if timeout.tv_sec == 0 && timeout.tv_usec == 0 {
                    timeout.tv_usec = 1;
                }
                Ok(timeout)
            }
            None => Ok(timeval {
                tv_sec: 0,
                tv_usec: 0,
            }),
        }
    }
}

impl Read for VsockStream {
    fn read(&mut self, buf: &mut [u8]) -> Result<usize> {
        <&Self>::read(&mut &*self, buf)
    }
}

impl Write for VsockStream {
    fn write(&mut self, buf: &[u8]) -> Result<usize> {
        <&Self>::write(&mut &*self, buf)
    }

    fn flush(&mut self) -> Result<()> {
        Ok(())
    }
}

impl Read for &VsockStream {
    fn read(&mut self, buf: &mut [u8]) -> Result<usize> {
        Ok(recv(self.socket.as_raw_fd(), buf, MsgFlags::empty())?)
    }
}

impl Write for &VsockStream {
    fn write(&mut self, buf: &[u8]) -> Result<usize> {
        Ok(send(self.socket.as_raw_fd(), buf, MsgFlags::MSG_NOSIGNAL)?)
    }

    fn flush(&mut self) -> Result<()> {
        Ok(())
    }
}

impl AsRawFd for VsockStream {
    fn as_raw_fd(&self) -> RawFd {
        self.socket.as_raw_fd()
    }
}

impl AsFd for VsockStream {
    fn as_fd(&self) -> BorrowedFd {
        self.socket.as_fd()
    }
}

impl FromRawFd for VsockStream {
    unsafe fn from_raw_fd(socket: RawFd) -> Self {
        Self {
            socket: OwnedFd::from_raw_fd(socket),
        }
    }
}

impl IntoRawFd for VsockStream {
    fn into_raw_fd(self) -> RawFd {
        self.socket.into_raw_fd()
    }
}

const IOCTL_VM_SOCKETS_GET_LOCAL_CID: usize = 0x7b9;
ioctl_read_bad!(
    vm_sockets_get_local_cid,
    IOCTL_VM_SOCKETS_GET_LOCAL_CID,
    u32
);

/// Gets the CID of the local machine.
///
/// Note that when calling [`VsockListener::bind`], you should generally use [`VMADDR_CID_ANY`]
/// instead, and for making a loopback connection you should use [`VMADDR_CID_LOCAL`].
pub fn get_local_cid() -> Result<u32> {
    let f = File::open("/dev/vsock")?;
    let mut cid = 0;
    // SAFETY: the kernel only modifies the given u32 integer.
    unsafe { vm_sockets_get_local_cid(f.as_raw_fd(), &mut cid) }?;
    Ok(cid)
}