1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
use super::{link_message, WireGuardDeviceLinkOperation};
use crate::attr::WgDeviceAttribute;
use crate::cmd::WgCmd;
use crate::consts::{WG_GENL_NAME, WG_GENL_VERSION};
use crate::err::{ConnectError, GetDeviceError, LinkDeviceError, SetDeviceError};
use crate::get;
use crate::set;
use crate::set::create_set_device_messages;
use crate::socket::parse::*;
use crate::socket::NlWgMsgType;
use crate::DeviceInterface;
use libc::IFNAMSIZ;
use neli::consts::{NlFamily, NlmF, Nlmsg};
use neli::genl::Genlmsghdr;
use neli::nl::Nlmsghdr;
use neli::nlattr::Nlattr;
use neli::socket::NlSocket;
use neli::Nl;
use neli::StreamWriteBuffer;

pub struct Socket {
    sock: NlSocket,
    family_id: NlWgMsgType,
}

impl Socket {
    pub fn connect() -> Result<Self, ConnectError> {
        let family_id = {
            NlSocket::new(NlFamily::Generic, true)?
                .resolve_genl_family(WG_GENL_NAME)
                .map_err(ConnectError::ResolveFamilyError)?
        };

        let track_seq = true;
        let mut wgsock = NlSocket::new(NlFamily::Generic, track_seq)?;

        // Autoselect a PID
        let pid = None;
        let groups = None;
        wgsock.bind(pid, groups)?;

        Ok(Self {
            sock: wgsock,
            family_id,
        })
    }

    pub fn get_device(
        &mut self,
        interface: DeviceInterface,
    ) -> Result<get::Device, GetDeviceError> {
        let mut mem = StreamWriteBuffer::new_growable(None);
        let attr = match interface {
            DeviceInterface::Name(name) => {
                Some(name.len())
                    .filter(|&len| 0 < len && len < IFNAMSIZ)
                    .ok_or_else(|| GetDeviceError::InvalidInterfaceName)?;
                name.as_ref().serialize(&mut mem)?;
                Nlattr::new(None, WgDeviceAttribute::Ifname, mem.as_ref())?
            }
            DeviceInterface::Index(index) => {
                index.serialize(&mut mem)?;
                Nlattr::new(None, WgDeviceAttribute::Ifindex, mem.as_ref())?
            }
        };
        let genlhdr = {
            let cmd = WgCmd::GetDevice;
            let version = WG_GENL_VERSION;
            let attrs = vec![attr];
            Genlmsghdr::new(cmd, version, attrs)?
        };
        let nlhdr = {
            let size = None;
            let nl_type = self.family_id;
            let flags = vec![NlmF::Request, NlmF::Ack, NlmF::Dump];
            let seq = None;
            let pid = None;
            let payload = genlhdr;
            Nlmsghdr::new(size, nl_type, flags, seq, pid, payload)
        };

        self.sock.send_nl(nlhdr)?;

        // In the future, neli will return multiple Netlink messages. We have to go through each
        // message and coalesce peers in the way described by the WireGuard UAPI when this change
        // happens. For now, parsing is broken if the entire response doesn't fit in a single
        // payload.
        //
        // See: https://github.com/jbaublitz/neli/issues/15

        let mut iter = self
            .sock
            .iter::<Nlmsg, Genlmsghdr<WgCmd, WgDeviceAttribute>>();

        let mut device = None;
        while let Some(Ok(response)) = iter.next() {
            match response.nl_type {
                Nlmsg::Error => return Err(GetDeviceError::AccessError),
                Nlmsg::Done => break,
                _ => (),
            };

            let handle = response.nl_payload.get_attr_handle();
            device = Some(match device {
                Some(device) => extend_device(device, handle)?,
                None => parse_device(handle)?,
            });
        }

        device.ok_or(GetDeviceError::AccessError)
    }

    pub fn set_device(&mut self, device: set::Device) -> Result<(), SetDeviceError> {
        for nl_message in create_set_device_messages(device, self.family_id)? {
            self.sock.send_nl(nl_message)?;
            self.sock.recv_ack()?;
        }

        Ok(())
    }

    pub fn add_device(&self, ifname: &str) -> Result<(), LinkDeviceError> {
        let mut sock = NlSocket::connect(NlFamily::Route, None, None, true)?;
        sock.send_nl(link_message(ifname, WireGuardDeviceLinkOperation::Add)?)?;
        sock.recv_ack()?;
        Ok(())
    }

    pub fn del_device(&self, ifname: &str) -> Result<(), LinkDeviceError> {
        let mut sock = NlSocket::connect(NlFamily::Route, None, None, true)?;
        sock.send_nl(link_message(ifname, WireGuardDeviceLinkOperation::Delete)?)?;
        sock.recv_ack()?;
        Ok(())
    }
}