use std::collections::hash_map;
use std::collections::HashMap;
use std::ffi::CStr;
use std::fmt;
use std::str;
use std::str::FromStr;
use std::string::String;
use external::c_tor_version_as_new_as;
use errors::ProtoverError;
use protoset::ProtoSet;
use protoset::Version;
const FIRST_TOR_VERSION_TO_ADVERTISE_PROTOCOLS: &'static str = "0.2.9.3-alpha";
const MAX_PROTOCOLS_TO_EXPAND: usize = 1 << 16;
pub(crate) const MAX_PROTOCOL_NAME_LENGTH: usize = 100;
#[derive(Clone, Hash, Eq, PartialEq, Debug)]
pub enum Protocol {
Cons,
Desc,
DirCache,
HSDir,
HSIntro,
HSRend,
Link,
LinkAuth,
Microdesc,
Relay,
Padding,
FlowCtrl,
}
impl fmt::Display for Protocol {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
write!(f, "{:?}", self)
}
}
impl FromStr for Protocol {
type Err = ProtoverError;
fn from_str(s: &str) -> Result<Self, Self::Err> {
match s {
"Cons" => Ok(Protocol::Cons),
"Desc" => Ok(Protocol::Desc),
"DirCache" => Ok(Protocol::DirCache),
"HSDir" => Ok(Protocol::HSDir),
"HSIntro" => Ok(Protocol::HSIntro),
"HSRend" => Ok(Protocol::HSRend),
"Link" => Ok(Protocol::Link),
"LinkAuth" => Ok(Protocol::LinkAuth),
"Microdesc" => Ok(Protocol::Microdesc),
"Relay" => Ok(Protocol::Relay),
"Padding" => Ok(Protocol::Padding),
"FlowCtrl" => Ok(Protocol::FlowCtrl),
_ => Err(ProtoverError::UnknownProtocol),
}
}
}
#[derive(Clone, Debug, Hash, Eq, PartialEq)]
pub struct UnknownProtocol(String);
impl fmt::Display for UnknownProtocol {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
write!(f, "{}", self.0)
}
}
fn is_valid_proto(s: &str) -> bool {
s.chars().all(|c| c.is_ascii_alphanumeric() || c == '-')
}
impl FromStr for UnknownProtocol {
type Err = ProtoverError;
fn from_str(s: &str) -> Result<Self, Self::Err> {
if !is_valid_proto(s) {
Err(ProtoverError::InvalidProtocol)
} else if s.len() <= MAX_PROTOCOL_NAME_LENGTH {
Ok(UnknownProtocol(s.to_string()))
} else {
Err(ProtoverError::ExceedsNameLimit)
}
}
}
impl UnknownProtocol {
fn from_str_any_len(s: &str) -> Result<Self, ProtoverError> {
if !is_valid_proto(s) {
return Err(ProtoverError::InvalidProtocol);
}
Ok(UnknownProtocol(s.to_string()))
}
}
impl From<Protocol> for UnknownProtocol {
fn from(p: Protocol) -> UnknownProtocol {
UnknownProtocol(p.to_string())
}
}
#[cfg(feature = "test_linking_hack")]
fn have_linkauth_v1() -> bool {
true
}
#[cfg(not(feature = "test_linking_hack"))]
fn have_linkauth_v1() -> bool {
use external::c_tor_is_using_nss;
!c_tor_is_using_nss()
}
pub(crate) fn get_supported_protocols_cstr() -> &'static CStr {
if !have_linkauth_v1() {
cstr!(
"Cons=1-2 \
Desc=1-2 \
DirCache=2 \
FlowCtrl=1 \
HSDir=1-2 \
HSIntro=3-5 \
HSRend=1-2 \
Link=1-5 \
LinkAuth=3 \
Microdesc=1-2 \
Padding=2 \
Relay=1-3"
)
} else {
cstr!(
"Cons=1-2 \
Desc=1-2 \
DirCache=2 \
FlowCtrl=1 \
HSDir=1-2 \
HSIntro=3-5 \
HSRend=1-2 \
Link=1-5 \
LinkAuth=1,3 \
Microdesc=1-2 \
Padding=2 \
Relay=1-3"
)
}
}
#[derive(Clone, Debug, PartialEq, Eq)]
pub struct ProtoEntry(HashMap<Protocol, ProtoSet>);
impl Default for ProtoEntry {
fn default() -> ProtoEntry {
ProtoEntry(HashMap::new())
}
}
impl ProtoEntry {
pub fn iter(&self) -> hash_map::Iter<Protocol, ProtoSet> {
self.0.iter()
}
pub fn supported() -> Result<Self, ProtoverError> {
let supported_cstr: &'static CStr = get_supported_protocols_cstr();
let supported: &str = supported_cstr.to_str().unwrap_or("");
supported.parse()
}
pub fn len(&self) -> usize {
self.0.len()
}
pub fn get(&self, protocol: &Protocol) -> Option<&ProtoSet> {
self.0.get(protocol)
}
pub fn insert(&mut self, key: Protocol, value: ProtoSet) {
self.0.insert(key, value);
}
pub fn remove(&mut self, key: &Protocol) -> Option<ProtoSet> {
self.0.remove(key)
}
pub fn is_empty(&self) -> bool {
self.0.is_empty()
}
}
impl FromStr for ProtoEntry {
type Err = ProtoverError;
fn from_str(protocol_entry: &str) -> Result<ProtoEntry, ProtoverError> {
let mut proto_entry: ProtoEntry = ProtoEntry::default();
if protocol_entry.is_empty() {
return Ok(proto_entry);
}
let entries = protocol_entry.split(' ');
for entry in entries {
let mut parts = entry.splitn(2, '=');
let proto = match parts.next() {
Some(n) => n,
None => return Err(ProtoverError::Unparseable),
};
let vers = match parts.next() {
Some(n) => n,
None => return Err(ProtoverError::Unparseable),
};
let versions: ProtoSet = vers.parse()?;
let proto_name: Protocol = proto.parse()?;
proto_entry.insert(proto_name, versions);
if proto_entry.len() > MAX_PROTOCOLS_TO_EXPAND {
return Err(ProtoverError::ExceedsMax);
}
}
Ok(proto_entry)
}
}
macro_rules! impl_to_string_for_proto_entry {
($t:ty) => {
impl ToString for $t {
fn to_string(&self) -> String {
let mut parts: Vec<String> = Vec::new();
for (protocol, versions) in self.iter() {
parts.push(format!("{}={}", protocol.to_string(), versions.to_string()));
}
parts.sort_unstable();
parts.join(" ")
}
}
};
}
impl_to_string_for_proto_entry!(ProtoEntry);
impl_to_string_for_proto_entry!(UnvalidatedProtoEntry);
#[derive(Clone, Debug, PartialEq, Eq)]
pub struct UnvalidatedProtoEntry(HashMap<UnknownProtocol, ProtoSet>);
impl Default for UnvalidatedProtoEntry {
fn default() -> UnvalidatedProtoEntry {
UnvalidatedProtoEntry(HashMap::new())
}
}
impl UnvalidatedProtoEntry {
pub fn iter(&self) -> hash_map::Iter<UnknownProtocol, ProtoSet> {
self.0.iter()
}
pub fn get(&self, protocol: &UnknownProtocol) -> Option<&ProtoSet> {
self.0.get(protocol)
}
pub fn insert(&mut self, key: UnknownProtocol, value: ProtoSet) {
self.0.insert(key, value);
}
pub fn remove(&mut self, key: &UnknownProtocol) -> Option<ProtoSet> {
self.0.remove(key)
}
pub fn is_empty(&self) -> bool {
self.0.is_empty()
}
pub fn len(&self) -> usize {
let mut total: usize = 0;
for (_, versions) in self.iter() {
total += versions.len();
}
total
}
pub fn all_supported(&self) -> Option<UnvalidatedProtoEntry> {
let mut unsupported: UnvalidatedProtoEntry = UnvalidatedProtoEntry::default();
let supported: ProtoEntry = match ProtoEntry::supported() {
Ok(x) => x,
Err(_) => return None,
};
for (protocol, versions) in self.iter() {
let is_supported: Result<Protocol, ProtoverError> = protocol.0.parse();
let supported_protocol: Protocol;
if is_supported.is_err() {
if !versions.is_empty() {
unsupported.insert(protocol.clone(), versions.clone());
}
continue;
} else {
supported_protocol = is_supported.unwrap();
}
let maybe_supported_versions: Option<&ProtoSet> = supported.get(&supported_protocol);
let supported_versions: &ProtoSet;
if maybe_supported_versions.is_none() {
if !versions.is_empty() {
unsupported.insert(protocol.clone(), versions.clone());
}
continue;
} else {
supported_versions = maybe_supported_versions.unwrap();
}
let unsupported_versions = versions.and_not_in(supported_versions);
if !unsupported_versions.is_empty() {
unsupported.insert(protocol.clone(), unsupported_versions);
}
}
if unsupported.is_empty() {
return None;
}
Some(unsupported)
}
pub fn supports_protocol(&self, proto: &UnknownProtocol, vers: &Version) -> bool {
let supported_versions: &ProtoSet = match self.get(proto) {
Some(n) => n,
None => return false,
};
supported_versions.contains(&vers)
}
pub fn supports_protocol_or_later(&self, proto: &UnknownProtocol, vers: &Version) -> bool {
let supported_versions: &ProtoSet = match self.get(&proto) {
Some(n) => n,
None => return false,
};
supported_versions.iter().any(|v| v.1 >= *vers)
}
fn parse_protocol_and_version_str<'a>(
protocol_string: &'a str,
) -> Result<Vec<(&'a str, &'a str)>, ProtoverError> {
let mut protovers: Vec<(&str, &str)> = Vec::new();
if protocol_string.is_empty() {
return Ok(protovers);
}
for subproto in protocol_string.split(' ') {
let mut parts = subproto.splitn(2, '=');
let name = match parts.next() {
Some("") => return Err(ProtoverError::Unparseable),
Some(n) => n,
None => return Err(ProtoverError::Unparseable),
};
let vers = match parts.next() {
Some(n) => n,
None => return Err(ProtoverError::Unparseable),
};
protovers.push((name, vers));
}
Ok(protovers)
}
}
impl FromStr for UnvalidatedProtoEntry {
type Err = ProtoverError;
fn from_str(protocol_string: &str) -> Result<UnvalidatedProtoEntry, ProtoverError> {
let mut parsed: UnvalidatedProtoEntry = UnvalidatedProtoEntry::default();
let parts: Vec<(&str, &str)> =
UnvalidatedProtoEntry::parse_protocol_and_version_str(protocol_string)?;
for &(name, vers) in parts.iter() {
let versions = ProtoSet::from_str(vers)?;
let protocol = UnknownProtocol::from_str(name)?;
parsed.insert(protocol, versions);
}
Ok(parsed)
}
}
impl UnvalidatedProtoEntry {
pub(crate) fn from_str_any_len(
protocol_string: &str,
) -> Result<UnvalidatedProtoEntry, ProtoverError> {
let mut parsed: UnvalidatedProtoEntry = UnvalidatedProtoEntry::default();
let parts: Vec<(&str, &str)> =
UnvalidatedProtoEntry::parse_protocol_and_version_str(protocol_string)?;
for &(name, vers) in parts.iter() {
let versions = ProtoSet::from_str(vers)?;
let protocol = UnknownProtocol::from_str_any_len(name)?;
parsed.insert(protocol, versions);
}
Ok(parsed)
}
}
impl From<ProtoEntry> for UnvalidatedProtoEntry {
fn from(proto_entry: ProtoEntry) -> UnvalidatedProtoEntry {
let mut unvalidated: UnvalidatedProtoEntry = UnvalidatedProtoEntry::default();
for (protocol, versions) in proto_entry.iter() {
unvalidated.insert(UnknownProtocol::from(protocol.clone()), versions.clone());
}
unvalidated
}
}
pub struct ProtoverVote(HashMap<UnknownProtocol, HashMap<Version, usize>>);
impl Default for ProtoverVote {
fn default() -> ProtoverVote {
ProtoverVote(HashMap::new())
}
}
impl IntoIterator for ProtoverVote {
type Item = (UnknownProtocol, HashMap<Version, usize>);
type IntoIter = hash_map::IntoIter<UnknownProtocol, HashMap<Version, usize>>;
fn into_iter(self) -> Self::IntoIter {
self.0.into_iter()
}
}
impl ProtoverVote {
pub fn entry(
&mut self,
key: UnknownProtocol,
) -> hash_map::Entry<UnknownProtocol, HashMap<Version, usize>> {
self.0.entry(key)
}
pub fn compute(
proto_entries: &[UnvalidatedProtoEntry],
threshold: &usize,
) -> UnvalidatedProtoEntry {
let mut all_count: ProtoverVote = ProtoverVote::default();
let mut final_output: UnvalidatedProtoEntry = UnvalidatedProtoEntry::default();
if proto_entries.is_empty() {
return final_output;
}
for vote in proto_entries {
if vote.len() > MAX_PROTOCOLS_TO_EXPAND {
continue;
}
for (protocol, versions) in vote.iter() {
let supported_vers: &mut HashMap<Version, usize> =
all_count.entry(protocol.clone()).or_insert(HashMap::new());
for version in versions.clone().expand() {
let counter: &mut usize = supported_vers.entry(version).or_insert(0);
*counter += 1;
}
}
}
for (protocol, mut versions) in all_count {
versions.retain(|_, count| *count as usize >= *threshold);
if versions.len() > 0 {
let voted_versions: Vec<Version> = versions.keys().cloned().collect();
let voted_protoset: ProtoSet = ProtoSet::from(voted_versions);
final_output.insert(protocol, voted_protoset);
}
}
final_output
}
}
pub fn is_supported_here(proto: &Protocol, vers: &Version) -> bool {
let currently_supported: ProtoEntry = match ProtoEntry::supported() {
Ok(result) => result,
Err(_) => return false,
};
let supported_versions = match currently_supported.get(proto) {
Some(n) => n,
None => return false,
};
supported_versions.contains(vers)
}
pub(crate) fn compute_for_old_tor_cstr(version: &str) -> &'static CStr {
let empty: &'static CStr = cstr!("");
if c_tor_version_as_new_as(version, FIRST_TOR_VERSION_TO_ADVERTISE_PROTOCOLS) {
return empty;
}
if c_tor_version_as_new_as(version, "0.2.9.1-alpha") {
return cstr!(
"Cons=1-2 Desc=1-2 DirCache=1 HSDir=1 HSIntro=3 HSRend=1-2 \
Link=1-4 LinkAuth=1 Microdesc=1-2 Relay=1-2"
);
}
if c_tor_version_as_new_as(version, "0.2.7.5") {
return cstr!(
"Cons=1-2 Desc=1-2 DirCache=1 HSDir=1 HSIntro=3 HSRend=1 \
Link=1-4 LinkAuth=1 Microdesc=1-2 Relay=1-2"
);
}
if c_tor_version_as_new_as(version, "0.2.4.19") {
return cstr!(
"Cons=1 Desc=1 DirCache=1 HSDir=1 HSIntro=3 HSRend=1 \
Link=1-4 LinkAuth=1 Microdesc=1 Relay=1-2"
);
}
empty
}
pub fn compute_for_old_tor(version: &str) -> Result<&'static str, ProtoverError> {
compute_for_old_tor_cstr(version)
.to_str()
.or(Err(ProtoverError::Unparseable))
}
#[cfg(test)]
mod test {
use std::str::FromStr;
use std::string::ToString;
use super::*;
macro_rules! parse_proto {
($e:expr) => {{
let proto: Result<UnknownProtocol, _> = $e.parse();
let proto2 = UnknownProtocol::from_str_any_len($e);
assert_eq!(proto, proto2);
proto
}};
}
#[test]
fn test_protocol_from_str() {
assert!(parse_proto!("Cons").is_ok());
assert!(parse_proto!("123").is_ok());
assert!(parse_proto!("1-2-3").is_ok());
let err = Err(ProtoverError::InvalidProtocol);
assert_eq!(err, parse_proto!("a_b_c"));
assert_eq!(err, parse_proto!("a b"));
assert_eq!(err, parse_proto!("a,"));
assert_eq!(err, parse_proto!("b."));
assert_eq!(err, parse_proto!("é"));
}
macro_rules! assert_protoentry_is_parseable {
($e:expr) => {
let protoentry: Result<ProtoEntry, ProtoverError> = $e.parse();
assert!(protoentry.is_ok(), format!("{:?}", protoentry.err()));
};
}
macro_rules! assert_protoentry_is_unparseable {
($e:expr) => {
let protoentry: Result<ProtoEntry, ProtoverError> = $e.parse();
assert!(protoentry.is_err());
};
}
#[test]
fn test_protoentry_from_str_multiple_protocols_multiple_versions() {
assert_protoentry_is_parseable!("Cons=3-4 Link=1,3-5");
}
#[test]
fn test_protoentry_from_str_empty() {
assert_protoentry_is_parseable!("");
assert!(UnvalidatedProtoEntry::from_str("").is_ok());
}
#[test]
fn test_protoentry_from_str_single_protocol_single_version() {
assert_protoentry_is_parseable!("HSDir=1");
}
#[test]
fn test_protoentry_from_str_unknown_protocol() {
assert_protoentry_is_unparseable!("Ducks=5-7,8");
}
#[test]
fn test_protoentry_from_str_allowed_number_of_versions() {
assert_protoentry_is_parseable!("Desc=1-63");
}
#[test]
fn test_protoentry_from_str_too_many_versions() {
assert_protoentry_is_unparseable!("Desc=1-64");
}
#[test]
fn test_protoentry_all_supported_single_protocol_single_version() {
let protocol: UnvalidatedProtoEntry = "Cons=1".parse().unwrap();
let unsupported: Option<UnvalidatedProtoEntry> = protocol.all_supported();
assert_eq!(true, unsupported.is_none());
}
#[test]
fn test_protoentry_all_supported_multiple_protocol_multiple_versions() {
let protocols: UnvalidatedProtoEntry = "Link=3-4 Desc=2".parse().unwrap();
let unsupported: Option<UnvalidatedProtoEntry> = protocols.all_supported();
assert_eq!(true, unsupported.is_none());
}
#[test]
fn test_protoentry_all_supported_three_values() {
let protocols: UnvalidatedProtoEntry = "LinkAuth=1 Microdesc=1-2 Relay=2".parse().unwrap();
let unsupported: Option<UnvalidatedProtoEntry> = protocols.all_supported();
assert_eq!(true, unsupported.is_none());
}
#[test]
fn test_protoentry_all_supported_unknown_protocol() {
let protocols: UnvalidatedProtoEntry = "Wombat=9".parse().unwrap();
let unsupported: Option<UnvalidatedProtoEntry> = protocols.all_supported();
assert_eq!(true, unsupported.is_some());
assert_eq!("Wombat=9", &unsupported.unwrap().to_string());
}
#[test]
fn test_protoentry_all_supported_unsupported_high_version() {
let protocols: UnvalidatedProtoEntry = "HSDir=12-60".parse().unwrap();
let unsupported: Option<UnvalidatedProtoEntry> = protocols.all_supported();
assert_eq!(true, unsupported.is_some());
assert_eq!("HSDir=12-60", &unsupported.unwrap().to_string());
}
#[test]
fn test_protoentry_all_supported_unsupported_low_version() {
let protocols: UnvalidatedProtoEntry = "HSIntro=2-3".parse().unwrap();
let unsupported: Option<UnvalidatedProtoEntry> = protocols.all_supported();
assert_eq!(true, unsupported.is_some());
assert_eq!("HSIntro=2", &unsupported.unwrap().to_string());
}
#[test]
fn test_contract_protocol_list() {
let mut versions = "";
assert_eq!(
String::from(versions),
ProtoSet::from_str(&versions).unwrap().to_string()
);
versions = "1";
assert_eq!(
String::from(versions),
ProtoSet::from_str(&versions).unwrap().to_string()
);
versions = "1-2";
assert_eq!(
String::from(versions),
ProtoSet::from_str(&versions).unwrap().to_string()
);
versions = "1,3";
assert_eq!(
String::from(versions),
ProtoSet::from_str(&versions).unwrap().to_string()
);
versions = "1-4";
assert_eq!(
String::from(versions),
ProtoSet::from_str(&versions).unwrap().to_string()
);
versions = "1,3,5-7";
assert_eq!(
String::from(versions),
ProtoSet::from_str(&versions).unwrap().to_string()
);
versions = "1-3,50";
assert_eq!(
String::from(versions),
ProtoSet::from_str(&versions).unwrap().to_string()
);
}
}