rama_proxy/proxydb/
layer.rs

1use super::{Proxy, ProxyContext, ProxyDB, ProxyFilter, ProxyQueryPredicate};
2use rama_core::{
3    Context, Layer, Service,
4    error::{BoxError, ErrorContext, ErrorExt, OpaqueError},
5};
6use rama_net::{
7    Protocol,
8    address::ProxyAddress,
9    transport::{TransportProtocol, TryRefIntoTransportContext},
10    user::{Basic, ProxyCredential},
11};
12use rama_utils::macros::define_inner_service_accessors;
13use std::fmt;
14
15/// A [`Service`] which selects a [`Proxy`] based on the given [`Context`].
16///
17/// Depending on the [`ProxyFilterMode`] the selection proxies might be optional,
18/// or use the default [`ProxyFilter`] in case none is defined.
19///
20/// A predicate can be used to provide additional filtering on the found proxies,
21/// that otherwise did match the used [`ProxyFilter`].
22///
23/// See [the crate docs](crate) for examples and more info on the usage of this service.
24///
25/// [`Proxy`]: crate::Proxy
26pub struct ProxyDBService<S, D, P, F> {
27    inner: S,
28    db: D,
29    mode: ProxyFilterMode,
30    predicate: P,
31    username_formatter: F,
32    preserve: bool,
33}
34
35#[derive(Debug, Clone, Default)]
36/// The modus operandi to decide how to deal with a missing [`ProxyFilter`] in the [`Context`]
37/// when selecting a [`Proxy`] from the [`ProxyDB`].
38///
39/// More advanced behaviour can be achieved by combining one of these modi
40/// with another (custom) layer prepending the parent.
41pub enum ProxyFilterMode {
42    #[default]
43    /// The [`ProxyFilter`] is optional, and if not present, no proxy is selected.
44    Optional,
45    /// The [`ProxyFilter`] is optional, and if not present, the default [`ProxyFilter`] is used.
46    Default,
47    /// The [`ProxyFilter`] is required, and if not present, an error is returned.
48    Required,
49    /// The [`ProxyFilter`] is optional, and if not present, the provided fallback [`ProxyFilter`] is used.
50    Fallback(ProxyFilter),
51}
52
53impl<S, D, P, F> fmt::Debug for ProxyDBService<S, D, P, F>
54where
55    S: fmt::Debug,
56    D: fmt::Debug,
57    P: fmt::Debug,
58    F: fmt::Debug,
59{
60    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
61        f.debug_struct("ProxyDBService")
62            .field("inner", &self.inner)
63            .field("db", &self.db)
64            .field("mode", &self.mode)
65            .field("predicate", &self.predicate)
66            .field("username_formatter", &self.username_formatter)
67            .field("preserve", &self.preserve)
68            .finish()
69    }
70}
71
72impl<S, D, P, F> Clone for ProxyDBService<S, D, P, F>
73where
74    S: Clone,
75    D: Clone,
76    P: Clone,
77    F: Clone,
78{
79    fn clone(&self) -> Self {
80        Self {
81            inner: self.inner.clone(),
82            db: self.db.clone(),
83            mode: self.mode.clone(),
84            predicate: self.predicate.clone(),
85            username_formatter: self.username_formatter.clone(),
86            preserve: self.preserve,
87        }
88    }
89}
90
91impl<S, D> ProxyDBService<S, D, bool, ()> {
92    /// Create a new [`ProxyDBService`] with the given inner [`Service`] and [`ProxyDB`].
93    pub const fn new(inner: S, db: D) -> Self {
94        Self {
95            inner,
96            db,
97            mode: ProxyFilterMode::Optional,
98            predicate: true,
99            username_formatter: (),
100            preserve: false,
101        }
102    }
103}
104
105impl<S, D, P, F> ProxyDBService<S, D, P, F> {
106    /// Set a [`ProxyFilterMode`] to define the behaviour surrounding
107    /// [`ProxyFilter`] usage, e.g. if a proxy filter is required to be available or not,
108    /// or what to do if it is optional and not available.
109    pub fn filter_mode(mut self, mode: ProxyFilterMode) -> Self {
110        self.mode = mode;
111        self
112    }
113
114    /// Set a [`ProxyFilterMode`] to define the behaviour surrounding
115    /// [`ProxyFilter`] usage, e.g. if a proxy filter is required to be available or not,
116    /// or what to do if it is optional and not available.
117    pub fn set_filter_mode(&mut self, mode: ProxyFilterMode) -> &mut Self {
118        self.mode = mode;
119        self
120    }
121
122    /// Define whether or not an existing [`ProxyAddress`] (in the [`Context`])
123    /// should be overwritten or not. By default `preserve=false`,
124    /// meaning we will overwrite the proxy address in case we selected one now.
125    ///
126    /// NOTE even when `preserve=false` it might still be that there's
127    /// a [`ProxyAddress`] in case it was set by a previous layer.
128    pub const fn preserve_proxy(mut self, preserve: bool) -> Self {
129        self.preserve = preserve;
130        self
131    }
132
133    /// Define whether or not an existing [`ProxyAddress`] (in the [`Context`])
134    /// should be overwritten or not. By default `preserve=false`,
135    /// meaning we will overwrite the proxy address in case we selected one now.
136    ///
137    /// NOTE even when `preserve=false` it might still be that there's
138    /// a [`ProxyAddress`] in case it was set by a previous layer.
139    pub fn set_preserve_proxy(&mut self, preserve: bool) -> &mut Self {
140        self.preserve = preserve;
141        self
142    }
143
144    /// Set a [`ProxyQueryPredicate`] that will be used
145    /// to possibly filter out proxies that according to the filters are correct,
146    /// but not according to the predicate.
147    pub fn select_predicate<Predicate>(self, p: Predicate) -> ProxyDBService<S, D, Predicate, F> {
148        ProxyDBService {
149            inner: self.inner,
150            db: self.db,
151            mode: self.mode,
152            predicate: p,
153            username_formatter: self.username_formatter,
154            preserve: self.preserve,
155        }
156    }
157
158    /// Set a [`UsernameFormatter`][crate::UsernameFormatter] that will be used to format
159    /// the username based on the selected [`Proxy`]. This is required
160    /// in case the proxy is a router that accepts or maybe even requires
161    /// username labels to configure proxies further down/up stream.
162    pub fn username_formatter<Formatter>(self, f: Formatter) -> ProxyDBService<S, D, P, Formatter> {
163        ProxyDBService {
164            inner: self.inner,
165            db: self.db,
166            mode: self.mode,
167            predicate: self.predicate,
168            username_formatter: f,
169            preserve: self.preserve,
170        }
171    }
172
173    define_inner_service_accessors!();
174}
175
176impl<S, D, P, F, State, Request> Service<State, Request> for ProxyDBService<S, D, P, F>
177where
178    S: Service<State, Request, Error: Into<BoxError> + Send + Sync + 'static>,
179    D: ProxyDB<Error: Into<BoxError> + Send + Sync + 'static>,
180    P: ProxyQueryPredicate,
181    F: UsernameFormatter<State>,
182    State: Clone + Send + Sync + 'static,
183    Request: TryRefIntoTransportContext<State, Error: Into<BoxError> + Send + Sync + 'static>
184        + Send
185        + 'static,
186{
187    type Response = S::Response;
188    type Error = BoxError;
189
190    async fn serve(
191        &self,
192        mut ctx: Context<State>,
193        req: Request,
194    ) -> Result<Self::Response, Self::Error> {
195        if self.preserve && ctx.contains::<ProxyAddress>() {
196            // shortcut in case a proxy address is already set,
197            // and we wish to preserve it
198            return self.inner.serve(ctx, req).await.map_err(Into::into);
199        }
200
201        let maybe_filter = match self.mode {
202            ProxyFilterMode::Optional => ctx.get::<ProxyFilter>().cloned(),
203            ProxyFilterMode::Default => Some(ctx.get_or_insert_default::<ProxyFilter>().clone()),
204            ProxyFilterMode::Required => Some(
205                ctx.get::<ProxyFilter>()
206                    .cloned()
207                    .context("missing proxy filter")?,
208            ),
209            ProxyFilterMode::Fallback(ref filter) => {
210                Some(ctx.get_or_insert_with(|| filter.clone()).clone())
211            }
212        };
213
214        if let Some(filter) = maybe_filter {
215            let proxy_ctx: ProxyContext = (&*ctx
216                .get_or_try_insert_with_ctx(|ctx| req.try_ref_into_transport_ctx(ctx))
217                .map_err(|err| {
218                    OpaqueError::from_boxed(err.into())
219                        .context("proxydb: select proxy: get transport context")
220                })?)
221                .into();
222            let transport_protocol = proxy_ctx.protocol;
223
224            let proxy = self
225                .db
226                .get_proxy_if(proxy_ctx, filter.clone(), self.predicate.clone())
227                .await
228                .map_err(|err| {
229                    OpaqueError::from_std(ProxySelectError {
230                        inner: err.into(),
231                        filter: filter.clone(),
232                    })
233                })?;
234
235            let mut proxy_address = proxy.address.clone();
236
237            // prepare the credential with labels in username if desired
238            proxy_address.credential = proxy_address.credential.take().map(|credential| {
239                match credential {
240                    ProxyCredential::Basic(ref basic) => {
241                        match self.username_formatter.fmt_username(
242                            &ctx,
243                            &proxy,
244                            &filter,
245                            basic.username(),
246                        ) {
247                            Some(username) => ProxyCredential::Basic(Basic::new(
248                                username,
249                                basic.password().to_owned(),
250                            )),
251                            None => credential, // nothing to do
252                        }
253                    }
254                    ProxyCredential::Bearer(_) => credential, // Remark: we can support this in future too if needed
255                }
256            });
257
258            // overwrite the proxy protocol if not set yet
259            if proxy_address.protocol.is_none() {
260                proxy_address.protocol = match transport_protocol {
261                    TransportProtocol::Udp => {
262                        if proxy.socks5 {
263                            Some(Protocol::SOCKS5)
264                        } else if proxy.socks5h {
265                            Some(Protocol::SOCKS5H)
266                        } else {
267                            return Err(OpaqueError::from_display(
268                                "selected udp proxy does not have a valid protocol available (db bug?!)",
269                            )
270                            .into());
271                        }
272                    }
273                    TransportProtocol::Tcp => match proxy_address.authority.port() {
274                        80 | 8080 if proxy.http => Some(Protocol::HTTP),
275                        443 | 8443 if proxy.https => Some(Protocol::HTTPS),
276                        1080 if proxy.socks5 => Some(Protocol::SOCKS5),
277                        1080 if proxy.socks5h => Some(Protocol::SOCKS5H),
278                        _ => {
279                            // speed: Socks5 > Http > Https
280                            if proxy.socks5 {
281                                Some(Protocol::SOCKS5)
282                            } else if proxy.socks5h {
283                                Some(Protocol::SOCKS5H)
284                            } else if proxy.http {
285                                Some(Protocol::HTTP)
286                            } else if proxy.https {
287                                Some(Protocol::HTTPS)
288                            } else {
289                                return Err(OpaqueError::from_display(
290                                "selected tcp proxy does not have a valid protocol available (db bug?!)",
291                            )
292                            .into());
293                            }
294                        }
295                    },
296                };
297            }
298
299            // insert proxy address in context so it will be used
300            ctx.insert(proxy_address);
301
302            // insert the id of the selected proxy
303            ctx.insert(super::ProxyID::from(proxy.id.clone()));
304
305            // insert the entire proxy also in there, for full "Context"
306            ctx.insert(proxy);
307        }
308
309        self.inner.serve(ctx, req).await.map_err(Into::into)
310    }
311}
312
313#[derive(Debug)]
314struct ProxySelectError {
315    inner: BoxError,
316    filter: ProxyFilter,
317}
318
319impl fmt::Display for ProxySelectError {
320    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
321        write!(
322            f,
323            "proxy select error ({}) for filter: {:?}",
324            self.inner, self.filter
325        )
326    }
327}
328
329impl std::error::Error for ProxySelectError {
330    fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
331        Some(self.inner.source().unwrap_or_else(|| self.inner.as_ref()))
332    }
333}
334
335/// A [`Layer`] which wraps an inner [`Service`] to select a [`Proxy`] based on the given [`Context`],
336/// and insert, if a [`Proxy`] is selected, it in the [`Context`] for further processing.
337///
338/// See [the crate docs](crate) for examples and more info on the usage of this service.
339pub struct ProxyDBLayer<D, P, F> {
340    db: D,
341    mode: ProxyFilterMode,
342    predicate: P,
343    username_formatter: F,
344    preserve: bool,
345}
346
347impl<D, P, F> fmt::Debug for ProxyDBLayer<D, P, F>
348where
349    D: fmt::Debug,
350    P: fmt::Debug,
351    F: fmt::Debug,
352{
353    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
354        f.debug_struct("ProxyDBLayer")
355            .field("db", &self.db)
356            .field("mode", &self.mode)
357            .field("predicate", &self.predicate)
358            .field("username_formatter", &self.username_formatter)
359            .field("preserve", &self.preserve)
360            .finish()
361    }
362}
363
364impl<D, P, F> Clone for ProxyDBLayer<D, P, F>
365where
366    D: Clone,
367    P: Clone,
368    F: Clone,
369{
370    fn clone(&self) -> Self {
371        Self {
372            db: self.db.clone(),
373            mode: self.mode.clone(),
374            predicate: self.predicate.clone(),
375            username_formatter: self.username_formatter.clone(),
376            preserve: self.preserve,
377        }
378    }
379}
380
381impl<D> ProxyDBLayer<D, bool, ()> {
382    /// Create a new [`ProxyDBLayer`] with the given [`ProxyDB`].
383    pub const fn new(db: D) -> Self {
384        Self {
385            db,
386            mode: ProxyFilterMode::Optional,
387            predicate: true,
388            username_formatter: (),
389            preserve: false,
390        }
391    }
392}
393
394impl<D, P, F> ProxyDBLayer<D, P, F> {
395    /// Set a [`ProxyFilterMode`] to define the behaviour surrounding
396    /// [`ProxyFilter`] usage, e.g. if a proxy filter is required to be available or not,
397    /// or what to do if it is optional and not available.
398    pub fn filter_mode(mut self, mode: ProxyFilterMode) -> Self {
399        self.mode = mode;
400        self
401    }
402
403    /// Define whether or not an existing [`ProxyAddress`] (in the [`Context`])
404    /// should be overwritten or not. By default `preserve=false`,
405    /// meaning we will overwrite the proxy address in case we selected one now.
406    ///
407    /// NOTE even when `preserve=false` it might still be that there's
408    /// a [`ProxyAddress`] in case it was set by a previous layer.
409    pub fn preserve_proxy(mut self, preserve: bool) -> Self {
410        self.preserve = preserve;
411        self
412    }
413
414    /// Set a [`ProxyQueryPredicate`] that will be used
415    /// to possibly filter out proxies that according to the filters are correct,
416    /// but not according to the predicate.
417    pub fn select_predicate<Predicate>(self, p: Predicate) -> ProxyDBLayer<D, Predicate, F> {
418        ProxyDBLayer {
419            db: self.db,
420            mode: self.mode,
421            predicate: p,
422            username_formatter: self.username_formatter,
423            preserve: self.preserve,
424        }
425    }
426
427    /// Set a [`UsernameFormatter`][crate::UsernameFormatter] that will be used to format
428    /// the username based on the selected [`Proxy`]. This is required
429    /// in case the proxy is a router that accepts or maybe even requires
430    /// username labels to configure proxies further down/up stream.
431    pub fn username_formatter<Formatter>(self, f: Formatter) -> ProxyDBLayer<D, P, Formatter> {
432        ProxyDBLayer {
433            db: self.db,
434            mode: self.mode,
435            predicate: self.predicate,
436            username_formatter: f,
437            preserve: self.preserve,
438        }
439    }
440}
441
442impl<S, D, P, F> Layer<S> for ProxyDBLayer<D, P, F>
443where
444    D: Clone,
445    P: Clone,
446    F: Clone,
447{
448    type Service = ProxyDBService<S, D, P, F>;
449
450    fn layer(&self, inner: S) -> Self::Service {
451        ProxyDBService {
452            inner,
453            db: self.db.clone(),
454            mode: self.mode.clone(),
455            predicate: self.predicate.clone(),
456            username_formatter: self.username_formatter.clone(),
457            preserve: self.preserve,
458        }
459    }
460
461    fn into_layer(self, inner: S) -> Self::Service {
462        ProxyDBService {
463            inner,
464            db: self.db,
465            mode: self.mode,
466            predicate: self.predicate,
467            username_formatter: self.username_formatter,
468            preserve: self.preserve,
469        }
470    }
471}
472
473/// Trait that is used to allow the formatting of a username,
474/// e.g. to allow proxy routers to have proxy config labels in the username.
475pub trait UsernameFormatter<S>: Send + Sync + 'static {
476    /// format the username based on the root properties of the given proxy.
477    fn fmt_username(
478        &self,
479        ctx: &Context<S>,
480        proxy: &Proxy,
481        filter: &ProxyFilter,
482        username: &str,
483    ) -> Option<String>;
484}
485
486impl<S> UsernameFormatter<S> for () {
487    fn fmt_username(
488        &self,
489        _ctx: &Context<S>,
490        _proxy: &Proxy,
491        _filter: &ProxyFilter,
492        _username: &str,
493    ) -> Option<String> {
494        None
495    }
496}
497
498impl<F, S> UsernameFormatter<S> for F
499where
500    F: Fn(&Context<S>, &Proxy, &ProxyFilter, &str) -> Option<String> + Send + Sync + 'static,
501{
502    fn fmt_username(
503        &self,
504        ctx: &Context<S>,
505        proxy: &Proxy,
506        filter: &ProxyFilter,
507        username: &str,
508    ) -> Option<String> {
509        (self)(ctx, proxy, filter, username)
510    }
511}
512
513#[cfg(test)]
514mod tests {
515    use super::*;
516    use crate::{MemoryProxyDB, Proxy, ProxyCsvRowReader, StringFilter};
517    use itertools::Itertools;
518    use rama_core::service::service_fn;
519    use rama_http_types::{Body, Request, Version};
520    use rama_net::{
521        Protocol,
522        address::{Authority, ProxyAddress},
523        asn::Asn,
524    };
525    use rama_utils::str::NonEmptyString;
526    use std::{convert::Infallible, str::FromStr, sync::Arc};
527
528    #[tokio::test]
529    async fn test_proxy_db_default_happy_path_example() {
530        let db = MemoryProxyDB::try_from_iter([
531            Proxy {
532                id: NonEmptyString::from_static("42"),
533                address: ProxyAddress::from_str("12.34.12.34:8080").unwrap(),
534                tcp: true,
535                udp: true,
536                http: true,
537                https: true,
538                socks5: true,
539                socks5h: true,
540                datacenter: false,
541                residential: true,
542                mobile: true,
543                pool_id: None,
544                continent: Some("*".into()),
545                country: Some("*".into()),
546                state: Some("*".into()),
547                city: Some("*".into()),
548                carrier: Some("*".into()),
549                asn: Some(Asn::unspecified()),
550            },
551            Proxy {
552                id: NonEmptyString::from_static("100"),
553                address: ProxyAddress::from_str("12.34.12.35:8080").unwrap(),
554                tcp: true,
555                udp: false,
556                http: true,
557                https: true,
558                socks5: false,
559                socks5h: false,
560                datacenter: true,
561                residential: false,
562                mobile: false,
563                pool_id: None,
564                continent: Some("americas".into()),
565                country: Some("US".into()),
566                state: None,
567                city: None,
568                carrier: None,
569                asn: Some(Asn::unspecified()),
570            },
571        ])
572        .unwrap();
573
574        let service = ProxyDBLayer::new(Arc::new(db))
575            .filter_mode(ProxyFilterMode::Default)
576            .into_layer(service_fn(async |ctx: Context<()>, _: Request| {
577                Ok::<_, Infallible>(ctx.get::<ProxyAddress>().unwrap().clone())
578            }));
579
580        let mut ctx = Context::default();
581        ctx.insert(ProxyFilter {
582            country: Some(vec!["BE".into()]),
583            mobile: Some(true),
584            residential: Some(true),
585            ..Default::default()
586        });
587
588        let req = Request::builder()
589            .version(Version::HTTP_3)
590            .method("GET")
591            .uri("https://example.com")
592            .body(Body::empty())
593            .unwrap();
594
595        let proxy_address = service.serve(ctx, req).await.unwrap();
596        assert_eq!(
597            proxy_address.authority,
598            Authority::try_from("12.34.12.34:8080").unwrap()
599        );
600    }
601
602    #[tokio::test]
603    async fn test_proxy_db_single_proxy_example() {
604        let proxy = Proxy {
605            id: NonEmptyString::from_static("42"),
606            address: ProxyAddress::from_str("12.34.12.34:8080").unwrap(),
607            tcp: true,
608            udp: true,
609            http: true,
610            https: true,
611            socks5: true,
612            socks5h: true,
613            datacenter: false,
614            residential: true,
615            mobile: true,
616            pool_id: None,
617            continent: Some("*".into()),
618            country: Some("*".into()),
619            state: Some("*".into()),
620            city: Some("*".into()),
621            carrier: Some("*".into()),
622            asn: Some(Asn::unspecified()),
623        };
624
625        let service = ProxyDBLayer::new(Arc::new(proxy))
626            .filter_mode(ProxyFilterMode::Default)
627            .into_layer(service_fn(async |ctx: Context<()>, _: Request| {
628                Ok::<_, Infallible>(ctx.get::<ProxyAddress>().unwrap().clone())
629            }));
630
631        let mut ctx = Context::default();
632        ctx.insert(ProxyFilter {
633            country: Some(vec!["BE".into()]),
634            mobile: Some(true),
635            residential: Some(true),
636            ..Default::default()
637        });
638
639        let req = Request::builder()
640            .version(Version::HTTP_3)
641            .method("GET")
642            .uri("https://example.com")
643            .body(Body::empty())
644            .unwrap();
645
646        let proxy_address = service.serve(ctx, req).await.unwrap();
647        assert_eq!(
648            proxy_address.authority,
649            Authority::try_from("12.34.12.34:8080").unwrap()
650        );
651    }
652
653    #[tokio::test]
654    async fn test_proxy_db_single_proxy_with_username_formatter() {
655        let proxy = Proxy {
656            id: NonEmptyString::from_static("42"),
657            address: ProxyAddress::from_str("john:secret@12.34.12.34:8080").unwrap(),
658            tcp: true,
659            udp: true,
660            http: true,
661            https: true,
662            socks5: true,
663            socks5h: true,
664            datacenter: false,
665            residential: true,
666            mobile: true,
667            pool_id: Some("routers".into()),
668            continent: Some("*".into()),
669            country: Some("*".into()),
670            state: Some("*".into()),
671            city: Some("*".into()),
672            carrier: Some("*".into()),
673            asn: Some(Asn::unspecified()),
674        };
675
676        let service = ProxyDBLayer::new(Arc::new(proxy))
677            .filter_mode(ProxyFilterMode::Default)
678            .username_formatter(
679                |_ctx: &Context<()>, proxy: &Proxy, filter: &ProxyFilter, username: &str| {
680                    if proxy
681                        .pool_id
682                        .as_ref()
683                        .map(|id| id.as_ref() == "routers")
684                        .unwrap_or_default()
685                    {
686                        use std::fmt::Write;
687
688                        let mut output = String::new();
689
690                        if let Some(countries) = filter.country.as_ref().filter(|t| !t.is_empty()) {
691                            let _ = write!(output, "country-{}", countries[0]);
692                        }
693                        if let Some(states) = filter.state.as_ref().filter(|t| !t.is_empty()) {
694                            let _ = write!(output, "state-{}", states[0]);
695                        }
696
697                        return (!output.is_empty()).then(|| format!("{username}-{output}"));
698                    }
699
700                    None
701                },
702            )
703            .into_layer(service_fn(async |ctx: Context<()>, _: Request| {
704                Ok::<_, Infallible>(ctx.get::<ProxyAddress>().unwrap().clone())
705            }));
706
707        let mut ctx = Context::default();
708        ctx.insert(ProxyFilter {
709            country: Some(vec!["BE".into()]),
710            mobile: Some(true),
711            residential: Some(true),
712            ..Default::default()
713        });
714
715        let req = Request::builder()
716            .version(Version::HTTP_3)
717            .method("GET")
718            .uri("https://example.com")
719            .body(Body::empty())
720            .unwrap();
721
722        let proxy_address = service.serve(ctx, req).await.unwrap();
723        assert_eq!(
724            "socks5://john-country-be:secret@12.34.12.34:8080",
725            proxy_address.to_string()
726        );
727    }
728
729    #[tokio::test]
730    async fn test_proxy_db_default_happy_path_example_transport_layer() {
731        let db = MemoryProxyDB::try_from_iter([
732            Proxy {
733                id: NonEmptyString::from_static("42"),
734                address: ProxyAddress::from_str("12.34.12.34:8080").unwrap(),
735                tcp: true,
736                udp: true,
737                http: true,
738                https: true,
739                socks5: true,
740                socks5h: true,
741                datacenter: false,
742                residential: true,
743                mobile: true,
744                pool_id: None,
745                continent: Some("*".into()),
746                country: Some("*".into()),
747                state: Some("*".into()),
748                city: Some("*".into()),
749                carrier: Some("*".into()),
750                asn: Some(Asn::unspecified()),
751            },
752            Proxy {
753                id: NonEmptyString::from_static("100"),
754                address: ProxyAddress::from_str("12.34.12.35:8080").unwrap(),
755                tcp: true,
756                udp: false,
757                http: true,
758                https: true,
759                socks5: false,
760                socks5h: false,
761                datacenter: true,
762                residential: false,
763                mobile: false,
764                pool_id: None,
765                continent: Some("americas".into()),
766                country: Some("US".into()),
767                state: None,
768                city: None,
769                carrier: None,
770                asn: Some(Asn::unspecified()),
771            },
772        ])
773        .unwrap();
774
775        let service = ProxyDBLayer::new(Arc::new(db))
776            .filter_mode(ProxyFilterMode::Default)
777            .into_layer(service_fn(async |ctx: Context<()>, _| {
778                Ok::<_, Infallible>(ctx.get::<ProxyAddress>().unwrap().clone())
779            }));
780
781        let mut ctx = Context::default();
782        ctx.insert(ProxyFilter {
783            country: Some(vec!["BE".into()]),
784            mobile: Some(true),
785            residential: Some(true),
786            ..Default::default()
787        });
788
789        let req = rama_tcp::client::Request::new("www.example.com:443".parse().unwrap())
790            .with_protocol(Protocol::HTTPS);
791
792        let proxy_address = service.serve(ctx, req).await.unwrap();
793        assert_eq!(
794            proxy_address.authority,
795            Authority::try_from("12.34.12.34:8080").unwrap()
796        );
797    }
798
799    const RAW_CSV_DATA: &str = include_str!("./test_proxydb_rows.csv");
800
801    async fn memproxydb() -> MemoryProxyDB {
802        let mut reader = ProxyCsvRowReader::raw(RAW_CSV_DATA);
803        let mut rows = Vec::new();
804        while let Some(proxy) = reader.next().await.unwrap() {
805            rows.push(proxy);
806        }
807        MemoryProxyDB::try_from_rows(rows).unwrap()
808    }
809
810    #[tokio::test]
811    async fn test_proxy_db_service_preserve_proxy_address() {
812        let db = memproxydb().await;
813
814        let service = ProxyDBLayer::new(Arc::new(db))
815            .preserve_proxy(true)
816            .filter_mode(ProxyFilterMode::Default)
817            .into_layer(service_fn(async |ctx: Context<()>, _: Request| {
818                Ok::<_, Infallible>(ctx.get::<ProxyAddress>().unwrap().clone())
819            }));
820
821        let mut ctx = Context::default();
822        ctx.insert(ProxyAddress::try_from("http://john:secret@1.2.3.4:1234").unwrap());
823
824        let req = Request::builder()
825            .version(Version::HTTP_11)
826            .method("GET")
827            .uri("http://example.com")
828            .body(Body::empty())
829            .unwrap();
830
831        let proxy_address = service.serve(ctx, req).await.unwrap();
832
833        assert_eq!(proxy_address.authority.to_string(), "1.2.3.4:1234");
834    }
835
836    #[tokio::test]
837    async fn test_proxy_db_service_optional() {
838        let db = memproxydb().await;
839
840        let service = ProxyDBLayer::new(Arc::new(db)).into_layer(service_fn(
841            async |ctx: Context<()>, _: Request| {
842                Ok::<_, Infallible>(ctx.get::<ProxyAddress>().cloned())
843            },
844        ));
845
846        for (filter, expected_authority, req) in [
847            (
848                None,
849                None,
850                Request::builder()
851                    .version(Version::HTTP_11)
852                    .method("GET")
853                    .uri("http://example.com")
854                    .body(Body::empty())
855                    .unwrap(),
856            ),
857            (
858                Some(ProxyFilter {
859                    id: Some(NonEmptyString::from_static("3031533634")),
860                    ..Default::default()
861                }),
862                Some("105.150.55.60:4898"),
863                Request::builder()
864                    .version(Version::HTTP_11)
865                    .method("GET")
866                    .uri("http://example.com")
867                    .body(Body::empty())
868                    .unwrap(),
869            ),
870            (
871                Some(ProxyFilter {
872                    country: Some(vec![StringFilter::new("BE")]),
873                    mobile: Some(true),
874                    residential: Some(true),
875                    ..Default::default()
876                }),
877                Some("140.249.154.18:5800"),
878                Request::builder()
879                    .version(Version::HTTP_3)
880                    .method("GET")
881                    .uri("https://example.com")
882                    .body(Body::empty())
883                    .unwrap(),
884            ),
885        ] {
886            let mut ctx = Context::default();
887            ctx.maybe_insert(filter);
888
889            let maybe_proxy_address = service.serve(ctx, req).await.unwrap();
890
891            assert_eq!(
892                maybe_proxy_address.map(|p| p.authority),
893                expected_authority.map(|s| Authority::try_from(s).unwrap())
894            );
895        }
896    }
897
898    #[tokio::test]
899    async fn test_proxy_db_service_default() {
900        let db = memproxydb().await;
901
902        let service = ProxyDBLayer::new(Arc::new(db))
903            .filter_mode(ProxyFilterMode::Default)
904            .into_layer(service_fn(async |ctx: Context<()>, _: Request| {
905                Ok::<_, Infallible>(ctx.get::<ProxyAddress>().unwrap().clone())
906            }));
907
908        for (filter, expected_addresses, req_info) in [
909            (
910                None,
911                "0.20.204.227:8373,104.207.92.167:9387,105.150.55.60:4898,106.213.197.28:9110,113.6.21.212:4525,115.29.251.35:5712,119.146.94.132:7851,129.204.152.130:6524,134.190.189.202:5772,136.186.95.10:7095,137.220.180.169:4929,140.249.154.18:5800,145.57.31.149:6304,151.254.135.9:6961,153.206.209.221:8696,162.97.174.152:1673,169.179.161.206:6843,171.174.56.89:5744,178.189.117.217:6496,182.34.76.182:2374,184.209.230.177:1358,193.188.239.29:3541,193.26.37.125:3780,204.168.216.113:1096,208.224.120.97:7118,209.176.177.182:4311,215.49.63.89:9458,223.234.242.63:7211,230.159.143.41:7296,233.22.59.115:1653,24.155.249.112:2645,247.118.71.100:1033,249.221.15.121:7434,252.69.242.136:4791,253.138.153.41:2640,28.139.151.127:2809,4.20.243.186:9155,42.54.35.118:6846,45.59.69.12:5934,46.247.45.238:3522,54.226.47.54:7442,61.112.212.160:3842,66.142.40.209:4251,66.171.139.181:4449,69.246.162.84:8964,75.43.123.181:7719,76.128.58.167:4797,85.14.163.105:8362,92.227.104.237:6161,97.192.206.72:6067",
912                (Version::HTTP_11, "GET", "http://example.com"),
913            ),
914            (
915                Some(ProxyFilter {
916                    country: Some(vec![StringFilter::new("BE")]),
917                    mobile: Some(true),
918                    residential: Some(true),
919                    ..Default::default()
920                }),
921                "140.249.154.18:5800",
922                (Version::HTTP_3, "GET", "https://example.com"),
923            ),
924        ] {
925            let mut seen_addresses = Vec::new();
926            for _ in 0..5000 {
927                let mut ctx = Context::default();
928                ctx.maybe_insert(filter.clone());
929
930                let req = Request::builder()
931                    .version(req_info.0)
932                    .method(req_info.1)
933                    .uri(req_info.2)
934                    .body(Body::empty())
935                    .unwrap();
936
937                let proxy_address = service.serve(ctx, req).await.unwrap().authority.to_string();
938                if !seen_addresses.contains(&proxy_address) {
939                    seen_addresses.push(proxy_address);
940                }
941            }
942
943            let seen_addresses = seen_addresses.into_iter().sorted().join(",");
944            assert_eq!(seen_addresses, expected_addresses);
945        }
946    }
947
948    #[tokio::test]
949    async fn test_proxy_db_service_fallback() {
950        let db = memproxydb().await;
951
952        let service = ProxyDBLayer::new(Arc::new(db))
953            .filter_mode(ProxyFilterMode::Fallback(ProxyFilter {
954                datacenter: Some(true),
955                residential: Some(false),
956                mobile: Some(false),
957                ..Default::default()
958            }))
959            .into_layer(service_fn(async |ctx: Context<()>, _: Request| {
960                Ok::<_, Infallible>(ctx.get::<ProxyAddress>().unwrap().clone())
961            }));
962
963        for (filter, expected_addresses, req_info) in [
964            (
965                None,
966                "113.6.21.212:4525,119.146.94.132:7851,136.186.95.10:7095,137.220.180.169:4929,247.118.71.100:1033,249.221.15.121:7434,92.227.104.237:6161",
967                (Version::HTTP_11, "GET", "http://example.com"),
968            ),
969            (
970                Some(ProxyFilter {
971                    country: Some(vec![StringFilter::new("BE")]),
972                    mobile: Some(true),
973                    residential: Some(true),
974                    ..Default::default()
975                }),
976                "140.249.154.18:5800",
977                (Version::HTTP_3, "GET", "https://example.com"),
978            ),
979        ] {
980            let mut seen_addresses = Vec::new();
981            for _ in 0..5000 {
982                let mut ctx = Context::default();
983                ctx.maybe_insert(filter.clone());
984
985                let req = Request::builder()
986                    .version(req_info.0)
987                    .method(req_info.1)
988                    .uri(req_info.2)
989                    .body(Body::empty())
990                    .unwrap();
991
992                let proxy_address = service.serve(ctx, req).await.unwrap().authority.to_string();
993                if !seen_addresses.contains(&proxy_address) {
994                    seen_addresses.push(proxy_address);
995                }
996            }
997
998            let seen_addresses = seen_addresses.into_iter().sorted().join(",");
999            assert_eq!(seen_addresses, expected_addresses);
1000        }
1001    }
1002
1003    #[tokio::test]
1004    async fn test_proxy_db_service_required() {
1005        let db = memproxydb().await;
1006
1007        let service = ProxyDBLayer::new(Arc::new(db))
1008            .filter_mode(ProxyFilterMode::Required)
1009            .into_layer(service_fn(async |ctx: Context<()>, _: Request| {
1010                Ok::<_, Infallible>(ctx.get::<ProxyAddress>().unwrap().clone())
1011            }));
1012
1013        for (filter, expected_address, req) in [
1014            (
1015                None,
1016                None,
1017                Request::builder()
1018                    .version(Version::HTTP_11)
1019                    .method("GET")
1020                    .uri("http://example.com")
1021                    .body(Body::empty())
1022                    .unwrap(),
1023            ),
1024            (
1025                Some(ProxyFilter {
1026                    country: Some(vec![StringFilter::new("BE")]),
1027                    mobile: Some(true),
1028                    residential: Some(true),
1029                    ..Default::default()
1030                }),
1031                Some("140.249.154.18:5800"),
1032                Request::builder()
1033                    .version(Version::HTTP_3)
1034                    .method("GET")
1035                    .uri("https://example.com")
1036                    .body(Body::empty())
1037                    .unwrap(),
1038            ),
1039            (
1040                Some(ProxyFilter {
1041                    id: Some(NonEmptyString::from_static("FooBar")),
1042                    ..Default::default()
1043                }),
1044                None,
1045                Request::builder()
1046                    .version(Version::HTTP_3)
1047                    .method("GET")
1048                    .uri("https://example.com")
1049                    .body(Body::empty())
1050                    .unwrap(),
1051            ),
1052            (
1053                Some(ProxyFilter {
1054                    id: Some(NonEmptyString::from_static("1316455915")),
1055                    country: Some(vec![StringFilter::new("BE")]),
1056                    mobile: Some(true),
1057                    residential: Some(true),
1058                    ..Default::default()
1059                }),
1060                None,
1061                Request::builder()
1062                    .version(Version::HTTP_3)
1063                    .method("GET")
1064                    .uri("https://example.com")
1065                    .body(Body::empty())
1066                    .unwrap(),
1067            ),
1068        ] {
1069            let mut ctx = Context::default();
1070            ctx.maybe_insert(filter.clone());
1071
1072            let proxy_address_result = service.serve(ctx, req).await;
1073            match expected_address {
1074                Some(expected_address) => {
1075                    assert_eq!(
1076                        proxy_address_result.unwrap().authority,
1077                        Authority::try_from(expected_address).unwrap()
1078                    );
1079                }
1080                None => {
1081                    assert!(proxy_address_result.is_err());
1082                }
1083            }
1084        }
1085    }
1086
1087    #[tokio::test]
1088    async fn test_proxy_db_service_required_with_predicate() {
1089        let db = memproxydb().await;
1090
1091        let service = ProxyDBLayer::new(Arc::new(db))
1092            .filter_mode(ProxyFilterMode::Required)
1093            .select_predicate(|proxy: &Proxy| proxy.mobile)
1094            .into_layer(service_fn(async |ctx: Context<()>, _: Request| {
1095                Ok::<_, Infallible>(ctx.get::<ProxyAddress>().unwrap().clone())
1096            }));
1097
1098        for (filter, expected, req) in [
1099            (
1100                None,
1101                None,
1102                Request::builder()
1103                    .version(Version::HTTP_11)
1104                    .method("GET")
1105                    .uri("http://example.com")
1106                    .body(Body::empty())
1107                    .unwrap(),
1108            ),
1109            (
1110                Some(ProxyFilter {
1111                    country: Some(vec![StringFilter::new("BE")]),
1112                    mobile: Some(true),
1113                    residential: Some(true),
1114                    ..Default::default()
1115                }),
1116                Some("140.249.154.18:5800"),
1117                Request::builder()
1118                    .version(Version::HTTP_3)
1119                    .method("GET")
1120                    .uri("https://example.com")
1121                    .body(Body::empty())
1122                    .unwrap(),
1123            ),
1124            (
1125                Some(ProxyFilter {
1126                    id: Some(NonEmptyString::from_static("FooBar")),
1127                    ..Default::default()
1128                }),
1129                None,
1130                Request::builder()
1131                    .version(Version::HTTP_3)
1132                    .method("GET")
1133                    .uri("https://example.com")
1134                    .body(Body::empty())
1135                    .unwrap(),
1136            ),
1137            (
1138                Some(ProxyFilter {
1139                    id: Some(NonEmptyString::from_static("1316455915")),
1140                    country: Some(vec![StringFilter::new("BE")]),
1141                    mobile: Some(true),
1142                    residential: Some(true),
1143                    ..Default::default()
1144                }),
1145                None,
1146                Request::builder()
1147                    .version(Version::HTTP_3)
1148                    .method("GET")
1149                    .uri("https://example.com")
1150                    .body(Body::empty())
1151                    .unwrap(),
1152            ),
1153            // match found, but due to custom predicate it won't check, given it is not mobile
1154            (
1155                Some(ProxyFilter {
1156                    id: Some(NonEmptyString::from_static("1316455915")),
1157                    ..Default::default()
1158                }),
1159                None,
1160                Request::builder()
1161                    .version(Version::HTTP_3)
1162                    .method("GET")
1163                    .uri("https://example.com")
1164                    .body(Body::empty())
1165                    .unwrap(),
1166            ),
1167        ] {
1168            let mut ctx = Context::default();
1169            ctx.maybe_insert(filter);
1170
1171            let proxy_result = service.serve(ctx, req).await;
1172            match expected {
1173                Some(expected_address) => {
1174                    assert_eq!(
1175                        proxy_result.unwrap().authority,
1176                        Authority::try_from(expected_address).unwrap()
1177                    );
1178                }
1179                None => {
1180                    assert!(proxy_result.is_err());
1181                }
1182            }
1183        }
1184    }
1185}