rama_proxy/proxydb/
update.rs1use super::ProxyDB;
2use arc_swap::ArcSwap;
3use rama_core::error::{BoxError, OpaqueError};
4use std::{fmt, ops::Deref, sync::Arc};
5
6pub fn proxy_db_updater<T>() -> (LiveUpdateProxyDB<T>, LiveUpdateProxyDBSetter<T>)
29where
30 T: ProxyDB<Error: Into<BoxError>>,
31{
32 let data = Arc::new(ArcSwap::from_pointee(None));
33 let reader = LiveUpdateProxyDB(data.clone());
34 let writer = LiveUpdateProxyDBSetter(data);
35 (reader, writer)
36}
37
38pub struct LiveUpdateProxyDB<T>(Arc<ArcSwap<Option<T>>>);
43
44impl<T: fmt::Debug> fmt::Debug for LiveUpdateProxyDB<T> {
45 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
46 f.debug_tuple("LiveUpdateProxyDB").field(&self.0).finish()
47 }
48}
49
50impl<T> Clone for LiveUpdateProxyDB<T> {
51 fn clone(&self) -> Self {
52 Self(self.0.clone())
53 }
54}
55
56impl<T> ProxyDB for LiveUpdateProxyDB<T>
57where
58 T: ProxyDB<Error: Into<BoxError>>,
59{
60 type Error = BoxError;
61
62 async fn get_proxy_if(
63 &self,
64 ctx: rama_net::transport::TransportContext,
65 filter: super::ProxyFilter,
66 predicate: impl super::ProxyQueryPredicate,
67 ) -> Result<super::Proxy, Self::Error> {
68 match self.0.load().deref().deref() {
69 Some(db) => db
70 .get_proxy_if(ctx, filter, predicate)
71 .await
72 .map_err(Into::into),
73 None => Err(OpaqueError::from_display(
74 "live proxy db: proxy db is None: get_proxy_if unable to proceed",
75 )
76 .into()),
77 }
78 }
79
80 async fn get_proxy(
81 &self,
82 ctx: rama_net::transport::TransportContext,
83 filter: super::ProxyFilter,
84 ) -> Result<super::Proxy, Self::Error> {
85 match self.0.load().deref().deref() {
86 Some(db) => db.get_proxy(ctx, filter).await.map_err(Into::into),
87 None => Err(OpaqueError::from_display(
88 "live proxy db: proxy db is None: get_proxy unable to proceed",
89 )
90 .into()),
91 }
92 }
93}
94
95pub struct LiveUpdateProxyDBSetter<T>(Arc<ArcSwap<Option<T>>>);
102
103impl<T> LiveUpdateProxyDBSetter<T> {
104 pub fn set(&self, db: T) {
107 self.0.store(Arc::new(Some(db)))
108 }
109}
110
111impl<T: fmt::Debug> fmt::Debug for LiveUpdateProxyDBSetter<T> {
112 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
113 f.debug_tuple("LiveUpdateProxyDBSetter")
114 .field(&self.0)
115 .finish()
116 }
117}
118
119#[cfg(test)]
120mod tests {
121 use crate::{Proxy, ProxyFilter};
122 use rama_net::{
123 asn::Asn,
124 transport::{TransportContext, TransportProtocol},
125 };
126 use rama_utils::str::NonEmptyString;
127
128 use super::*;
129
130 #[tokio::test]
131 async fn test_empty_live_update_db() {
132 let (reader, _) = proxy_db_updater::<Proxy>();
133 assert!(reader
134 .get_proxy(
135 TransportContext {
136 protocol: TransportProtocol::Tcp,
137 app_protocol: None,
138 http_version: None,
139 authority: "proxy.example.com:1080".parse().unwrap(),
140 },
141 ProxyFilter::default(),
142 )
143 .await
144 .is_err());
145 }
146
147 #[tokio::test]
148 async fn test_live_update_db_updated() {
149 let (reader, writer) = proxy_db_updater();
150
151 assert!(reader
152 .get_proxy(
153 TransportContext {
154 protocol: TransportProtocol::Tcp,
155 app_protocol: None,
156 http_version: None,
157 authority: "proxy.example.com:1080".parse().unwrap(),
158 },
159 ProxyFilter::default(),
160 )
161 .await
162 .is_err());
163
164 writer.set(Proxy {
165 id: NonEmptyString::from_static("id"),
166 address: "authority".parse().unwrap(),
167 tcp: true,
168 udp: false,
169 http: false,
170 https: true,
171 socks5: false,
172 socks5h: false,
173 datacenter: true,
174 residential: false,
175 mobile: true,
176 pool_id: Some("pool_id".into()),
177 continent: Some("continent".into()),
178 country: Some("country".into()),
179 state: Some("state".into()),
180 city: Some("city".into()),
181 carrier: Some("carrier".into()),
182 asn: Some(Asn::from_static(1)),
183 });
184
185 assert_eq!(
186 "id",
187 reader
188 .get_proxy(
189 TransportContext {
190 protocol: TransportProtocol::Tcp,
191 app_protocol: None,
192 http_version: None,
193 authority: "proxy.example.com:1080".parse().unwrap(),
194 },
195 ProxyFilter::default(),
196 )
197 .await
198 .unwrap()
199 .id
200 );
201
202 assert!(reader
203 .get_proxy(
204 TransportContext {
205 protocol: TransportProtocol::Udp,
206 app_protocol: None,
207 http_version: None,
208 authority: "proxy.example.com:1080".parse().unwrap(),
209 },
210 ProxyFilter::default(),
211 )
212 .await
213 .is_err());
214
215 assert_eq!(
216 "id",
217 reader
218 .get_proxy(
219 TransportContext {
220 protocol: TransportProtocol::Tcp,
221 app_protocol: None,
222 http_version: None,
223 authority: "proxy.example.com:1080".parse().unwrap(),
224 },
225 ProxyFilter::default(),
226 )
227 .await
228 .unwrap()
229 .id
230 );
231 }
232}