rama_proxy/proxydb/
mod.rs

1use rama_core::error::{BoxError, ErrorContext, OpaqueError};
2use rama_net::asn::Asn;
3use rama_utils::str::NonEmptyString;
4use serde::{Deserialize, Serialize};
5use std::fmt;
6
7#[cfg(feature = "live-update")]
8mod update;
9#[cfg(feature = "live-update")]
10#[doc(inline)]
11pub use update::{LiveUpdateProxyDB, LiveUpdateProxyDBSetter, proxy_db_updater};
12
13mod context;
14pub use context::ProxyContext;
15
16mod internal;
17#[doc(inline)]
18pub use internal::Proxy;
19
20#[cfg(feature = "csv")]
21mod csv;
22
23#[cfg(feature = "csv")]
24#[doc(inline)]
25pub use csv::{ProxyCsvRowReader, ProxyCsvRowReaderError, ProxyCsvRowReaderErrorKind};
26
27pub(super) mod layer;
28
29mod str;
30#[doc(inline)]
31pub use str::StringFilter;
32
33#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)]
34/// `ID` of the selected proxy. To be inserted into the `Context`,
35/// only if that proxy is selected.
36pub struct ProxyID(NonEmptyString);
37
38impl ProxyID {
39    /// View  this [`ProxyID`] as a `str`.
40    #[must_use]
41    pub fn as_str(&self) -> &str {
42        self.0.as_str()
43    }
44}
45
46impl AsRef<str> for ProxyID {
47    fn as_ref(&self) -> &str {
48        self.0.as_ref()
49    }
50}
51
52impl fmt::Display for ProxyID {
53    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
54        self.0.fmt(f)
55    }
56}
57
58impl From<NonEmptyString> for ProxyID {
59    fn from(value: NonEmptyString) -> Self {
60        Self(value)
61    }
62}
63
64#[derive(Debug, Default, Clone, Serialize, Deserialize, PartialEq, Eq, Hash)]
65/// Filter to select a specific kind of proxy.
66///
67/// If the `id` is specified the other fields are used
68/// as a validator to see if the only possible matching proxy
69/// matches these fields.
70///
71/// If the `id` is not specified, the other fields are used
72/// to select a random proxy from the pool.
73///
74/// Filters can be combined to make combinations with special meaning.
75/// E.g. `datacenter:true, residential:true` is essentially an ISP proxy.
76///
77/// ## Usage
78///
79/// - Use `HeaderConfigLayer` (`rama-http`) to have this proxy filter be given by the http `Request` headers,
80///   which will add the extracted and parsed [`ProxyFilter`] to the [`Context`]'s [`Extensions`].
81/// - Or extract yourself from the username/token validated in the `ProxyAuthLayer` (`rama-http`)
82///   to add it manually to the [`Context`]'s [`Extensions`].
83///
84/// [`Request`]: crate::http::Request
85/// [`Context`]: rama_core::Context
86/// [`Extensions`]: rama_core::context::Extensions
87pub struct ProxyFilter {
88    /// The ID of the proxy to select.
89    pub id: Option<NonEmptyString>,
90
91    /// The ID of the pool from which to select the proxy.
92    #[serde(alias = "pool")]
93    pub pool_id: Option<Vec<StringFilter>>,
94
95    /// The continent of the proxy.
96    pub continent: Option<Vec<StringFilter>>,
97
98    /// The country of the proxy.
99    pub country: Option<Vec<StringFilter>>,
100
101    /// The state of the proxy.
102    pub state: Option<Vec<StringFilter>>,
103
104    /// The city of the proxy.
105    pub city: Option<Vec<StringFilter>>,
106
107    /// Set explicitly to `true` to select a datacenter proxy.
108    pub datacenter: Option<bool>,
109
110    /// Set explicitly to `true` to select a residential proxy.
111    pub residential: Option<bool>,
112
113    /// Set explicitly to `true` to select a mobile proxy.
114    pub mobile: Option<bool>,
115
116    /// The mobile carrier desired.
117    pub carrier: Option<Vec<StringFilter>>,
118
119    ///  Autonomous System Number (ASN).
120    pub asn: Option<Vec<Asn>>,
121}
122
123/// The trait to implement to provide a proxy database to other facilities,
124/// such as connection pools, to provide a proxy based on the given
125/// [`TransportContext`] and [`ProxyFilter`].
126pub trait ProxyDB: Send + Sync + 'static {
127    /// The error type that can be returned by the proxy database
128    ///
129    /// Examples are generic I/O issues or
130    /// even more common if no proxy match could be found.
131    type Error: Send + 'static;
132
133    /// Same as [`Self::get_proxy`] but with a predicate
134    /// to filter out found proxies that do not match the given predicate.
135    fn get_proxy_if(
136        &self,
137        ctx: ProxyContext,
138        filter: ProxyFilter,
139        predicate: impl ProxyQueryPredicate,
140    ) -> impl Future<Output = Result<Proxy, Self::Error>> + Send + '_;
141
142    /// Get a [`Proxy`] based on the given [`ProxyContext`] and [`ProxyFilter`],
143    /// or return an error in case no [`Proxy`] could be returned.
144    fn get_proxy(
145        &self,
146        ctx: ProxyContext,
147        filter: ProxyFilter,
148    ) -> impl Future<Output = Result<Proxy, Self::Error>> + Send + '_ {
149        self.get_proxy_if(ctx, filter, true)
150    }
151}
152
153impl ProxyDB for () {
154    type Error = OpaqueError;
155
156    #[inline]
157    async fn get_proxy_if(
158        &self,
159        _ctx: ProxyContext,
160        _filter: ProxyFilter,
161        _predicate: impl ProxyQueryPredicate,
162    ) -> Result<Proxy, Self::Error> {
163        Err(OpaqueError::from_display(
164            "()::get_proxy_if: no ProxyDB defined",
165        ))
166    }
167
168    #[inline]
169    async fn get_proxy(
170        &self,
171        _ctx: ProxyContext,
172        _filter: ProxyFilter,
173    ) -> Result<Proxy, Self::Error> {
174        Err(OpaqueError::from_display(
175            "()::get_proxy: no ProxyDB defined",
176        ))
177    }
178}
179
180impl<T> ProxyDB for Option<T>
181where
182    T: ProxyDB<Error: Into<BoxError>>,
183{
184    type Error = OpaqueError;
185
186    #[inline]
187    async fn get_proxy_if(
188        &self,
189        ctx: ProxyContext,
190        filter: ProxyFilter,
191        predicate: impl ProxyQueryPredicate,
192    ) -> Result<Proxy, Self::Error> {
193        match self {
194            Some(db) => db
195                .get_proxy_if(ctx, filter, predicate)
196                .await
197                .map_err(|err| OpaqueError::from_boxed(err.into()))
198                .context("Some::get_proxy_if"),
199            None => Err(OpaqueError::from_display(
200                "None::get_proxy_if: no ProxyDB defined",
201            )),
202        }
203    }
204
205    #[inline]
206    async fn get_proxy(
207        &self,
208        ctx: ProxyContext,
209        filter: ProxyFilter,
210    ) -> Result<Proxy, Self::Error> {
211        match self {
212            Some(db) => db
213                .get_proxy(ctx, filter)
214                .await
215                .map_err(|err| OpaqueError::from_boxed(err.into()))
216                .context("Some::get_proxy"),
217            None => Err(OpaqueError::from_display(
218                "None::get_proxy: no ProxyDB defined",
219            )),
220        }
221    }
222}
223
224impl<T> ProxyDB for std::sync::Arc<T>
225where
226    T: ProxyDB,
227{
228    type Error = T::Error;
229
230    #[inline]
231    fn get_proxy_if(
232        &self,
233        ctx: ProxyContext,
234        filter: ProxyFilter,
235        predicate: impl ProxyQueryPredicate,
236    ) -> impl Future<Output = Result<Proxy, Self::Error>> + Send + '_ {
237        (**self).get_proxy_if(ctx, filter, predicate)
238    }
239
240    #[inline]
241    fn get_proxy(
242        &self,
243        ctx: ProxyContext,
244        filter: ProxyFilter,
245    ) -> impl Future<Output = Result<Proxy, Self::Error>> + Send + '_ {
246        (**self).get_proxy(ctx, filter)
247    }
248}
249
250macro_rules! impl_proxydb_either {
251    ($id:ident, $($param:ident),+ $(,)?) => {
252        impl<$($param),+> ProxyDB for rama_core::combinators::$id<$($param),+>
253        where
254            $(
255                $param: ProxyDB<Error: Into<BoxError>>,
256            )+
257    {
258        type Error = BoxError;
259
260        #[inline]
261        async fn get_proxy_if(
262            &self,
263            ctx: ProxyContext,
264            filter: ProxyFilter,
265            predicate: impl ProxyQueryPredicate,
266        ) -> Result<Proxy, Self::Error> {
267            match self {
268                $(
269                    rama_core::combinators::$id::$param(s) => s.get_proxy_if(ctx, filter, predicate).await.map_err(Into::into),
270                )+
271            }
272        }
273
274        #[inline]
275        async fn get_proxy(
276            &self,
277            ctx: ProxyContext,
278            filter: ProxyFilter,
279        ) -> Result<Proxy, Self::Error> {
280            match self {
281                $(
282                    rama_core::combinators::$id::$param(s) => s.get_proxy(ctx, filter).await.map_err(Into::into),
283                )+
284            }
285        }
286        }
287    };
288}
289
290rama_core::combinators::impl_either!(impl_proxydb_either);
291
292/// Trait that is used by the [`ProxyDB`] for providing an optional
293/// filter predicate to rule out returned results.
294pub trait ProxyQueryPredicate: Clone + Send + Sync + 'static {
295    /// Execute the predicate.
296    fn execute(&self, proxy: &Proxy) -> bool;
297}
298
299impl ProxyQueryPredicate for bool {
300    fn execute(&self, _proxy: &Proxy) -> bool {
301        *self
302    }
303}
304
305impl<F> ProxyQueryPredicate for F
306where
307    F: Fn(&Proxy) -> bool + Clone + Send + Sync + 'static,
308{
309    fn execute(&self, proxy: &Proxy) -> bool {
310        (self)(proxy)
311    }
312}
313
314impl ProxyDB for Proxy {
315    type Error = rama_core::error::OpaqueError;
316
317    async fn get_proxy_if(
318        &self,
319        ctx: ProxyContext,
320        filter: ProxyFilter,
321        predicate: impl ProxyQueryPredicate,
322    ) -> Result<Self, Self::Error> {
323        (self.is_match(&ctx, &filter) && predicate.execute(self))
324            .then(|| self.clone())
325            .ok_or_else(|| rama_core::error::OpaqueError::from_display("hardcoded proxy no match"))
326    }
327}
328
329#[cfg(feature = "memory-db")]
330mod memdb {
331    use super::*;
332    use crate::proxydb::internal::ProxyDBErrorKind;
333    use rama_net::transport::TransportProtocol;
334
335    /// A fast in-memory ProxyDatabase that is the default choice for Rama.
336    #[derive(Debug)]
337    pub struct MemoryProxyDB {
338        data: internal::ProxyDB,
339    }
340
341    impl MemoryProxyDB {
342        /// Create a new in-memory proxy database with the given proxies.
343        pub fn try_from_rows(proxies: Vec<Proxy>) -> Result<Self, MemoryProxyDBInsertError> {
344            Ok(Self {
345                data: internal::ProxyDB::from_rows(proxies).map_err(|err| match err.kind() {
346                    ProxyDBErrorKind::DuplicateKey => {
347                        MemoryProxyDBInsertError::duplicate_key(err.into_input())
348                    }
349                    ProxyDBErrorKind::InvalidRow => {
350                        MemoryProxyDBInsertError::invalid_proxy(err.into_input())
351                    }
352                })?,
353            })
354        }
355
356        /// Create a new in-memory proxy database with the given proxies from an iterator.
357        pub fn try_from_iter<I>(proxies: I) -> Result<Self, MemoryProxyDBInsertError>
358        where
359            I: IntoIterator<Item = Proxy>,
360        {
361            Ok(Self {
362                data: internal::ProxyDB::from_iter(proxies).map_err(|err| match err.kind() {
363                    ProxyDBErrorKind::DuplicateKey => {
364                        MemoryProxyDBInsertError::duplicate_key(err.into_input())
365                    }
366                    ProxyDBErrorKind::InvalidRow => {
367                        MemoryProxyDBInsertError::invalid_proxy(err.into_input())
368                    }
369                })?,
370            })
371        }
372
373        /// Return the number of proxies in the database.
374        #[must_use]
375        pub fn len(&self) -> usize {
376            self.data.len()
377        }
378
379        /// Rerturns if the database is empty.
380        #[must_use]
381        pub fn is_empty(&self) -> bool {
382            self.data.is_empty()
383        }
384
385        #[allow(clippy::needless_pass_by_value)]
386        fn query_from_filter(
387            &self,
388            ctx: ProxyContext,
389            filter: ProxyFilter,
390        ) -> internal::ProxyDBQuery<'_> {
391            let mut query = self.data.query();
392
393            for pool_id in filter.pool_id.into_iter().flatten() {
394                query.pool_id(pool_id);
395            }
396            for continent in filter.continent.into_iter().flatten() {
397                query.continent(continent);
398            }
399            for country in filter.country.into_iter().flatten() {
400                query.country(country);
401            }
402            for state in filter.state.into_iter().flatten() {
403                query.state(state);
404            }
405            for city in filter.city.into_iter().flatten() {
406                query.city(city);
407            }
408            for carrier in filter.carrier.into_iter().flatten() {
409                query.carrier(carrier);
410            }
411            for asn in filter.asn.into_iter().flatten() {
412                query.asn(asn);
413            }
414
415            if let Some(value) = filter.datacenter {
416                query.datacenter(value);
417            }
418            if let Some(value) = filter.residential {
419                query.residential(value);
420            }
421            if let Some(value) = filter.mobile {
422                query.mobile(value);
423            }
424
425            match ctx.protocol {
426                TransportProtocol::Tcp => {
427                    query.tcp(true);
428                }
429                TransportProtocol::Udp => {
430                    query.udp(true).socks5(true);
431                }
432            }
433
434            query
435        }
436    }
437
438    // TODO: custom query filters using ProxyQueryPredicate
439    // might be a lot faster for cases where we want to filter a big batch of proxies,
440    // in which case a bitmap could be supported by a future VennDB version...
441    //
442    // Would just need to figure out how to allow this to happen.
443
444    impl ProxyDB for MemoryProxyDB {
445        type Error = MemoryProxyDBQueryError;
446
447        async fn get_proxy_if(
448            &self,
449            ctx: ProxyContext,
450            filter: ProxyFilter,
451            predicate: impl ProxyQueryPredicate,
452        ) -> Result<Proxy, Self::Error> {
453            if let Some(id) = &filter.id {
454                match self.data.get_by_id(id) {
455                    None => Err(MemoryProxyDBQueryError::not_found()),
456                    Some(proxy) => {
457                        if proxy.is_match(&ctx, &filter) && predicate.execute(proxy) {
458                            Ok(proxy.clone())
459                        } else {
460                            Err(MemoryProxyDBQueryError::mismatch())
461                        }
462                    }
463                }
464            } else {
465                let query = self.query_from_filter(ctx, filter);
466                match query
467                    .execute()
468                    .and_then(|result| result.filter(|proxy| predicate.execute(proxy)))
469                    .map(|result| result.any())
470                {
471                    None => Err(MemoryProxyDBQueryError::not_found()),
472                    Some(proxy) => Ok(proxy.clone()),
473                }
474            }
475        }
476    }
477
478    /// The error type that can be returned by [`MemoryProxyDB`] when some of the proxies
479    /// could not be inserted due to a proxy that had a duplicate key or was invalid for some other reason.
480    #[derive(Debug)]
481    pub struct MemoryProxyDBInsertError {
482        kind: MemoryProxyDBInsertErrorKind,
483        proxies: Vec<Proxy>,
484    }
485
486    impl std::fmt::Display for MemoryProxyDBInsertError {
487        fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
488            match self.kind {
489                MemoryProxyDBInsertErrorKind::DuplicateKey => write!(
490                    f,
491                    "A proxy with the same key already exists in the database"
492                ),
493                MemoryProxyDBInsertErrorKind::InvalidProxy => {
494                    write!(f, "A proxy in the list is invalid for some reason")
495                }
496            }
497        }
498    }
499
500    impl std::error::Error for MemoryProxyDBInsertError {}
501
502    #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
503    /// The kind of error that [`MemoryProxyDBInsertError`] represents.
504    pub enum MemoryProxyDBInsertErrorKind {
505        /// Duplicate key found in the proxies.
506        DuplicateKey,
507        /// Invalid proxy found in the proxies.
508        ///
509        /// This could be due to a proxy that is not valid for some reason.
510        /// E.g. a proxy that neither supports http or socks5.
511        InvalidProxy,
512    }
513
514    impl MemoryProxyDBInsertError {
515        fn duplicate_key(proxies: Vec<Proxy>) -> Self {
516            Self {
517                kind: MemoryProxyDBInsertErrorKind::DuplicateKey,
518                proxies,
519            }
520        }
521
522        fn invalid_proxy(proxies: Vec<Proxy>) -> Self {
523            Self {
524                kind: MemoryProxyDBInsertErrorKind::InvalidProxy,
525                proxies,
526            }
527        }
528
529        /// Returns the kind of error that [`MemoryProxyDBInsertError`] represents.
530        #[must_use]
531        pub fn kind(&self) -> MemoryProxyDBInsertErrorKind {
532            self.kind
533        }
534
535        /// Returns the proxies that were not inserted.
536        #[must_use]
537        pub fn proxies(&self) -> &[Proxy] {
538            &self.proxies
539        }
540
541        /// Consumes the error and returns the proxies that were not inserted.
542        #[must_use]
543        pub fn into_proxies(self) -> Vec<Proxy> {
544            self.proxies
545        }
546    }
547
548    /// The error type that can be returned by [`MemoryProxyDB`] when no proxy could be returned.
549    #[derive(Debug)]
550    pub struct MemoryProxyDBQueryError {
551        kind: MemoryProxyDBQueryErrorKind,
552    }
553
554    #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
555    /// The kind of error that [`MemoryProxyDBQueryError`] represents.
556    pub enum MemoryProxyDBQueryErrorKind {
557        /// No proxy match could be found.
558        NotFound,
559        /// A proxy looked up by key had a config that did not match the given filters/requirements.
560        Mismatch,
561    }
562
563    impl std::fmt::Display for MemoryProxyDBQueryError {
564        fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
565            match self.kind {
566                MemoryProxyDBQueryErrorKind::NotFound => write!(f, "No proxy match could be found"),
567                MemoryProxyDBQueryErrorKind::Mismatch => write!(
568                    f,
569                    "Proxy config did not match the given filters/requirements"
570                ),
571            }
572        }
573    }
574
575    impl std::error::Error for MemoryProxyDBQueryError {}
576
577    impl MemoryProxyDBQueryError {
578        /// Create a new error that indicates no proxy match could be found.
579        #[must_use]
580        pub fn not_found() -> Self {
581            Self {
582                kind: MemoryProxyDBQueryErrorKind::NotFound,
583            }
584        }
585
586        /// Create a new error that indicates a proxy looked up by key had a config that did not match the given filters/requirements.
587        #[must_use]
588        pub fn mismatch() -> Self {
589            Self {
590                kind: MemoryProxyDBQueryErrorKind::Mismatch,
591            }
592        }
593
594        /// Returns the kind of error that [`MemoryProxyDBQueryError`] represents.
595        #[must_use]
596        pub fn kind(&self) -> MemoryProxyDBQueryErrorKind {
597            self.kind
598        }
599    }
600
601    #[cfg(test)]
602    mod tests {
603        use super::*;
604        use itertools::Itertools;
605        use rama_net::address::ProxyAddress;
606        use rama_utils::str::NonEmptyString;
607        use std::str::FromStr;
608
609        const RAW_CSV_DATA: &str = include_str!("./test_proxydb_rows.csv");
610
611        async fn memproxydb() -> MemoryProxyDB {
612            let mut reader = ProxyCsvRowReader::raw(RAW_CSV_DATA);
613            let mut rows = Vec::new();
614            while let Some(proxy) = reader.next().await.unwrap() {
615                rows.push(proxy);
616            }
617            MemoryProxyDB::try_from_rows(rows).unwrap()
618        }
619
620        #[tokio::test]
621        async fn test_load_memproxydb_from_rows() {
622            let db = memproxydb().await;
623            assert_eq!(db.len(), 64);
624        }
625
626        fn h2_proxy_context() -> ProxyContext {
627            ProxyContext {
628                protocol: TransportProtocol::Tcp,
629            }
630        }
631
632        #[tokio::test]
633        async fn test_memproxydb_get_proxy_by_id_found() {
634            let db = memproxydb().await;
635            let ctx = h2_proxy_context();
636            let filter = ProxyFilter {
637                id: Some(NonEmptyString::from_static("3031533634")),
638                ..Default::default()
639            };
640            let proxy = db.get_proxy(ctx, filter).await.unwrap();
641            assert_eq!(proxy.id, "3031533634");
642        }
643
644        #[tokio::test]
645        async fn test_memproxydb_get_proxy_by_id_found_correct_filters() {
646            let db = memproxydb().await;
647            let ctx = h2_proxy_context();
648            let filter = ProxyFilter {
649                id: Some(NonEmptyString::from_static("3031533634")),
650                pool_id: Some(vec![StringFilter::new("poolF")]),
651                country: Some(vec![StringFilter::new("JP")]),
652                city: Some(vec![StringFilter::new("Yokohama")]),
653                datacenter: Some(true),
654                residential: Some(false),
655                mobile: Some(true),
656                carrier: Some(vec![StringFilter::new("Verizon")]),
657                ..Default::default()
658            };
659            let proxy = db.get_proxy(ctx, filter).await.unwrap();
660            assert_eq!(proxy.id, "3031533634");
661        }
662
663        #[tokio::test]
664        async fn test_memproxydb_get_proxy_by_id_not_found() {
665            let db = memproxydb().await;
666            let ctx = h2_proxy_context();
667            let filter = ProxyFilter {
668                id: Some(NonEmptyString::from_static("notfound")),
669                ..Default::default()
670            };
671            let err = db.get_proxy(ctx, filter).await.unwrap_err();
672            assert_eq!(err.kind(), MemoryProxyDBQueryErrorKind::NotFound);
673        }
674
675        #[tokio::test]
676        async fn test_memproxydb_get_proxy_by_id_mismatch_filter() {
677            let db = memproxydb().await;
678            let ctx = h2_proxy_context();
679            let filters = [
680                ProxyFilter {
681                    id: Some(NonEmptyString::from_static("3031533634")),
682                    pool_id: Some(vec![StringFilter::new("poolB")]),
683                    ..Default::default()
684                },
685                ProxyFilter {
686                    id: Some(NonEmptyString::from_static("3031533634")),
687                    country: Some(vec![StringFilter::new("US")]),
688                    ..Default::default()
689                },
690                ProxyFilter {
691                    id: Some(NonEmptyString::from_static("3031533634")),
692                    city: Some(vec![StringFilter::new("New York")]),
693                    ..Default::default()
694                },
695                ProxyFilter {
696                    id: Some(NonEmptyString::from_static("3031533634")),
697                    continent: Some(vec![StringFilter::new("americas")]),
698                    ..Default::default()
699                },
700                ProxyFilter {
701                    id: Some(NonEmptyString::from_static("3732488183")),
702                    state: Some(vec![StringFilter::new("Texas")]),
703                    ..Default::default()
704                },
705                ProxyFilter {
706                    id: Some(NonEmptyString::from_static("3031533634")),
707                    datacenter: Some(false),
708                    ..Default::default()
709                },
710                ProxyFilter {
711                    id: Some(NonEmptyString::from_static("3031533634")),
712                    residential: Some(true),
713                    ..Default::default()
714                },
715                ProxyFilter {
716                    id: Some(NonEmptyString::from_static("3031533634")),
717                    mobile: Some(false),
718                    ..Default::default()
719                },
720                ProxyFilter {
721                    id: Some(NonEmptyString::from_static("3031533634")),
722                    carrier: Some(vec![StringFilter::new("AT&T")]),
723                    ..Default::default()
724                },
725                ProxyFilter {
726                    id: Some(NonEmptyString::from_static("292096733")),
727                    asn: Some(vec![Asn::from_static(1)]),
728                    ..Default::default()
729                },
730            ];
731            for filter in filters.iter() {
732                let err = db.get_proxy(ctx.clone(), filter.clone()).await.unwrap_err();
733                assert_eq!(err.kind(), MemoryProxyDBQueryErrorKind::Mismatch);
734            }
735        }
736
737        fn h3_proxy_context() -> ProxyContext {
738            ProxyContext {
739                protocol: TransportProtocol::Udp,
740            }
741        }
742
743        #[tokio::test]
744        async fn test_memproxydb_get_proxy_by_id_mismatch_req_context() {
745            let db = memproxydb().await;
746            let ctx = h3_proxy_context();
747            let filter = ProxyFilter {
748                id: Some(NonEmptyString::from_static("3031533634")),
749                ..Default::default()
750            };
751            // this proxy does not support socks5 UDP, which is what we need
752            let err = db.get_proxy(ctx, filter).await.unwrap_err();
753            assert_eq!(err.kind(), MemoryProxyDBQueryErrorKind::Mismatch);
754        }
755
756        #[tokio::test]
757        async fn test_memorydb_get_h3_capable_proxies() {
758            let db = memproxydb().await;
759            let ctx = h3_proxy_context();
760            let filter = ProxyFilter::default();
761            let mut found_ids = Vec::new();
762            for _ in 0..5000 {
763                let proxy = db.get_proxy(ctx.clone(), filter.clone()).await.unwrap();
764                if found_ids.contains(&proxy.id) {
765                    continue;
766                }
767                assert!(proxy.udp);
768                assert!(proxy.socks5);
769                found_ids.push(proxy.id);
770            }
771            assert_eq!(found_ids.len(), 40);
772            assert_eq!(
773                found_ids.iter().sorted().join(","),
774                r##"1125300915,1259341971,1316455915,153202126,1571861931,1684342915,1742367441,1844412609,1916851007,20647117,2107229589,2261612122,2497865606,2521901221,2560727338,2593294918,2596743625,2745456299,2880295577,2909724448,2950022859,2951529660,3187902553,3269411602,3269465574,3269921904,3481200027,3498810974,362091157,3679054656,3732488183,3836943127,39048766,3951672504,3976711563,4187178960,56402588,724884866,738626121,906390012"##
775            );
776        }
777
778        #[tokio::test]
779        async fn test_memorydb_get_h2_capable_proxies() {
780            let db = memproxydb().await;
781            let ctx = h2_proxy_context();
782            let filter = ProxyFilter::default();
783            let mut found_ids = Vec::new();
784            for _ in 0..5000 {
785                let proxy = db.get_proxy(ctx.clone(), filter.clone()).await.unwrap();
786                if found_ids.contains(&proxy.id) {
787                    continue;
788                }
789                assert!(proxy.tcp);
790                found_ids.push(proxy.id);
791            }
792            assert_eq!(found_ids.len(), 50);
793            assert_eq!(
794                found_ids.iter().sorted().join(","),
795                r#"1125300915,1259341971,1264821985,129108927,1316455915,1425588737,1571861931,1810781137,1836040682,1844412609,1885107293,2021561518,2079461709,2107229589,2141152822,2438596154,2497865606,2521901221,2551759475,2560727338,2593294918,2798907087,2854473221,2880295577,2909724448,2912880381,292096733,2951529660,3031533634,3187902553,3269411602,3269465574,339020035,3481200027,3498810974,3503691556,362091157,3679054656,371209663,3861736957,39048766,3976711563,4062553709,49590203,56402588,724884866,738626121,767809962,846528631,906390012"#,
796            );
797        }
798
799        #[tokio::test]
800        async fn test_memorydb_get_any_country_proxies() {
801            let db = memproxydb().await;
802            let ctx = h2_proxy_context();
803            let filter = ProxyFilter {
804                // there are no explicit BE proxies,
805                // so these will only match the proxies that have a wildcard country
806                country: Some(vec!["BE".into()]),
807                ..Default::default()
808            };
809            let mut found_ids = Vec::new();
810            for _ in 0..5000 {
811                let proxy = db.get_proxy(ctx.clone(), filter.clone()).await.unwrap();
812                if found_ids.contains(&proxy.id) {
813                    continue;
814                }
815                found_ids.push(proxy.id);
816            }
817            assert_eq!(found_ids.len(), 5);
818            assert_eq!(
819                found_ids.iter().sorted().join(","),
820                r#"2141152822,2593294918,2912880381,371209663,767809962"#,
821            );
822        }
823
824        #[tokio::test]
825        async fn test_memorydb_get_illinois_proxies() {
826            let db = memproxydb().await;
827            let ctx = h2_proxy_context();
828            let filter = ProxyFilter {
829                // this will also work for proxies that have 'any' state
830                state: Some(vec!["illinois".into()]),
831                ..Default::default()
832            };
833            let mut found_ids = Vec::new();
834            for _ in 0..5000 {
835                let proxy = db.get_proxy(ctx.clone(), filter.clone()).await.unwrap();
836                if found_ids.contains(&proxy.id) {
837                    continue;
838                }
839                found_ids.push(proxy.id);
840            }
841            assert_eq!(found_ids.len(), 9);
842            assert_eq!(
843                found_ids.iter().sorted().join(","),
844                r#"2141152822,2521901221,2560727338,2593294918,2912880381,292096733,371209663,39048766,767809962"#,
845            );
846        }
847
848        #[tokio::test]
849        async fn test_memorydb_get_asn_proxies() {
850            let db = memproxydb().await;
851            let ctx = h2_proxy_context();
852            let filter = ProxyFilter {
853                // this will also work for proxies that have 'any' ASN
854                asn: Some(vec![Asn::from_static(42)]),
855                ..Default::default()
856            };
857            let mut found_ids = Vec::new();
858            for _ in 0..5000 {
859                let proxy = db.get_proxy(ctx.clone(), filter.clone()).await.unwrap();
860                if found_ids.contains(&proxy.id) {
861                    continue;
862                }
863                found_ids.push(proxy.id);
864            }
865            assert_eq!(found_ids.len(), 4);
866            assert_eq!(
867                found_ids.iter().sorted().join(","),
868                r#"2141152822,2912880381,292096733,3481200027"#,
869            );
870        }
871
872        #[tokio::test]
873        async fn test_memorydb_get_h3_capable_mobile_residential_be_asterix_proxies() {
874            let db = memproxydb().await;
875            let ctx = h3_proxy_context();
876            let filter = ProxyFilter {
877                country: Some(vec!["BE".into()]),
878                mobile: Some(true),
879                residential: Some(true),
880                ..Default::default()
881            };
882            for _ in 0..50 {
883                let proxy = db.get_proxy(ctx.clone(), filter.clone()).await.unwrap();
884                assert_eq!(proxy.id, "2593294918");
885            }
886        }
887
888        #[tokio::test]
889        async fn test_memorydb_get_blocked_proxies() {
890            let db = memproxydb().await;
891            let ctx = h2_proxy_context();
892            let filter = ProxyFilter::default();
893
894            let mut blocked_proxies = vec![
895                "1125300915",
896                "1259341971",
897                "1264821985",
898                "129108927",
899                "1316455915",
900                "1425588737",
901                "1571861931",
902                "1810781137",
903                "1836040682",
904                "1844412609",
905                "1885107293",
906                "2021561518",
907                "2079461709",
908                "2107229589",
909                "2141152822",
910                "2438596154",
911                "2497865606",
912                "2521901221",
913                "2551759475",
914                "2560727338",
915                "2593294918",
916                "2798907087",
917                "2854473221",
918                "2880295577",
919                "2909724448",
920                "2912880381",
921                "292096733",
922                "2951529660",
923                "3031533634",
924                "3187902553",
925                "3269411602",
926                "3269465574",
927                "339020035",
928                "3481200027",
929                "3498810974",
930                "3503691556",
931                "362091157",
932                "3679054656",
933                "371209663",
934                "3861736957",
935                "39048766",
936                "3976711563",
937                "4062553709",
938                "49590203",
939                "56402588",
940                "724884866",
941                "738626121",
942                "767809962",
943                "846528631",
944                "906390012",
945            ];
946
947            {
948                let blocked_proxies = blocked_proxies.clone();
949
950                assert_eq!(
951                    MemoryProxyDBQueryErrorKind::NotFound,
952                    db.get_proxy_if(ctx.clone(), filter.clone(), move |proxy: &Proxy| {
953                        !blocked_proxies.contains(&proxy.id.as_str())
954                    })
955                    .await
956                    .unwrap_err()
957                    .kind()
958                );
959            }
960
961            let last_proxy_id = blocked_proxies.pop().unwrap();
962
963            let proxy = db
964                .get_proxy_if(ctx, filter.clone(), move |proxy: &Proxy| {
965                    !blocked_proxies.contains(&proxy.id.as_str())
966                })
967                .await
968                .unwrap();
969            assert_eq!(proxy.id, last_proxy_id);
970        }
971
972        #[tokio::test]
973        async fn test_db_proxy_filter_any_use_filter_property() {
974            let db = MemoryProxyDB::try_from_iter([Proxy {
975                id: NonEmptyString::from_static("1"),
976                address: ProxyAddress::from_str("example.com").unwrap(),
977                tcp: true,
978                udp: true,
979                http: true,
980                https: true,
981                socks5: true,
982                socks5h: true,
983                datacenter: true,
984                residential: true,
985                mobile: true,
986                pool_id: Some("*".into()),
987                continent: Some("*".into()),
988                country: Some("*".into()),
989                state: Some("*".into()),
990                city: Some("*".into()),
991                carrier: Some("*".into()),
992                asn: Some(Asn::unspecified()),
993            }])
994            .unwrap();
995
996            let ctx = h2_proxy_context();
997
998            for filter in [
999                ProxyFilter {
1000                    id: Some(NonEmptyString::from_static("1")),
1001                    ..Default::default()
1002                },
1003                ProxyFilter {
1004                    pool_id: Some(vec![StringFilter::new("*")]),
1005                    ..Default::default()
1006                },
1007                ProxyFilter {
1008                    pool_id: Some(vec![StringFilter::new("hq")]),
1009                    ..Default::default()
1010                },
1011                ProxyFilter {
1012                    country: Some(vec![StringFilter::new("*")]),
1013                    ..Default::default()
1014                },
1015                ProxyFilter {
1016                    country: Some(vec![StringFilter::new("US")]),
1017                    ..Default::default()
1018                },
1019                ProxyFilter {
1020                    city: Some(vec![StringFilter::new("*")]),
1021                    ..Default::default()
1022                },
1023                ProxyFilter {
1024                    city: Some(vec![StringFilter::new("NY")]),
1025                    ..Default::default()
1026                },
1027                ProxyFilter {
1028                    carrier: Some(vec![StringFilter::new("*")]),
1029                    ..Default::default()
1030                },
1031                ProxyFilter {
1032                    carrier: Some(vec![StringFilter::new("Telenet")]),
1033                    ..Default::default()
1034                },
1035                ProxyFilter {
1036                    pool_id: Some(vec![StringFilter::new("hq")]),
1037                    country: Some(vec![StringFilter::new("US")]),
1038                    city: Some(vec![StringFilter::new("NY")]),
1039                    carrier: Some(vec![StringFilter::new("AT&T")]),
1040                    ..Default::default()
1041                },
1042            ] {
1043                let proxy = db.get_proxy(ctx.clone(), filter.clone()).await.unwrap();
1044                assert!(filter.id.map(|id| proxy.id == id).unwrap_or(true));
1045                assert!(
1046                    filter
1047                        .pool_id
1048                        .map(|pool_id| pool_id.contains(proxy.pool_id.as_ref().unwrap()))
1049                        .unwrap_or(true)
1050                );
1051                assert!(
1052                    filter
1053                        .country
1054                        .map(|country| country.contains(proxy.country.as_ref().unwrap()))
1055                        .unwrap_or(true)
1056                );
1057                assert!(
1058                    filter
1059                        .city
1060                        .map(|city| city.contains(proxy.city.as_ref().unwrap()))
1061                        .unwrap_or(true)
1062                );
1063                assert!(
1064                    filter
1065                        .carrier
1066                        .map(|carrier| carrier.contains(proxy.carrier.as_ref().unwrap()))
1067                        .unwrap_or(true)
1068                );
1069            }
1070        }
1071
1072        #[tokio::test]
1073        async fn test_db_proxy_filter_any_only_matches_any_value() {
1074            let db = MemoryProxyDB::try_from_iter([Proxy {
1075                id: NonEmptyString::from_static("1"),
1076                address: ProxyAddress::from_str("example.com").unwrap(),
1077                tcp: true,
1078                udp: true,
1079                http: true,
1080                https: true,
1081                socks5: true,
1082                socks5h: true,
1083                datacenter: true,
1084                residential: true,
1085                mobile: true,
1086                pool_id: Some("hq".into()),
1087                continent: Some("americas".into()),
1088                country: Some("US".into()),
1089                state: Some("NY".into()),
1090                city: Some("NY".into()),
1091                carrier: Some("AT&T".into()),
1092                asn: Some(Asn::from_static(7018)),
1093            }])
1094            .unwrap();
1095
1096            let ctx = h2_proxy_context();
1097
1098            for filter in [
1099                ProxyFilter {
1100                    pool_id: Some(vec![StringFilter::new("*")]),
1101                    ..Default::default()
1102                },
1103                ProxyFilter {
1104                    continent: Some(vec![StringFilter::new("*")]),
1105                    ..Default::default()
1106                },
1107                ProxyFilter {
1108                    country: Some(vec![StringFilter::new("*")]),
1109                    ..Default::default()
1110                },
1111                ProxyFilter {
1112                    state: Some(vec![StringFilter::new("*")]),
1113                    ..Default::default()
1114                },
1115                ProxyFilter {
1116                    city: Some(vec![StringFilter::new("*")]),
1117                    ..Default::default()
1118                },
1119                ProxyFilter {
1120                    carrier: Some(vec![StringFilter::new("*")]),
1121                    ..Default::default()
1122                },
1123                ProxyFilter {
1124                    asn: Some(vec![Asn::unspecified()]),
1125                    ..Default::default()
1126                },
1127                ProxyFilter {
1128                    pool_id: Some(vec![StringFilter::new("*")]),
1129                    continent: Some(vec![StringFilter::new("*")]),
1130                    country: Some(vec![StringFilter::new("*")]),
1131                    state: Some(vec![StringFilter::new("*")]),
1132                    city: Some(vec![StringFilter::new("*")]),
1133                    carrier: Some(vec![StringFilter::new("*")]),
1134                    asn: Some(vec![Asn::unspecified()]),
1135                    ..Default::default()
1136                },
1137            ] {
1138                let err = match db.get_proxy(ctx.clone(), filter.clone()).await {
1139                    Ok(proxy) => {
1140                        panic!("expected error for filter {filter:?}, not found proxy: {proxy:?}");
1141                    }
1142                    Err(err) => err,
1143                };
1144                assert_eq!(
1145                    MemoryProxyDBQueryErrorKind::NotFound,
1146                    err.kind(),
1147                    "filter: {filter:?}",
1148                );
1149            }
1150        }
1151
1152        #[tokio::test]
1153        async fn test_search_proxy_for_any_of_given_pools() {
1154            let db = MemoryProxyDB::try_from_iter([
1155                Proxy {
1156                    id: NonEmptyString::from_static("1"),
1157                    address: ProxyAddress::from_str("example.com").unwrap(),
1158                    tcp: true,
1159                    udp: true,
1160                    http: true,
1161                    https: true,
1162                    socks5: true,
1163                    socks5h: true,
1164                    datacenter: true,
1165                    residential: true,
1166                    mobile: true,
1167                    pool_id: Some("a".into()),
1168                    continent: Some("americas".into()),
1169                    country: Some("US".into()),
1170                    state: Some("NY".into()),
1171                    city: Some("NY".into()),
1172                    carrier: Some("AT&T".into()),
1173                    asn: Some(Asn::from_static(7018)),
1174                },
1175                Proxy {
1176                    id: NonEmptyString::from_static("2"),
1177                    address: ProxyAddress::from_str("example.com").unwrap(),
1178                    tcp: true,
1179                    udp: true,
1180                    http: true,
1181                    https: true,
1182                    socks5: true,
1183                    socks5h: true,
1184                    datacenter: true,
1185                    residential: true,
1186                    mobile: true,
1187                    pool_id: Some("b".into()),
1188                    continent: Some("americas".into()),
1189                    country: Some("US".into()),
1190                    state: Some("NY".into()),
1191                    city: Some("NY".into()),
1192                    carrier: Some("AT&T".into()),
1193                    asn: Some(Asn::from_static(7018)),
1194                },
1195                Proxy {
1196                    id: NonEmptyString::from_static("3"),
1197                    address: ProxyAddress::from_str("example.com").unwrap(),
1198                    tcp: true,
1199                    udp: true,
1200                    http: true,
1201                    https: true,
1202                    socks5: true,
1203                    socks5h: true,
1204                    datacenter: true,
1205                    residential: true,
1206                    mobile: true,
1207                    pool_id: Some("b".into()),
1208                    continent: Some("americas".into()),
1209                    country: Some("US".into()),
1210                    state: Some("NY".into()),
1211                    city: Some("NY".into()),
1212                    carrier: Some("AT&T".into()),
1213                    asn: Some(Asn::from_static(7018)),
1214                },
1215                Proxy {
1216                    id: NonEmptyString::from_static("4"),
1217                    address: ProxyAddress::from_str("example.com").unwrap(),
1218                    tcp: true,
1219                    udp: true,
1220                    http: true,
1221                    https: true,
1222                    socks5: true,
1223                    socks5h: true,
1224                    datacenter: true,
1225                    residential: true,
1226                    mobile: true,
1227                    pool_id: Some("c".into()),
1228                    continent: Some("americas".into()),
1229                    country: Some("US".into()),
1230                    state: Some("NY".into()),
1231                    city: Some("NY".into()),
1232                    carrier: Some("AT&T".into()),
1233                    asn: Some(Asn::from_static(7018)),
1234                },
1235            ])
1236            .unwrap();
1237
1238            let ctx = h2_proxy_context();
1239
1240            let filter = ProxyFilter {
1241                pool_id: Some(vec![StringFilter::new("a"), StringFilter::new("c")]),
1242                ..Default::default()
1243            };
1244
1245            let mut seen_1 = false;
1246            let mut seen_4 = false;
1247            for _ in 0..100 {
1248                let proxy = db.get_proxy(ctx.clone(), filter.clone()).await.unwrap();
1249                match proxy.id.as_str() {
1250                    "1" => seen_1 = true,
1251                    "4" => seen_4 = true,
1252                    _ => panic!("unexpected pool id"),
1253                }
1254            }
1255            assert!(seen_1);
1256            assert!(seen_4);
1257        }
1258
1259        #[tokio::test]
1260        async fn test_deserialize_url_proxy_filter() {
1261            for (input, expected_output) in [
1262                (
1263                    "id=1",
1264                    ProxyFilter {
1265                        id: Some(NonEmptyString::from_static("1")),
1266                        ..Default::default()
1267                    },
1268                ),
1269                (
1270                    "pool=hq&country=us",
1271                    ProxyFilter {
1272                        pool_id: Some(vec![StringFilter::new("hq")]),
1273                        country: Some(vec![StringFilter::new("us")]),
1274                        ..Default::default()
1275                    },
1276                ),
1277                (
1278                    "pool=hq&country=us&country=be",
1279                    ProxyFilter {
1280                        pool_id: Some(vec![StringFilter::new("hq")]),
1281                        country: Some(vec![StringFilter::new("us"), StringFilter::new("be")]),
1282                        ..Default::default()
1283                    },
1284                ),
1285                (
1286                    "pool=a&country=uk&pool=b",
1287                    ProxyFilter {
1288                        pool_id: Some(vec![StringFilter::new("a"), StringFilter::new("b")]),
1289                        country: Some(vec![StringFilter::new("uk")]),
1290                        ..Default::default()
1291                    },
1292                ),
1293                (
1294                    "continent=europe&continent=asia",
1295                    ProxyFilter {
1296                        continent: Some(vec![
1297                            StringFilter::new("europe"),
1298                            StringFilter::new("asia"),
1299                        ]),
1300                        ..Default::default()
1301                    },
1302                ),
1303                (
1304                    "continent=americas&country=us&state=NY&city=buffalo&carrier=AT%26T&asn=7018",
1305                    ProxyFilter {
1306                        continent: Some(vec![StringFilter::new("americas")]),
1307                        country: Some(vec![StringFilter::new("us")]),
1308                        state: Some(vec![StringFilter::new("ny")]),
1309                        city: Some(vec![StringFilter::new("buffalo")]),
1310                        carrier: Some(vec![StringFilter::new("at&t")]),
1311                        asn: Some(vec![Asn::from_static(7018)]),
1312                        ..Default::default()
1313                    },
1314                ),
1315                (
1316                    "asn=1&asn=2",
1317                    ProxyFilter {
1318                        asn: Some(vec![Asn::from_static(1), Asn::from_static(2)]),
1319                        ..Default::default()
1320                    },
1321                ),
1322            ] {
1323                let filter: ProxyFilter = serde_html_form::from_str(input).unwrap();
1324                assert_eq!(filter, expected_output);
1325            }
1326        }
1327    }
1328}
1329
1330#[cfg(feature = "memory-db")]
1331pub use memdb::{
1332    MemoryProxyDB, MemoryProxyDBInsertError, MemoryProxyDBInsertErrorKind, MemoryProxyDBQueryError,
1333    MemoryProxyDBQueryErrorKind,
1334};