1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
extern crate libc;
#[macro_use]
extern crate lazy_static;
#[macro_use]
extern crate libnss;

extern crate debug_print;

use libnss::host::{AddressFamily, Addresses, Host, HostHooks};
use libnss::interop::Response;

use rusqlite::{params, Connection, Error, OpenFlags, Result};

use debug_print::debug_eprintln;

use std::net::{IpAddr, Ipv4Addr};

static DBPATH: &'static str = "/var/lib/wiregarden/db";

fn dbflags() -> OpenFlags {
    OpenFlags::SQLITE_OPEN_READ_ONLY
        | OpenFlags::SQLITE_OPEN_NOFOLLOW
        | OpenFlags::SQLITE_OPEN_PRIVATE_CACHE
        | OpenFlags::SQLITE_OPEN_FULL_MUTEX
}

struct WiregardenHost;
libnss_host_hooks!(wiregarden, WiregardenHost);

impl HostHooks for WiregardenHost {
    fn get_all_entries() -> Response<Vec<Host>> {
        match get_all_entries() {
            Ok(hosts) => {
                if hosts.is_empty() {
                    Response::NotFound
                } else {
                    Response::Success(hosts)
                }
            }
            Err(Error::QueryReturnedNoRows) => Response::NotFound,
            Err(_e) => {
                debug_eprintln!("get_all_entries failed: {}", _e);
                Response::Unavail
            }
        }
    }

    fn get_host_by_addr(addr: IpAddr) -> Response<Host> {
        match get_host_by_addr(addr) {
            Ok(Some(host)) => Response::Success(host),
            Ok(None) => Response::NotFound,
            Err(Error::QueryReturnedNoRows) => Response::NotFound,
            Err(_e) => {
                debug_eprintln!("get_host_by_addr {} failed: {}", addr, _e);
                Response::Unavail
            }
        }
    }

    fn get_host_by_name(name: &str, family: AddressFamily) -> Response<Host> {
        if family != AddressFamily::IPv4 {
            Response::NotFound
        } else {
            match get_host_by_name(name) {
                Ok(Some(host)) => Response::Success(host),
                Ok(None) => Response::NotFound,
                Err(Error::QueryReturnedNoRows) => Response::NotFound,
                Err(_e) => {
                    debug_eprintln!("get_host_by_name {} failed: {}", name, _e);
                    Response::Unavail
                }
            }
        }
    }
}

fn get_all_entries() -> Result<Vec<Host>> {
    let mut hosts = vec![];
    let db = Connection::open_with_flags(&DBPATH, dbflags())?;
    let mut stmt = db.prepare(
        "
select device_name, net_name, device_addr
from iface
union
select p.device_name, i.net_name, p.device_addr
from iface i join peer p on (i.id = p.iface_id",
    )?;
    stmt.query_map(params![], |row| {
        let device_name: String = row.get(0)?;
        let net_name: String = row.get(1)?;
        let device_addr_s: String = row.get(2)?;
        let device_addr: std::result::Result<Ipv4Addr, _> = trim_subnet(&device_addr_s).parse();
        match device_addr {
            Ok(v4addr) => {
                hosts.push(Host {
                    name: format!("{}.{}", device_name, net_name),
                    addresses: Addresses::V4(vec![v4addr]),
                    aliases: vec![],
                });
            }
            _ => {}
        };
        Ok(())
    })?;
    Ok(hosts)
}

fn get_host_by_addr(addr: IpAddr) -> Result<Option<Host>> {
    match addr {
        IpAddr::V4(v4addr) => {
            let db = Connection::open_with_flags(&DBPATH, dbflags())?;
            let addr_s = format!("{}", addr);
            let mut stmt = db.prepare(
                "
select device_name, net_name
from iface
where device_addr like ? || '/%'
union
select p.device_name, i.net_name
from iface i join peer p on (i.id = p.iface_id)
where p.device_addr like ? || '/%'",
            )?;
            stmt.query_row(params![addr_s, addr_s], |row| {
                let device_name: String = row.get(0)?;
                let net_name: String = row.get(1)?;
                Ok(Some(Host {
                    name: format!("{}.{}", device_name, net_name),
                    addresses: Addresses::V4(vec![v4addr]),
                    aliases: vec![],
                }))
            })
        }
        _ => Ok(None),
    }
}

fn get_host_by_name(name: &str) -> Result<Option<Host>> {
    let db = Connection::open_with_flags(&DBPATH, dbflags())?;
    let mut stmt = db.prepare(
        "
select device_name, net_name, device_addr
from iface
where device_name || '.' || net_name = ?
union
select p.device_name, i.net_name, p.device_addr
from iface i join peer p on (i.id = p.iface_id)
where p.device_name || '.' || i.net_name = ?",
    )?;
    stmt.query_row(params![name, name], |row| {
        let device_name: String = row.get(0)?;
        let net_name: String = row.get(1)?;
        let device_addr_s: String = row.get(2)?;
        let device_addr: std::result::Result<Ipv4Addr, _> = trim_subnet(&device_addr_s).parse();
        match device_addr {
            Ok(v4addr) => Ok(Some(Host {
                name: format!("{}.{}", device_name, net_name),
                addresses: Addresses::V4(vec![v4addr]),
                aliases: vec![],
            })),
            _ => Ok(None),
        }
    })
}

fn trim_subnet(s: &str) -> &str {
    match s.rfind("/") {
        Some(ri) => &s[0..ri],
        None => s,
    }
}