rama_proxy/proxydb/
update.rs

1use super::ProxyDB;
2use arc_swap::ArcSwap;
3use rama_core::error::{BoxError, OpaqueError};
4use std::{fmt, ops::Deref, sync::Arc};
5
6/// Create a new [`ProxyDB`] updater which allows you to have a (typically in-memory) [`ProxyDB`]
7/// which you can update live.
8///
9/// This construct returns a pair of:
10///
11/// - [`LiveUpdateProxyDB`]: to be used as the [`ProxyDB`] instead of the inner `T`, dubbed the "reader";
12/// - [`LiveUpdateProxyDBSetter`]: to be used as the _only_ way to set the inner `T` as many time as you wish, dubbed the "writer".
13///
14/// Note that the inner `T` is not yet created when this construct returns this pair.
15/// Until you actually called [`LiveUpdateProxyDBSetter::set`] with the inner `T` [`ProxyDB`],
16/// any [`ProxyDB`] trait method call to [`LiveUpdateProxyDB`] will fail.
17///
18/// It is therefore recommended that you immediately set the inner `T` [`ProxyDB`] upon
19/// receiving the reader/writer pair, prior to starting to actually use the [`ProxyDB`]
20/// in your rama service stack.
21///
22/// This goal of this updater is to be fast for reading (getting proxies),
23/// and slow for the infrequent updates (setting the proxy db). As such it is recommended
24/// to not update the [`ProxyDB`] to frequent. An example use case for this updater
25/// could be to update your in-memory proxy database every 15 minutes, by populating it from
26/// a shared external database (e.g. MySQL`). Failures to create a new `T` ProxyDB should be handled
27/// by the Writer, and can be as simple as just logging it and move on without an update.
28pub fn proxy_db_updater<T>() -> (LiveUpdateProxyDB<T>, LiveUpdateProxyDBSetter<T>)
29where
30    T: ProxyDB<Error: Into<BoxError>>,
31{
32    let data = Arc::new(ArcSwap::from_pointee(None));
33    let reader = LiveUpdateProxyDB(data.clone());
34    let writer = LiveUpdateProxyDBSetter(data);
35    (reader, writer)
36}
37
38/// A wrapper around a `T` [`ProxyDB`] which can be updated
39/// through the _only_ linked writer [`LiveUpdateProxyDBSetter`].
40///
41/// See [`proxy_db_updater`] for more details.
42pub struct LiveUpdateProxyDB<T>(Arc<ArcSwap<Option<T>>>);
43
44impl<T: fmt::Debug> fmt::Debug for LiveUpdateProxyDB<T> {
45    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
46        f.debug_tuple("LiveUpdateProxyDB").field(&self.0).finish()
47    }
48}
49
50impl<T> Clone for LiveUpdateProxyDB<T> {
51    fn clone(&self) -> Self {
52        Self(self.0.clone())
53    }
54}
55
56impl<T> ProxyDB for LiveUpdateProxyDB<T>
57where
58    T: ProxyDB<Error: Into<BoxError>>,
59{
60    type Error = BoxError;
61
62    async fn get_proxy_if(
63        &self,
64        ctx: rama_net::transport::TransportContext,
65        filter: super::ProxyFilter,
66        predicate: impl super::ProxyQueryPredicate,
67    ) -> Result<super::Proxy, Self::Error> {
68        match self.0.load().deref().deref() {
69            Some(db) => db
70                .get_proxy_if(ctx, filter, predicate)
71                .await
72                .map_err(Into::into),
73            None => Err(OpaqueError::from_display(
74                "live proxy db: proxy db is None: get_proxy_if unable to proceed",
75            )
76            .into()),
77        }
78    }
79
80    async fn get_proxy(
81        &self,
82        ctx: rama_net::transport::TransportContext,
83        filter: super::ProxyFilter,
84    ) -> Result<super::Proxy, Self::Error> {
85        match self.0.load().deref().deref() {
86            Some(db) => db.get_proxy(ctx, filter).await.map_err(Into::into),
87            None => Err(OpaqueError::from_display(
88                "live proxy db: proxy db is None: get_proxy unable to proceed",
89            )
90            .into()),
91        }
92    }
93}
94
95/// Writer to set a new [`ProxyDB`] in the linked [`LiveUpdateProxyDB`].
96///
97/// There can only be one writer [`LiveUpdateProxyDBSetter`] for each
98/// collection of [`LiveUpdateProxyDB`] linked to the same internal data `T`.
99///
100/// See [`proxy_db_updater`] for more details.
101pub struct LiveUpdateProxyDBSetter<T>(Arc<ArcSwap<Option<T>>>);
102
103impl<T> LiveUpdateProxyDBSetter<T> {
104    /// Set the new `T` [`ProxyDB`] to be used for future [`ProxyDB`]
105    /// calls made to the linked [`LiveUpdateProxyDB`] instances.
106    pub fn set(&self, db: T) {
107        self.0.store(Arc::new(Some(db)))
108    }
109}
110
111impl<T: fmt::Debug> fmt::Debug for LiveUpdateProxyDBSetter<T> {
112    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
113        f.debug_tuple("LiveUpdateProxyDBSetter")
114            .field(&self.0)
115            .finish()
116    }
117}
118
119#[cfg(test)]
120mod tests {
121    use crate::{Proxy, ProxyFilter};
122    use rama_net::{
123        asn::Asn,
124        transport::{TransportContext, TransportProtocol},
125    };
126    use rama_utils::str::NonEmptyString;
127
128    use super::*;
129
130    #[tokio::test]
131    async fn test_empty_live_update_db() {
132        let (reader, _) = proxy_db_updater::<Proxy>();
133        assert!(reader
134            .get_proxy(
135                TransportContext {
136                    protocol: TransportProtocol::Tcp,
137                    app_protocol: None,
138                    http_version: None,
139                    authority: "proxy.example.com:1080".parse().unwrap(),
140                },
141                ProxyFilter::default(),
142            )
143            .await
144            .is_err());
145    }
146
147    #[tokio::test]
148    async fn test_live_update_db_updated() {
149        let (reader, writer) = proxy_db_updater();
150
151        assert!(reader
152            .get_proxy(
153                TransportContext {
154                    protocol: TransportProtocol::Tcp,
155                    app_protocol: None,
156                    http_version: None,
157                    authority: "proxy.example.com:1080".parse().unwrap(),
158                },
159                ProxyFilter::default(),
160            )
161            .await
162            .is_err());
163
164        writer.set(Proxy {
165            id: NonEmptyString::from_static("id"),
166            address: "authority".parse().unwrap(),
167            tcp: true,
168            udp: false,
169            http: false,
170            https: true,
171            socks5: false,
172            socks5h: false,
173            datacenter: true,
174            residential: false,
175            mobile: true,
176            pool_id: Some("pool_id".into()),
177            continent: Some("continent".into()),
178            country: Some("country".into()),
179            state: Some("state".into()),
180            city: Some("city".into()),
181            carrier: Some("carrier".into()),
182            asn: Some(Asn::from_static(1)),
183        });
184
185        assert_eq!(
186            "id",
187            reader
188                .get_proxy(
189                    TransportContext {
190                        protocol: TransportProtocol::Tcp,
191                        app_protocol: None,
192                        http_version: None,
193                        authority: "proxy.example.com:1080".parse().unwrap(),
194                    },
195                    ProxyFilter::default(),
196                )
197                .await
198                .unwrap()
199                .id
200        );
201
202        assert!(reader
203            .get_proxy(
204                TransportContext {
205                    protocol: TransportProtocol::Udp,
206                    app_protocol: None,
207                    http_version: None,
208                    authority: "proxy.example.com:1080".parse().unwrap(),
209                },
210                ProxyFilter::default(),
211            )
212            .await
213            .is_err());
214
215        assert_eq!(
216            "id",
217            reader
218                .get_proxy(
219                    TransportContext {
220                        protocol: TransportProtocol::Tcp,
221                        app_protocol: None,
222                        http_version: None,
223                        authority: "proxy.example.com:1080".parse().unwrap(),
224                    },
225                    ProxyFilter::default(),
226                )
227                .await
228                .unwrap()
229                .id
230        );
231    }
232}