use super::ProxyDB;
use arc_swap::ArcSwap;
use rama_core::error::{BoxError, OpaqueError};
use std::{fmt, ops::Deref, sync::Arc};
pub fn proxy_db_updater<T>() -> (LiveUpdateProxyDB<T>, LiveUpdateProxyDBSetter<T>)
where
T: ProxyDB<Error: Into<BoxError>>,
{
let data = Arc::new(ArcSwap::from_pointee(None));
let reader = LiveUpdateProxyDB(data.clone());
let writer = LiveUpdateProxyDBSetter(data);
(reader, writer)
}
pub struct LiveUpdateProxyDB<T>(Arc<ArcSwap<Option<T>>>);
impl<T: fmt::Debug> fmt::Debug for LiveUpdateProxyDB<T> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_tuple("LiveUpdateProxyDB").field(&self.0).finish()
}
}
impl<T> Clone for LiveUpdateProxyDB<T> {
fn clone(&self) -> Self {
Self(self.0.clone())
}
}
impl<T> ProxyDB for LiveUpdateProxyDB<T>
where
T: ProxyDB<Error: Into<BoxError>>,
{
type Error = BoxError;
async fn get_proxy_if(
&self,
ctx: super::ProxyContext,
filter: super::ProxyFilter,
predicate: impl super::ProxyQueryPredicate,
) -> Result<super::Proxy, Self::Error> {
match self.0.load().deref().deref() {
Some(db) => db
.get_proxy_if(ctx, filter, predicate)
.await
.map_err(Into::into),
None => Err(OpaqueError::from_display(
"live proxy db: proxy db is None: get_proxy_if unable to proceed",
)
.into()),
}
}
async fn get_proxy(
&self,
ctx: super::ProxyContext,
filter: super::ProxyFilter,
) -> Result<super::Proxy, Self::Error> {
match self.0.load().deref().deref() {
Some(db) => db.get_proxy(ctx, filter).await.map_err(Into::into),
None => Err(OpaqueError::from_display(
"live proxy db: proxy db is None: get_proxy unable to proceed",
)
.into()),
}
}
}
pub struct LiveUpdateProxyDBSetter<T>(Arc<ArcSwap<Option<T>>>);
impl<T> LiveUpdateProxyDBSetter<T> {
pub fn set(&self, db: T) {
self.0.store(Arc::new(Some(db)))
}
}
impl<T: fmt::Debug> fmt::Debug for LiveUpdateProxyDBSetter<T> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_tuple("LiveUpdateProxyDBSetter")
.field(&self.0)
.finish()
}
}
#[cfg(test)]
mod tests {
use crate::{Proxy, ProxyFilter, proxydb::ProxyContext};
use rama_net::{asn::Asn, transport::TransportProtocol};
use rama_utils::str::NonEmptyString;
use super::*;
#[tokio::test]
async fn test_empty_live_update_db() {
let (reader, _) = proxy_db_updater::<Proxy>();
assert!(
reader
.get_proxy(
ProxyContext {
protocol: TransportProtocol::Tcp,
},
ProxyFilter::default(),
)
.await
.is_err()
);
}
#[tokio::test]
async fn test_live_update_db_updated() {
let (reader, writer) = proxy_db_updater();
assert!(
reader
.get_proxy(
ProxyContext {
protocol: TransportProtocol::Tcp,
},
ProxyFilter::default(),
)
.await
.is_err()
);
writer.set(Proxy {
id: NonEmptyString::from_static("id"),
address: "authority".parse().unwrap(),
tcp: true,
udp: false,
http: false,
https: true,
socks5: false,
socks5h: false,
datacenter: true,
residential: false,
mobile: true,
pool_id: Some("pool_id".into()),
continent: Some("continent".into()),
country: Some("country".into()),
state: Some("state".into()),
city: Some("city".into()),
carrier: Some("carrier".into()),
asn: Some(Asn::from_static(1)),
});
assert_eq!(
"id",
reader
.get_proxy(
ProxyContext {
protocol: TransportProtocol::Tcp,
},
ProxyFilter::default(),
)
.await
.unwrap()
.id
);
assert!(
reader
.get_proxy(
ProxyContext {
protocol: TransportProtocol::Udp,
},
ProxyFilter::default(),
)
.await
.is_err()
);
assert_eq!(
"id",
reader
.get_proxy(
ProxyContext {
protocol: TransportProtocol::Tcp,
},
ProxyFilter::default(),
)
.await
.unwrap()
.id
);
}
}