use std::collections::HashSet;
use std::net::SocketAddr;
use std::sync::Arc;
use anyhow::{anyhow, Result};
use http::Uri;
use super::auth::AuthVerifier;
#[derive(Debug, Clone)]
pub enum Transport {
Tcp,
Grpc {
service_name: String,
tls_sni: String,
request_uri: Uri,
},
}
#[derive(Debug, Clone)]
pub struct Upstream {
pub addr: String,
pub parsed_addr: SocketAddr,
pub transport: Transport,
#[allow(dead_code)]
pub tcp_fast_open: bool,
}
struct Entry {
#[allow(dead_code)]
uuid: String,
verifier: AuthVerifier,
upstream: Arc<Upstream>,
}
pub struct Validator {
entries: Vec<Entry>,
}
impl Validator {
pub fn new(pairs: Vec<(String, Arc<Upstream>)>) -> Result<Self> {
let mut seen = std::collections::HashSet::new();
let mut entries = Vec::with_capacity(pairs.len());
for (uuid, upstream) in pairs {
if !seen.insert(uuid.clone()) {
return Err(anyhow!("duplicate UUID: {}", uuid));
}
let verifier = AuthVerifier::from_uuid(&uuid)?;
entries.push(Entry {
uuid,
verifier,
upstream,
});
}
Ok(Self { entries })
}
pub fn match_auth_id(&self, auth_id: &[u8; 16]) -> Option<Arc<Upstream>> {
for entry in &self.entries {
if entry.verifier.verify(auth_id) {
return Some(entry.upstream.clone());
}
}
None
}
pub fn grpc_endpoints(&self) -> HashSet<(String, String)> {
self.entries
.iter()
.filter_map(|entry| match &entry.upstream.transport {
Transport::Grpc { tls_sni, .. } => {
Some((entry.upstream.addr.clone(), tls_sni.clone()))
}
Transport::Tcp => None,
})
.collect()
}
#[allow(dead_code)]
pub fn len(&self) -> usize {
self.entries.len()
}
#[allow(dead_code)]
pub fn is_empty(&self) -> bool {
self.entries.is_empty()
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::vmess::auth::AuthVerifier;
use aes::cipher::{Block, BlockCipherEncrypt, KeyInit};
use aes::Aes128;
use std::time::{SystemTime, UNIX_EPOCH};
fn make_auth_id(uuid: &str) -> [u8; 16] {
let verifier = AuthVerifier::from_uuid(uuid).unwrap();
let now = SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap()
.as_secs();
let mut plain = [0u8; 16];
plain[0..8].copy_from_slice(&now.to_be_bytes());
rand::fill(&mut plain[8..12]);
let checksum = crc32fast::hash(&plain[0..12]);
plain[12..16].copy_from_slice(&checksum.to_be_bytes());
let cipher = Aes128::new_from_slice(&verifier.ecb_key).unwrap();
let mut block =
Block::<Aes128>::try_from(&plain[..]).expect("plain auth ID is one AES block");
cipher.encrypt_block(&mut block);
block.into()
}
fn tcp_upstream(host: &str, port: u16) -> Arc<Upstream> {
let addr_str = format!("{}:{}", host, port);
let parsed_addr = addr_str.parse().unwrap();
Arc::new(Upstream {
addr: addr_str,
parsed_addr,
transport: Transport::Tcp,
tcp_fast_open: false,
})
}
fn grpc_upstream(host: &str, port: u16, tls_sni: &str) -> Arc<Upstream> {
let addr_str = format!("{}:{}", host, port);
let parsed_addr = addr_str.parse().unwrap();
Arc::new(Upstream {
addr: addr_str,
parsed_addr,
transport: Transport::Grpc {
service_name: "GunService".to_string(),
tls_sni: tls_sni.to_string(),
request_uri: format!("https://{}:{}/GunService/Tun", tls_sni, port)
.parse()
.unwrap(),
},
tcp_fast_open: false,
})
}
#[test]
fn test_validator_routing() {
let uuid = "550e8400-e29b-41d4-a716-446655440000";
let upstream = tcp_upstream("127.0.0.1", 9000);
let validator = Validator::new(vec![(uuid.to_string(), upstream.clone())]).unwrap();
let auth_id = make_auth_id(uuid);
let result = validator.match_auth_id(&auth_id);
assert!(result.is_some());
assert_eq!(result.unwrap().addr, upstream.addr);
}
#[test]
fn test_validator_no_match() {
let uuid = "550e8400-e29b-41d4-a716-446655440000";
let upstream = tcp_upstream("127.0.0.1", 9000);
let validator = Validator::new(vec![(uuid.to_string(), upstream)]).unwrap();
let bad_id = [0u8; 16];
assert!(validator.match_auth_id(&bad_id).is_none());
}
#[test]
fn test_validator_grpc_endpoints() {
let tcp = tcp_upstream("127.0.0.1", 9000);
let grpc1 = grpc_upstream("127.0.0.1", 9001, "one.example.com");
let grpc2 = grpc_upstream("127.0.0.1", 9002, "two.example.com");
let validator = Validator::new(vec![
("550e8400-e29b-41d4-a716-446655440000".to_string(), tcp),
("550e8400-e29b-41d4-a716-446655440001".to_string(), grpc1),
("550e8400-e29b-41d4-a716-446655440002".to_string(), grpc2),
])
.unwrap();
let endpoints = validator.grpc_endpoints();
assert_eq!(endpoints.len(), 2);
assert!(endpoints.contains(&("127.0.0.1:9001".to_string(), "one.example.com".to_string())));
assert!(endpoints.contains(&("127.0.0.1:9002".to_string(), "two.example.com".to_string())));
}
#[test]
fn test_validator_duplicate_uuid() {
let uuid = "550e8400-e29b-41d4-a716-446655440000";
let upstream = tcp_upstream("127.0.0.1", 9000);
let result = Validator::new(vec![
(uuid.to_string(), upstream.clone()),
(uuid.to_string(), upstream.clone()),
]);
assert!(result.is_err());
assert!(result.err().unwrap().to_string().contains("duplicate UUID"));
}
#[test]
fn test_validator_multiple_uuids() {
let uuid1 = "550e8400-e29b-41d4-a716-446655440000";
let uuid2 = "550e8400-e29b-41d4-a716-446655440001";
let up1 = tcp_upstream("127.0.0.1", 9001);
let up2 = tcp_upstream("127.0.0.1", 9002);
let validator = Validator::new(vec![
(uuid1.to_string(), up1.clone()),
(uuid2.to_string(), up2.clone()),
])
.unwrap();
let id1 = make_auth_id(uuid1);
let id2 = make_auth_id(uuid2);
assert_eq!(validator.match_auth_id(&id1).unwrap().addr, up1.addr);
assert_eq!(validator.match_auth_id(&id2).unwrap().addr, up2.addr);
}
#[test]
fn test_validator_concurrent_reads() {
use std::sync::Arc;
use std::thread;
let uuid = "550e8400-e29b-41d4-a716-446655440000";
let upstream = tcp_upstream("127.0.0.1", 9000);
let validator = Arc::new(Validator::new(vec![(uuid.to_string(), upstream)]).unwrap());
let bad_id = [0u8; 16];
let handles: Vec<_> = (0..8)
.map(|_| {
let v = validator.clone();
thread::spawn(move || {
assert!(v.match_auth_id(&bad_id).is_none());
})
})
.collect();
for h in handles {
h.join().unwrap();
}
}
}