extern crate libc;
use self::libc::{c_char, c_int};
use std::cmp::Ordering;
use std::collections::HashMap;
use std::ffi::{CString, CStr};
use std::slice;
use std::str::from_utf8;
extern crate rand;
use self::rand::Rng;
use self::rand::distributions::{IndependentSample, Range};
use ctypes::*;
#[link(name = "resolv")]
extern {
pub fn __res_query(dname: *const c_char, class: c_int, typef: c_int,
answer: *const u8, anslen: c_int) -> c_int;
pub fn ns_initparse(answer: *const u8, len: c_int, dst: *mut ns_msg);
pub fn ns_parserr(msg: *mut ns_msg, sect: ns_sect_q, which: c_int, rr: *mut ns_rr);
pub fn ns_sprintrr(msg: *mut ns_msg, rr: *mut ns_rr, b1: *const c_char,
b2: *const c_char, buf: *const c_char, buflen: c_int);
}
#[derive(PartialEq,Eq, PartialOrd, Ord, Debug, Clone)]
pub struct RR {
pub priority: u16,
pub weight: u16,
pub port: u16,
pub ip: [u8;4],
pub ttl: u32,
}
pub fn query_srv(name: &str) -> Result<Vec<RR>, Rcode> {
let dname = CString::new(name).unwrap();
let ans_buf = [0u8;4096];
let mut msg = ns_msg{..Default::default() };
let mut res = vec![];
unsafe {
let len = __res_query(dname.as_ptr() as *const i8, Class::ANY as i32, Type::SRV as i32,
&ans_buf as *const u8, 4096);
ns_initparse(&ans_buf as *const u8, len, &mut msg as *mut ns_msg);
let mut host_to_ip: HashMap<String, [u8;4]> = HashMap::new();
let nmsg_additional = msg.counts[3] as c_int;
for i in 0..nmsg_additional {
let dispbuf = [0u8;4096];
let mut rr = ns_rr{..Default::default() };
ns_parserr(&mut msg as *mut ns_msg, ns_sect_q::ns_s_ar, i, &mut rr as *mut ns_rr);
ns_sprintrr(&mut msg as *mut ns_msg, &mut rr as *mut ns_rr,
0 as *const c_char, 0 as *const c_char,
dispbuf.as_ptr() as *const i8, 4096);
let c_str = CStr::from_ptr(dispbuf.as_ptr() as *const i8);
let s = from_utf8(c_str.to_bytes()).unwrap().to_owned();
let host: &str = s.split_whitespace().nth(0).unwrap();
let ip = s.split_whitespace().last().unwrap();
let octets: Vec<u8> = ip.split(".").map( |o| {
o.parse::<u8>().unwrap()
}).collect();
if octets.len() != 4 {
continue;
}
let ip: [u8; 4] = [
octets[0],
octets[1],
octets[2],
octets[3],
];
host_to_ip.insert(host.to_string(), ip);
}
let nmsg_answer = msg.counts[1] as c_int;
for i in 0..nmsg_answer {
let dispbuf = [0u8;4096];
let mut rr = ns_rr{..Default::default() };
ns_parserr(&mut msg as *mut ns_msg, ns_sect_q::ns_s_an, i, &mut rr as *mut ns_rr);
ns_sprintrr(&mut msg as *mut ns_msg, &mut rr as *mut ns_rr,
0 as *const c_char, 0 as *const c_char,
dispbuf.as_ptr() as *const i8, 4096);
let c_str = CStr::from_ptr(dispbuf.as_ptr() as *const i8);
let s = from_utf8(c_str.to_bytes()).unwrap().to_owned();
if rr.rdlength < 6 {
return Err(Rcode::ServerFailure);
}
let rdata = slice::from_raw_parts(rr.rdata, rr.rdlength as usize);
let prio: u16 = ((rdata[0] as u16) << 8) + rdata[1] as u16;
let weight: u16 = ((rdata[2] as u16) << 8) + rdata[3] as u16;
let port: u16 = ((rdata[4] as u16) << 8) + rdata[5] as u16;
res.push(RR{
ip: *host_to_ip.get(&s.split(" ").last().unwrap().to_string()).unwrap(),
priority: prio,
weight: weight,
port: port,
ttl: rr.ttl,
});
}
}
if res.len() == 0 {
Err(Rcode::NameError)
} else {
Ok(res)
}
}
pub fn srv_mapper(host: &String) -> Result<RR, String> {
let q = query_srv(host);
if q.is_err() {
return Err("srv lookup failed".to_string());
}
srv_chooser(q.unwrap())
}
pub fn srv_chooser(rrs: Vec<RR>) -> Result<RR, String> {
let mut results = rrs.clone();
if results.len() == 0 {
return Err("no records found".to_string());
}
results.sort();
let high_prio = results.first().unwrap().priority;
let mut weights = 0;
let mut rng = rand::thread_rng();
results.sort_by(|_, _| {
if rng.gen() {
Ordering::Less
} else {
Ordering::Greater
}
});
let mut choices = vec![];
for r in results.iter() {
if r.priority == high_prio {
choices.push(r);
weights += r.weight;
}
}
if weights == 0 {
let range = Range::new(0, choices.len());
let weight = range.ind_sample(&mut rng);
return Ok(choices[weight].clone());
} else {
let range = Range::new(0, weights);
let weight = range.ind_sample(&mut rng);
let mut sofar = 0;
for rr in choices {
sofar += rr.weight;
if sofar >= weight {
return Ok(rr.clone());
}
}
}
Err("no srv picked!".to_string())
}
#[test]
fn test_query() {
let r = query_srv("_etcd-server._tcp.etcd-t1.mesos");
println!("result: {:?}", r);
}