rama_proxy/proxydb/
update.rs1use super::ProxyDB;
2use arc_swap::ArcSwap;
3use rama_core::error::{BoxError, OpaqueError};
4use std::{fmt, ops::Deref, sync::Arc};
5
6pub fn proxy_db_updater<T>() -> (LiveUpdateProxyDB<T>, LiveUpdateProxyDBSetter<T>)
29where
30 T: ProxyDB<Error: Into<BoxError>>,
31{
32 let data = Arc::new(ArcSwap::from_pointee(None));
33 let reader = LiveUpdateProxyDB(data.clone());
34 let writer = LiveUpdateProxyDBSetter(data);
35 (reader, writer)
36}
37
38pub struct LiveUpdateProxyDB<T>(Arc<ArcSwap<Option<T>>>);
43
44impl<T: fmt::Debug> fmt::Debug for LiveUpdateProxyDB<T> {
45 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
46 f.debug_tuple("LiveUpdateProxyDB").field(&self.0).finish()
47 }
48}
49
50impl<T> Clone for LiveUpdateProxyDB<T> {
51 fn clone(&self) -> Self {
52 Self(self.0.clone())
53 }
54}
55
56impl<T> ProxyDB for LiveUpdateProxyDB<T>
57where
58 T: ProxyDB<Error: Into<BoxError>>,
59{
60 type Error = BoxError;
61
62 async fn get_proxy_if(
63 &self,
64 ctx: super::ProxyContext,
65 filter: super::ProxyFilter,
66 predicate: impl super::ProxyQueryPredicate,
67 ) -> Result<super::Proxy, Self::Error> {
68 match self.0.load().deref().deref() {
69 Some(db) => db
70 .get_proxy_if(ctx, filter, predicate)
71 .await
72 .map_err(Into::into),
73 None => Err(OpaqueError::from_display(
74 "live proxy db: proxy db is None: get_proxy_if unable to proceed",
75 )
76 .into()),
77 }
78 }
79
80 async fn get_proxy(
81 &self,
82 ctx: super::ProxyContext,
83 filter: super::ProxyFilter,
84 ) -> Result<super::Proxy, Self::Error> {
85 match self.0.load().deref().deref() {
86 Some(db) => db.get_proxy(ctx, filter).await.map_err(Into::into),
87 None => Err(OpaqueError::from_display(
88 "live proxy db: proxy db is None: get_proxy unable to proceed",
89 )
90 .into()),
91 }
92 }
93}
94
95pub struct LiveUpdateProxyDBSetter<T>(Arc<ArcSwap<Option<T>>>);
102
103impl<T> LiveUpdateProxyDBSetter<T> {
104 pub fn set(&self, db: T) {
107 self.0.store(Arc::new(Some(db)))
108 }
109}
110
111impl<T: fmt::Debug> fmt::Debug for LiveUpdateProxyDBSetter<T> {
112 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
113 f.debug_tuple("LiveUpdateProxyDBSetter")
114 .field(&self.0)
115 .finish()
116 }
117}
118
119#[cfg(test)]
120mod tests {
121 use crate::{Proxy, ProxyFilter, proxydb::ProxyContext};
122 use rama_net::{asn::Asn, transport::TransportProtocol};
123 use rama_utils::str::NonEmptyString;
124
125 use super::*;
126
127 #[tokio::test]
128 async fn test_empty_live_update_db() {
129 let (reader, _) = proxy_db_updater::<Proxy>();
130 assert!(
131 reader
132 .get_proxy(
133 ProxyContext {
134 protocol: TransportProtocol::Tcp,
135 },
136 ProxyFilter::default(),
137 )
138 .await
139 .is_err()
140 );
141 }
142
143 #[tokio::test]
144 async fn test_live_update_db_updated() {
145 let (reader, writer) = proxy_db_updater();
146
147 assert!(
148 reader
149 .get_proxy(
150 ProxyContext {
151 protocol: TransportProtocol::Tcp,
152 },
153 ProxyFilter::default(),
154 )
155 .await
156 .is_err()
157 );
158
159 writer.set(Proxy {
160 id: NonEmptyString::from_static("id"),
161 address: "authority".parse().unwrap(),
162 tcp: true,
163 udp: false,
164 http: false,
165 https: true,
166 socks5: false,
167 socks5h: false,
168 datacenter: true,
169 residential: false,
170 mobile: true,
171 pool_id: Some("pool_id".into()),
172 continent: Some("continent".into()),
173 country: Some("country".into()),
174 state: Some("state".into()),
175 city: Some("city".into()),
176 carrier: Some("carrier".into()),
177 asn: Some(Asn::from_static(1)),
178 });
179
180 assert_eq!(
181 "id",
182 reader
183 .get_proxy(
184 ProxyContext {
185 protocol: TransportProtocol::Tcp,
186 },
187 ProxyFilter::default(),
188 )
189 .await
190 .unwrap()
191 .id
192 );
193
194 assert!(
195 reader
196 .get_proxy(
197 ProxyContext {
198 protocol: TransportProtocol::Udp,
199 },
200 ProxyFilter::default(),
201 )
202 .await
203 .is_err()
204 );
205
206 assert_eq!(
207 "id",
208 reader
209 .get_proxy(
210 ProxyContext {
211 protocol: TransportProtocol::Tcp,
212 },
213 ProxyFilter::default(),
214 )
215 .await
216 .unwrap()
217 .id
218 );
219 }
220}