rama_proxy/proxydb/
update.rs

1use super::ProxyDB;
2use arc_swap::ArcSwap;
3use rama_core::error::{BoxError, OpaqueError};
4use std::{fmt, ops::Deref, sync::Arc};
5
6/// Create a new [`ProxyDB`] updater which allows you to have a (typically in-memory) [`ProxyDB`]
7/// which you can update live.
8///
9/// This construct returns a pair of:
10///
11/// - [`LiveUpdateProxyDB`]: to be used as the [`ProxyDB`] instead of the inner `T`, dubbed the "reader";
12/// - [`LiveUpdateProxyDBSetter`]: to be used as the _only_ way to set the inner `T` as many time as you wish, dubbed the "writer".
13///
14/// Note that the inner `T` is not yet created when this construct returns this pair.
15/// Until you actually called [`LiveUpdateProxyDBSetter::set`] with the inner `T` [`ProxyDB`],
16/// any [`ProxyDB`] trait method call to [`LiveUpdateProxyDB`] will fail.
17///
18/// It is therefore recommended that you immediately set the inner `T` [`ProxyDB`] upon
19/// receiving the reader/writer pair, prior to starting to actually use the [`ProxyDB`]
20/// in your rama service stack.
21///
22/// This goal of this updater is to be fast for reading (getting proxies),
23/// and slow for the infrequent updates (setting the proxy db). As such it is recommended
24/// to not update the [`ProxyDB`] to frequent. An example use case for this updater
25/// could be to update your in-memory proxy database every 15 minutes, by populating it from
26/// a shared external database (e.g. MySQL`). Failures to create a new `T` ProxyDB should be handled
27/// by the Writer, and can be as simple as just logging it and move on without an update.
28pub fn proxy_db_updater<T>() -> (LiveUpdateProxyDB<T>, LiveUpdateProxyDBSetter<T>)
29where
30    T: ProxyDB<Error: Into<BoxError>>,
31{
32    let data = Arc::new(ArcSwap::from_pointee(None));
33    let reader = LiveUpdateProxyDB(data.clone());
34    let writer = LiveUpdateProxyDBSetter(data);
35    (reader, writer)
36}
37
38/// A wrapper around a `T` [`ProxyDB`] which can be updated
39/// through the _only_ linked writer [`LiveUpdateProxyDBSetter`].
40///
41/// See [`proxy_db_updater`] for more details.
42pub struct LiveUpdateProxyDB<T>(Arc<ArcSwap<Option<T>>>);
43
44impl<T: fmt::Debug> fmt::Debug for LiveUpdateProxyDB<T> {
45    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
46        f.debug_tuple("LiveUpdateProxyDB").field(&self.0).finish()
47    }
48}
49
50impl<T> Clone for LiveUpdateProxyDB<T> {
51    fn clone(&self) -> Self {
52        Self(self.0.clone())
53    }
54}
55
56impl<T> ProxyDB for LiveUpdateProxyDB<T>
57where
58    T: ProxyDB<Error: Into<BoxError>>,
59{
60    type Error = BoxError;
61
62    async fn get_proxy_if(
63        &self,
64        ctx: super::ProxyContext,
65        filter: super::ProxyFilter,
66        predicate: impl super::ProxyQueryPredicate,
67    ) -> Result<super::Proxy, Self::Error> {
68        match self.0.load().deref().deref() {
69            Some(db) => db
70                .get_proxy_if(ctx, filter, predicate)
71                .await
72                .map_err(Into::into),
73            None => Err(OpaqueError::from_display(
74                "live proxy db: proxy db is None: get_proxy_if unable to proceed",
75            )
76            .into()),
77        }
78    }
79
80    async fn get_proxy(
81        &self,
82        ctx: super::ProxyContext,
83        filter: super::ProxyFilter,
84    ) -> Result<super::Proxy, Self::Error> {
85        match self.0.load().deref().deref() {
86            Some(db) => db.get_proxy(ctx, filter).await.map_err(Into::into),
87            None => Err(OpaqueError::from_display(
88                "live proxy db: proxy db is None: get_proxy unable to proceed",
89            )
90            .into()),
91        }
92    }
93}
94
95/// Writer to set a new [`ProxyDB`] in the linked [`LiveUpdateProxyDB`].
96///
97/// There can only be one writer [`LiveUpdateProxyDBSetter`] for each
98/// collection of [`LiveUpdateProxyDB`] linked to the same internal data `T`.
99///
100/// See [`proxy_db_updater`] for more details.
101pub struct LiveUpdateProxyDBSetter<T>(Arc<ArcSwap<Option<T>>>);
102
103impl<T> LiveUpdateProxyDBSetter<T> {
104    /// Set the new `T` [`ProxyDB`] to be used for future [`ProxyDB`]
105    /// calls made to the linked [`LiveUpdateProxyDB`] instances.
106    pub fn set(&self, db: T) {
107        self.0.store(Arc::new(Some(db)))
108    }
109}
110
111impl<T: fmt::Debug> fmt::Debug for LiveUpdateProxyDBSetter<T> {
112    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
113        f.debug_tuple("LiveUpdateProxyDBSetter")
114            .field(&self.0)
115            .finish()
116    }
117}
118
119#[cfg(test)]
120mod tests {
121    use crate::{Proxy, ProxyFilter, proxydb::ProxyContext};
122    use rama_net::{asn::Asn, transport::TransportProtocol};
123    use rama_utils::str::NonEmptyString;
124
125    use super::*;
126
127    #[tokio::test]
128    async fn test_empty_live_update_db() {
129        let (reader, _) = proxy_db_updater::<Proxy>();
130        assert!(
131            reader
132                .get_proxy(
133                    ProxyContext {
134                        protocol: TransportProtocol::Tcp,
135                    },
136                    ProxyFilter::default(),
137                )
138                .await
139                .is_err()
140        );
141    }
142
143    #[tokio::test]
144    async fn test_live_update_db_updated() {
145        let (reader, writer) = proxy_db_updater();
146
147        assert!(
148            reader
149                .get_proxy(
150                    ProxyContext {
151                        protocol: TransportProtocol::Tcp,
152                    },
153                    ProxyFilter::default(),
154                )
155                .await
156                .is_err()
157        );
158
159        writer.set(Proxy {
160            id: NonEmptyString::from_static("id"),
161            address: "authority".parse().unwrap(),
162            tcp: true,
163            udp: false,
164            http: false,
165            https: true,
166            socks5: false,
167            socks5h: false,
168            datacenter: true,
169            residential: false,
170            mobile: true,
171            pool_id: Some("pool_id".into()),
172            continent: Some("continent".into()),
173            country: Some("country".into()),
174            state: Some("state".into()),
175            city: Some("city".into()),
176            carrier: Some("carrier".into()),
177            asn: Some(Asn::from_static(1)),
178        });
179
180        assert_eq!(
181            "id",
182            reader
183                .get_proxy(
184                    ProxyContext {
185                        protocol: TransportProtocol::Tcp,
186                    },
187                    ProxyFilter::default(),
188                )
189                .await
190                .unwrap()
191                .id
192        );
193
194        assert!(
195            reader
196                .get_proxy(
197                    ProxyContext {
198                        protocol: TransportProtocol::Udp,
199                    },
200                    ProxyFilter::default(),
201                )
202                .await
203                .is_err()
204        );
205
206        assert_eq!(
207            "id",
208            reader
209                .get_proxy(
210                    ProxyContext {
211                        protocol: TransportProtocol::Tcp,
212                    },
213                    ProxyFilter::default(),
214                )
215                .await
216                .unwrap()
217                .id
218        );
219    }
220}