use std::collections::HashMap;
use std::net::SocketAddr;
use std::str::FromStr as _;
use async_trait::async_trait;
use uuid::Uuid;
use crate::errors::TranslationError;
#[derive(Debug)]
pub struct UntranslatedPeer<'a> {
pub(crate) host_id: Uuid,
pub(crate) untranslated_address: SocketAddr,
pub(crate) datacenter: Option<&'a str>,
pub(crate) rack: Option<&'a str>,
}
impl UntranslatedPeer<'_> {
#[inline]
pub fn host_id(&self) -> Uuid {
self.host_id
}
#[inline]
pub fn untranslated_address(&self) -> SocketAddr {
self.untranslated_address
}
#[inline]
pub fn datacenter(&self) -> Option<&str> {
self.datacenter
}
#[inline]
pub fn rack(&self) -> Option<&str> {
self.rack
}
#[cfg(all(scylla_unstable, feature = "unstable-nodejs-rs"))]
pub fn new(address: SocketAddr) -> Self {
UntranslatedPeer {
host_id: Uuid::nil(),
untranslated_address: address,
datacenter: None,
rack: None,
}
}
}
#[async_trait]
pub trait AddressTranslator: Send + Sync {
async fn translate_address(
&self,
untranslated_peer: &UntranslatedPeer,
) -> Result<SocketAddr, TranslationError>;
}
#[async_trait]
impl AddressTranslator for HashMap<SocketAddr, SocketAddr> {
async fn translate_address(
&self,
untranslated_peer: &UntranslatedPeer,
) -> Result<SocketAddr, TranslationError> {
match self.get(&untranslated_peer.untranslated_address()) {
Some(&translated_addr) => Ok(translated_addr),
None => Err(TranslationError::NoRuleForAddress(
untranslated_peer.untranslated_address(),
)),
}
}
}
#[async_trait]
impl AddressTranslator for HashMap<&'static str, &'static str> {
async fn translate_address(
&self,
untranslated_peer: &UntranslatedPeer,
) -> Result<SocketAddr, TranslationError> {
for (&rule_addr_str, &translated_addr_str) in self.iter() {
if let Ok(rule_addr) = SocketAddr::from_str(rule_addr_str)
&& rule_addr == untranslated_peer.untranslated_address()
{
return SocketAddr::from_str(translated_addr_str).map_err(|reason| {
TranslationError::InvalidAddressInRule {
translated_addr_str,
reason,
}
});
}
}
Err(TranslationError::NoRuleForAddress(
untranslated_peer.untranslated_address(),
))
}
}