rama_proxy/proxydb/
layer.rs

1use super::{Proxy, ProxyContext, ProxyDB, ProxyFilter, ProxyQueryPredicate};
2use rama_core::{
3    Context, Layer, Service,
4    error::{BoxError, ErrorContext, ErrorExt, OpaqueError},
5};
6use rama_net::{
7    Protocol,
8    address::ProxyAddress,
9    transport::{TransportProtocol, TryRefIntoTransportContext},
10    user::{Basic, ProxyCredential},
11};
12use rama_utils::macros::define_inner_service_accessors;
13use std::fmt;
14
15/// A [`Service`] which selects a [`Proxy`] based on the given [`Context`].
16///
17/// Depending on the [`ProxyFilterMode`] the selection proxies might be optional,
18/// or use the default [`ProxyFilter`] in case none is defined.
19///
20/// A predicate can be used to provide additional filtering on the found proxies,
21/// that otherwise did match the used [`ProxyFilter`].
22///
23/// See [the crate docs](crate) for examples and more info on the usage of this service.
24///
25/// [`Proxy`]: crate::Proxy
26pub struct ProxyDBService<S, D, P, F> {
27    inner: S,
28    db: D,
29    mode: ProxyFilterMode,
30    predicate: P,
31    username_formatter: F,
32    preserve: bool,
33}
34
35#[derive(Debug, Clone, Default)]
36/// The modus operandi to decide how to deal with a missing [`ProxyFilter`] in the [`Context`]
37/// when selecting a [`Proxy`] from the [`ProxyDB`].
38///
39/// More advanced behaviour can be achieved by combining one of these modi
40/// with another (custom) layer prepending the parent.
41pub enum ProxyFilterMode {
42    #[default]
43    /// The [`ProxyFilter`] is optional, and if not present, no proxy is selected.
44    Optional,
45    /// The [`ProxyFilter`] is optional, and if not present, the default [`ProxyFilter`] is used.
46    Default,
47    /// The [`ProxyFilter`] is required, and if not present, an error is returned.
48    Required,
49    /// The [`ProxyFilter`] is optional, and if not present, the provided fallback [`ProxyFilter`] is used.
50    Fallback(ProxyFilter),
51}
52
53impl<S, D, P, F> fmt::Debug for ProxyDBService<S, D, P, F>
54where
55    S: fmt::Debug,
56    D: fmt::Debug,
57    P: fmt::Debug,
58    F: fmt::Debug,
59{
60    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
61        f.debug_struct("ProxyDBService")
62            .field("inner", &self.inner)
63            .field("db", &self.db)
64            .field("mode", &self.mode)
65            .field("predicate", &self.predicate)
66            .field("username_formatter", &self.username_formatter)
67            .field("preserve", &self.preserve)
68            .finish()
69    }
70}
71
72impl<S, D, P, F> Clone for ProxyDBService<S, D, P, F>
73where
74    S: Clone,
75    D: Clone,
76    P: Clone,
77    F: Clone,
78{
79    fn clone(&self) -> Self {
80        Self {
81            inner: self.inner.clone(),
82            db: self.db.clone(),
83            mode: self.mode.clone(),
84            predicate: self.predicate.clone(),
85            username_formatter: self.username_formatter.clone(),
86            preserve: self.preserve,
87        }
88    }
89}
90
91impl<S, D> ProxyDBService<S, D, bool, ()> {
92    /// Create a new [`ProxyDBService`] with the given inner [`Service`] and [`ProxyDB`].
93    pub const fn new(inner: S, db: D) -> Self {
94        Self {
95            inner,
96            db,
97            mode: ProxyFilterMode::Optional,
98            predicate: true,
99            username_formatter: (),
100            preserve: false,
101        }
102    }
103}
104
105impl<S, D, P, F> ProxyDBService<S, D, P, F> {
106    /// Set a [`ProxyFilterMode`] to define the behaviour surrounding
107    /// [`ProxyFilter`] usage, e.g. if a proxy filter is required to be available or not,
108    /// or what to do if it is optional and not available.
109    #[must_use]
110    pub fn filter_mode(mut self, mode: ProxyFilterMode) -> Self {
111        self.mode = mode;
112        self
113    }
114
115    /// Set a [`ProxyFilterMode`] to define the behaviour surrounding
116    /// [`ProxyFilter`] usage, e.g. if a proxy filter is required to be available or not,
117    /// or what to do if it is optional and not available.
118    pub fn set_filter_mode(&mut self, mode: ProxyFilterMode) -> &mut Self {
119        self.mode = mode;
120        self
121    }
122
123    /// Define whether or not an existing [`ProxyAddress`] (in the [`Context`])
124    /// should be overwritten or not. By default `preserve=false`,
125    /// meaning we will overwrite the proxy address in case we selected one now.
126    ///
127    /// NOTE even when `preserve=false` it might still be that there's
128    /// a [`ProxyAddress`] in case it was set by a previous layer.
129    #[must_use]
130    pub const fn preserve_proxy(mut self, preserve: bool) -> Self {
131        self.preserve = preserve;
132        self
133    }
134
135    /// Define whether or not an existing [`ProxyAddress`] (in the [`Context`])
136    /// should be overwritten or not. By default `preserve=false`,
137    /// meaning we will overwrite the proxy address in case we selected one now.
138    ///
139    /// NOTE even when `preserve=false` it might still be that there's
140    /// a [`ProxyAddress`] in case it was set by a previous layer.
141    pub fn set_preserve_proxy(&mut self, preserve: bool) -> &mut Self {
142        self.preserve = preserve;
143        self
144    }
145
146    /// Set a [`ProxyQueryPredicate`] that will be used
147    /// to possibly filter out proxies that according to the filters are correct,
148    /// but not according to the predicate.
149    pub fn select_predicate<Predicate>(self, p: Predicate) -> ProxyDBService<S, D, Predicate, F> {
150        ProxyDBService {
151            inner: self.inner,
152            db: self.db,
153            mode: self.mode,
154            predicate: p,
155            username_formatter: self.username_formatter,
156            preserve: self.preserve,
157        }
158    }
159
160    /// Set a [`UsernameFormatter`][crate::UsernameFormatter] that will be used to format
161    /// the username based on the selected [`Proxy`]. This is required
162    /// in case the proxy is a router that accepts or maybe even requires
163    /// username labels to configure proxies further down/up stream.
164    pub fn username_formatter<Formatter>(self, f: Formatter) -> ProxyDBService<S, D, P, Formatter> {
165        ProxyDBService {
166            inner: self.inner,
167            db: self.db,
168            mode: self.mode,
169            predicate: self.predicate,
170            username_formatter: f,
171            preserve: self.preserve,
172        }
173    }
174
175    define_inner_service_accessors!();
176}
177
178impl<S, D, P, F, State, Request> Service<State, Request> for ProxyDBService<S, D, P, F>
179where
180    S: Service<State, Request, Error: Into<BoxError> + Send + Sync + 'static>,
181    D: ProxyDB<Error: Into<BoxError> + Send + Sync + 'static>,
182    P: ProxyQueryPredicate,
183    F: UsernameFormatter<State>,
184    State: Clone + Send + Sync + 'static,
185    Request:
186        TryRefIntoTransportContext<State, Error: Into<BoxError> + Send + 'static> + Send + 'static,
187{
188    type Response = S::Response;
189    type Error = BoxError;
190
191    async fn serve(
192        &self,
193        mut ctx: Context<State>,
194        req: Request,
195    ) -> Result<Self::Response, Self::Error> {
196        if self.preserve && ctx.contains::<ProxyAddress>() {
197            // shortcut in case a proxy address is already set,
198            // and we wish to preserve it
199            return self.inner.serve(ctx, req).await.map_err(Into::into);
200        }
201
202        let maybe_filter = match self.mode {
203            ProxyFilterMode::Optional => ctx.get::<ProxyFilter>().cloned(),
204            ProxyFilterMode::Default => Some(ctx.get_or_insert_default::<ProxyFilter>().clone()),
205            ProxyFilterMode::Required => Some(
206                ctx.get::<ProxyFilter>()
207                    .cloned()
208                    .context("missing proxy filter")?,
209            ),
210            ProxyFilterMode::Fallback(ref filter) => {
211                Some(ctx.get_or_insert_with(|| filter.clone()).clone())
212            }
213        };
214
215        if let Some(filter) = maybe_filter {
216            let proxy_ctx: ProxyContext = (&*ctx
217                .get_or_try_insert_with_ctx(|ctx| req.try_ref_into_transport_ctx(ctx))
218                .map_err(|err| {
219                    OpaqueError::from_boxed(err.into())
220                        .context("proxydb: select proxy: get transport context")
221                })?)
222                .into();
223            let transport_protocol = proxy_ctx.protocol;
224
225            let proxy = self
226                .db
227                .get_proxy_if(proxy_ctx, filter.clone(), self.predicate.clone())
228                .await
229                .map_err(|err| {
230                    OpaqueError::from_std(ProxySelectError {
231                        inner: err.into(),
232                        filter: filter.clone(),
233                    })
234                })?;
235
236            let mut proxy_address = proxy.address.clone();
237
238            // prepare the credential with labels in username if desired
239            proxy_address.credential = proxy_address.credential.take().map(|credential| {
240                match credential {
241                    ProxyCredential::Basic(ref basic) => {
242                        match self.username_formatter.fmt_username(
243                            &ctx,
244                            &proxy,
245                            &filter,
246                            basic.username(),
247                        ) {
248                            Some(username) => ProxyCredential::Basic(Basic::new(
249                                username,
250                                basic.password().to_owned(),
251                            )),
252                            None => credential, // nothing to do
253                        }
254                    }
255                    ProxyCredential::Bearer(_) => credential, // Remark: we can support this in future too if needed
256                }
257            });
258
259            // overwrite the proxy protocol if not set yet
260            if proxy_address.protocol.is_none() {
261                proxy_address.protocol = match transport_protocol {
262                    TransportProtocol::Udp => {
263                        if proxy.socks5 {
264                            Some(Protocol::SOCKS5)
265                        } else if proxy.socks5h {
266                            Some(Protocol::SOCKS5H)
267                        } else {
268                            return Err(OpaqueError::from_display(
269                                "selected udp proxy does not have a valid protocol available (db bug?!)",
270                            )
271                            .into());
272                        }
273                    }
274                    TransportProtocol::Tcp => match proxy_address.authority.port() {
275                        80 | 8080 if proxy.http => Some(Protocol::HTTP),
276                        443 | 8443 if proxy.https => Some(Protocol::HTTPS),
277                        1080 if proxy.socks5 => Some(Protocol::SOCKS5),
278                        1080 if proxy.socks5h => Some(Protocol::SOCKS5H),
279                        _ => {
280                            // speed: Socks5 > Http > Https
281                            if proxy.socks5 {
282                                Some(Protocol::SOCKS5)
283                            } else if proxy.socks5h {
284                                Some(Protocol::SOCKS5H)
285                            } else if proxy.http {
286                                Some(Protocol::HTTP)
287                            } else if proxy.https {
288                                Some(Protocol::HTTPS)
289                            } else {
290                                return Err(OpaqueError::from_display(
291                                "selected tcp proxy does not have a valid protocol available (db bug?!)",
292                            )
293                            .into());
294                            }
295                        }
296                    },
297                };
298            }
299
300            // insert proxy address in context so it will be used
301            ctx.insert(proxy_address);
302
303            // insert the id of the selected proxy
304            ctx.insert(super::ProxyID::from(proxy.id.clone()));
305
306            // insert the entire proxy also in there, for full "Context"
307            ctx.insert(proxy);
308        }
309
310        self.inner.serve(ctx, req).await.map_err(Into::into)
311    }
312}
313
314#[derive(Debug)]
315struct ProxySelectError {
316    inner: BoxError,
317    filter: ProxyFilter,
318}
319
320impl fmt::Display for ProxySelectError {
321    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
322        write!(
323            f,
324            "proxy select error ({}) for filter: {:?}",
325            self.inner, self.filter
326        )
327    }
328}
329
330impl std::error::Error for ProxySelectError {
331    fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
332        Some(self.inner.source().unwrap_or_else(|| self.inner.as_ref()))
333    }
334}
335
336/// A [`Layer`] which wraps an inner [`Service`] to select a [`Proxy`] based on the given [`Context`],
337/// and insert, if a [`Proxy`] is selected, it in the [`Context`] for further processing.
338///
339/// See [the crate docs](crate) for examples and more info on the usage of this service.
340pub struct ProxyDBLayer<D, P, F> {
341    db: D,
342    mode: ProxyFilterMode,
343    predicate: P,
344    username_formatter: F,
345    preserve: bool,
346}
347
348impl<D, P, F> fmt::Debug for ProxyDBLayer<D, P, F>
349where
350    D: fmt::Debug,
351    P: fmt::Debug,
352    F: fmt::Debug,
353{
354    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
355        f.debug_struct("ProxyDBLayer")
356            .field("db", &self.db)
357            .field("mode", &self.mode)
358            .field("predicate", &self.predicate)
359            .field("username_formatter", &self.username_formatter)
360            .field("preserve", &self.preserve)
361            .finish()
362    }
363}
364
365impl<D, P, F> Clone for ProxyDBLayer<D, P, F>
366where
367    D: Clone,
368    P: Clone,
369    F: Clone,
370{
371    fn clone(&self) -> Self {
372        Self {
373            db: self.db.clone(),
374            mode: self.mode.clone(),
375            predicate: self.predicate.clone(),
376            username_formatter: self.username_formatter.clone(),
377            preserve: self.preserve,
378        }
379    }
380}
381
382impl<D> ProxyDBLayer<D, bool, ()> {
383    /// Create a new [`ProxyDBLayer`] with the given [`ProxyDB`].
384    pub const fn new(db: D) -> Self {
385        Self {
386            db,
387            mode: ProxyFilterMode::Optional,
388            predicate: true,
389            username_formatter: (),
390            preserve: false,
391        }
392    }
393}
394
395impl<D, P, F> ProxyDBLayer<D, P, F> {
396    /// Set a [`ProxyFilterMode`] to define the behaviour surrounding
397    /// [`ProxyFilter`] usage, e.g. if a proxy filter is required to be available or not,
398    /// or what to do if it is optional and not available.
399    #[must_use]
400    pub fn filter_mode(mut self, mode: ProxyFilterMode) -> Self {
401        self.mode = mode;
402        self
403    }
404
405    /// Define whether or not an existing [`ProxyAddress`] (in the [`Context`])
406    /// should be overwritten or not. By default `preserve=false`,
407    /// meaning we will overwrite the proxy address in case we selected one now.
408    ///
409    /// NOTE even when `preserve=false` it might still be that there's
410    /// a [`ProxyAddress`] in case it was set by a previous layer.
411    #[must_use]
412    pub fn preserve_proxy(mut self, preserve: bool) -> Self {
413        self.preserve = preserve;
414        self
415    }
416
417    /// Set a [`ProxyQueryPredicate`] that will be used
418    /// to possibly filter out proxies that according to the filters are correct,
419    /// but not according to the predicate.
420    #[must_use]
421    pub fn select_predicate<Predicate>(self, p: Predicate) -> ProxyDBLayer<D, Predicate, F> {
422        ProxyDBLayer {
423            db: self.db,
424            mode: self.mode,
425            predicate: p,
426            username_formatter: self.username_formatter,
427            preserve: self.preserve,
428        }
429    }
430
431    /// Set a [`UsernameFormatter`][crate::UsernameFormatter] that will be used to format
432    /// the username based on the selected [`Proxy`]. This is required
433    /// in case the proxy is a router that accepts or maybe even requires
434    /// username labels to configure proxies further down/up stream.
435    #[must_use]
436    pub fn username_formatter<Formatter>(self, f: Formatter) -> ProxyDBLayer<D, P, Formatter> {
437        ProxyDBLayer {
438            db: self.db,
439            mode: self.mode,
440            predicate: self.predicate,
441            username_formatter: f,
442            preserve: self.preserve,
443        }
444    }
445}
446
447impl<S, D, P, F> Layer<S> for ProxyDBLayer<D, P, F>
448where
449    D: Clone,
450    P: Clone,
451    F: Clone,
452{
453    type Service = ProxyDBService<S, D, P, F>;
454
455    fn layer(&self, inner: S) -> Self::Service {
456        ProxyDBService {
457            inner,
458            db: self.db.clone(),
459            mode: self.mode.clone(),
460            predicate: self.predicate.clone(),
461            username_formatter: self.username_formatter.clone(),
462            preserve: self.preserve,
463        }
464    }
465
466    fn into_layer(self, inner: S) -> Self::Service {
467        ProxyDBService {
468            inner,
469            db: self.db,
470            mode: self.mode,
471            predicate: self.predicate,
472            username_formatter: self.username_formatter,
473            preserve: self.preserve,
474        }
475    }
476}
477
478/// Trait that is used to allow the formatting of a username,
479/// e.g. to allow proxy routers to have proxy config labels in the username.
480pub trait UsernameFormatter<S>: Send + Sync + 'static {
481    /// format the username based on the root properties of the given proxy.
482    fn fmt_username(
483        &self,
484        ctx: &Context<S>,
485        proxy: &Proxy,
486        filter: &ProxyFilter,
487        username: &str,
488    ) -> Option<String>;
489}
490
491impl<S> UsernameFormatter<S> for () {
492    fn fmt_username(
493        &self,
494        _ctx: &Context<S>,
495        _proxy: &Proxy,
496        _filter: &ProxyFilter,
497        _username: &str,
498    ) -> Option<String> {
499        None
500    }
501}
502
503impl<F, S> UsernameFormatter<S> for F
504where
505    F: Fn(&Context<S>, &Proxy, &ProxyFilter, &str) -> Option<String> + Send + Sync + 'static,
506{
507    fn fmt_username(
508        &self,
509        ctx: &Context<S>,
510        proxy: &Proxy,
511        filter: &ProxyFilter,
512        username: &str,
513    ) -> Option<String> {
514        (self)(ctx, proxy, filter, username)
515    }
516}
517
518#[cfg(test)]
519mod tests {
520    use super::*;
521    use crate::{MemoryProxyDB, Proxy, ProxyCsvRowReader, StringFilter};
522    use itertools::Itertools;
523    use rama_core::service::service_fn;
524    use rama_http_types::{Body, Request, Version};
525    use rama_net::{
526        Protocol,
527        address::{Authority, ProxyAddress},
528        asn::Asn,
529    };
530    use rama_utils::str::NonEmptyString;
531    use std::{convert::Infallible, str::FromStr, sync::Arc};
532
533    #[tokio::test]
534    async fn test_proxy_db_default_happy_path_example() {
535        let db = MemoryProxyDB::try_from_iter([
536            Proxy {
537                id: NonEmptyString::from_static("42"),
538                address: ProxyAddress::from_str("12.34.12.34:8080").unwrap(),
539                tcp: true,
540                udp: true,
541                http: true,
542                https: true,
543                socks5: true,
544                socks5h: true,
545                datacenter: false,
546                residential: true,
547                mobile: true,
548                pool_id: None,
549                continent: Some("*".into()),
550                country: Some("*".into()),
551                state: Some("*".into()),
552                city: Some("*".into()),
553                carrier: Some("*".into()),
554                asn: Some(Asn::unspecified()),
555            },
556            Proxy {
557                id: NonEmptyString::from_static("100"),
558                address: ProxyAddress::from_str("12.34.12.35:8080").unwrap(),
559                tcp: true,
560                udp: false,
561                http: true,
562                https: true,
563                socks5: false,
564                socks5h: false,
565                datacenter: true,
566                residential: false,
567                mobile: false,
568                pool_id: None,
569                continent: Some("americas".into()),
570                country: Some("US".into()),
571                state: None,
572                city: None,
573                carrier: None,
574                asn: Some(Asn::unspecified()),
575            },
576        ])
577        .unwrap();
578
579        let service = ProxyDBLayer::new(Arc::new(db))
580            .filter_mode(ProxyFilterMode::Default)
581            .into_layer(service_fn(async |ctx: Context<()>, _: Request| {
582                Ok::<_, Infallible>(ctx.get::<ProxyAddress>().unwrap().clone())
583            }));
584
585        let mut ctx = Context::default();
586        ctx.insert(ProxyFilter {
587            country: Some(vec!["BE".into()]),
588            mobile: Some(true),
589            residential: Some(true),
590            ..Default::default()
591        });
592
593        let req = Request::builder()
594            .version(Version::HTTP_3)
595            .method("GET")
596            .uri("https://example.com")
597            .body(Body::empty())
598            .unwrap();
599
600        let proxy_address = service.serve(ctx, req).await.unwrap();
601        assert_eq!(
602            proxy_address.authority,
603            Authority::try_from("12.34.12.34:8080").unwrap()
604        );
605    }
606
607    #[tokio::test]
608    async fn test_proxy_db_single_proxy_example() {
609        let proxy = Proxy {
610            id: NonEmptyString::from_static("42"),
611            address: ProxyAddress::from_str("12.34.12.34:8080").unwrap(),
612            tcp: true,
613            udp: true,
614            http: true,
615            https: true,
616            socks5: true,
617            socks5h: true,
618            datacenter: false,
619            residential: true,
620            mobile: true,
621            pool_id: None,
622            continent: Some("*".into()),
623            country: Some("*".into()),
624            state: Some("*".into()),
625            city: Some("*".into()),
626            carrier: Some("*".into()),
627            asn: Some(Asn::unspecified()),
628        };
629
630        let service = ProxyDBLayer::new(Arc::new(proxy))
631            .filter_mode(ProxyFilterMode::Default)
632            .into_layer(service_fn(async |ctx: Context<()>, _: Request| {
633                Ok::<_, Infallible>(ctx.get::<ProxyAddress>().unwrap().clone())
634            }));
635
636        let mut ctx = Context::default();
637        ctx.insert(ProxyFilter {
638            country: Some(vec!["BE".into()]),
639            mobile: Some(true),
640            residential: Some(true),
641            ..Default::default()
642        });
643
644        let req = Request::builder()
645            .version(Version::HTTP_3)
646            .method("GET")
647            .uri("https://example.com")
648            .body(Body::empty())
649            .unwrap();
650
651        let proxy_address = service.serve(ctx, req).await.unwrap();
652        assert_eq!(
653            proxy_address.authority,
654            Authority::try_from("12.34.12.34:8080").unwrap()
655        );
656    }
657
658    #[tokio::test]
659    async fn test_proxy_db_single_proxy_with_username_formatter() {
660        let proxy = Proxy {
661            id: NonEmptyString::from_static("42"),
662            address: ProxyAddress::from_str("john:secret@12.34.12.34:8080").unwrap(),
663            tcp: true,
664            udp: true,
665            http: true,
666            https: true,
667            socks5: true,
668            socks5h: true,
669            datacenter: false,
670            residential: true,
671            mobile: true,
672            pool_id: Some("routers".into()),
673            continent: Some("*".into()),
674            country: Some("*".into()),
675            state: Some("*".into()),
676            city: Some("*".into()),
677            carrier: Some("*".into()),
678            asn: Some(Asn::unspecified()),
679        };
680
681        let service = ProxyDBLayer::new(Arc::new(proxy))
682            .filter_mode(ProxyFilterMode::Default)
683            .username_formatter(
684                |_ctx: &Context<()>, proxy: &Proxy, filter: &ProxyFilter, username: &str| {
685                    if proxy
686                        .pool_id
687                        .as_ref()
688                        .map(|id| id.as_ref() == "routers")
689                        .unwrap_or_default()
690                    {
691                        use std::fmt::Write;
692
693                        let mut output = String::new();
694
695                        if let Some(countries) = filter.country.as_ref().filter(|t| !t.is_empty()) {
696                            let _ = write!(output, "country-{}", countries[0]);
697                        }
698                        if let Some(states) = filter.state.as_ref().filter(|t| !t.is_empty()) {
699                            let _ = write!(output, "state-{}", states[0]);
700                        }
701
702                        return (!output.is_empty()).then(|| format!("{username}-{output}"));
703                    }
704
705                    None
706                },
707            )
708            .into_layer(service_fn(async |ctx: Context<()>, _: Request| {
709                Ok::<_, Infallible>(ctx.get::<ProxyAddress>().unwrap().clone())
710            }));
711
712        let mut ctx = Context::default();
713        ctx.insert(ProxyFilter {
714            country: Some(vec!["BE".into()]),
715            mobile: Some(true),
716            residential: Some(true),
717            ..Default::default()
718        });
719
720        let req = Request::builder()
721            .version(Version::HTTP_3)
722            .method("GET")
723            .uri("https://example.com")
724            .body(Body::empty())
725            .unwrap();
726
727        let proxy_address = service.serve(ctx, req).await.unwrap();
728        assert_eq!(
729            "socks5://john-country-be:secret@12.34.12.34:8080",
730            proxy_address.to_string()
731        );
732    }
733
734    #[tokio::test]
735    async fn test_proxy_db_default_happy_path_example_transport_layer() {
736        let db = MemoryProxyDB::try_from_iter([
737            Proxy {
738                id: NonEmptyString::from_static("42"),
739                address: ProxyAddress::from_str("12.34.12.34:8080").unwrap(),
740                tcp: true,
741                udp: true,
742                http: true,
743                https: true,
744                socks5: true,
745                socks5h: true,
746                datacenter: false,
747                residential: true,
748                mobile: true,
749                pool_id: None,
750                continent: Some("*".into()),
751                country: Some("*".into()),
752                state: Some("*".into()),
753                city: Some("*".into()),
754                carrier: Some("*".into()),
755                asn: Some(Asn::unspecified()),
756            },
757            Proxy {
758                id: NonEmptyString::from_static("100"),
759                address: ProxyAddress::from_str("12.34.12.35:8080").unwrap(),
760                tcp: true,
761                udp: false,
762                http: true,
763                https: true,
764                socks5: false,
765                socks5h: false,
766                datacenter: true,
767                residential: false,
768                mobile: false,
769                pool_id: None,
770                continent: Some("americas".into()),
771                country: Some("US".into()),
772                state: None,
773                city: None,
774                carrier: None,
775                asn: Some(Asn::unspecified()),
776            },
777        ])
778        .unwrap();
779
780        let service = ProxyDBLayer::new(Arc::new(db))
781            .filter_mode(ProxyFilterMode::Default)
782            .into_layer(service_fn(async |ctx: Context<()>, _| {
783                Ok::<_, Infallible>(ctx.get::<ProxyAddress>().unwrap().clone())
784            }));
785
786        let mut ctx = Context::default();
787        ctx.insert(ProxyFilter {
788            country: Some(vec!["BE".into()]),
789            mobile: Some(true),
790            residential: Some(true),
791            ..Default::default()
792        });
793
794        let req = rama_tcp::client::Request::new("www.example.com:443".parse().unwrap())
795            .with_protocol(Protocol::HTTPS);
796
797        let proxy_address = service.serve(ctx, req).await.unwrap();
798        assert_eq!(
799            proxy_address.authority,
800            Authority::try_from("12.34.12.34:8080").unwrap()
801        );
802    }
803
804    const RAW_CSV_DATA: &str = include_str!("./test_proxydb_rows.csv");
805
806    async fn memproxydb() -> MemoryProxyDB {
807        let mut reader = ProxyCsvRowReader::raw(RAW_CSV_DATA);
808        let mut rows = Vec::new();
809        while let Some(proxy) = reader.next().await.unwrap() {
810            rows.push(proxy);
811        }
812        MemoryProxyDB::try_from_rows(rows).unwrap()
813    }
814
815    #[tokio::test]
816    async fn test_proxy_db_service_preserve_proxy_address() {
817        let db = memproxydb().await;
818
819        let service = ProxyDBLayer::new(Arc::new(db))
820            .preserve_proxy(true)
821            .filter_mode(ProxyFilterMode::Default)
822            .into_layer(service_fn(async |ctx: Context<()>, _: Request| {
823                Ok::<_, Infallible>(ctx.get::<ProxyAddress>().unwrap().clone())
824            }));
825
826        let mut ctx = Context::default();
827        ctx.insert(ProxyAddress::try_from("http://john:secret@1.2.3.4:1234").unwrap());
828
829        let req = Request::builder()
830            .version(Version::HTTP_11)
831            .method("GET")
832            .uri("http://example.com")
833            .body(Body::empty())
834            .unwrap();
835
836        let proxy_address = service.serve(ctx, req).await.unwrap();
837
838        assert_eq!(proxy_address.authority.to_string(), "1.2.3.4:1234");
839    }
840
841    #[tokio::test]
842    async fn test_proxy_db_service_optional() {
843        let db = memproxydb().await;
844
845        let service = ProxyDBLayer::new(Arc::new(db)).into_layer(service_fn(
846            async |ctx: Context<()>, _: Request| {
847                Ok::<_, Infallible>(ctx.get::<ProxyAddress>().cloned())
848            },
849        ));
850
851        for (filter, expected_authority, req) in [
852            (
853                None,
854                None,
855                Request::builder()
856                    .version(Version::HTTP_11)
857                    .method("GET")
858                    .uri("http://example.com")
859                    .body(Body::empty())
860                    .unwrap(),
861            ),
862            (
863                Some(ProxyFilter {
864                    id: Some(NonEmptyString::from_static("3031533634")),
865                    ..Default::default()
866                }),
867                Some("105.150.55.60:4898"),
868                Request::builder()
869                    .version(Version::HTTP_11)
870                    .method("GET")
871                    .uri("http://example.com")
872                    .body(Body::empty())
873                    .unwrap(),
874            ),
875            (
876                Some(ProxyFilter {
877                    country: Some(vec![StringFilter::new("BE")]),
878                    mobile: Some(true),
879                    residential: Some(true),
880                    ..Default::default()
881                }),
882                Some("140.249.154.18:5800"),
883                Request::builder()
884                    .version(Version::HTTP_3)
885                    .method("GET")
886                    .uri("https://example.com")
887                    .body(Body::empty())
888                    .unwrap(),
889            ),
890        ] {
891            let mut ctx = Context::default();
892            ctx.maybe_insert(filter);
893
894            let maybe_proxy_address = service.serve(ctx, req).await.unwrap();
895
896            assert_eq!(
897                maybe_proxy_address.map(|p| p.authority),
898                expected_authority.map(|s| Authority::try_from(s).unwrap())
899            );
900        }
901    }
902
903    #[tokio::test]
904    async fn test_proxy_db_service_default() {
905        let db = memproxydb().await;
906
907        let service = ProxyDBLayer::new(Arc::new(db))
908            .filter_mode(ProxyFilterMode::Default)
909            .into_layer(service_fn(async |ctx: Context<()>, _: Request| {
910                Ok::<_, Infallible>(ctx.get::<ProxyAddress>().unwrap().clone())
911            }));
912
913        for (filter, expected_addresses, req_info) in [
914            (
915                None,
916                "0.20.204.227:8373,104.207.92.167:9387,105.150.55.60:4898,106.213.197.28:9110,113.6.21.212:4525,115.29.251.35:5712,119.146.94.132:7851,129.204.152.130:6524,134.190.189.202:5772,136.186.95.10:7095,137.220.180.169:4929,140.249.154.18:5800,145.57.31.149:6304,151.254.135.9:6961,153.206.209.221:8696,162.97.174.152:1673,169.179.161.206:6843,171.174.56.89:5744,178.189.117.217:6496,182.34.76.182:2374,184.209.230.177:1358,193.188.239.29:3541,193.26.37.125:3780,204.168.216.113:1096,208.224.120.97:7118,209.176.177.182:4311,215.49.63.89:9458,223.234.242.63:7211,230.159.143.41:7296,233.22.59.115:1653,24.155.249.112:2645,247.118.71.100:1033,249.221.15.121:7434,252.69.242.136:4791,253.138.153.41:2640,28.139.151.127:2809,4.20.243.186:9155,42.54.35.118:6846,45.59.69.12:5934,46.247.45.238:3522,54.226.47.54:7442,61.112.212.160:3842,66.142.40.209:4251,66.171.139.181:4449,69.246.162.84:8964,75.43.123.181:7719,76.128.58.167:4797,85.14.163.105:8362,92.227.104.237:6161,97.192.206.72:6067",
917                (Version::HTTP_11, "GET", "http://example.com"),
918            ),
919            (
920                Some(ProxyFilter {
921                    country: Some(vec![StringFilter::new("BE")]),
922                    mobile: Some(true),
923                    residential: Some(true),
924                    ..Default::default()
925                }),
926                "140.249.154.18:5800",
927                (Version::HTTP_3, "GET", "https://example.com"),
928            ),
929        ] {
930            let mut seen_addresses = Vec::new();
931            for _ in 0..5000 {
932                let mut ctx = Context::default();
933                ctx.maybe_insert(filter.clone());
934
935                let req = Request::builder()
936                    .version(req_info.0)
937                    .method(req_info.1)
938                    .uri(req_info.2)
939                    .body(Body::empty())
940                    .unwrap();
941
942                let proxy_address = service.serve(ctx, req).await.unwrap().authority.to_string();
943                if !seen_addresses.contains(&proxy_address) {
944                    seen_addresses.push(proxy_address);
945                }
946            }
947
948            let seen_addresses = seen_addresses.into_iter().sorted().join(",");
949            assert_eq!(seen_addresses, expected_addresses);
950        }
951    }
952
953    #[tokio::test]
954    async fn test_proxy_db_service_fallback() {
955        let db = memproxydb().await;
956
957        let service = ProxyDBLayer::new(Arc::new(db))
958            .filter_mode(ProxyFilterMode::Fallback(ProxyFilter {
959                datacenter: Some(true),
960                residential: Some(false),
961                mobile: Some(false),
962                ..Default::default()
963            }))
964            .into_layer(service_fn(async |ctx: Context<()>, _: Request| {
965                Ok::<_, Infallible>(ctx.get::<ProxyAddress>().unwrap().clone())
966            }));
967
968        for (filter, expected_addresses, req_info) in [
969            (
970                None,
971                "113.6.21.212:4525,119.146.94.132:7851,136.186.95.10:7095,137.220.180.169:4929,247.118.71.100:1033,249.221.15.121:7434,92.227.104.237:6161",
972                (Version::HTTP_11, "GET", "http://example.com"),
973            ),
974            (
975                Some(ProxyFilter {
976                    country: Some(vec![StringFilter::new("BE")]),
977                    mobile: Some(true),
978                    residential: Some(true),
979                    ..Default::default()
980                }),
981                "140.249.154.18:5800",
982                (Version::HTTP_3, "GET", "https://example.com"),
983            ),
984        ] {
985            let mut seen_addresses = Vec::new();
986            for _ in 0..5000 {
987                let mut ctx = Context::default();
988                ctx.maybe_insert(filter.clone());
989
990                let req = Request::builder()
991                    .version(req_info.0)
992                    .method(req_info.1)
993                    .uri(req_info.2)
994                    .body(Body::empty())
995                    .unwrap();
996
997                let proxy_address = service.serve(ctx, req).await.unwrap().authority.to_string();
998                if !seen_addresses.contains(&proxy_address) {
999                    seen_addresses.push(proxy_address);
1000                }
1001            }
1002
1003            let seen_addresses = seen_addresses.into_iter().sorted().join(",");
1004            assert_eq!(seen_addresses, expected_addresses);
1005        }
1006    }
1007
1008    #[tokio::test]
1009    async fn test_proxy_db_service_required() {
1010        let db = memproxydb().await;
1011
1012        let service = ProxyDBLayer::new(Arc::new(db))
1013            .filter_mode(ProxyFilterMode::Required)
1014            .into_layer(service_fn(async |ctx: Context<()>, _: Request| {
1015                Ok::<_, Infallible>(ctx.get::<ProxyAddress>().unwrap().clone())
1016            }));
1017
1018        for (filter, expected_address, req) in [
1019            (
1020                None,
1021                None,
1022                Request::builder()
1023                    .version(Version::HTTP_11)
1024                    .method("GET")
1025                    .uri("http://example.com")
1026                    .body(Body::empty())
1027                    .unwrap(),
1028            ),
1029            (
1030                Some(ProxyFilter {
1031                    country: Some(vec![StringFilter::new("BE")]),
1032                    mobile: Some(true),
1033                    residential: Some(true),
1034                    ..Default::default()
1035                }),
1036                Some("140.249.154.18:5800"),
1037                Request::builder()
1038                    .version(Version::HTTP_3)
1039                    .method("GET")
1040                    .uri("https://example.com")
1041                    .body(Body::empty())
1042                    .unwrap(),
1043            ),
1044            (
1045                Some(ProxyFilter {
1046                    id: Some(NonEmptyString::from_static("FooBar")),
1047                    ..Default::default()
1048                }),
1049                None,
1050                Request::builder()
1051                    .version(Version::HTTP_3)
1052                    .method("GET")
1053                    .uri("https://example.com")
1054                    .body(Body::empty())
1055                    .unwrap(),
1056            ),
1057            (
1058                Some(ProxyFilter {
1059                    id: Some(NonEmptyString::from_static("1316455915")),
1060                    country: Some(vec![StringFilter::new("BE")]),
1061                    mobile: Some(true),
1062                    residential: Some(true),
1063                    ..Default::default()
1064                }),
1065                None,
1066                Request::builder()
1067                    .version(Version::HTTP_3)
1068                    .method("GET")
1069                    .uri("https://example.com")
1070                    .body(Body::empty())
1071                    .unwrap(),
1072            ),
1073        ] {
1074            let mut ctx = Context::default();
1075            ctx.maybe_insert(filter.clone());
1076
1077            let proxy_address_result = service.serve(ctx, req).await;
1078            match expected_address {
1079                Some(expected_address) => {
1080                    assert_eq!(
1081                        proxy_address_result.unwrap().authority,
1082                        Authority::try_from(expected_address).unwrap()
1083                    );
1084                }
1085                None => {
1086                    assert!(proxy_address_result.is_err());
1087                }
1088            }
1089        }
1090    }
1091
1092    #[tokio::test]
1093    async fn test_proxy_db_service_required_with_predicate() {
1094        let db = memproxydb().await;
1095
1096        let service = ProxyDBLayer::new(Arc::new(db))
1097            .filter_mode(ProxyFilterMode::Required)
1098            .select_predicate(|proxy: &Proxy| proxy.mobile)
1099            .into_layer(service_fn(async |ctx: Context<()>, _: Request| {
1100                Ok::<_, Infallible>(ctx.get::<ProxyAddress>().unwrap().clone())
1101            }));
1102
1103        for (filter, expected, req) in [
1104            (
1105                None,
1106                None,
1107                Request::builder()
1108                    .version(Version::HTTP_11)
1109                    .method("GET")
1110                    .uri("http://example.com")
1111                    .body(Body::empty())
1112                    .unwrap(),
1113            ),
1114            (
1115                Some(ProxyFilter {
1116                    country: Some(vec![StringFilter::new("BE")]),
1117                    mobile: Some(true),
1118                    residential: Some(true),
1119                    ..Default::default()
1120                }),
1121                Some("140.249.154.18:5800"),
1122                Request::builder()
1123                    .version(Version::HTTP_3)
1124                    .method("GET")
1125                    .uri("https://example.com")
1126                    .body(Body::empty())
1127                    .unwrap(),
1128            ),
1129            (
1130                Some(ProxyFilter {
1131                    id: Some(NonEmptyString::from_static("FooBar")),
1132                    ..Default::default()
1133                }),
1134                None,
1135                Request::builder()
1136                    .version(Version::HTTP_3)
1137                    .method("GET")
1138                    .uri("https://example.com")
1139                    .body(Body::empty())
1140                    .unwrap(),
1141            ),
1142            (
1143                Some(ProxyFilter {
1144                    id: Some(NonEmptyString::from_static("1316455915")),
1145                    country: Some(vec![StringFilter::new("BE")]),
1146                    mobile: Some(true),
1147                    residential: Some(true),
1148                    ..Default::default()
1149                }),
1150                None,
1151                Request::builder()
1152                    .version(Version::HTTP_3)
1153                    .method("GET")
1154                    .uri("https://example.com")
1155                    .body(Body::empty())
1156                    .unwrap(),
1157            ),
1158            // match found, but due to custom predicate it won't check, given it is not mobile
1159            (
1160                Some(ProxyFilter {
1161                    id: Some(NonEmptyString::from_static("1316455915")),
1162                    ..Default::default()
1163                }),
1164                None,
1165                Request::builder()
1166                    .version(Version::HTTP_3)
1167                    .method("GET")
1168                    .uri("https://example.com")
1169                    .body(Body::empty())
1170                    .unwrap(),
1171            ),
1172        ] {
1173            let mut ctx = Context::default();
1174            ctx.maybe_insert(filter);
1175
1176            let proxy_result = service.serve(ctx, req).await;
1177            match expected {
1178                Some(expected_address) => {
1179                    assert_eq!(
1180                        proxy_result.unwrap().authority,
1181                        Authority::try_from(expected_address).unwrap()
1182                    );
1183                }
1184                None => {
1185                    assert!(proxy_result.is_err());
1186                }
1187            }
1188        }
1189    }
1190}