#[cfg(feature = "jemalloc")]
use jemallocator;
#[cfg(feature = "jemalloc")]
#[global_allocator]
static ALLOC: jemallocator::Jemalloc = jemallocator::Jemalloc;
use bytes::Bytes;
use parking_lot::{Mutex, MutexGuard};
use std::ffi::{CStr, CString};
use std::io::{Error, ErrorKind, Result};
use std::mem;
use std::net::{IpAddr, SocketAddr};
use std::ptr;
use std::sync::Arc;
mod c {
#![allow(non_snake_case, non_camel_case_types, non_upper_case_globals)]
include!(concat!(env!("OUT_DIR"), "/bindings.rs"));
}
pub use self::c::lkr_state as State;
pub struct Context {
inner: Mutex<*mut c::lkr_context>,
}
unsafe impl Send for Context {}
unsafe impl Sync for Context {}
impl Context {
pub fn new() -> Arc<Self> {
unsafe {
Arc::new(Self {
inner: Mutex::new(c::lkr_context_new()),
})
}
}
pub fn with_cache(path: &str, max_bytes: usize) -> Result<Arc<Self>> {
unsafe {
let inner = c::lkr_context_new();
let path_c = CString::new(path).unwrap();
let cache_c = CStr::from_bytes_with_nul(b"cache\0").unwrap();
match c::lkr_cache_open(inner, path_c.as_ptr(), max_bytes) {
0 => {
c::lkr_module_load(inner, cache_c.as_ptr());
Ok(Arc::new(Self {
inner: Mutex::new(inner),
}))
}
_ => Err(Error::new(ErrorKind::Other, "failed to open cache")),
}
}
}
pub fn add_module(&self, name: &str) -> Result<()> {
let inner = self.locked();
let name_c = CString::new(name)?;
unsafe {
let res = c::lkr_module_load(*inner, name_c.as_ptr());
if res != 0 {
return Err(Error::new(ErrorKind::NotFound, "failed to load module"));
}
}
Ok(())
}
pub fn remove_module(&self, name: &str) -> Result<()> {
let inner = self.locked();
let name_c = CString::new(name)?;
unsafe {
let res = c::lkr_module_unload(*inner, name_c.as_ptr());
if res != 0 {
return Err(Error::new(ErrorKind::NotFound, "failed to unload module"));
}
}
Ok(())
}
pub fn add_root_hint(&self, addr: IpAddr) -> Result<()> {
let inner = self.locked();
let slice = match addr {
IpAddr::V4(ip) => ip.octets().to_vec(),
IpAddr::V6(ip) => ip.octets().to_vec(),
};
unsafe {
let res = c::lkr_root_hint(*inner, slice.as_ptr(), slice.len());
if res != 0 {
return Err(Error::new(
ErrorKind::InvalidInput,
"failed to add a root hint",
));
}
}
Ok(())
}
pub fn add_trust_anchor(&self, rdata: &[u8]) -> Result<()> {
let inner = self.locked();
unsafe {
let res = c::lkr_trust_anchor(*inner, rdata.as_ptr(), rdata.len());
if res != 0 {
return Err(Error::new(
ErrorKind::InvalidInput,
"failed to add trust anchor",
));
}
}
Ok(())
}
pub fn set_verbose(&self, val: bool) {
let inner = self.locked();
unsafe {
c::lkr_verbose(*inner, val);
}
}
fn locked(&self) -> MutexGuard<*mut c::lkr_context> {
self.inner.lock()
}
}
impl Drop for Context {
fn drop(&mut self) {
let inner = self.locked();
if !inner.is_null() {
unsafe {
c::lkr_context_free(*inner);
}
}
}
}
pub struct Request {
context: Arc<Context>,
inner: Mutex<*mut c::lkr_request>,
}
unsafe impl Send for Request {}
unsafe impl Sync for Request {}
impl Request {
pub fn new(context: Arc<Context>) -> Self {
unsafe {
let inner = c::lkr_request_new(*context.locked());
Self {
context,
inner: Mutex::new(inner),
}
}
}
pub fn consume(&self, msg: &[u8], from: SocketAddr) -> State {
let (_, inner) = self.locked();
let from = socket2::SockAddr::from(from);
unsafe { c::lkr_consume(*inner, from.as_ptr() as *const _, msg.as_ptr(), msg.len()) }
}
pub fn produce(&self) -> Option<(Bytes, Vec<SocketAddr>)> {
let mut msg = vec![0; 512];
let mut size = 0;
let mut addresses = Vec::new();
let mut sa_vec: Vec<*mut c::sockaddr> = vec![ptr::null_mut(); 4];
let (_, inner) = self.locked();
unsafe {
let state = {
let buf = &mut msg;
let sa_slice = &mut sa_vec;
let mut state = State::PRODUCE;
let mut ctr = 0;
while state == State::PRODUCE {
ctr = ctr + 1;
if ctr == 8 {
break;
}
size = buf.capacity();
state = c::lkr_produce(
*inner,
sa_slice.as_mut_ptr() as *mut _,
sa_slice.len(),
buf.as_mut_ptr() as *mut _,
&mut size,
false,
);
}
state
};
match state {
State::DONE => None,
State::CONSUME => {
for ptr_addr in sa_vec {
if ptr_addr.is_null() {
break;
}
let addr = socket2::SockAddr::from_raw_parts(
ptr_addr as *const _,
c::lkr_sockaddr_len(ptr_addr) as u32,
);
let as_inet = addr.as_inet();
if !as_inet.is_none() {
addresses.push(as_inet.unwrap().into());
} else {
addresses.push(addr.as_inet6().unwrap().into());
}
}
Some((Bytes::from(&msg[..size]), addresses))
}
_ => None,
}
}
}
pub fn finish(self, state: State) -> Result<Bytes> {
let (_, inner) = self.locked();
let answer_len = unsafe { c::lkr_finish(*inner, state) };
let mut v: Vec<u8> = Vec::with_capacity(answer_len);
let p = v.as_mut_ptr();
let v = unsafe {
mem::forget(v);
c::lkr_write_answer(*inner, p, answer_len);
Vec::from_raw_parts(p, answer_len, answer_len)
};
Ok(Bytes::from(v))
}
fn locked(
&self,
) -> (
MutexGuard<*mut c::lkr_context>,
MutexGuard<*mut c::lkr_request>,
) {
(self.context.locked(), self.inner.lock())
}
}
impl Drop for Request {
fn drop(&mut self) {
let (_, inner) = self.locked();
if !inner.is_null() {
unsafe {
c::lkr_request_free(*inner);
}
}
}
}
#[cfg(test)]
mod tests {
use super::{Context, Request, State};
use dnssector::constants::*;
use dnssector::synth::gen;
use dnssector::{DNSSector, Section};
use std::net::SocketAddr;
#[test]
fn context_create() {
let context = Context::new();
let r1 = Request::new(context.clone());
let r2 = Request::new(context.clone());
let (_, p1) = r1.locked();
let (_, p2) = r2.locked();
assert!(*p1 != *p2);
}
#[test]
fn context_create_cached() {
assert!(Context::with_cache(".", 64 * 1024).is_ok());
}
#[test]
fn context_root_hints() {
let context = Context::new();
assert!(context.add_root_hint("127.0.0.1".parse().unwrap()).is_ok());
assert!(context.add_root_hint("::1".parse().unwrap()).is_ok());
}
#[test]
fn context_with_module() {
let context = Context::new();
assert!(context.add_module("iterate").is_ok());
assert!(context.remove_module("iterate").is_ok());
}
#[test]
fn context_trust_anchor() {
let context = Context::new();
let ta = gen::RR::from_string(
". 0 IN DS 20326 8 2 E06D44B80B8F1D39A95C0B0D7C65D08458E880409BBC683457104237C7F8EC8D",
)
.unwrap();
assert!(context.add_trust_anchor(ta.rdata()).is_ok());
}
#[test]
fn context_verbose() {
let context = Context::new();
context.set_verbose(true);
context.set_verbose(false);
}
#[test]
fn request_processing() {
let context = Context::new();
let request = Request::new(context.clone());
let buf = gen::query(
b".",
Type::from_string("NS").unwrap(),
Class::from_string("IN").unwrap(),
)
.unwrap();
let addr = "1.1.1.1:53".parse::<SocketAddr>().unwrap();
request.consume(buf.packet(), addr);
let state = match request.produce() {
Some((buf, addresses)) => {
let mut resp = DNSSector::new(buf.to_vec()).unwrap().parse().unwrap();
resp.set_response(true);
resp.insert_rr(
Section::Answer,
gen::RR::from_string(". 86399 IN NS e.root-servers.net").unwrap(),
)
.unwrap();
resp.insert_rr(
Section::Additional,
gen::RR::from_string("e.root-servers.net 86399 IN A 192.203.230.10").unwrap(),
)
.unwrap();
request.consume(resp.packet(), addresses[0])
}
None => State::DONE,
};
assert_eq!(state, State::DONE);
let buf = request.finish(state).unwrap();
let resp = DNSSector::new(buf.to_vec()).unwrap().parse();
assert!(resp.is_ok());
}
}