use Result;
use Error::ArgumentError;
use std::collections::BTreeMap;
use trust_dns_resolver::Resolver;
pub const DEFAULT_PORT: u16 = 27017;
pub const URI_SCHEME: &str = "mongodb://";
pub const URI_SCHEME_DNS_SEEDLIST: &str = "mongodb+srv://";
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub struct DNS {
pub name: String,
discovered_hosts: Vec<Host>,
}
impl DNS {
pub fn new<S: Into<String>>(name: S) -> Self {
Self {
name: name.into(),
discovered_hosts: Vec::new(),
}
}
pub fn discover_hosts(&mut self) -> Result<()> {
let host = format!("_mongodb._tcp.{}", self.name);
let srv_lookup = Resolver::from_system_conf()?.lookup_srv(&host)?;
for srv in srv_lookup {
self.discovered_hosts.push(Host::new(srv.target().to_utf8(), srv.port()));
}
Ok(())
}
}
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub struct Host {
pub host_name: String,
pub ipc: String,
pub port: u16,
}
impl Host {
fn new(host_name: String, port: u16) -> Host {
Host {
host_name: host_name,
port: port,
ipc: String::new(),
}
}
fn with_ipc(ipc: String) -> Host {
Host {
host_name: String::new(),
port: DEFAULT_PORT,
ipc: ipc,
}
}
pub fn has_ipc(&self) -> bool {
!self.ipc.is_empty()
}
}
#[derive(Debug, Clone, PartialEq, Eq, Hash, Default)]
pub struct ConnectionOptions {
pub options: BTreeMap<String, String>,
pub read_pref_tags: Vec<String>,
}
impl ConnectionOptions {
pub fn new(
options: BTreeMap<String, String>,
read_pref_tags: Vec<String>,
) -> ConnectionOptions {
ConnectionOptions {
options: options,
read_pref_tags: read_pref_tags,
}
}
pub fn get(&self, key: &str) -> Option<&String> {
self.options.get(key)
}
}
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub enum ConnectionProtocol {
DNS(DNS),
Hosts(Vec<Host>),
}
impl ConnectionProtocol {
pub fn num_hosts(&self) -> usize {
match &self {
ConnectionProtocol::DNS(dns) => dns.discovered_hosts.len(),
ConnectionProtocol::Hosts(hosts) => hosts.len(),
}
}
pub fn into_iter(self) -> impl Iterator<Item=Host> {
match self {
ConnectionProtocol::DNS(dns) => dns.discovered_hosts.into_iter(),
ConnectionProtocol::Hosts(hosts) => hosts.into_iter(),
}
}
pub fn iter(&self) -> impl Iterator<Item=&Host> {
match self {
ConnectionProtocol::DNS(dns) => dns.discovered_hosts.iter(),
ConnectionProtocol::Hosts(hosts) => hosts.iter(),
}
}
pub fn iter_mut(&mut self) -> impl Iterator<Item=&mut Host> {
match self {
ConnectionProtocol::DNS(dns) => dns.discovered_hosts.iter_mut(),
ConnectionProtocol::Hosts(hosts) => hosts.iter_mut(),
}
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct ConnectionString {
pub hosts: ConnectionProtocol,
pub string: Option<String>,
pub user: Option<String>,
pub password: Option<String>,
pub database: Option<String>,
pub collection: Option<String>,
pub options: Option<ConnectionOptions>,
}
impl ConnectionString {
pub fn new(host_name: &str, port: u16) -> ConnectionString {
let host = Host::new(String::from(host_name), port);
ConnectionString::with_host(host)
}
fn with_host(host: Host) -> ConnectionString {
ConnectionString {
hosts: ConnectionProtocol::Hosts(vec![host]),
string: None,
user: None,
password: None,
database: Some(String::from("test")),
collection: None,
options: None,
}
}
}
pub fn parse(address: &str) -> Result<ConnectionString> {
let (dns_seed_discovery, addr) = if address.starts_with(URI_SCHEME_DNS_SEEDLIST) {
(true, &address[URI_SCHEME_DNS_SEEDLIST.len()..])
} else {
if !address.starts_with(URI_SCHEME) {
return Err(ArgumentError(String::from(
"MongoDB connection string must start with 'mongodb://' or 'mongodb+srv://'.",
)));
}
(false, &address[URI_SCHEME.len()..])
};
let mut user: Option<String> = None;
let mut password: Option<String> = None;
let mut database: Option<String> = Some(String::from("test"));
let mut collection: Option<String> = None;
let mut options: Option<ConnectionOptions> = None;
let (host_str, path_str) = if addr.contains(".sock") {
let (host_part, path_part) = rsplit(addr, ".sock");
if path_part.starts_with('/') {
(host_part, &path_part[1..])
} else {
(host_part, path_part)
}
} else {
partition(addr, "/")
};
if path_str.is_empty() && host_str.contains('?') {
return Err(ArgumentError(String::from(
"A '/' is required between the host list and any options.",
)));
}
let hosts = if host_str.contains('@') {
let (user_info, host_str) = rpartition(host_str, "@");
let (u, p) = parse_user_info(user_info)?;
user = Some(String::from(u));
password = Some(String::from(p));
if dns_seed_discovery {
ConnectionProtocol::DNS(parse_dns_addr(host_str)?)
} else {
ConnectionProtocol::Hosts(split_hosts(host_str)?)
}
} else {
if dns_seed_discovery {
ConnectionProtocol::DNS(parse_dns_addr(host_str)?)
} else {
ConnectionProtocol::Hosts(split_hosts(host_str)?)
}
};
let mut opts = "";
if !path_str.is_empty() {
if path_str.starts_with('?') {
opts = &path_str[1..];
} else {
let (dbase, options) = partition(path_str, "?");
let (dbase_new, coll) = partition(dbase, ".");
database = Some(String::from(dbase_new));
collection = Some(String::from(coll));
opts = options;
}
}
if !opts.is_empty() {
options = Some(split_options(opts).unwrap());
}
if dns_seed_discovery {
let mut conn_options = options.take().unwrap_or_default();
if conn_options.get("ssl").is_none() {
conn_options.options.insert("ssl".to_owned(), "true".to_owned());
}
options = Some(conn_options);
}
Ok(ConnectionString {
hosts: hosts,
string: Some(String::from(address)),
user: user,
password: password,
database: database,
collection: collection,
options: options,
})
}
fn parse_dns_addr(dns_str: &str) -> Result<DNS> {
if dns_str.split('.').collect::<Vec<_>>().len() < 3 {
return Err(ArgumentError(String::from(
"DNS must consists of a least a hostname, a domain name and a TLD",
)));
}
if dns_str.find(':').is_some() {
return Err(ArgumentError(String::from(
"Connection string using DNS MUST not contains a port",
)));
}
if dns_str.find(',').is_some() {
return Err(ArgumentError(String::from(
"Connection string using DNS can't contains more than one host name",
)));
}
Ok(DNS::new(dns_str))
}
fn parse_user_info(user_info: &str) -> Result<(&str, &str)> {
let (user, password) = rpartition(user_info, ":");
if user_info.contains('@') || user.contains(':') {
return Err(ArgumentError(String::from(
"':' or '@' characters in a username or password must be escaped according to RFC 2396.",
)));
}
if user.is_empty() {
return Err(ArgumentError(
String::from("The empty string is not a valid username."),
));
}
Ok((user, password))
}
fn parse_ipv6_literal_host(entity: &str) -> Result<Host> {
match entity.find(']') {
Some(_) => {
match entity.find("]:") {
Some(idx) => {
let port = &entity[idx + 2..];
match port.parse::<u16>() {
Ok(val) => Ok(Host::new(entity[1..idx].to_ascii_lowercase(), val)),
Err(_) => Err(ArgumentError(String::from("Port must be an integer."))),
}
}
None => Ok(Host::new(entity[1..].to_ascii_lowercase(), DEFAULT_PORT)),
}
}
None => {
Err(ArgumentError(String::from(
"An IPv6 address must be enclosed in '[' and ']' according to RFC 2732.",
)))
}
}
}
pub fn parse_host(entity: &str) -> Result<Host> {
if entity.starts_with('[') {
parse_ipv6_literal_host(entity)
} else if entity.contains(':') {
let (host, port) = partition(entity, ":");
if port.contains(':') {
return Err(ArgumentError(String::from(
"Reserved characters such as ':' must
be escaped according to RFC 2396. An IPv6 address literal
must be enclosed in '[' and according to RFC 2732.",
)));
}
match port.parse::<u16>() {
Ok(val) => Ok(Host::new(host.to_ascii_lowercase(), val)),
Err(_) => Err(ArgumentError(
String::from("Port must be an unsigned integer."),
)),
}
} else if entity.contains(".sock") {
Ok(Host::with_ipc(entity.to_ascii_lowercase()))
} else {
Ok(Host::new(entity.to_ascii_lowercase(), DEFAULT_PORT))
}
}
fn split_hosts(host_str: &str) -> Result<Vec<Host>> {
let mut hosts: Vec<Host> = Vec::new();
for entity in host_str.split(',') {
if entity.is_empty() {
return Err(ArgumentError(
String::from("Empty host, or extra comma in host list."),
));
}
let host = parse_host(entity)?;
hosts.push(host);
}
Ok(hosts)
}
fn parse_options(opts: &str, delim: Option<&str>) -> ConnectionOptions {
let mut options = BTreeMap::new();
let mut read_pref_tags = Vec::new();
let opt_list = match delim {
Some(delim) => opts.split(delim).collect(),
None => vec![opts],
};
for opt in opt_list {
let (key, val) = partition(opt, "=");
if key.to_ascii_lowercase() == "readpreferencetags" {
read_pref_tags.push(String::from(val));
} else {
options.insert(String::from(key), String::from(val));
}
}
ConnectionOptions::new(options, read_pref_tags)
}
fn split_options(opts: &str) -> Result<ConnectionOptions> {
let and_idx = opts.find('&');
let semi_idx = opts.find(';');
let mut delim = None;
if and_idx != None && semi_idx != None {
return Err(ArgumentError(String::from(
"Cannot mix '&' and ';' for option separators.",
)));
} else if and_idx != None {
delim = Some("&");
} else if semi_idx != None {
delim = Some(";");
} else if opts.find('=') == None {
return Err(ArgumentError(String::from(
"InvalidURI: MongoDB URI options are key=value pairs.",
)));
}
let options = parse_options(opts, delim);
Ok(options)
}
fn partition<'a>(string: &'a str, sep: &str) -> (&'a str, &'a str) {
match string.find(sep) {
Some(idx) => (&string[..idx], &string[idx + sep.len()..]),
None => (string, ""),
}
}
fn rpartition<'a>(string: &'a str, sep: &str) -> (&'a str, &'a str) {
match string.rfind(sep) {
Some(idx) => (&string[..idx], &string[idx + sep.len()..]),
None => (string, ""),
}
}
fn rsplit<'a>(string: &'a str, sep: &str) -> (&'a str, &'a str) {
match string.rfind(sep) {
Some(idx) => (&string[..idx + sep.len()], &string[idx + sep.len()..]),
None => (string, ""),
}
}