rama_proxy/proxydb/
mod.rs

1use rama_core::error::{BoxError, ErrorContext, OpaqueError};
2use rama_net::{asn::Asn, transport::TransportContext};
3use rama_utils::str::NonEmptyString;
4use serde::{Deserialize, Serialize};
5use std::{fmt, future::Future};
6
7#[cfg(feature = "live-update")]
8mod update;
9#[cfg(feature = "live-update")]
10#[doc(inline)]
11pub use update::{proxy_db_updater, LiveUpdateProxyDB, LiveUpdateProxyDBSetter};
12
13mod internal;
14#[doc(inline)]
15pub use internal::Proxy;
16
17#[cfg(feature = "csv")]
18mod csv;
19
20#[cfg(feature = "csv")]
21#[doc(inline)]
22pub use csv::{ProxyCsvRowReader, ProxyCsvRowReaderError, ProxyCsvRowReaderErrorKind};
23
24pub(super) mod layer;
25
26mod str;
27#[doc(inline)]
28pub use str::StringFilter;
29
30#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)]
31/// `ID` of the selected proxy. To be inserted into the `Context`,
32/// only if that proxy is selected.
33pub struct ProxyID(NonEmptyString);
34
35impl ProxyID {
36    /// View  this [`ProxyID`] as a `str`.
37    pub fn as_str(&self) -> &str {
38        self.0.as_str()
39    }
40}
41
42impl AsRef<str> for ProxyID {
43    fn as_ref(&self) -> &str {
44        self.0.as_ref()
45    }
46}
47
48impl fmt::Display for ProxyID {
49    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
50        self.0.fmt(f)
51    }
52}
53
54impl From<NonEmptyString> for ProxyID {
55    fn from(value: NonEmptyString) -> Self {
56        Self(value)
57    }
58}
59
60#[derive(Debug, Default, Clone, Serialize, Deserialize, PartialEq, Eq, Hash)]
61/// Filter to select a specific kind of proxy.
62///
63/// If the `id` is specified the other fields are used
64/// as a validator to see if the only possible matching proxy
65/// matches these fields.
66///
67/// If the `id` is not specified, the other fields are used
68/// to select a random proxy from the pool.
69///
70/// Filters can be combined to make combinations with special meaning.
71/// E.g. `datacenter:true, residential:true` is essentially an ISP proxy.
72///
73/// ## Usage
74///
75/// - Use `HeaderConfigLayer` (`rama-http`) to have this proxy filter be given by the http `Request` headers,
76///   which will add the extracted and parsed [`ProxyFilter`] to the [`Context`]'s [`Extensions`].
77/// - Or extract yourself from the username/token validated in the `ProxyAuthLayer` (`rama-http`)
78///   to add it manually to the [`Context`]'s [`Extensions`].
79///
80/// [`Request`]: crate::http::Request
81/// [`Context`]: rama_core::Context
82/// [`Extensions`]: rama_core::context::Extensions
83pub struct ProxyFilter {
84    /// The ID of the proxy to select.
85    pub id: Option<NonEmptyString>,
86
87    /// The ID of the pool from which to select the proxy.
88    #[serde(alias = "pool")]
89    pub pool_id: Option<Vec<StringFilter>>,
90
91    /// The continent of the proxy.
92    pub continent: Option<Vec<StringFilter>>,
93
94    /// The country of the proxy.
95    pub country: Option<Vec<StringFilter>>,
96
97    /// The state of the proxy.
98    pub state: Option<Vec<StringFilter>>,
99
100    /// The city of the proxy.
101    pub city: Option<Vec<StringFilter>>,
102
103    /// Set explicitly to `true` to select a datacenter proxy.
104    pub datacenter: Option<bool>,
105
106    /// Set explicitly to `true` to select a residential proxy.
107    pub residential: Option<bool>,
108
109    /// Set explicitly to `true` to select a mobile proxy.
110    pub mobile: Option<bool>,
111
112    /// The mobile carrier desired.
113    pub carrier: Option<Vec<StringFilter>>,
114
115    ///  Autonomous System Number (ASN).
116    pub asn: Option<Vec<Asn>>,
117}
118
119/// The trait to implement to provide a proxy database to other facilities,
120/// such as connection pools, to provide a proxy based on the given
121/// [`TransportContext`] and [`ProxyFilter`].
122pub trait ProxyDB: Send + Sync + 'static {
123    /// The error type that can be returned by the proxy database
124    ///
125    /// Examples are generic I/O issues or
126    /// even more common if no proxy match could be found.
127    type Error: Send + Sync + 'static;
128
129    /// Same as [`Self::get_proxy`] but with a predicate
130    /// to filter out found proxies that do not match the given predicate.
131    fn get_proxy_if(
132        &self,
133        ctx: TransportContext,
134        filter: ProxyFilter,
135        predicate: impl ProxyQueryPredicate,
136    ) -> impl Future<Output = Result<Proxy, Self::Error>> + Send + '_;
137
138    /// Get a [`Proxy`] based on the given [`TransportContext`] and [`ProxyFilter`],
139    /// or return an error in case no [`Proxy`] could be returned.
140    fn get_proxy(
141        &self,
142        ctx: TransportContext,
143        filter: ProxyFilter,
144    ) -> impl Future<Output = Result<Proxy, Self::Error>> + Send + '_ {
145        self.get_proxy_if(ctx, filter, true)
146    }
147}
148
149impl ProxyDB for () {
150    type Error = OpaqueError;
151
152    #[inline]
153    async fn get_proxy_if(
154        &self,
155        _ctx: TransportContext,
156        _filter: ProxyFilter,
157        _predicate: impl ProxyQueryPredicate,
158    ) -> Result<Proxy, Self::Error> {
159        Err(OpaqueError::from_display(
160            "()::get_proxy_if: no ProxyDB defined",
161        ))
162    }
163
164    #[inline]
165    async fn get_proxy(
166        &self,
167        _ctx: TransportContext,
168        _filter: ProxyFilter,
169    ) -> Result<Proxy, Self::Error> {
170        Err(OpaqueError::from_display(
171            "()::get_proxy: no ProxyDB defined",
172        ))
173    }
174}
175
176impl<T> ProxyDB for Option<T>
177where
178    T: ProxyDB<Error: Into<BoxError>>,
179{
180    type Error = OpaqueError;
181
182    #[inline]
183    async fn get_proxy_if(
184        &self,
185        ctx: TransportContext,
186        filter: ProxyFilter,
187        predicate: impl ProxyQueryPredicate,
188    ) -> Result<Proxy, Self::Error> {
189        match self {
190            Some(db) => db
191                .get_proxy_if(ctx, filter, predicate)
192                .await
193                .map_err(|err| OpaqueError::from_boxed(err.into()))
194                .context("Some::get_proxy_if"),
195            None => Err(OpaqueError::from_display(
196                "None::get_proxy_if: no ProxyDB defined",
197            )),
198        }
199    }
200
201    #[inline]
202    async fn get_proxy(
203        &self,
204        ctx: TransportContext,
205        filter: ProxyFilter,
206    ) -> Result<Proxy, Self::Error> {
207        match self {
208            Some(db) => db
209                .get_proxy(ctx, filter)
210                .await
211                .map_err(|err| OpaqueError::from_boxed(err.into()))
212                .context("Some::get_proxy"),
213            None => Err(OpaqueError::from_display(
214                "None::get_proxy: no ProxyDB defined",
215            )),
216        }
217    }
218}
219
220impl<T> ProxyDB for std::sync::Arc<T>
221where
222    T: ProxyDB,
223{
224    type Error = T::Error;
225
226    #[inline]
227    fn get_proxy_if(
228        &self,
229        ctx: TransportContext,
230        filter: ProxyFilter,
231        predicate: impl ProxyQueryPredicate,
232    ) -> impl Future<Output = Result<Proxy, Self::Error>> + Send + '_ {
233        (**self).get_proxy_if(ctx, filter, predicate)
234    }
235
236    #[inline]
237    fn get_proxy(
238        &self,
239        ctx: TransportContext,
240        filter: ProxyFilter,
241    ) -> impl Future<Output = Result<Proxy, Self::Error>> + Send + '_ {
242        (**self).get_proxy(ctx, filter)
243    }
244}
245
246macro_rules! impl_proxydb_either {
247    ($id:ident, $($param:ident),+ $(,)?) => {
248        impl<$($param),+> ProxyDB for rama_core::combinators::$id<$($param),+>
249        where
250            $(
251                $param: ProxyDB<Error: Into<BoxError>>,
252            )+
253    {
254        type Error = BoxError;
255
256        #[inline]
257        async fn get_proxy_if(
258            &self,
259            ctx: TransportContext,
260            filter: ProxyFilter,
261            predicate: impl ProxyQueryPredicate,
262        ) -> Result<Proxy, Self::Error> {
263            match self {
264                $(
265                    rama_core::combinators::$id::$param(s) => s.get_proxy_if(ctx, filter, predicate).await.map_err(Into::into),
266                )+
267            }
268        }
269
270        #[inline]
271        async fn get_proxy(
272            &self,
273            ctx: TransportContext,
274            filter: ProxyFilter,
275        ) -> Result<Proxy, Self::Error> {
276            match self {
277                $(
278                    rama_core::combinators::$id::$param(s) => s.get_proxy(ctx, filter).await.map_err(Into::into),
279                )+
280            }
281        }
282        }
283    };
284}
285
286rama_core::combinators::impl_either!(impl_proxydb_either);
287
288/// Trait that is used by the [`ProxyDB`] for providing an optional
289/// filter predicate to rule out returned results.
290pub trait ProxyQueryPredicate: Clone + Send + Sync + 'static {
291    /// Execute the predicate.
292    fn execute(&self, proxy: &Proxy) -> bool;
293}
294
295impl ProxyQueryPredicate for bool {
296    fn execute(&self, _proxy: &Proxy) -> bool {
297        *self
298    }
299}
300
301impl<F> ProxyQueryPredicate for F
302where
303    F: Fn(&Proxy) -> bool + Clone + Send + Sync + 'static,
304{
305    fn execute(&self, proxy: &Proxy) -> bool {
306        (self)(proxy)
307    }
308}
309
310impl ProxyDB for Proxy {
311    type Error = rama_core::error::OpaqueError;
312
313    async fn get_proxy_if(
314        &self,
315        ctx: TransportContext,
316        filter: ProxyFilter,
317        predicate: impl ProxyQueryPredicate,
318    ) -> Result<Proxy, Self::Error> {
319        (self.is_match(&ctx, &filter) && predicate.execute(self))
320            .then(|| self.clone())
321            .ok_or_else(|| rama_core::error::OpaqueError::from_display("hardcoded proxy no match"))
322    }
323}
324
325#[cfg(feature = "memory-db")]
326mod memdb {
327    use super::*;
328    use crate::proxydb::internal::ProxyDBErrorKind;
329    use rama_net::transport::TransportProtocol;
330
331    /// A fast in-memory ProxyDatabase that is the default choice for Rama.
332    #[derive(Debug)]
333    pub struct MemoryProxyDB {
334        data: internal::ProxyDB,
335    }
336
337    impl MemoryProxyDB {
338        /// Create a new in-memory proxy database with the given proxies.
339        pub fn try_from_rows(proxies: Vec<Proxy>) -> Result<Self, MemoryProxyDBInsertError> {
340            Ok(MemoryProxyDB {
341                data: internal::ProxyDB::from_rows(proxies).map_err(|err| match err.kind() {
342                    ProxyDBErrorKind::DuplicateKey => {
343                        MemoryProxyDBInsertError::duplicate_key(err.into_input())
344                    }
345                    ProxyDBErrorKind::InvalidRow => {
346                        MemoryProxyDBInsertError::invalid_proxy(err.into_input())
347                    }
348                })?,
349            })
350        }
351
352        /// Create a new in-memory proxy database with the given proxies from an iterator.
353        pub fn try_from_iter<I>(proxies: I) -> Result<Self, MemoryProxyDBInsertError>
354        where
355            I: IntoIterator<Item = Proxy>,
356        {
357            Ok(MemoryProxyDB {
358                data: internal::ProxyDB::from_iter(proxies).map_err(|err| match err.kind() {
359                    ProxyDBErrorKind::DuplicateKey => {
360                        MemoryProxyDBInsertError::duplicate_key(err.into_input())
361                    }
362                    ProxyDBErrorKind::InvalidRow => {
363                        MemoryProxyDBInsertError::invalid_proxy(err.into_input())
364                    }
365                })?,
366            })
367        }
368
369        /// Return the number of proxies in the database.
370        pub fn len(&self) -> usize {
371            self.data.len()
372        }
373
374        /// Rerturns if the database is empty.
375        pub fn is_empty(&self) -> bool {
376            self.data.is_empty()
377        }
378
379        fn query_from_filter(
380            &self,
381            ctx: TransportContext,
382            filter: ProxyFilter,
383        ) -> internal::ProxyDBQuery {
384            let mut query = self.data.query();
385
386            for pool_id in filter.pool_id.into_iter().flatten() {
387                query.pool_id(pool_id);
388            }
389            for continent in filter.continent.into_iter().flatten() {
390                query.continent(continent);
391            }
392            for country in filter.country.into_iter().flatten() {
393                query.country(country);
394            }
395            for state in filter.state.into_iter().flatten() {
396                query.state(state);
397            }
398            for city in filter.city.into_iter().flatten() {
399                query.city(city);
400            }
401            for carrier in filter.carrier.into_iter().flatten() {
402                query.carrier(carrier);
403            }
404            for asn in filter.asn.into_iter().flatten() {
405                query.asn(asn);
406            }
407
408            if let Some(value) = filter.datacenter {
409                query.datacenter(value);
410            }
411            if let Some(value) = filter.residential {
412                query.residential(value);
413            }
414            if let Some(value) = filter.mobile {
415                query.mobile(value);
416            }
417
418            match ctx.protocol {
419                TransportProtocol::Tcp => {
420                    query.tcp(true);
421                }
422                TransportProtocol::Udp => {
423                    query.udp(true).socks5(true);
424                }
425            }
426
427            query
428        }
429    }
430
431    // TODO: custom query filters using ProxyQueryPredicate
432    // might be a lot faster for cases where we want to filter a big batch of proxies,
433    // in which case a bitmap could be supported by a future VennDB version...
434    //
435    // Would just need to figure out how to allow this to happen.
436
437    impl ProxyDB for MemoryProxyDB {
438        type Error = MemoryProxyDBQueryError;
439
440        async fn get_proxy_if(
441            &self,
442            ctx: TransportContext,
443            filter: ProxyFilter,
444            predicate: impl ProxyQueryPredicate,
445        ) -> Result<Proxy, Self::Error> {
446            match &filter.id {
447                Some(id) => match self.data.get_by_id(id) {
448                    None => Err(MemoryProxyDBQueryError::not_found()),
449                    Some(proxy) => {
450                        if proxy.is_match(&ctx, &filter) && predicate.execute(proxy) {
451                            Ok(proxy.clone())
452                        } else {
453                            Err(MemoryProxyDBQueryError::mismatch())
454                        }
455                    }
456                },
457                None => {
458                    let query = self.query_from_filter(ctx, filter.clone());
459                    match query
460                        .execute()
461                        .and_then(|result| result.filter(|proxy| predicate.execute(proxy)))
462                        .map(|result| result.any())
463                    {
464                        None => Err(MemoryProxyDBQueryError::not_found()),
465                        Some(proxy) => Ok(proxy.clone()),
466                    }
467                }
468            }
469        }
470    }
471
472    /// The error type that can be returned by [`MemoryProxyDB`] when some of the proxies
473    /// could not be inserted due to a proxy that had a duplicate key or was invalid for some other reason.
474    #[derive(Debug)]
475    pub struct MemoryProxyDBInsertError {
476        kind: MemoryProxyDBInsertErrorKind,
477        proxies: Vec<Proxy>,
478    }
479
480    impl std::fmt::Display for MemoryProxyDBInsertError {
481        fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
482            match self.kind {
483                MemoryProxyDBInsertErrorKind::DuplicateKey => write!(
484                    f,
485                    "A proxy with the same key already exists in the database"
486                ),
487                MemoryProxyDBInsertErrorKind::InvalidProxy => {
488                    write!(f, "A proxy in the list is invalid for some reason")
489                }
490            }
491        }
492    }
493
494    impl std::error::Error for MemoryProxyDBInsertError {}
495
496    #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
497    /// The kind of error that [`MemoryProxyDBInsertError`] represents.
498    pub enum MemoryProxyDBInsertErrorKind {
499        /// Duplicate key found in the proxies.
500        DuplicateKey,
501        /// Invalid proxy found in the proxies.
502        ///
503        /// This could be due to a proxy that is not valid for some reason.
504        /// E.g. a proxy that neither supports http or socks5.
505        InvalidProxy,
506    }
507
508    impl MemoryProxyDBInsertError {
509        fn duplicate_key(proxies: Vec<Proxy>) -> Self {
510            MemoryProxyDBInsertError {
511                kind: MemoryProxyDBInsertErrorKind::DuplicateKey,
512                proxies,
513            }
514        }
515
516        fn invalid_proxy(proxies: Vec<Proxy>) -> Self {
517            MemoryProxyDBInsertError {
518                kind: MemoryProxyDBInsertErrorKind::InvalidProxy,
519                proxies,
520            }
521        }
522
523        /// Returns the kind of error that [`MemoryProxyDBInsertError`] represents.
524        pub fn kind(&self) -> MemoryProxyDBInsertErrorKind {
525            self.kind
526        }
527
528        /// Returns the proxies that were not inserted.
529        pub fn proxies(&self) -> &[Proxy] {
530            &self.proxies
531        }
532
533        /// Consumes the error and returns the proxies that were not inserted.
534        pub fn into_proxies(self) -> Vec<Proxy> {
535            self.proxies
536        }
537    }
538
539    /// The error type that can be returned by [`MemoryProxyDB`] when no proxy could be returned.
540    #[derive(Debug)]
541    pub struct MemoryProxyDBQueryError {
542        kind: MemoryProxyDBQueryErrorKind,
543    }
544
545    #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
546    /// The kind of error that [`MemoryProxyDBQueryError`] represents.
547    pub enum MemoryProxyDBQueryErrorKind {
548        /// No proxy match could be found.
549        NotFound,
550        /// A proxy looked up by key had a config that did not match the given filters/requirements.
551        Mismatch,
552    }
553
554    impl std::fmt::Display for MemoryProxyDBQueryError {
555        fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
556            match self.kind {
557                MemoryProxyDBQueryErrorKind::NotFound => write!(f, "No proxy match could be found"),
558                MemoryProxyDBQueryErrorKind::Mismatch => write!(
559                    f,
560                    "Proxy config did not match the given filters/requirements"
561                ),
562            }
563        }
564    }
565
566    impl std::error::Error for MemoryProxyDBQueryError {}
567
568    impl MemoryProxyDBQueryError {
569        /// Create a new error that indicates no proxy match could be found.
570        pub fn not_found() -> Self {
571            MemoryProxyDBQueryError {
572                kind: MemoryProxyDBQueryErrorKind::NotFound,
573            }
574        }
575
576        /// Create a new error that indicates a proxy looked up by key had a config that did not match the given filters/requirements.
577        pub fn mismatch() -> Self {
578            MemoryProxyDBQueryError {
579                kind: MemoryProxyDBQueryErrorKind::Mismatch,
580            }
581        }
582
583        /// Returns the kind of error that [`MemoryProxyDBQueryError`] represents.
584        pub fn kind(&self) -> MemoryProxyDBQueryErrorKind {
585            self.kind
586        }
587    }
588
589    #[cfg(test)]
590    mod tests {
591        use super::*;
592        use itertools::Itertools;
593        use rama_net::{address::ProxyAddress, Protocol};
594        use rama_utils::str::NonEmptyString;
595        use std::str::FromStr;
596
597        const RAW_CSV_DATA: &str = include_str!("./test_proxydb_rows.csv");
598
599        async fn memproxydb() -> MemoryProxyDB {
600            let mut reader = ProxyCsvRowReader::raw(RAW_CSV_DATA);
601            let mut rows = Vec::new();
602            while let Some(proxy) = reader.next().await.unwrap() {
603                rows.push(proxy);
604            }
605            MemoryProxyDB::try_from_rows(rows).unwrap()
606        }
607
608        #[tokio::test]
609        async fn test_load_memproxydb_from_rows() {
610            let db = memproxydb().await;
611            assert_eq!(db.len(), 64);
612        }
613
614        fn h2_transport_context() -> TransportContext {
615            TransportContext {
616                protocol: TransportProtocol::Tcp,
617                app_protocol: Some(Protocol::HTTPS),
618                http_version: None,
619                authority: "localhost:8443".try_into().unwrap(),
620            }
621        }
622
623        #[tokio::test]
624        async fn test_memproxydb_get_proxy_by_id_found() {
625            let db = memproxydb().await;
626            let ctx = h2_transport_context();
627            let filter = ProxyFilter {
628                id: Some(NonEmptyString::from_static("3031533634")),
629                ..Default::default()
630            };
631            let proxy = db.get_proxy(ctx, filter).await.unwrap();
632            assert_eq!(proxy.id, "3031533634");
633        }
634
635        #[tokio::test]
636        async fn test_memproxydb_get_proxy_by_id_found_correct_filters() {
637            let db = memproxydb().await;
638            let ctx = h2_transport_context();
639            let filter = ProxyFilter {
640                id: Some(NonEmptyString::from_static("3031533634")),
641                pool_id: Some(vec![StringFilter::new("poolF")]),
642                country: Some(vec![StringFilter::new("JP")]),
643                city: Some(vec![StringFilter::new("Yokohama")]),
644                datacenter: Some(true),
645                residential: Some(false),
646                mobile: Some(true),
647                carrier: Some(vec![StringFilter::new("Verizon")]),
648                ..Default::default()
649            };
650            let proxy = db.get_proxy(ctx, filter).await.unwrap();
651            assert_eq!(proxy.id, "3031533634");
652        }
653
654        #[tokio::test]
655        async fn test_memproxydb_get_proxy_by_id_not_found() {
656            let db = memproxydb().await;
657            let ctx = h2_transport_context();
658            let filter = ProxyFilter {
659                id: Some(NonEmptyString::from_static("notfound")),
660                ..Default::default()
661            };
662            let err = db.get_proxy(ctx, filter).await.unwrap_err();
663            assert_eq!(err.kind(), MemoryProxyDBQueryErrorKind::NotFound);
664        }
665
666        #[tokio::test]
667        async fn test_memproxydb_get_proxy_by_id_mismatch_filter() {
668            let db = memproxydb().await;
669            let ctx = h2_transport_context();
670            let filters = [
671                ProxyFilter {
672                    id: Some(NonEmptyString::from_static("3031533634")),
673                    pool_id: Some(vec![StringFilter::new("poolB")]),
674                    ..Default::default()
675                },
676                ProxyFilter {
677                    id: Some(NonEmptyString::from_static("3031533634")),
678                    country: Some(vec![StringFilter::new("US")]),
679                    ..Default::default()
680                },
681                ProxyFilter {
682                    id: Some(NonEmptyString::from_static("3031533634")),
683                    city: Some(vec![StringFilter::new("New York")]),
684                    ..Default::default()
685                },
686                ProxyFilter {
687                    id: Some(NonEmptyString::from_static("3031533634")),
688                    continent: Some(vec![StringFilter::new("americas")]),
689                    ..Default::default()
690                },
691                ProxyFilter {
692                    id: Some(NonEmptyString::from_static("3732488183")),
693                    state: Some(vec![StringFilter::new("Texas")]),
694                    ..Default::default()
695                },
696                ProxyFilter {
697                    id: Some(NonEmptyString::from_static("3031533634")),
698                    datacenter: Some(false),
699                    ..Default::default()
700                },
701                ProxyFilter {
702                    id: Some(NonEmptyString::from_static("3031533634")),
703                    residential: Some(true),
704                    ..Default::default()
705                },
706                ProxyFilter {
707                    id: Some(NonEmptyString::from_static("3031533634")),
708                    mobile: Some(false),
709                    ..Default::default()
710                },
711                ProxyFilter {
712                    id: Some(NonEmptyString::from_static("3031533634")),
713                    carrier: Some(vec![StringFilter::new("AT&T")]),
714                    ..Default::default()
715                },
716                ProxyFilter {
717                    id: Some(NonEmptyString::from_static("292096733")),
718                    asn: Some(vec![Asn::from_static(1)]),
719                    ..Default::default()
720                },
721            ];
722            for filter in filters.iter() {
723                let err = db.get_proxy(ctx.clone(), filter.clone()).await.unwrap_err();
724                assert_eq!(err.kind(), MemoryProxyDBQueryErrorKind::Mismatch);
725            }
726        }
727
728        fn h3_transport_context() -> TransportContext {
729            TransportContext {
730                protocol: TransportProtocol::Udp,
731                app_protocol: Some(Protocol::HTTPS),
732                http_version: None,
733                authority: "localhost:8443".try_into().unwrap(),
734            }
735        }
736
737        #[tokio::test]
738        async fn test_memproxydb_get_proxy_by_id_mismatch_req_context() {
739            let db = memproxydb().await;
740            let ctx = h3_transport_context();
741            let filter = ProxyFilter {
742                id: Some(NonEmptyString::from_static("3031533634")),
743                ..Default::default()
744            };
745            // this proxy does not support socks5 UDP, which is what we need
746            let err = db.get_proxy(ctx, filter).await.unwrap_err();
747            assert_eq!(err.kind(), MemoryProxyDBQueryErrorKind::Mismatch);
748        }
749
750        #[tokio::test]
751        async fn test_memorydb_get_h3_capable_proxies() {
752            let db = memproxydb().await;
753            let ctx = h3_transport_context();
754            let filter = ProxyFilter::default();
755            let mut found_ids = Vec::new();
756            for _ in 0..5000 {
757                let proxy = db.get_proxy(ctx.clone(), filter.clone()).await.unwrap();
758                if found_ids.contains(&proxy.id) {
759                    continue;
760                }
761                assert!(proxy.udp);
762                assert!(proxy.socks5);
763                found_ids.push(proxy.id);
764            }
765            assert_eq!(found_ids.len(), 40);
766            assert_eq!(
767                found_ids.iter().sorted().join(","),
768                r##"1125300915,1259341971,1316455915,153202126,1571861931,1684342915,1742367441,1844412609,1916851007,20647117,2107229589,2261612122,2497865606,2521901221,2560727338,2593294918,2596743625,2745456299,2880295577,2909724448,2950022859,2951529660,3187902553,3269411602,3269465574,3269921904,3481200027,3498810974,362091157,3679054656,3732488183,3836943127,39048766,3951672504,3976711563,4187178960,56402588,724884866,738626121,906390012"##
769            );
770        }
771
772        #[tokio::test]
773        async fn test_memorydb_get_h2_capable_proxies() {
774            let db = memproxydb().await;
775            let ctx = h2_transport_context();
776            let filter = ProxyFilter::default();
777            let mut found_ids = Vec::new();
778            for _ in 0..5000 {
779                let proxy = db.get_proxy(ctx.clone(), filter.clone()).await.unwrap();
780                if found_ids.contains(&proxy.id) {
781                    continue;
782                }
783                assert!(proxy.tcp);
784                found_ids.push(proxy.id);
785            }
786            assert_eq!(found_ids.len(), 50);
787            assert_eq!(
788                found_ids.iter().sorted().join(","),
789                r#"1125300915,1259341971,1264821985,129108927,1316455915,1425588737,1571861931,1810781137,1836040682,1844412609,1885107293,2021561518,2079461709,2107229589,2141152822,2438596154,2497865606,2521901221,2551759475,2560727338,2593294918,2798907087,2854473221,2880295577,2909724448,2912880381,292096733,2951529660,3031533634,3187902553,3269411602,3269465574,339020035,3481200027,3498810974,3503691556,362091157,3679054656,371209663,3861736957,39048766,3976711563,4062553709,49590203,56402588,724884866,738626121,767809962,846528631,906390012"#,
790            );
791        }
792
793        #[tokio::test]
794        async fn test_memorydb_get_any_country_proxies() {
795            let db = memproxydb().await;
796            let ctx = h2_transport_context();
797            let filter = ProxyFilter {
798                // there are no explicit BE proxies,
799                // so these will only match the proxies that have a wildcard country
800                country: Some(vec!["BE".into()]),
801                ..Default::default()
802            };
803            let mut found_ids = Vec::new();
804            for _ in 0..5000 {
805                let proxy = db.get_proxy(ctx.clone(), filter.clone()).await.unwrap();
806                if found_ids.contains(&proxy.id) {
807                    continue;
808                }
809                found_ids.push(proxy.id);
810            }
811            assert_eq!(found_ids.len(), 5);
812            assert_eq!(
813                found_ids.iter().sorted().join(","),
814                r#"2141152822,2593294918,2912880381,371209663,767809962"#,
815            );
816        }
817
818        #[tokio::test]
819        async fn test_memorydb_get_illinois_proxies() {
820            let db = memproxydb().await;
821            let ctx = h2_transport_context();
822            let filter = ProxyFilter {
823                // this will also work for proxies that have 'any' state
824                state: Some(vec!["illinois".into()]),
825                ..Default::default()
826            };
827            let mut found_ids = Vec::new();
828            for _ in 0..5000 {
829                let proxy = db.get_proxy(ctx.clone(), filter.clone()).await.unwrap();
830                if found_ids.contains(&proxy.id) {
831                    continue;
832                }
833                found_ids.push(proxy.id);
834            }
835            assert_eq!(found_ids.len(), 9);
836            assert_eq!(
837                found_ids.iter().sorted().join(","),
838                r#"2141152822,2521901221,2560727338,2593294918,2912880381,292096733,371209663,39048766,767809962"#,
839            );
840        }
841
842        #[tokio::test]
843        async fn test_memorydb_get_asn_proxies() {
844            let db = memproxydb().await;
845            let ctx = h2_transport_context();
846            let filter = ProxyFilter {
847                // this will also work for proxies that have 'any' ASN
848                asn: Some(vec![Asn::from_static(42)]),
849                ..Default::default()
850            };
851            let mut found_ids = Vec::new();
852            for _ in 0..5000 {
853                let proxy = db.get_proxy(ctx.clone(), filter.clone()).await.unwrap();
854                if found_ids.contains(&proxy.id) {
855                    continue;
856                }
857                found_ids.push(proxy.id);
858            }
859            assert_eq!(found_ids.len(), 4);
860            assert_eq!(
861                found_ids.iter().sorted().join(","),
862                r#"2141152822,2912880381,292096733,3481200027"#,
863            );
864        }
865
866        #[tokio::test]
867        async fn test_memorydb_get_h3_capable_mobile_residential_be_asterix_proxies() {
868            let db = memproxydb().await;
869            let ctx = h3_transport_context();
870            let filter = ProxyFilter {
871                country: Some(vec!["BE".into()]),
872                mobile: Some(true),
873                residential: Some(true),
874                ..Default::default()
875            };
876            for _ in 0..50 {
877                let proxy = db.get_proxy(ctx.clone(), filter.clone()).await.unwrap();
878                assert_eq!(proxy.id, "2593294918");
879            }
880        }
881
882        #[tokio::test]
883        async fn test_memorydb_get_blocked_proxies() {
884            let db = memproxydb().await;
885            let ctx = h2_transport_context();
886            let filter = ProxyFilter::default();
887
888            let mut blocked_proxies = vec![
889                "1125300915",
890                "1259341971",
891                "1264821985",
892                "129108927",
893                "1316455915",
894                "1425588737",
895                "1571861931",
896                "1810781137",
897                "1836040682",
898                "1844412609",
899                "1885107293",
900                "2021561518",
901                "2079461709",
902                "2107229589",
903                "2141152822",
904                "2438596154",
905                "2497865606",
906                "2521901221",
907                "2551759475",
908                "2560727338",
909                "2593294918",
910                "2798907087",
911                "2854473221",
912                "2880295577",
913                "2909724448",
914                "2912880381",
915                "292096733",
916                "2951529660",
917                "3031533634",
918                "3187902553",
919                "3269411602",
920                "3269465574",
921                "339020035",
922                "3481200027",
923                "3498810974",
924                "3503691556",
925                "362091157",
926                "3679054656",
927                "371209663",
928                "3861736957",
929                "39048766",
930                "3976711563",
931                "4062553709",
932                "49590203",
933                "56402588",
934                "724884866",
935                "738626121",
936                "767809962",
937                "846528631",
938                "906390012",
939            ];
940
941            {
942                let blocked_proxies = blocked_proxies.clone();
943
944                assert_eq!(
945                    MemoryProxyDBQueryErrorKind::NotFound,
946                    db.get_proxy_if(ctx.clone(), filter.clone(), move |proxy: &Proxy| {
947                        !blocked_proxies.contains(&proxy.id.as_str())
948                    })
949                    .await
950                    .unwrap_err()
951                    .kind()
952                );
953            }
954
955            let last_proxy_id = blocked_proxies.pop().unwrap();
956
957            let proxy = db
958                .get_proxy_if(ctx, filter.clone(), move |proxy: &Proxy| {
959                    !blocked_proxies.contains(&proxy.id.as_str())
960                })
961                .await
962                .unwrap();
963            assert_eq!(proxy.id, last_proxy_id);
964        }
965
966        #[tokio::test]
967        async fn test_db_proxy_filter_any_use_filter_property() {
968            let db = MemoryProxyDB::try_from_iter([Proxy {
969                id: NonEmptyString::from_static("1"),
970                address: ProxyAddress::from_str("example.com").unwrap(),
971                tcp: true,
972                udp: true,
973                http: true,
974                https: true,
975                socks5: true,
976                socks5h: true,
977                datacenter: true,
978                residential: true,
979                mobile: true,
980                pool_id: Some("*".into()),
981                continent: Some("*".into()),
982                country: Some("*".into()),
983                state: Some("*".into()),
984                city: Some("*".into()),
985                carrier: Some("*".into()),
986                asn: Some(Asn::unspecified()),
987            }])
988            .unwrap();
989
990            let ctx = h2_transport_context();
991
992            for filter in [
993                ProxyFilter {
994                    id: Some(NonEmptyString::from_static("1")),
995                    ..Default::default()
996                },
997                ProxyFilter {
998                    pool_id: Some(vec![StringFilter::new("*")]),
999                    ..Default::default()
1000                },
1001                ProxyFilter {
1002                    pool_id: Some(vec![StringFilter::new("hq")]),
1003                    ..Default::default()
1004                },
1005                ProxyFilter {
1006                    country: Some(vec![StringFilter::new("*")]),
1007                    ..Default::default()
1008                },
1009                ProxyFilter {
1010                    country: Some(vec![StringFilter::new("US")]),
1011                    ..Default::default()
1012                },
1013                ProxyFilter {
1014                    city: Some(vec![StringFilter::new("*")]),
1015                    ..Default::default()
1016                },
1017                ProxyFilter {
1018                    city: Some(vec![StringFilter::new("NY")]),
1019                    ..Default::default()
1020                },
1021                ProxyFilter {
1022                    carrier: Some(vec![StringFilter::new("*")]),
1023                    ..Default::default()
1024                },
1025                ProxyFilter {
1026                    carrier: Some(vec![StringFilter::new("Telenet")]),
1027                    ..Default::default()
1028                },
1029                ProxyFilter {
1030                    pool_id: Some(vec![StringFilter::new("hq")]),
1031                    country: Some(vec![StringFilter::new("US")]),
1032                    city: Some(vec![StringFilter::new("NY")]),
1033                    carrier: Some(vec![StringFilter::new("AT&T")]),
1034                    ..Default::default()
1035                },
1036            ] {
1037                let proxy = db.get_proxy(ctx.clone(), filter.clone()).await.unwrap();
1038                assert!(filter.id.map(|id| proxy.id == id).unwrap_or(true));
1039                assert!(filter
1040                    .pool_id
1041                    .map(|pool_id| pool_id.contains(proxy.pool_id.as_ref().unwrap()))
1042                    .unwrap_or(true));
1043                assert!(filter
1044                    .country
1045                    .map(|country| country.contains(proxy.country.as_ref().unwrap()))
1046                    .unwrap_or(true));
1047                assert!(filter
1048                    .city
1049                    .map(|city| city.contains(proxy.city.as_ref().unwrap()))
1050                    .unwrap_or(true));
1051                assert!(filter
1052                    .carrier
1053                    .map(|carrier| carrier.contains(proxy.carrier.as_ref().unwrap()))
1054                    .unwrap_or(true));
1055            }
1056        }
1057
1058        #[tokio::test]
1059        async fn test_db_proxy_filter_any_only_matches_any_value() {
1060            let db = MemoryProxyDB::try_from_iter([Proxy {
1061                id: NonEmptyString::from_static("1"),
1062                address: ProxyAddress::from_str("example.com").unwrap(),
1063                tcp: true,
1064                udp: true,
1065                http: true,
1066                https: true,
1067                socks5: true,
1068                socks5h: true,
1069                datacenter: true,
1070                residential: true,
1071                mobile: true,
1072                pool_id: Some("hq".into()),
1073                continent: Some("americas".into()),
1074                country: Some("US".into()),
1075                state: Some("NY".into()),
1076                city: Some("NY".into()),
1077                carrier: Some("AT&T".into()),
1078                asn: Some(Asn::from_static(7018)),
1079            }])
1080            .unwrap();
1081
1082            let ctx = h2_transport_context();
1083
1084            for filter in [
1085                ProxyFilter {
1086                    pool_id: Some(vec![StringFilter::new("*")]),
1087                    ..Default::default()
1088                },
1089                ProxyFilter {
1090                    continent: Some(vec![StringFilter::new("*")]),
1091                    ..Default::default()
1092                },
1093                ProxyFilter {
1094                    country: Some(vec![StringFilter::new("*")]),
1095                    ..Default::default()
1096                },
1097                ProxyFilter {
1098                    state: Some(vec![StringFilter::new("*")]),
1099                    ..Default::default()
1100                },
1101                ProxyFilter {
1102                    city: Some(vec![StringFilter::new("*")]),
1103                    ..Default::default()
1104                },
1105                ProxyFilter {
1106                    carrier: Some(vec![StringFilter::new("*")]),
1107                    ..Default::default()
1108                },
1109                ProxyFilter {
1110                    asn: Some(vec![Asn::unspecified()]),
1111                    ..Default::default()
1112                },
1113                ProxyFilter {
1114                    pool_id: Some(vec![StringFilter::new("*")]),
1115                    continent: Some(vec![StringFilter::new("*")]),
1116                    country: Some(vec![StringFilter::new("*")]),
1117                    state: Some(vec![StringFilter::new("*")]),
1118                    city: Some(vec![StringFilter::new("*")]),
1119                    carrier: Some(vec![StringFilter::new("*")]),
1120                    asn: Some(vec![Asn::unspecified()]),
1121                    ..Default::default()
1122                },
1123            ] {
1124                let err = match db.get_proxy(ctx.clone(), filter.clone()).await {
1125                    Ok(proxy) => {
1126                        panic!(
1127                            "expected error for filter {:?}, not found proxy: {:?}",
1128                            filter, proxy
1129                        );
1130                    }
1131                    Err(err) => err,
1132                };
1133                assert_eq!(
1134                    MemoryProxyDBQueryErrorKind::NotFound,
1135                    err.kind(),
1136                    "filter: {:?}",
1137                    filter
1138                );
1139            }
1140        }
1141
1142        #[tokio::test]
1143        async fn test_search_proxy_for_any_of_given_pools() {
1144            let db = MemoryProxyDB::try_from_iter([
1145                Proxy {
1146                    id: NonEmptyString::from_static("1"),
1147                    address: ProxyAddress::from_str("example.com").unwrap(),
1148                    tcp: true,
1149                    udp: true,
1150                    http: true,
1151                    https: true,
1152                    socks5: true,
1153                    socks5h: true,
1154                    datacenter: true,
1155                    residential: true,
1156                    mobile: true,
1157                    pool_id: Some("a".into()),
1158                    continent: Some("americas".into()),
1159                    country: Some("US".into()),
1160                    state: Some("NY".into()),
1161                    city: Some("NY".into()),
1162                    carrier: Some("AT&T".into()),
1163                    asn: Some(Asn::from_static(7018)),
1164                },
1165                Proxy {
1166                    id: NonEmptyString::from_static("2"),
1167                    address: ProxyAddress::from_str("example.com").unwrap(),
1168                    tcp: true,
1169                    udp: true,
1170                    http: true,
1171                    https: true,
1172                    socks5: true,
1173                    socks5h: true,
1174                    datacenter: true,
1175                    residential: true,
1176                    mobile: true,
1177                    pool_id: Some("b".into()),
1178                    continent: Some("americas".into()),
1179                    country: Some("US".into()),
1180                    state: Some("NY".into()),
1181                    city: Some("NY".into()),
1182                    carrier: Some("AT&T".into()),
1183                    asn: Some(Asn::from_static(7018)),
1184                },
1185                Proxy {
1186                    id: NonEmptyString::from_static("3"),
1187                    address: ProxyAddress::from_str("example.com").unwrap(),
1188                    tcp: true,
1189                    udp: true,
1190                    http: true,
1191                    https: true,
1192                    socks5: true,
1193                    socks5h: true,
1194                    datacenter: true,
1195                    residential: true,
1196                    mobile: true,
1197                    pool_id: Some("b".into()),
1198                    continent: Some("americas".into()),
1199                    country: Some("US".into()),
1200                    state: Some("NY".into()),
1201                    city: Some("NY".into()),
1202                    carrier: Some("AT&T".into()),
1203                    asn: Some(Asn::from_static(7018)),
1204                },
1205                Proxy {
1206                    id: NonEmptyString::from_static("4"),
1207                    address: ProxyAddress::from_str("example.com").unwrap(),
1208                    tcp: true,
1209                    udp: true,
1210                    http: true,
1211                    https: true,
1212                    socks5: true,
1213                    socks5h: true,
1214                    datacenter: true,
1215                    residential: true,
1216                    mobile: true,
1217                    pool_id: Some("c".into()),
1218                    continent: Some("americas".into()),
1219                    country: Some("US".into()),
1220                    state: Some("NY".into()),
1221                    city: Some("NY".into()),
1222                    carrier: Some("AT&T".into()),
1223                    asn: Some(Asn::from_static(7018)),
1224                },
1225            ])
1226            .unwrap();
1227
1228            let ctx = h2_transport_context();
1229
1230            let filter = ProxyFilter {
1231                pool_id: Some(vec![StringFilter::new("a"), StringFilter::new("c")]),
1232                ..Default::default()
1233            };
1234
1235            let mut seen_1 = false;
1236            let mut seen_4 = false;
1237            for _ in 0..100 {
1238                let proxy = db.get_proxy(ctx.clone(), filter.clone()).await.unwrap();
1239                match proxy.id.as_str() {
1240                    "1" => seen_1 = true,
1241                    "4" => seen_4 = true,
1242                    _ => panic!("unexpected pool id"),
1243                }
1244            }
1245            assert!(seen_1);
1246            assert!(seen_4);
1247        }
1248
1249        #[tokio::test]
1250        async fn test_deserialize_url_proxy_filter() {
1251            for (input, expected_output) in [
1252                (
1253                    "id=1",
1254                    ProxyFilter {
1255                        id: Some(NonEmptyString::from_static("1")),
1256                        ..Default::default()
1257                    },
1258                ),
1259                (
1260                    "pool=hq&country=us",
1261                    ProxyFilter {
1262                        pool_id: Some(vec![StringFilter::new("hq")]),
1263                        country: Some(vec![StringFilter::new("us")]),
1264                        ..Default::default()
1265                    },
1266                ),
1267                (
1268                    "pool=hq&country=us&country=be",
1269                    ProxyFilter {
1270                        pool_id: Some(vec![StringFilter::new("hq")]),
1271                        country: Some(vec![StringFilter::new("us"), StringFilter::new("be")]),
1272                        ..Default::default()
1273                    },
1274                ),
1275                (
1276                    "pool=a&country=uk&pool=b",
1277                    ProxyFilter {
1278                        pool_id: Some(vec![StringFilter::new("a"), StringFilter::new("b")]),
1279                        country: Some(vec![StringFilter::new("uk")]),
1280                        ..Default::default()
1281                    },
1282                ),
1283                (
1284                    "continent=europe&continent=asia",
1285                    ProxyFilter {
1286                        continent: Some(vec![
1287                            StringFilter::new("europe"),
1288                            StringFilter::new("asia"),
1289                        ]),
1290                        ..Default::default()
1291                    },
1292                ),
1293                (
1294                    "continent=americas&country=us&state=NY&city=buffalo&carrier=AT%26T&asn=7018",
1295                    ProxyFilter {
1296                        continent: Some(vec![StringFilter::new("americas")]),
1297                        country: Some(vec![StringFilter::new("us")]),
1298                        state: Some(vec![StringFilter::new("ny")]),
1299                        city: Some(vec![StringFilter::new("buffalo")]),
1300                        carrier: Some(vec![StringFilter::new("at&t")]),
1301                        asn: Some(vec![Asn::from_static(7018)]),
1302                        ..Default::default()
1303                    },
1304                ),
1305                (
1306                    "asn=1&asn=2",
1307                    ProxyFilter {
1308                        asn: Some(vec![Asn::from_static(1), Asn::from_static(2)]),
1309                        ..Default::default()
1310                    },
1311                ),
1312            ] {
1313                let filter: ProxyFilter = serde_html_form::from_str(input).unwrap();
1314                assert_eq!(filter, expected_output);
1315            }
1316        }
1317    }
1318}
1319
1320#[cfg(feature = "memory-db")]
1321pub use memdb::{
1322    MemoryProxyDB, MemoryProxyDBInsertError, MemoryProxyDBInsertErrorKind, MemoryProxyDBQueryError,
1323    MemoryProxyDBQueryErrorKind,
1324};