rama_proxy/proxydb/
mod.rs

1use rama_core::error::{BoxError, ErrorContext, OpaqueError};
2use rama_net::asn::Asn;
3use rama_utils::str::NonEmptyString;
4use serde::{Deserialize, Serialize};
5use std::fmt;
6
7#[cfg(feature = "live-update")]
8mod update;
9#[cfg(feature = "live-update")]
10#[doc(inline)]
11pub use update::{LiveUpdateProxyDB, LiveUpdateProxyDBSetter, proxy_db_updater};
12
13mod context;
14pub use context::ProxyContext;
15
16mod internal;
17#[doc(inline)]
18pub use internal::Proxy;
19
20#[cfg(feature = "csv")]
21mod csv;
22
23#[cfg(feature = "csv")]
24#[doc(inline)]
25pub use csv::{ProxyCsvRowReader, ProxyCsvRowReaderError, ProxyCsvRowReaderErrorKind};
26
27pub(super) mod layer;
28
29mod str;
30#[doc(inline)]
31pub use str::StringFilter;
32
33#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)]
34/// `ID` of the selected proxy. To be inserted into the `Context`,
35/// only if that proxy is selected.
36pub struct ProxyID(NonEmptyString);
37
38impl ProxyID {
39    /// View  this [`ProxyID`] as a `str`.
40    pub fn as_str(&self) -> &str {
41        self.0.as_str()
42    }
43}
44
45impl AsRef<str> for ProxyID {
46    fn as_ref(&self) -> &str {
47        self.0.as_ref()
48    }
49}
50
51impl fmt::Display for ProxyID {
52    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
53        self.0.fmt(f)
54    }
55}
56
57impl From<NonEmptyString> for ProxyID {
58    fn from(value: NonEmptyString) -> Self {
59        Self(value)
60    }
61}
62
63#[derive(Debug, Default, Clone, Serialize, Deserialize, PartialEq, Eq, Hash)]
64/// Filter to select a specific kind of proxy.
65///
66/// If the `id` is specified the other fields are used
67/// as a validator to see if the only possible matching proxy
68/// matches these fields.
69///
70/// If the `id` is not specified, the other fields are used
71/// to select a random proxy from the pool.
72///
73/// Filters can be combined to make combinations with special meaning.
74/// E.g. `datacenter:true, residential:true` is essentially an ISP proxy.
75///
76/// ## Usage
77///
78/// - Use `HeaderConfigLayer` (`rama-http`) to have this proxy filter be given by the http `Request` headers,
79///   which will add the extracted and parsed [`ProxyFilter`] to the [`Context`]'s [`Extensions`].
80/// - Or extract yourself from the username/token validated in the `ProxyAuthLayer` (`rama-http`)
81///   to add it manually to the [`Context`]'s [`Extensions`].
82///
83/// [`Request`]: crate::http::Request
84/// [`Context`]: rama_core::Context
85/// [`Extensions`]: rama_core::context::Extensions
86pub struct ProxyFilter {
87    /// The ID of the proxy to select.
88    pub id: Option<NonEmptyString>,
89
90    /// The ID of the pool from which to select the proxy.
91    #[serde(alias = "pool")]
92    pub pool_id: Option<Vec<StringFilter>>,
93
94    /// The continent of the proxy.
95    pub continent: Option<Vec<StringFilter>>,
96
97    /// The country of the proxy.
98    pub country: Option<Vec<StringFilter>>,
99
100    /// The state of the proxy.
101    pub state: Option<Vec<StringFilter>>,
102
103    /// The city of the proxy.
104    pub city: Option<Vec<StringFilter>>,
105
106    /// Set explicitly to `true` to select a datacenter proxy.
107    pub datacenter: Option<bool>,
108
109    /// Set explicitly to `true` to select a residential proxy.
110    pub residential: Option<bool>,
111
112    /// Set explicitly to `true` to select a mobile proxy.
113    pub mobile: Option<bool>,
114
115    /// The mobile carrier desired.
116    pub carrier: Option<Vec<StringFilter>>,
117
118    ///  Autonomous System Number (ASN).
119    pub asn: Option<Vec<Asn>>,
120}
121
122/// The trait to implement to provide a proxy database to other facilities,
123/// such as connection pools, to provide a proxy based on the given
124/// [`TransportContext`] and [`ProxyFilter`].
125pub trait ProxyDB: Send + Sync + 'static {
126    /// The error type that can be returned by the proxy database
127    ///
128    /// Examples are generic I/O issues or
129    /// even more common if no proxy match could be found.
130    type Error: Send + 'static;
131
132    /// Same as [`Self::get_proxy`] but with a predicate
133    /// to filter out found proxies that do not match the given predicate.
134    fn get_proxy_if(
135        &self,
136        ctx: ProxyContext,
137        filter: ProxyFilter,
138        predicate: impl ProxyQueryPredicate,
139    ) -> impl Future<Output = Result<Proxy, Self::Error>> + Send + '_;
140
141    /// Get a [`Proxy`] based on the given [`ProxyContext`] and [`ProxyFilter`],
142    /// or return an error in case no [`Proxy`] could be returned.
143    fn get_proxy(
144        &self,
145        ctx: ProxyContext,
146        filter: ProxyFilter,
147    ) -> impl Future<Output = Result<Proxy, Self::Error>> + Send + '_ {
148        self.get_proxy_if(ctx, filter, true)
149    }
150}
151
152impl ProxyDB for () {
153    type Error = OpaqueError;
154
155    #[inline]
156    async fn get_proxy_if(
157        &self,
158        _ctx: ProxyContext,
159        _filter: ProxyFilter,
160        _predicate: impl ProxyQueryPredicate,
161    ) -> Result<Proxy, Self::Error> {
162        Err(OpaqueError::from_display(
163            "()::get_proxy_if: no ProxyDB defined",
164        ))
165    }
166
167    #[inline]
168    async fn get_proxy(
169        &self,
170        _ctx: ProxyContext,
171        _filter: ProxyFilter,
172    ) -> Result<Proxy, Self::Error> {
173        Err(OpaqueError::from_display(
174            "()::get_proxy: no ProxyDB defined",
175        ))
176    }
177}
178
179impl<T> ProxyDB for Option<T>
180where
181    T: ProxyDB<Error: Into<BoxError>>,
182{
183    type Error = OpaqueError;
184
185    #[inline]
186    async fn get_proxy_if(
187        &self,
188        ctx: ProxyContext,
189        filter: ProxyFilter,
190        predicate: impl ProxyQueryPredicate,
191    ) -> Result<Proxy, Self::Error> {
192        match self {
193            Some(db) => db
194                .get_proxy_if(ctx, filter, predicate)
195                .await
196                .map_err(|err| OpaqueError::from_boxed(err.into()))
197                .context("Some::get_proxy_if"),
198            None => Err(OpaqueError::from_display(
199                "None::get_proxy_if: no ProxyDB defined",
200            )),
201        }
202    }
203
204    #[inline]
205    async fn get_proxy(
206        &self,
207        ctx: ProxyContext,
208        filter: ProxyFilter,
209    ) -> Result<Proxy, Self::Error> {
210        match self {
211            Some(db) => db
212                .get_proxy(ctx, filter)
213                .await
214                .map_err(|err| OpaqueError::from_boxed(err.into()))
215                .context("Some::get_proxy"),
216            None => Err(OpaqueError::from_display(
217                "None::get_proxy: no ProxyDB defined",
218            )),
219        }
220    }
221}
222
223impl<T> ProxyDB for std::sync::Arc<T>
224where
225    T: ProxyDB,
226{
227    type Error = T::Error;
228
229    #[inline]
230    fn get_proxy_if(
231        &self,
232        ctx: ProxyContext,
233        filter: ProxyFilter,
234        predicate: impl ProxyQueryPredicate,
235    ) -> impl Future<Output = Result<Proxy, Self::Error>> + Send + '_ {
236        (**self).get_proxy_if(ctx, filter, predicate)
237    }
238
239    #[inline]
240    fn get_proxy(
241        &self,
242        ctx: ProxyContext,
243        filter: ProxyFilter,
244    ) -> impl Future<Output = Result<Proxy, Self::Error>> + Send + '_ {
245        (**self).get_proxy(ctx, filter)
246    }
247}
248
249macro_rules! impl_proxydb_either {
250    ($id:ident, $($param:ident),+ $(,)?) => {
251        impl<$($param),+> ProxyDB for rama_core::combinators::$id<$($param),+>
252        where
253            $(
254                $param: ProxyDB<Error: Into<BoxError>>,
255            )+
256    {
257        type Error = BoxError;
258
259        #[inline]
260        async fn get_proxy_if(
261            &self,
262            ctx: ProxyContext,
263            filter: ProxyFilter,
264            predicate: impl ProxyQueryPredicate,
265        ) -> Result<Proxy, Self::Error> {
266            match self {
267                $(
268                    rama_core::combinators::$id::$param(s) => s.get_proxy_if(ctx, filter, predicate).await.map_err(Into::into),
269                )+
270            }
271        }
272
273        #[inline]
274        async fn get_proxy(
275            &self,
276            ctx: ProxyContext,
277            filter: ProxyFilter,
278        ) -> Result<Proxy, Self::Error> {
279            match self {
280                $(
281                    rama_core::combinators::$id::$param(s) => s.get_proxy(ctx, filter).await.map_err(Into::into),
282                )+
283            }
284        }
285        }
286    };
287}
288
289rama_core::combinators::impl_either!(impl_proxydb_either);
290
291/// Trait that is used by the [`ProxyDB`] for providing an optional
292/// filter predicate to rule out returned results.
293pub trait ProxyQueryPredicate: Clone + Send + Sync + 'static {
294    /// Execute the predicate.
295    fn execute(&self, proxy: &Proxy) -> bool;
296}
297
298impl ProxyQueryPredicate for bool {
299    fn execute(&self, _proxy: &Proxy) -> bool {
300        *self
301    }
302}
303
304impl<F> ProxyQueryPredicate for F
305where
306    F: Fn(&Proxy) -> bool + Clone + Send + Sync + 'static,
307{
308    fn execute(&self, proxy: &Proxy) -> bool {
309        (self)(proxy)
310    }
311}
312
313impl ProxyDB for Proxy {
314    type Error = rama_core::error::OpaqueError;
315
316    async fn get_proxy_if(
317        &self,
318        ctx: ProxyContext,
319        filter: ProxyFilter,
320        predicate: impl ProxyQueryPredicate,
321    ) -> Result<Proxy, Self::Error> {
322        (self.is_match(&ctx, &filter) && predicate.execute(self))
323            .then(|| self.clone())
324            .ok_or_else(|| rama_core::error::OpaqueError::from_display("hardcoded proxy no match"))
325    }
326}
327
328#[cfg(feature = "memory-db")]
329mod memdb {
330    use super::*;
331    use crate::proxydb::internal::ProxyDBErrorKind;
332    use rama_net::transport::TransportProtocol;
333
334    /// A fast in-memory ProxyDatabase that is the default choice for Rama.
335    #[derive(Debug)]
336    pub struct MemoryProxyDB {
337        data: internal::ProxyDB,
338    }
339
340    impl MemoryProxyDB {
341        /// Create a new in-memory proxy database with the given proxies.
342        pub fn try_from_rows(proxies: Vec<Proxy>) -> Result<Self, MemoryProxyDBInsertError> {
343            Ok(MemoryProxyDB {
344                data: internal::ProxyDB::from_rows(proxies).map_err(|err| match err.kind() {
345                    ProxyDBErrorKind::DuplicateKey => {
346                        MemoryProxyDBInsertError::duplicate_key(err.into_input())
347                    }
348                    ProxyDBErrorKind::InvalidRow => {
349                        MemoryProxyDBInsertError::invalid_proxy(err.into_input())
350                    }
351                })?,
352            })
353        }
354
355        /// Create a new in-memory proxy database with the given proxies from an iterator.
356        pub fn try_from_iter<I>(proxies: I) -> Result<Self, MemoryProxyDBInsertError>
357        where
358            I: IntoIterator<Item = Proxy>,
359        {
360            Ok(MemoryProxyDB {
361                data: internal::ProxyDB::from_iter(proxies).map_err(|err| match err.kind() {
362                    ProxyDBErrorKind::DuplicateKey => {
363                        MemoryProxyDBInsertError::duplicate_key(err.into_input())
364                    }
365                    ProxyDBErrorKind::InvalidRow => {
366                        MemoryProxyDBInsertError::invalid_proxy(err.into_input())
367                    }
368                })?,
369            })
370        }
371
372        /// Return the number of proxies in the database.
373        pub fn len(&self) -> usize {
374            self.data.len()
375        }
376
377        /// Rerturns if the database is empty.
378        pub fn is_empty(&self) -> bool {
379            self.data.is_empty()
380        }
381
382        fn query_from_filter(
383            &self,
384            ctx: ProxyContext,
385            filter: ProxyFilter,
386        ) -> internal::ProxyDBQuery {
387            let mut query = self.data.query();
388
389            for pool_id in filter.pool_id.into_iter().flatten() {
390                query.pool_id(pool_id);
391            }
392            for continent in filter.continent.into_iter().flatten() {
393                query.continent(continent);
394            }
395            for country in filter.country.into_iter().flatten() {
396                query.country(country);
397            }
398            for state in filter.state.into_iter().flatten() {
399                query.state(state);
400            }
401            for city in filter.city.into_iter().flatten() {
402                query.city(city);
403            }
404            for carrier in filter.carrier.into_iter().flatten() {
405                query.carrier(carrier);
406            }
407            for asn in filter.asn.into_iter().flatten() {
408                query.asn(asn);
409            }
410
411            if let Some(value) = filter.datacenter {
412                query.datacenter(value);
413            }
414            if let Some(value) = filter.residential {
415                query.residential(value);
416            }
417            if let Some(value) = filter.mobile {
418                query.mobile(value);
419            }
420
421            match ctx.protocol {
422                TransportProtocol::Tcp => {
423                    query.tcp(true);
424                }
425                TransportProtocol::Udp => {
426                    query.udp(true).socks5(true);
427                }
428            }
429
430            query
431        }
432    }
433
434    // TODO: custom query filters using ProxyQueryPredicate
435    // might be a lot faster for cases where we want to filter a big batch of proxies,
436    // in which case a bitmap could be supported by a future VennDB version...
437    //
438    // Would just need to figure out how to allow this to happen.
439
440    impl ProxyDB for MemoryProxyDB {
441        type Error = MemoryProxyDBQueryError;
442
443        async fn get_proxy_if(
444            &self,
445            ctx: ProxyContext,
446            filter: ProxyFilter,
447            predicate: impl ProxyQueryPredicate,
448        ) -> Result<Proxy, Self::Error> {
449            match &filter.id {
450                Some(id) => match self.data.get_by_id(id) {
451                    None => Err(MemoryProxyDBQueryError::not_found()),
452                    Some(proxy) => {
453                        if proxy.is_match(&ctx, &filter) && predicate.execute(proxy) {
454                            Ok(proxy.clone())
455                        } else {
456                            Err(MemoryProxyDBQueryError::mismatch())
457                        }
458                    }
459                },
460                None => {
461                    let query = self.query_from_filter(ctx, filter.clone());
462                    match query
463                        .execute()
464                        .and_then(|result| result.filter(|proxy| predicate.execute(proxy)))
465                        .map(|result| result.any())
466                    {
467                        None => Err(MemoryProxyDBQueryError::not_found()),
468                        Some(proxy) => Ok(proxy.clone()),
469                    }
470                }
471            }
472        }
473    }
474
475    /// The error type that can be returned by [`MemoryProxyDB`] when some of the proxies
476    /// could not be inserted due to a proxy that had a duplicate key or was invalid for some other reason.
477    #[derive(Debug)]
478    pub struct MemoryProxyDBInsertError {
479        kind: MemoryProxyDBInsertErrorKind,
480        proxies: Vec<Proxy>,
481    }
482
483    impl std::fmt::Display for MemoryProxyDBInsertError {
484        fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
485            match self.kind {
486                MemoryProxyDBInsertErrorKind::DuplicateKey => write!(
487                    f,
488                    "A proxy with the same key already exists in the database"
489                ),
490                MemoryProxyDBInsertErrorKind::InvalidProxy => {
491                    write!(f, "A proxy in the list is invalid for some reason")
492                }
493            }
494        }
495    }
496
497    impl std::error::Error for MemoryProxyDBInsertError {}
498
499    #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
500    /// The kind of error that [`MemoryProxyDBInsertError`] represents.
501    pub enum MemoryProxyDBInsertErrorKind {
502        /// Duplicate key found in the proxies.
503        DuplicateKey,
504        /// Invalid proxy found in the proxies.
505        ///
506        /// This could be due to a proxy that is not valid for some reason.
507        /// E.g. a proxy that neither supports http or socks5.
508        InvalidProxy,
509    }
510
511    impl MemoryProxyDBInsertError {
512        fn duplicate_key(proxies: Vec<Proxy>) -> Self {
513            MemoryProxyDBInsertError {
514                kind: MemoryProxyDBInsertErrorKind::DuplicateKey,
515                proxies,
516            }
517        }
518
519        fn invalid_proxy(proxies: Vec<Proxy>) -> Self {
520            MemoryProxyDBInsertError {
521                kind: MemoryProxyDBInsertErrorKind::InvalidProxy,
522                proxies,
523            }
524        }
525
526        /// Returns the kind of error that [`MemoryProxyDBInsertError`] represents.
527        pub fn kind(&self) -> MemoryProxyDBInsertErrorKind {
528            self.kind
529        }
530
531        /// Returns the proxies that were not inserted.
532        pub fn proxies(&self) -> &[Proxy] {
533            &self.proxies
534        }
535
536        /// Consumes the error and returns the proxies that were not inserted.
537        pub fn into_proxies(self) -> Vec<Proxy> {
538            self.proxies
539        }
540    }
541
542    /// The error type that can be returned by [`MemoryProxyDB`] when no proxy could be returned.
543    #[derive(Debug)]
544    pub struct MemoryProxyDBQueryError {
545        kind: MemoryProxyDBQueryErrorKind,
546    }
547
548    #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
549    /// The kind of error that [`MemoryProxyDBQueryError`] represents.
550    pub enum MemoryProxyDBQueryErrorKind {
551        /// No proxy match could be found.
552        NotFound,
553        /// A proxy looked up by key had a config that did not match the given filters/requirements.
554        Mismatch,
555    }
556
557    impl std::fmt::Display for MemoryProxyDBQueryError {
558        fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
559            match self.kind {
560                MemoryProxyDBQueryErrorKind::NotFound => write!(f, "No proxy match could be found"),
561                MemoryProxyDBQueryErrorKind::Mismatch => write!(
562                    f,
563                    "Proxy config did not match the given filters/requirements"
564                ),
565            }
566        }
567    }
568
569    impl std::error::Error for MemoryProxyDBQueryError {}
570
571    impl MemoryProxyDBQueryError {
572        /// Create a new error that indicates no proxy match could be found.
573        pub fn not_found() -> Self {
574            MemoryProxyDBQueryError {
575                kind: MemoryProxyDBQueryErrorKind::NotFound,
576            }
577        }
578
579        /// Create a new error that indicates a proxy looked up by key had a config that did not match the given filters/requirements.
580        pub fn mismatch() -> Self {
581            MemoryProxyDBQueryError {
582                kind: MemoryProxyDBQueryErrorKind::Mismatch,
583            }
584        }
585
586        /// Returns the kind of error that [`MemoryProxyDBQueryError`] represents.
587        pub fn kind(&self) -> MemoryProxyDBQueryErrorKind {
588            self.kind
589        }
590    }
591
592    #[cfg(test)]
593    mod tests {
594        use super::*;
595        use itertools::Itertools;
596        use rama_net::address::ProxyAddress;
597        use rama_utils::str::NonEmptyString;
598        use std::str::FromStr;
599
600        const RAW_CSV_DATA: &str = include_str!("./test_proxydb_rows.csv");
601
602        async fn memproxydb() -> MemoryProxyDB {
603            let mut reader = ProxyCsvRowReader::raw(RAW_CSV_DATA);
604            let mut rows = Vec::new();
605            while let Some(proxy) = reader.next().await.unwrap() {
606                rows.push(proxy);
607            }
608            MemoryProxyDB::try_from_rows(rows).unwrap()
609        }
610
611        #[tokio::test]
612        async fn test_load_memproxydb_from_rows() {
613            let db = memproxydb().await;
614            assert_eq!(db.len(), 64);
615        }
616
617        fn h2_proxy_context() -> ProxyContext {
618            ProxyContext {
619                protocol: TransportProtocol::Tcp,
620            }
621        }
622
623        #[tokio::test]
624        async fn test_memproxydb_get_proxy_by_id_found() {
625            let db = memproxydb().await;
626            let ctx = h2_proxy_context();
627            let filter = ProxyFilter {
628                id: Some(NonEmptyString::from_static("3031533634")),
629                ..Default::default()
630            };
631            let proxy = db.get_proxy(ctx, filter).await.unwrap();
632            assert_eq!(proxy.id, "3031533634");
633        }
634
635        #[tokio::test]
636        async fn test_memproxydb_get_proxy_by_id_found_correct_filters() {
637            let db = memproxydb().await;
638            let ctx = h2_proxy_context();
639            let filter = ProxyFilter {
640                id: Some(NonEmptyString::from_static("3031533634")),
641                pool_id: Some(vec![StringFilter::new("poolF")]),
642                country: Some(vec![StringFilter::new("JP")]),
643                city: Some(vec![StringFilter::new("Yokohama")]),
644                datacenter: Some(true),
645                residential: Some(false),
646                mobile: Some(true),
647                carrier: Some(vec![StringFilter::new("Verizon")]),
648                ..Default::default()
649            };
650            let proxy = db.get_proxy(ctx, filter).await.unwrap();
651            assert_eq!(proxy.id, "3031533634");
652        }
653
654        #[tokio::test]
655        async fn test_memproxydb_get_proxy_by_id_not_found() {
656            let db = memproxydb().await;
657            let ctx = h2_proxy_context();
658            let filter = ProxyFilter {
659                id: Some(NonEmptyString::from_static("notfound")),
660                ..Default::default()
661            };
662            let err = db.get_proxy(ctx, filter).await.unwrap_err();
663            assert_eq!(err.kind(), MemoryProxyDBQueryErrorKind::NotFound);
664        }
665
666        #[tokio::test]
667        async fn test_memproxydb_get_proxy_by_id_mismatch_filter() {
668            let db = memproxydb().await;
669            let ctx = h2_proxy_context();
670            let filters = [
671                ProxyFilter {
672                    id: Some(NonEmptyString::from_static("3031533634")),
673                    pool_id: Some(vec![StringFilter::new("poolB")]),
674                    ..Default::default()
675                },
676                ProxyFilter {
677                    id: Some(NonEmptyString::from_static("3031533634")),
678                    country: Some(vec![StringFilter::new("US")]),
679                    ..Default::default()
680                },
681                ProxyFilter {
682                    id: Some(NonEmptyString::from_static("3031533634")),
683                    city: Some(vec![StringFilter::new("New York")]),
684                    ..Default::default()
685                },
686                ProxyFilter {
687                    id: Some(NonEmptyString::from_static("3031533634")),
688                    continent: Some(vec![StringFilter::new("americas")]),
689                    ..Default::default()
690                },
691                ProxyFilter {
692                    id: Some(NonEmptyString::from_static("3732488183")),
693                    state: Some(vec![StringFilter::new("Texas")]),
694                    ..Default::default()
695                },
696                ProxyFilter {
697                    id: Some(NonEmptyString::from_static("3031533634")),
698                    datacenter: Some(false),
699                    ..Default::default()
700                },
701                ProxyFilter {
702                    id: Some(NonEmptyString::from_static("3031533634")),
703                    residential: Some(true),
704                    ..Default::default()
705                },
706                ProxyFilter {
707                    id: Some(NonEmptyString::from_static("3031533634")),
708                    mobile: Some(false),
709                    ..Default::default()
710                },
711                ProxyFilter {
712                    id: Some(NonEmptyString::from_static("3031533634")),
713                    carrier: Some(vec![StringFilter::new("AT&T")]),
714                    ..Default::default()
715                },
716                ProxyFilter {
717                    id: Some(NonEmptyString::from_static("292096733")),
718                    asn: Some(vec![Asn::from_static(1)]),
719                    ..Default::default()
720                },
721            ];
722            for filter in filters.iter() {
723                let err = db.get_proxy(ctx.clone(), filter.clone()).await.unwrap_err();
724                assert_eq!(err.kind(), MemoryProxyDBQueryErrorKind::Mismatch);
725            }
726        }
727
728        fn h3_proxy_context() -> ProxyContext {
729            ProxyContext {
730                protocol: TransportProtocol::Udp,
731            }
732        }
733
734        #[tokio::test]
735        async fn test_memproxydb_get_proxy_by_id_mismatch_req_context() {
736            let db = memproxydb().await;
737            let ctx = h3_proxy_context();
738            let filter = ProxyFilter {
739                id: Some(NonEmptyString::from_static("3031533634")),
740                ..Default::default()
741            };
742            // this proxy does not support socks5 UDP, which is what we need
743            let err = db.get_proxy(ctx, filter).await.unwrap_err();
744            assert_eq!(err.kind(), MemoryProxyDBQueryErrorKind::Mismatch);
745        }
746
747        #[tokio::test]
748        async fn test_memorydb_get_h3_capable_proxies() {
749            let db = memproxydb().await;
750            let ctx = h3_proxy_context();
751            let filter = ProxyFilter::default();
752            let mut found_ids = Vec::new();
753            for _ in 0..5000 {
754                let proxy = db.get_proxy(ctx.clone(), filter.clone()).await.unwrap();
755                if found_ids.contains(&proxy.id) {
756                    continue;
757                }
758                assert!(proxy.udp);
759                assert!(proxy.socks5);
760                found_ids.push(proxy.id);
761            }
762            assert_eq!(found_ids.len(), 40);
763            assert_eq!(
764                found_ids.iter().sorted().join(","),
765                r##"1125300915,1259341971,1316455915,153202126,1571861931,1684342915,1742367441,1844412609,1916851007,20647117,2107229589,2261612122,2497865606,2521901221,2560727338,2593294918,2596743625,2745456299,2880295577,2909724448,2950022859,2951529660,3187902553,3269411602,3269465574,3269921904,3481200027,3498810974,362091157,3679054656,3732488183,3836943127,39048766,3951672504,3976711563,4187178960,56402588,724884866,738626121,906390012"##
766            );
767        }
768
769        #[tokio::test]
770        async fn test_memorydb_get_h2_capable_proxies() {
771            let db = memproxydb().await;
772            let ctx = h2_proxy_context();
773            let filter = ProxyFilter::default();
774            let mut found_ids = Vec::new();
775            for _ in 0..5000 {
776                let proxy = db.get_proxy(ctx.clone(), filter.clone()).await.unwrap();
777                if found_ids.contains(&proxy.id) {
778                    continue;
779                }
780                assert!(proxy.tcp);
781                found_ids.push(proxy.id);
782            }
783            assert_eq!(found_ids.len(), 50);
784            assert_eq!(
785                found_ids.iter().sorted().join(","),
786                r#"1125300915,1259341971,1264821985,129108927,1316455915,1425588737,1571861931,1810781137,1836040682,1844412609,1885107293,2021561518,2079461709,2107229589,2141152822,2438596154,2497865606,2521901221,2551759475,2560727338,2593294918,2798907087,2854473221,2880295577,2909724448,2912880381,292096733,2951529660,3031533634,3187902553,3269411602,3269465574,339020035,3481200027,3498810974,3503691556,362091157,3679054656,371209663,3861736957,39048766,3976711563,4062553709,49590203,56402588,724884866,738626121,767809962,846528631,906390012"#,
787            );
788        }
789
790        #[tokio::test]
791        async fn test_memorydb_get_any_country_proxies() {
792            let db = memproxydb().await;
793            let ctx = h2_proxy_context();
794            let filter = ProxyFilter {
795                // there are no explicit BE proxies,
796                // so these will only match the proxies that have a wildcard country
797                country: Some(vec!["BE".into()]),
798                ..Default::default()
799            };
800            let mut found_ids = Vec::new();
801            for _ in 0..5000 {
802                let proxy = db.get_proxy(ctx.clone(), filter.clone()).await.unwrap();
803                if found_ids.contains(&proxy.id) {
804                    continue;
805                }
806                found_ids.push(proxy.id);
807            }
808            assert_eq!(found_ids.len(), 5);
809            assert_eq!(
810                found_ids.iter().sorted().join(","),
811                r#"2141152822,2593294918,2912880381,371209663,767809962"#,
812            );
813        }
814
815        #[tokio::test]
816        async fn test_memorydb_get_illinois_proxies() {
817            let db = memproxydb().await;
818            let ctx = h2_proxy_context();
819            let filter = ProxyFilter {
820                // this will also work for proxies that have 'any' state
821                state: Some(vec!["illinois".into()]),
822                ..Default::default()
823            };
824            let mut found_ids = Vec::new();
825            for _ in 0..5000 {
826                let proxy = db.get_proxy(ctx.clone(), filter.clone()).await.unwrap();
827                if found_ids.contains(&proxy.id) {
828                    continue;
829                }
830                found_ids.push(proxy.id);
831            }
832            assert_eq!(found_ids.len(), 9);
833            assert_eq!(
834                found_ids.iter().sorted().join(","),
835                r#"2141152822,2521901221,2560727338,2593294918,2912880381,292096733,371209663,39048766,767809962"#,
836            );
837        }
838
839        #[tokio::test]
840        async fn test_memorydb_get_asn_proxies() {
841            let db = memproxydb().await;
842            let ctx = h2_proxy_context();
843            let filter = ProxyFilter {
844                // this will also work for proxies that have 'any' ASN
845                asn: Some(vec![Asn::from_static(42)]),
846                ..Default::default()
847            };
848            let mut found_ids = Vec::new();
849            for _ in 0..5000 {
850                let proxy = db.get_proxy(ctx.clone(), filter.clone()).await.unwrap();
851                if found_ids.contains(&proxy.id) {
852                    continue;
853                }
854                found_ids.push(proxy.id);
855            }
856            assert_eq!(found_ids.len(), 4);
857            assert_eq!(
858                found_ids.iter().sorted().join(","),
859                r#"2141152822,2912880381,292096733,3481200027"#,
860            );
861        }
862
863        #[tokio::test]
864        async fn test_memorydb_get_h3_capable_mobile_residential_be_asterix_proxies() {
865            let db = memproxydb().await;
866            let ctx = h3_proxy_context();
867            let filter = ProxyFilter {
868                country: Some(vec!["BE".into()]),
869                mobile: Some(true),
870                residential: Some(true),
871                ..Default::default()
872            };
873            for _ in 0..50 {
874                let proxy = db.get_proxy(ctx.clone(), filter.clone()).await.unwrap();
875                assert_eq!(proxy.id, "2593294918");
876            }
877        }
878
879        #[tokio::test]
880        async fn test_memorydb_get_blocked_proxies() {
881            let db = memproxydb().await;
882            let ctx = h2_proxy_context();
883            let filter = ProxyFilter::default();
884
885            let mut blocked_proxies = vec![
886                "1125300915",
887                "1259341971",
888                "1264821985",
889                "129108927",
890                "1316455915",
891                "1425588737",
892                "1571861931",
893                "1810781137",
894                "1836040682",
895                "1844412609",
896                "1885107293",
897                "2021561518",
898                "2079461709",
899                "2107229589",
900                "2141152822",
901                "2438596154",
902                "2497865606",
903                "2521901221",
904                "2551759475",
905                "2560727338",
906                "2593294918",
907                "2798907087",
908                "2854473221",
909                "2880295577",
910                "2909724448",
911                "2912880381",
912                "292096733",
913                "2951529660",
914                "3031533634",
915                "3187902553",
916                "3269411602",
917                "3269465574",
918                "339020035",
919                "3481200027",
920                "3498810974",
921                "3503691556",
922                "362091157",
923                "3679054656",
924                "371209663",
925                "3861736957",
926                "39048766",
927                "3976711563",
928                "4062553709",
929                "49590203",
930                "56402588",
931                "724884866",
932                "738626121",
933                "767809962",
934                "846528631",
935                "906390012",
936            ];
937
938            {
939                let blocked_proxies = blocked_proxies.clone();
940
941                assert_eq!(
942                    MemoryProxyDBQueryErrorKind::NotFound,
943                    db.get_proxy_if(ctx.clone(), filter.clone(), move |proxy: &Proxy| {
944                        !blocked_proxies.contains(&proxy.id.as_str())
945                    })
946                    .await
947                    .unwrap_err()
948                    .kind()
949                );
950            }
951
952            let last_proxy_id = blocked_proxies.pop().unwrap();
953
954            let proxy = db
955                .get_proxy_if(ctx, filter.clone(), move |proxy: &Proxy| {
956                    !blocked_proxies.contains(&proxy.id.as_str())
957                })
958                .await
959                .unwrap();
960            assert_eq!(proxy.id, last_proxy_id);
961        }
962
963        #[tokio::test]
964        async fn test_db_proxy_filter_any_use_filter_property() {
965            let db = MemoryProxyDB::try_from_iter([Proxy {
966                id: NonEmptyString::from_static("1"),
967                address: ProxyAddress::from_str("example.com").unwrap(),
968                tcp: true,
969                udp: true,
970                http: true,
971                https: true,
972                socks5: true,
973                socks5h: true,
974                datacenter: true,
975                residential: true,
976                mobile: true,
977                pool_id: Some("*".into()),
978                continent: Some("*".into()),
979                country: Some("*".into()),
980                state: Some("*".into()),
981                city: Some("*".into()),
982                carrier: Some("*".into()),
983                asn: Some(Asn::unspecified()),
984            }])
985            .unwrap();
986
987            let ctx = h2_proxy_context();
988
989            for filter in [
990                ProxyFilter {
991                    id: Some(NonEmptyString::from_static("1")),
992                    ..Default::default()
993                },
994                ProxyFilter {
995                    pool_id: Some(vec![StringFilter::new("*")]),
996                    ..Default::default()
997                },
998                ProxyFilter {
999                    pool_id: Some(vec![StringFilter::new("hq")]),
1000                    ..Default::default()
1001                },
1002                ProxyFilter {
1003                    country: Some(vec![StringFilter::new("*")]),
1004                    ..Default::default()
1005                },
1006                ProxyFilter {
1007                    country: Some(vec![StringFilter::new("US")]),
1008                    ..Default::default()
1009                },
1010                ProxyFilter {
1011                    city: Some(vec![StringFilter::new("*")]),
1012                    ..Default::default()
1013                },
1014                ProxyFilter {
1015                    city: Some(vec![StringFilter::new("NY")]),
1016                    ..Default::default()
1017                },
1018                ProxyFilter {
1019                    carrier: Some(vec![StringFilter::new("*")]),
1020                    ..Default::default()
1021                },
1022                ProxyFilter {
1023                    carrier: Some(vec![StringFilter::new("Telenet")]),
1024                    ..Default::default()
1025                },
1026                ProxyFilter {
1027                    pool_id: Some(vec![StringFilter::new("hq")]),
1028                    country: Some(vec![StringFilter::new("US")]),
1029                    city: Some(vec![StringFilter::new("NY")]),
1030                    carrier: Some(vec![StringFilter::new("AT&T")]),
1031                    ..Default::default()
1032                },
1033            ] {
1034                let proxy = db.get_proxy(ctx.clone(), filter.clone()).await.unwrap();
1035                assert!(filter.id.map(|id| proxy.id == id).unwrap_or(true));
1036                assert!(
1037                    filter
1038                        .pool_id
1039                        .map(|pool_id| pool_id.contains(proxy.pool_id.as_ref().unwrap()))
1040                        .unwrap_or(true)
1041                );
1042                assert!(
1043                    filter
1044                        .country
1045                        .map(|country| country.contains(proxy.country.as_ref().unwrap()))
1046                        .unwrap_or(true)
1047                );
1048                assert!(
1049                    filter
1050                        .city
1051                        .map(|city| city.contains(proxy.city.as_ref().unwrap()))
1052                        .unwrap_or(true)
1053                );
1054                assert!(
1055                    filter
1056                        .carrier
1057                        .map(|carrier| carrier.contains(proxy.carrier.as_ref().unwrap()))
1058                        .unwrap_or(true)
1059                );
1060            }
1061        }
1062
1063        #[tokio::test]
1064        async fn test_db_proxy_filter_any_only_matches_any_value() {
1065            let db = MemoryProxyDB::try_from_iter([Proxy {
1066                id: NonEmptyString::from_static("1"),
1067                address: ProxyAddress::from_str("example.com").unwrap(),
1068                tcp: true,
1069                udp: true,
1070                http: true,
1071                https: true,
1072                socks5: true,
1073                socks5h: true,
1074                datacenter: true,
1075                residential: true,
1076                mobile: true,
1077                pool_id: Some("hq".into()),
1078                continent: Some("americas".into()),
1079                country: Some("US".into()),
1080                state: Some("NY".into()),
1081                city: Some("NY".into()),
1082                carrier: Some("AT&T".into()),
1083                asn: Some(Asn::from_static(7018)),
1084            }])
1085            .unwrap();
1086
1087            let ctx = h2_proxy_context();
1088
1089            for filter in [
1090                ProxyFilter {
1091                    pool_id: Some(vec![StringFilter::new("*")]),
1092                    ..Default::default()
1093                },
1094                ProxyFilter {
1095                    continent: Some(vec![StringFilter::new("*")]),
1096                    ..Default::default()
1097                },
1098                ProxyFilter {
1099                    country: Some(vec![StringFilter::new("*")]),
1100                    ..Default::default()
1101                },
1102                ProxyFilter {
1103                    state: Some(vec![StringFilter::new("*")]),
1104                    ..Default::default()
1105                },
1106                ProxyFilter {
1107                    city: Some(vec![StringFilter::new("*")]),
1108                    ..Default::default()
1109                },
1110                ProxyFilter {
1111                    carrier: Some(vec![StringFilter::new("*")]),
1112                    ..Default::default()
1113                },
1114                ProxyFilter {
1115                    asn: Some(vec![Asn::unspecified()]),
1116                    ..Default::default()
1117                },
1118                ProxyFilter {
1119                    pool_id: Some(vec![StringFilter::new("*")]),
1120                    continent: Some(vec![StringFilter::new("*")]),
1121                    country: Some(vec![StringFilter::new("*")]),
1122                    state: Some(vec![StringFilter::new("*")]),
1123                    city: Some(vec![StringFilter::new("*")]),
1124                    carrier: Some(vec![StringFilter::new("*")]),
1125                    asn: Some(vec![Asn::unspecified()]),
1126                    ..Default::default()
1127                },
1128            ] {
1129                let err = match db.get_proxy(ctx.clone(), filter.clone()).await {
1130                    Ok(proxy) => {
1131                        panic!(
1132                            "expected error for filter {:?}, not found proxy: {:?}",
1133                            filter, proxy
1134                        );
1135                    }
1136                    Err(err) => err,
1137                };
1138                assert_eq!(
1139                    MemoryProxyDBQueryErrorKind::NotFound,
1140                    err.kind(),
1141                    "filter: {:?}",
1142                    filter
1143                );
1144            }
1145        }
1146
1147        #[tokio::test]
1148        async fn test_search_proxy_for_any_of_given_pools() {
1149            let db = MemoryProxyDB::try_from_iter([
1150                Proxy {
1151                    id: NonEmptyString::from_static("1"),
1152                    address: ProxyAddress::from_str("example.com").unwrap(),
1153                    tcp: true,
1154                    udp: true,
1155                    http: true,
1156                    https: true,
1157                    socks5: true,
1158                    socks5h: true,
1159                    datacenter: true,
1160                    residential: true,
1161                    mobile: true,
1162                    pool_id: Some("a".into()),
1163                    continent: Some("americas".into()),
1164                    country: Some("US".into()),
1165                    state: Some("NY".into()),
1166                    city: Some("NY".into()),
1167                    carrier: Some("AT&T".into()),
1168                    asn: Some(Asn::from_static(7018)),
1169                },
1170                Proxy {
1171                    id: NonEmptyString::from_static("2"),
1172                    address: ProxyAddress::from_str("example.com").unwrap(),
1173                    tcp: true,
1174                    udp: true,
1175                    http: true,
1176                    https: true,
1177                    socks5: true,
1178                    socks5h: true,
1179                    datacenter: true,
1180                    residential: true,
1181                    mobile: true,
1182                    pool_id: Some("b".into()),
1183                    continent: Some("americas".into()),
1184                    country: Some("US".into()),
1185                    state: Some("NY".into()),
1186                    city: Some("NY".into()),
1187                    carrier: Some("AT&T".into()),
1188                    asn: Some(Asn::from_static(7018)),
1189                },
1190                Proxy {
1191                    id: NonEmptyString::from_static("3"),
1192                    address: ProxyAddress::from_str("example.com").unwrap(),
1193                    tcp: true,
1194                    udp: true,
1195                    http: true,
1196                    https: true,
1197                    socks5: true,
1198                    socks5h: true,
1199                    datacenter: true,
1200                    residential: true,
1201                    mobile: true,
1202                    pool_id: Some("b".into()),
1203                    continent: Some("americas".into()),
1204                    country: Some("US".into()),
1205                    state: Some("NY".into()),
1206                    city: Some("NY".into()),
1207                    carrier: Some("AT&T".into()),
1208                    asn: Some(Asn::from_static(7018)),
1209                },
1210                Proxy {
1211                    id: NonEmptyString::from_static("4"),
1212                    address: ProxyAddress::from_str("example.com").unwrap(),
1213                    tcp: true,
1214                    udp: true,
1215                    http: true,
1216                    https: true,
1217                    socks5: true,
1218                    socks5h: true,
1219                    datacenter: true,
1220                    residential: true,
1221                    mobile: true,
1222                    pool_id: Some("c".into()),
1223                    continent: Some("americas".into()),
1224                    country: Some("US".into()),
1225                    state: Some("NY".into()),
1226                    city: Some("NY".into()),
1227                    carrier: Some("AT&T".into()),
1228                    asn: Some(Asn::from_static(7018)),
1229                },
1230            ])
1231            .unwrap();
1232
1233            let ctx = h2_proxy_context();
1234
1235            let filter = ProxyFilter {
1236                pool_id: Some(vec![StringFilter::new("a"), StringFilter::new("c")]),
1237                ..Default::default()
1238            };
1239
1240            let mut seen_1 = false;
1241            let mut seen_4 = false;
1242            for _ in 0..100 {
1243                let proxy = db.get_proxy(ctx.clone(), filter.clone()).await.unwrap();
1244                match proxy.id.as_str() {
1245                    "1" => seen_1 = true,
1246                    "4" => seen_4 = true,
1247                    _ => panic!("unexpected pool id"),
1248                }
1249            }
1250            assert!(seen_1);
1251            assert!(seen_4);
1252        }
1253
1254        #[tokio::test]
1255        async fn test_deserialize_url_proxy_filter() {
1256            for (input, expected_output) in [
1257                (
1258                    "id=1",
1259                    ProxyFilter {
1260                        id: Some(NonEmptyString::from_static("1")),
1261                        ..Default::default()
1262                    },
1263                ),
1264                (
1265                    "pool=hq&country=us",
1266                    ProxyFilter {
1267                        pool_id: Some(vec![StringFilter::new("hq")]),
1268                        country: Some(vec![StringFilter::new("us")]),
1269                        ..Default::default()
1270                    },
1271                ),
1272                (
1273                    "pool=hq&country=us&country=be",
1274                    ProxyFilter {
1275                        pool_id: Some(vec![StringFilter::new("hq")]),
1276                        country: Some(vec![StringFilter::new("us"), StringFilter::new("be")]),
1277                        ..Default::default()
1278                    },
1279                ),
1280                (
1281                    "pool=a&country=uk&pool=b",
1282                    ProxyFilter {
1283                        pool_id: Some(vec![StringFilter::new("a"), StringFilter::new("b")]),
1284                        country: Some(vec![StringFilter::new("uk")]),
1285                        ..Default::default()
1286                    },
1287                ),
1288                (
1289                    "continent=europe&continent=asia",
1290                    ProxyFilter {
1291                        continent: Some(vec![
1292                            StringFilter::new("europe"),
1293                            StringFilter::new("asia"),
1294                        ]),
1295                        ..Default::default()
1296                    },
1297                ),
1298                (
1299                    "continent=americas&country=us&state=NY&city=buffalo&carrier=AT%26T&asn=7018",
1300                    ProxyFilter {
1301                        continent: Some(vec![StringFilter::new("americas")]),
1302                        country: Some(vec![StringFilter::new("us")]),
1303                        state: Some(vec![StringFilter::new("ny")]),
1304                        city: Some(vec![StringFilter::new("buffalo")]),
1305                        carrier: Some(vec![StringFilter::new("at&t")]),
1306                        asn: Some(vec![Asn::from_static(7018)]),
1307                        ..Default::default()
1308                    },
1309                ),
1310                (
1311                    "asn=1&asn=2",
1312                    ProxyFilter {
1313                        asn: Some(vec![Asn::from_static(1), Asn::from_static(2)]),
1314                        ..Default::default()
1315                    },
1316                ),
1317            ] {
1318                let filter: ProxyFilter = serde_html_form::from_str(input).unwrap();
1319                assert_eq!(filter, expected_output);
1320            }
1321        }
1322    }
1323}
1324
1325#[cfg(feature = "memory-db")]
1326pub use memdb::{
1327    MemoryProxyDB, MemoryProxyDBInsertError, MemoryProxyDBInsertErrorKind, MemoryProxyDBQueryError,
1328    MemoryProxyDBQueryErrorKind,
1329};