rama_proxy/proxydb/
layer.rs

1use super::{Proxy, ProxyDB, ProxyFilter, ProxyQueryPredicate};
2use rama_core::{
3    error::{BoxError, ErrorContext, ErrorExt, OpaqueError},
4    Context, Layer, Service,
5};
6use rama_net::{
7    address::ProxyAddress,
8    transport::{TransportProtocol, TryRefIntoTransportContext},
9    user::{Basic, ProxyCredential},
10    Protocol,
11};
12use rama_utils::macros::define_inner_service_accessors;
13use std::fmt;
14
15/// A [`Service`] which selects a [`Proxy`] based on the given [`Context`].
16///
17/// Depending on the [`ProxyFilterMode`] the selection proxies might be optional,
18/// or use the default [`ProxyFilter`] in case none is defined.
19///
20/// A predicate can be used to provide additional filtering on the found proxies,
21/// that otherwise did match the used [`ProxyFilter`].
22///
23/// See [the crate docs](crate) for examples and more info on the usage of this service.
24///
25/// [`Proxy`]: crate::Proxy
26pub struct ProxyDBService<S, D, P, F> {
27    inner: S,
28    db: D,
29    mode: ProxyFilterMode,
30    predicate: P,
31    username_formatter: F,
32    preserve: bool,
33}
34
35#[derive(Debug, Clone, Default)]
36/// The modus operandi to decide how to deal with a missing [`ProxyFilter`] in the [`Context`]
37/// when selecting a [`Proxy`] from the [`ProxyDB`].
38///
39/// More advanced behaviour can be achieved by combining one of these modi
40/// with another (custom) layer prepending the parent.
41pub enum ProxyFilterMode {
42    #[default]
43    /// The [`ProxyFilter`] is optional, and if not present, no proxy is selected.
44    Optional,
45    /// The [`ProxyFilter`] is optional, and if not present, the default [`ProxyFilter`] is used.
46    Default,
47    /// The [`ProxyFilter`] is required, and if not present, an error is returned.
48    Required,
49    /// The [`ProxyFilter`] is optional, and if not present, the provided fallback [`ProxyFilter`] is used.
50    Fallback(ProxyFilter),
51}
52
53impl<S, D, P, F> fmt::Debug for ProxyDBService<S, D, P, F>
54where
55    S: fmt::Debug,
56    D: fmt::Debug,
57    P: fmt::Debug,
58    F: fmt::Debug,
59{
60    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
61        f.debug_struct("ProxyDBService")
62            .field("inner", &self.inner)
63            .field("db", &self.db)
64            .field("mode", &self.mode)
65            .field("predicate", &self.predicate)
66            .field("username_formatter", &self.username_formatter)
67            .field("preserve", &self.preserve)
68            .finish()
69    }
70}
71
72impl<S, D, P, F> Clone for ProxyDBService<S, D, P, F>
73where
74    S: Clone,
75    D: Clone,
76    P: Clone,
77    F: Clone,
78{
79    fn clone(&self) -> Self {
80        Self {
81            inner: self.inner.clone(),
82            db: self.db.clone(),
83            mode: self.mode.clone(),
84            predicate: self.predicate.clone(),
85            username_formatter: self.username_formatter.clone(),
86            preserve: self.preserve,
87        }
88    }
89}
90
91impl<S, D> ProxyDBService<S, D, bool, ()> {
92    /// Create a new [`ProxyDBService`] with the given inner [`Service`] and [`ProxyDB`].
93    pub const fn new(inner: S, db: D) -> Self {
94        Self {
95            inner,
96            db,
97            mode: ProxyFilterMode::Optional,
98            predicate: true,
99            username_formatter: (),
100            preserve: false,
101        }
102    }
103}
104
105impl<S, D, P, F> ProxyDBService<S, D, P, F> {
106    /// Set a [`ProxyFilterMode`] to define the behaviour surrounding
107    /// [`ProxyFilter`] usage, e.g. if a proxy filter is required to be available or not,
108    /// or what to do if it is optional and not available.
109    pub fn filter_mode(mut self, mode: ProxyFilterMode) -> Self {
110        self.mode = mode;
111        self
112    }
113
114    /// Set a [`ProxyFilterMode`] to define the behaviour surrounding
115    /// [`ProxyFilter`] usage, e.g. if a proxy filter is required to be available or not,
116    /// or what to do if it is optional and not available.
117    pub fn set_filter_mode(&mut self, mode: ProxyFilterMode) -> &mut Self {
118        self.mode = mode;
119        self
120    }
121
122    /// Define whether or not an existing [`ProxyAddress`] (in the [`Context`])
123    /// should be overwritten or not. By default `preserve=false`,
124    /// meaning we will overwrite the proxy address in case we selected one now.
125    ///
126    /// NOTE even when `preserve=false` it might still be that there's
127    /// a [`ProxyAddress`] in case it was set by a previous layer.
128    pub const fn preserve_proxy(mut self, preserve: bool) -> Self {
129        self.preserve = preserve;
130        self
131    }
132
133    /// Define whether or not an existing [`ProxyAddress`] (in the [`Context`])
134    /// should be overwritten or not. By default `preserve=false`,
135    /// meaning we will overwrite the proxy address in case we selected one now.
136    ///
137    /// NOTE even when `preserve=false` it might still be that there's
138    /// a [`ProxyAddress`] in case it was set by a previous layer.
139    pub fn set_preserve_proxy(&mut self, preserve: bool) -> &mut Self {
140        self.preserve = preserve;
141        self
142    }
143
144    /// Set a [`ProxyQueryPredicate`] that will be used
145    /// to possibly filter out proxies that according to the filters are correct,
146    /// but not according to the predicate.
147    pub fn select_predicate<Predicate>(self, p: Predicate) -> ProxyDBService<S, D, Predicate, F> {
148        ProxyDBService {
149            inner: self.inner,
150            db: self.db,
151            mode: self.mode,
152            predicate: p,
153            username_formatter: self.username_formatter,
154            preserve: self.preserve,
155        }
156    }
157
158    /// Set a [`UsernameFormatter`][crate::UsernameFormatter] that will be used to format
159    /// the username based on the selected [`Proxy`]. This is required
160    /// in case the proxy is a router that accepts or maybe even requires
161    /// username labels to configure proxies further down/up stream.
162    pub fn username_formatter<Formatter>(self, f: Formatter) -> ProxyDBService<S, D, P, Formatter> {
163        ProxyDBService {
164            inner: self.inner,
165            db: self.db,
166            mode: self.mode,
167            predicate: self.predicate,
168            username_formatter: f,
169            preserve: self.preserve,
170        }
171    }
172
173    define_inner_service_accessors!();
174}
175
176impl<S, D, P, F, State, Request> Service<State, Request> for ProxyDBService<S, D, P, F>
177where
178    S: Service<State, Request, Error: Into<BoxError> + Send + Sync + 'static>,
179    D: ProxyDB<Error: Into<BoxError> + Send + Sync + 'static>,
180    P: ProxyQueryPredicate,
181    F: UsernameFormatter<State>,
182    State: Clone + Send + Sync + 'static,
183    Request: TryRefIntoTransportContext<State, Error: Into<BoxError> + Send + Sync + 'static>
184        + Send
185        + 'static,
186{
187    type Response = S::Response;
188    type Error = BoxError;
189
190    async fn serve(
191        &self,
192        mut ctx: Context<State>,
193        req: Request,
194    ) -> Result<Self::Response, Self::Error> {
195        if self.preserve && ctx.contains::<ProxyAddress>() {
196            // shortcut in case a proxy address is already set,
197            // and we wish to preserve it
198            return self.inner.serve(ctx, req).await.map_err(Into::into);
199        }
200
201        let maybe_filter = match self.mode {
202            ProxyFilterMode::Optional => ctx.get::<ProxyFilter>().cloned(),
203            ProxyFilterMode::Default => Some(ctx.get_or_insert_default::<ProxyFilter>().clone()),
204            ProxyFilterMode::Required => Some(
205                ctx.get::<ProxyFilter>()
206                    .cloned()
207                    .context("missing proxy filter")?,
208            ),
209            ProxyFilterMode::Fallback(ref filter) => {
210                Some(ctx.get_or_insert_with(|| filter.clone()).clone())
211            }
212        };
213
214        if let Some(filter) = maybe_filter {
215            let transport_ctx = ctx
216                .get_or_try_insert_with_ctx(|ctx| req.try_ref_into_transport_ctx(ctx))
217                .map_err(|err| {
218                    OpaqueError::from_boxed(err.into())
219                        .context("proxydb: select proxy: get transport context")
220                })?
221                .clone();
222            let transport_protocol = transport_ctx.protocol.clone();
223
224            let proxy = self
225                .db
226                .get_proxy_if(transport_ctx, filter.clone(), self.predicate.clone())
227                .await
228                .map_err(|err| {
229                    OpaqueError::from_std(ProxySelectError {
230                        inner: err.into(),
231                        filter: filter.clone(),
232                    })
233                })?;
234
235            let mut proxy_address = proxy.address.clone();
236
237            // prepare the credential with labels in username if desired
238            proxy_address.credential = proxy_address.credential.take().map(|credential| {
239                match credential {
240                    ProxyCredential::Basic(ref basic) => {
241                        match self.username_formatter.fmt_username(
242                            &ctx,
243                            &proxy,
244                            &filter,
245                            basic.username(),
246                        ) {
247                            Some(username) => ProxyCredential::Basic(Basic::new(
248                                username,
249                                basic.password().to_owned(),
250                            )),
251                            None => credential, // nothing to do
252                        }
253                    }
254                    ProxyCredential::Bearer(_) => credential, // Remark: we can support this in future too if needed
255                }
256            });
257
258            // overwrite the proxy protocol if not set yet
259            if proxy_address.protocol.is_none() {
260                proxy_address.protocol = match transport_protocol {
261                    TransportProtocol::Udp => {
262                        if proxy.socks5 {
263                            Some(Protocol::SOCKS5)
264                        } else if proxy.socks5h {
265                            Some(Protocol::SOCKS5H)
266                        } else {
267                            return Err(OpaqueError::from_display(
268                                "selected udp proxy does not have a valid protocol available (db bug?!)",
269                            )
270                            .into());
271                        }
272                    }
273                    TransportProtocol::Tcp => match proxy_address.authority.port() {
274                        80 | 8080 if proxy.http => Some(Protocol::HTTP),
275                        443 | 8443 if proxy.https => Some(Protocol::HTTPS),
276                        1080 if proxy.socks5 => Some(Protocol::SOCKS5),
277                        1080 if proxy.socks5h => Some(Protocol::SOCKS5H),
278                        _ => {
279                            // speed: Socks5 > Http > Https
280                            if proxy.socks5 {
281                                Some(Protocol::SOCKS5)
282                            } else if proxy.socks5h {
283                                Some(Protocol::SOCKS5H)
284                            } else if proxy.http {
285                                Some(Protocol::HTTP)
286                            } else if proxy.https {
287                                Some(Protocol::HTTPS)
288                            } else {
289                                return Err(OpaqueError::from_display(
290                                "selected tcp proxy does not have a valid protocol available (db bug?!)",
291                            )
292                            .into());
293                            }
294                        }
295                    },
296                };
297            }
298
299            // insert proxy address in context so it will be used
300            ctx.insert(proxy_address);
301
302            // insert the id of the selected proxy
303            ctx.insert(super::ProxyID::from(proxy.id.clone()));
304
305            // insert the entire proxy also in there, for full "Context"
306            ctx.insert(proxy);
307        }
308
309        self.inner.serve(ctx, req).await.map_err(Into::into)
310    }
311}
312
313#[derive(Debug)]
314struct ProxySelectError {
315    inner: BoxError,
316    filter: ProxyFilter,
317}
318
319impl fmt::Display for ProxySelectError {
320    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
321        write!(
322            f,
323            "proxy select error ({}) for filter: {:?}",
324            self.inner, self.filter
325        )
326    }
327}
328
329impl std::error::Error for ProxySelectError {
330    fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
331        Some(self.inner.source().unwrap_or_else(|| self.inner.as_ref()))
332    }
333}
334
335/// A [`Layer`] which wraps an inner [`Service`] to select a [`Proxy`] based on the given [`Context`],
336/// and insert, if a [`Proxy`] is selected, it in the [`Context`] for further processing.
337///
338/// See [the crate docs](crate) for examples and more info on the usage of this service.
339pub struct ProxyDBLayer<D, P, F> {
340    db: D,
341    mode: ProxyFilterMode,
342    predicate: P,
343    username_formatter: F,
344    preserve: bool,
345}
346
347impl<D, P, F> fmt::Debug for ProxyDBLayer<D, P, F>
348where
349    D: fmt::Debug,
350    P: fmt::Debug,
351    F: fmt::Debug,
352{
353    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
354        f.debug_struct("ProxyDBLayer")
355            .field("db", &self.db)
356            .field("mode", &self.mode)
357            .field("predicate", &self.predicate)
358            .field("username_formatter", &self.username_formatter)
359            .field("preserve", &self.preserve)
360            .finish()
361    }
362}
363
364impl<D, P, F> Clone for ProxyDBLayer<D, P, F>
365where
366    D: Clone,
367    P: Clone,
368    F: Clone,
369{
370    fn clone(&self) -> Self {
371        Self {
372            db: self.db.clone(),
373            mode: self.mode.clone(),
374            predicate: self.predicate.clone(),
375            username_formatter: self.username_formatter.clone(),
376            preserve: self.preserve,
377        }
378    }
379}
380
381impl<D> ProxyDBLayer<D, bool, ()> {
382    /// Create a new [`ProxyDBLayer`] with the given [`ProxyDB`].
383    pub const fn new(db: D) -> Self {
384        Self {
385            db,
386            mode: ProxyFilterMode::Optional,
387            predicate: true,
388            username_formatter: (),
389            preserve: false,
390        }
391    }
392}
393
394impl<D, P, F> ProxyDBLayer<D, P, F> {
395    /// Set a [`ProxyFilterMode`] to define the behaviour surrounding
396    /// [`ProxyFilter`] usage, e.g. if a proxy filter is required to be available or not,
397    /// or what to do if it is optional and not available.
398    pub fn filter_mode(mut self, mode: ProxyFilterMode) -> Self {
399        self.mode = mode;
400        self
401    }
402
403    /// Define whether or not an existing [`ProxyAddress`] (in the [`Context`])
404    /// should be overwritten or not. By default `preserve=false`,
405    /// meaning we will overwrite the proxy address in case we selected one now.
406    ///
407    /// NOTE even when `preserve=false` it might still be that there's
408    /// a [`ProxyAddress`] in case it was set by a previous layer.
409    pub fn preserve_proxy(mut self, preserve: bool) -> Self {
410        self.preserve = preserve;
411        self
412    }
413
414    /// Set a [`ProxyQueryPredicate`] that will be used
415    /// to possibly filter out proxies that according to the filters are correct,
416    /// but not according to the predicate.
417    pub fn select_predicate<Predicate>(self, p: Predicate) -> ProxyDBLayer<D, Predicate, F> {
418        ProxyDBLayer {
419            db: self.db,
420            mode: self.mode,
421            predicate: p,
422            username_formatter: self.username_formatter,
423            preserve: self.preserve,
424        }
425    }
426
427    /// Set a [`UsernameFormatter`][crate::UsernameFormatter] that will be used to format
428    /// the username based on the selected [`Proxy`]. This is required
429    /// in case the proxy is a router that accepts or maybe even requires
430    /// username labels to configure proxies further down/up stream.
431    pub fn username_formatter<Formatter>(self, f: Formatter) -> ProxyDBLayer<D, P, Formatter> {
432        ProxyDBLayer {
433            db: self.db,
434            mode: self.mode,
435            predicate: self.predicate,
436            username_formatter: f,
437            preserve: self.preserve,
438        }
439    }
440}
441
442impl<S, D, P, F> Layer<S> for ProxyDBLayer<D, P, F>
443where
444    D: Clone,
445    P: Clone,
446    F: Clone,
447{
448    type Service = ProxyDBService<S, D, P, F>;
449
450    fn layer(&self, inner: S) -> Self::Service {
451        ProxyDBService {
452            inner,
453            db: self.db.clone(),
454            mode: self.mode.clone(),
455            predicate: self.predicate.clone(),
456            username_formatter: self.username_formatter.clone(),
457            preserve: self.preserve,
458        }
459    }
460}
461
462/// Trait that is used to allow the formatting of a username,
463/// e.g. to allow proxy routers to have proxy config labels in the username.
464pub trait UsernameFormatter<S>: Send + Sync + 'static {
465    /// format the username based on the root properties of the given proxy.
466    fn fmt_username(
467        &self,
468        ctx: &Context<S>,
469        proxy: &Proxy,
470        filter: &ProxyFilter,
471        username: &str,
472    ) -> Option<String>;
473}
474
475impl<S> UsernameFormatter<S> for () {
476    fn fmt_username(
477        &self,
478        _ctx: &Context<S>,
479        _proxy: &Proxy,
480        _filter: &ProxyFilter,
481        _username: &str,
482    ) -> Option<String> {
483        None
484    }
485}
486
487impl<F, S> UsernameFormatter<S> for F
488where
489    F: Fn(&Context<S>, &Proxy, &ProxyFilter, &str) -> Option<String> + Send + Sync + 'static,
490{
491    fn fmt_username(
492        &self,
493        ctx: &Context<S>,
494        proxy: &Proxy,
495        filter: &ProxyFilter,
496        username: &str,
497    ) -> Option<String> {
498        (self)(ctx, proxy, filter, username)
499    }
500}
501
502#[cfg(test)]
503mod tests {
504    use super::*;
505    use crate::{MemoryProxyDB, Proxy, ProxyCsvRowReader, StringFilter};
506    use itertools::Itertools;
507    use rama_core::service::service_fn;
508    use rama_http_types::{Body, Request, Version};
509    use rama_net::{
510        address::{Authority, ProxyAddress},
511        asn::Asn,
512        Protocol,
513    };
514    use rama_utils::str::NonEmptyString;
515    use std::{convert::Infallible, str::FromStr, sync::Arc};
516
517    #[tokio::test]
518    async fn test_proxy_db_default_happy_path_example() {
519        let db = MemoryProxyDB::try_from_iter([
520            Proxy {
521                id: NonEmptyString::from_static("42"),
522                address: ProxyAddress::from_str("12.34.12.34:8080").unwrap(),
523                tcp: true,
524                udp: true,
525                http: true,
526                https: true,
527                socks5: true,
528                socks5h: true,
529                datacenter: false,
530                residential: true,
531                mobile: true,
532                pool_id: None,
533                continent: Some("*".into()),
534                country: Some("*".into()),
535                state: Some("*".into()),
536                city: Some("*".into()),
537                carrier: Some("*".into()),
538                asn: Some(Asn::unspecified()),
539            },
540            Proxy {
541                id: NonEmptyString::from_static("100"),
542                address: ProxyAddress::from_str("12.34.12.35:8080").unwrap(),
543                tcp: true,
544                udp: false,
545                http: true,
546                https: true,
547                socks5: false,
548                socks5h: false,
549                datacenter: true,
550                residential: false,
551                mobile: false,
552                pool_id: None,
553                continent: Some("americas".into()),
554                country: Some("US".into()),
555                state: None,
556                city: None,
557                carrier: None,
558                asn: Some(Asn::unspecified()),
559            },
560        ])
561        .unwrap();
562
563        let service = ProxyDBLayer::new(Arc::new(db))
564            .filter_mode(ProxyFilterMode::Default)
565            .layer(service_fn(|ctx: Context<()>, _: Request| async move {
566                Ok::<_, Infallible>(ctx.get::<ProxyAddress>().unwrap().clone())
567            }));
568
569        let mut ctx = Context::default();
570        ctx.insert(ProxyFilter {
571            country: Some(vec!["BE".into()]),
572            mobile: Some(true),
573            residential: Some(true),
574            ..Default::default()
575        });
576
577        let req = Request::builder()
578            .version(Version::HTTP_3)
579            .method("GET")
580            .uri("https://example.com")
581            .body(Body::empty())
582            .unwrap();
583
584        let proxy_address = service.serve(ctx, req).await.unwrap();
585        assert_eq!(
586            proxy_address.authority,
587            Authority::try_from("12.34.12.34:8080").unwrap()
588        );
589    }
590
591    #[tokio::test]
592    async fn test_proxy_db_single_proxy_example() {
593        let proxy = Proxy {
594            id: NonEmptyString::from_static("42"),
595            address: ProxyAddress::from_str("12.34.12.34:8080").unwrap(),
596            tcp: true,
597            udp: true,
598            http: true,
599            https: true,
600            socks5: true,
601            socks5h: true,
602            datacenter: false,
603            residential: true,
604            mobile: true,
605            pool_id: None,
606            continent: Some("*".into()),
607            country: Some("*".into()),
608            state: Some("*".into()),
609            city: Some("*".into()),
610            carrier: Some("*".into()),
611            asn: Some(Asn::unspecified()),
612        };
613
614        let service = ProxyDBLayer::new(Arc::new(proxy))
615            .filter_mode(ProxyFilterMode::Default)
616            .layer(service_fn(|ctx: Context<()>, _: Request| async move {
617                Ok::<_, Infallible>(ctx.get::<ProxyAddress>().unwrap().clone())
618            }));
619
620        let mut ctx = Context::default();
621        ctx.insert(ProxyFilter {
622            country: Some(vec!["BE".into()]),
623            mobile: Some(true),
624            residential: Some(true),
625            ..Default::default()
626        });
627
628        let req = Request::builder()
629            .version(Version::HTTP_3)
630            .method("GET")
631            .uri("https://example.com")
632            .body(Body::empty())
633            .unwrap();
634
635        let proxy_address = service.serve(ctx, req).await.unwrap();
636        assert_eq!(
637            proxy_address.authority,
638            Authority::try_from("12.34.12.34:8080").unwrap()
639        );
640    }
641
642    #[tokio::test]
643    async fn test_proxy_db_single_proxy_with_username_formatter() {
644        let proxy = Proxy {
645            id: NonEmptyString::from_static("42"),
646            address: ProxyAddress::from_str("john:secret@12.34.12.34:8080").unwrap(),
647            tcp: true,
648            udp: true,
649            http: true,
650            https: true,
651            socks5: true,
652            socks5h: true,
653            datacenter: false,
654            residential: true,
655            mobile: true,
656            pool_id: Some("routers".into()),
657            continent: Some("*".into()),
658            country: Some("*".into()),
659            state: Some("*".into()),
660            city: Some("*".into()),
661            carrier: Some("*".into()),
662            asn: Some(Asn::unspecified()),
663        };
664
665        let service = ProxyDBLayer::new(Arc::new(proxy))
666            .filter_mode(ProxyFilterMode::Default)
667            .username_formatter(
668                |_ctx: &Context<()>, proxy: &Proxy, filter: &ProxyFilter, username: &str| {
669                    if proxy
670                        .pool_id
671                        .as_ref()
672                        .map(|id| id.as_ref() == "routers")
673                        .unwrap_or_default()
674                    {
675                        use std::fmt::Write;
676
677                        let mut output = String::new();
678
679                        if let Some(countries) = filter.country.as_ref().filter(|t| !t.is_empty()) {
680                            let _ = write!(output, "country-{}", countries[0]);
681                        }
682                        if let Some(states) = filter.state.as_ref().filter(|t| !t.is_empty()) {
683                            let _ = write!(output, "state-{}", states[0]);
684                        }
685
686                        return (!output.is_empty()).then(|| format!("{username}-{output}"));
687                    }
688
689                    None
690                },
691            )
692            .layer(service_fn(|ctx: Context<()>, _: Request| async move {
693                Ok::<_, Infallible>(ctx.get::<ProxyAddress>().unwrap().clone())
694            }));
695
696        let mut ctx = Context::default();
697        ctx.insert(ProxyFilter {
698            country: Some(vec!["BE".into()]),
699            mobile: Some(true),
700            residential: Some(true),
701            ..Default::default()
702        });
703
704        let req = Request::builder()
705            .version(Version::HTTP_3)
706            .method("GET")
707            .uri("https://example.com")
708            .body(Body::empty())
709            .unwrap();
710
711        let proxy_address = service.serve(ctx, req).await.unwrap();
712        assert_eq!(
713            "socks5://john-country-be:secret@12.34.12.34:8080",
714            proxy_address.to_string()
715        );
716    }
717
718    #[tokio::test]
719    async fn test_proxy_db_default_happy_path_example_transport_layer() {
720        let db = MemoryProxyDB::try_from_iter([
721            Proxy {
722                id: NonEmptyString::from_static("42"),
723                address: ProxyAddress::from_str("12.34.12.34:8080").unwrap(),
724                tcp: true,
725                udp: true,
726                http: true,
727                https: true,
728                socks5: true,
729                socks5h: true,
730                datacenter: false,
731                residential: true,
732                mobile: true,
733                pool_id: None,
734                continent: Some("*".into()),
735                country: Some("*".into()),
736                state: Some("*".into()),
737                city: Some("*".into()),
738                carrier: Some("*".into()),
739                asn: Some(Asn::unspecified()),
740            },
741            Proxy {
742                id: NonEmptyString::from_static("100"),
743                address: ProxyAddress::from_str("12.34.12.35:8080").unwrap(),
744                tcp: true,
745                udp: false,
746                http: true,
747                https: true,
748                socks5: false,
749                socks5h: false,
750                datacenter: true,
751                residential: false,
752                mobile: false,
753                pool_id: None,
754                continent: Some("americas".into()),
755                country: Some("US".into()),
756                state: None,
757                city: None,
758                carrier: None,
759                asn: Some(Asn::unspecified()),
760            },
761        ])
762        .unwrap();
763
764        let service = ProxyDBLayer::new(Arc::new(db))
765            .filter_mode(ProxyFilterMode::Default)
766            .layer(service_fn(|ctx: Context<()>, _| async move {
767                Ok::<_, Infallible>(ctx.get::<ProxyAddress>().unwrap().clone())
768            }));
769
770        let mut ctx = Context::default();
771        ctx.insert(ProxyFilter {
772            country: Some(vec!["BE".into()]),
773            mobile: Some(true),
774            residential: Some(true),
775            ..Default::default()
776        });
777
778        let req = rama_tcp::client::Request::new("www.example.com:443".parse().unwrap())
779            .with_protocol(Protocol::HTTPS);
780
781        let proxy_address = service.serve(ctx, req).await.unwrap();
782        assert_eq!(
783            proxy_address.authority,
784            Authority::try_from("12.34.12.34:8080").unwrap()
785        );
786    }
787
788    const RAW_CSV_DATA: &str = include_str!("./test_proxydb_rows.csv");
789
790    async fn memproxydb() -> MemoryProxyDB {
791        let mut reader = ProxyCsvRowReader::raw(RAW_CSV_DATA);
792        let mut rows = Vec::new();
793        while let Some(proxy) = reader.next().await.unwrap() {
794            rows.push(proxy);
795        }
796        MemoryProxyDB::try_from_rows(rows).unwrap()
797    }
798
799    #[tokio::test]
800    async fn test_proxy_db_service_preserve_proxy_address() {
801        let db = memproxydb().await;
802
803        let service = ProxyDBLayer::new(Arc::new(db))
804            .preserve_proxy(true)
805            .filter_mode(ProxyFilterMode::Default)
806            .layer(service_fn(|ctx: Context<()>, _: Request| async move {
807                Ok::<_, Infallible>(ctx.get::<ProxyAddress>().unwrap().clone())
808            }));
809
810        let mut ctx = Context::default();
811        ctx.insert(ProxyAddress::try_from("http://john:secret@1.2.3.4:1234").unwrap());
812
813        let req = Request::builder()
814            .version(Version::HTTP_11)
815            .method("GET")
816            .uri("http://example.com")
817            .body(Body::empty())
818            .unwrap();
819
820        let proxy_address = service.serve(ctx, req).await.unwrap();
821
822        assert_eq!(proxy_address.authority.to_string(), "1.2.3.4:1234");
823    }
824
825    #[tokio::test]
826    async fn test_proxy_db_service_optional() {
827        let db = memproxydb().await;
828
829        let service = ProxyDBLayer::new(Arc::new(db)).layer(service_fn(
830            |ctx: Context<()>, _: Request| async move {
831                Ok::<_, Infallible>(ctx.get::<ProxyAddress>().cloned())
832            },
833        ));
834
835        for (filter, expected_authority, req) in [
836            (
837                None,
838                None,
839                Request::builder()
840                    .version(Version::HTTP_11)
841                    .method("GET")
842                    .uri("http://example.com")
843                    .body(Body::empty())
844                    .unwrap(),
845            ),
846            (
847                Some(ProxyFilter {
848                    id: Some(NonEmptyString::from_static("3031533634")),
849                    ..Default::default()
850                }),
851                Some("105.150.55.60:4898"),
852                Request::builder()
853                    .version(Version::HTTP_11)
854                    .method("GET")
855                    .uri("http://example.com")
856                    .body(Body::empty())
857                    .unwrap(),
858            ),
859            (
860                Some(ProxyFilter {
861                    country: Some(vec![StringFilter::new("BE")]),
862                    mobile: Some(true),
863                    residential: Some(true),
864                    ..Default::default()
865                }),
866                Some("140.249.154.18:5800"),
867                Request::builder()
868                    .version(Version::HTTP_3)
869                    .method("GET")
870                    .uri("https://example.com")
871                    .body(Body::empty())
872                    .unwrap(),
873            ),
874        ] {
875            let mut ctx = Context::default();
876            ctx.maybe_insert(filter);
877
878            let maybe_proxy_address = service.serve(ctx, req).await.unwrap();
879
880            assert_eq!(
881                maybe_proxy_address.map(|p| p.authority),
882                expected_authority.map(|s| Authority::try_from(s).unwrap())
883            );
884        }
885    }
886
887    #[tokio::test]
888    async fn test_proxy_db_service_default() {
889        let db = memproxydb().await;
890
891        let service = ProxyDBLayer::new(Arc::new(db))
892            .filter_mode(ProxyFilterMode::Default)
893            .layer(service_fn(|ctx: Context<()>, _: Request| async move {
894                Ok::<_, Infallible>(ctx.get::<ProxyAddress>().unwrap().clone())
895            }));
896
897        for (filter, expected_addresses, req_info) in [
898            (None, "0.20.204.227:8373,104.207.92.167:9387,105.150.55.60:4898,106.213.197.28:9110,113.6.21.212:4525,115.29.251.35:5712,119.146.94.132:7851,129.204.152.130:6524,134.190.189.202:5772,136.186.95.10:7095,137.220.180.169:4929,140.249.154.18:5800,145.57.31.149:6304,151.254.135.9:6961,153.206.209.221:8696,162.97.174.152:1673,169.179.161.206:6843,171.174.56.89:5744,178.189.117.217:6496,182.34.76.182:2374,184.209.230.177:1358,193.188.239.29:3541,193.26.37.125:3780,204.168.216.113:1096,208.224.120.97:7118,209.176.177.182:4311,215.49.63.89:9458,223.234.242.63:7211,230.159.143.41:7296,233.22.59.115:1653,24.155.249.112:2645,247.118.71.100:1033,249.221.15.121:7434,252.69.242.136:4791,253.138.153.41:2640,28.139.151.127:2809,4.20.243.186:9155,42.54.35.118:6846,45.59.69.12:5934,46.247.45.238:3522,54.226.47.54:7442,61.112.212.160:3842,66.142.40.209:4251,66.171.139.181:4449,69.246.162.84:8964,75.43.123.181:7719,76.128.58.167:4797,85.14.163.105:8362,92.227.104.237:6161,97.192.206.72:6067", (Version::HTTP_11, "GET", "http://example.com")),
899            (
900                Some(ProxyFilter {
901                    country: Some(vec![StringFilter::new("BE")]),
902                    mobile: Some(true),
903                    residential: Some(true),
904                    ..Default::default()
905                }),
906                "140.249.154.18:5800",
907                (Version::HTTP_3, "GET", "https://example.com"),
908            ),
909        ] {
910            let mut seen_addresses = Vec::new();
911            for _ in 0..5000 {
912                let mut ctx = Context::default();
913                ctx.maybe_insert(filter.clone());
914
915                let req = Request::builder()
916                    .version(req_info.0)
917                    .method(req_info.1)
918                    .uri(req_info.2)
919                    .body(Body::empty())
920                    .unwrap();
921
922                let proxy_address = service.serve(ctx, req).await.unwrap().authority.to_string();
923                if !seen_addresses.contains(&proxy_address) {
924                    seen_addresses.push(proxy_address);
925                }
926            }
927
928            let seen_addresses = seen_addresses.into_iter().sorted().join(",");
929            assert_eq!(seen_addresses, expected_addresses);
930        }
931    }
932
933    #[tokio::test]
934    async fn test_proxy_db_service_fallback() {
935        let db = memproxydb().await;
936
937        let service = ProxyDBLayer::new(Arc::new(db))
938            .filter_mode(ProxyFilterMode::Fallback(ProxyFilter {
939                datacenter: Some(true),
940                residential: Some(false),
941                mobile: Some(false),
942                ..Default::default()
943            }))
944            .layer(service_fn(|ctx: Context<()>, _: Request| async move {
945                Ok::<_, Infallible>(ctx.get::<ProxyAddress>().unwrap().clone())
946            }));
947
948        for (filter, expected_addresses, req_info) in [
949            (
950                None,
951                "113.6.21.212:4525,119.146.94.132:7851,136.186.95.10:7095,137.220.180.169:4929,247.118.71.100:1033,249.221.15.121:7434,92.227.104.237:6161",
952                (Version::HTTP_11, "GET", "http://example.com"),
953            ),
954            (
955                Some(ProxyFilter {
956                    country: Some(vec![StringFilter::new("BE")]),
957                    mobile: Some(true),
958                    residential: Some(true),
959                    ..Default::default()
960                }),
961                "140.249.154.18:5800",
962                (Version::HTTP_3, "GET", "https://example.com"),
963            ),
964        ] {
965            let mut seen_addresses = Vec::new();
966            for _ in 0..5000 {
967                let mut ctx = Context::default();
968                ctx.maybe_insert(filter.clone());
969
970                let req = Request::builder()
971                    .version(req_info.0)
972                    .method(req_info.1)
973                    .uri(req_info.2)
974                    .body(Body::empty())
975                    .unwrap();
976
977                let proxy_address = service.serve(ctx, req).await.unwrap().authority.to_string();
978                if !seen_addresses.contains(&proxy_address) {
979                    seen_addresses.push(proxy_address);
980                }
981            }
982
983            let seen_addresses = seen_addresses.into_iter().sorted().join(",");
984            assert_eq!(seen_addresses, expected_addresses);
985        }
986    }
987
988    #[tokio::test]
989    async fn test_proxy_db_service_required() {
990        let db = memproxydb().await;
991
992        let service = ProxyDBLayer::new(Arc::new(db))
993            .filter_mode(ProxyFilterMode::Required)
994            .layer(service_fn(|ctx: Context<()>, _: Request| async move {
995                Ok::<_, Infallible>(ctx.get::<ProxyAddress>().unwrap().clone())
996            }));
997
998        for (filter, expected_address, req) in [
999            (
1000                None,
1001                None,
1002                Request::builder()
1003                    .version(Version::HTTP_11)
1004                    .method("GET")
1005                    .uri("http://example.com")
1006                    .body(Body::empty())
1007                    .unwrap(),
1008            ),
1009            (
1010                Some(ProxyFilter {
1011                    country: Some(vec![StringFilter::new("BE")]),
1012                    mobile: Some(true),
1013                    residential: Some(true),
1014                    ..Default::default()
1015                }),
1016                Some("140.249.154.18:5800"),
1017                Request::builder()
1018                    .version(Version::HTTP_3)
1019                    .method("GET")
1020                    .uri("https://example.com")
1021                    .body(Body::empty())
1022                    .unwrap(),
1023            ),
1024            (
1025                Some(ProxyFilter {
1026                    id: Some(NonEmptyString::from_static("FooBar")),
1027                    ..Default::default()
1028                }),
1029                None,
1030                Request::builder()
1031                    .version(Version::HTTP_3)
1032                    .method("GET")
1033                    .uri("https://example.com")
1034                    .body(Body::empty())
1035                    .unwrap(),
1036            ),
1037            (
1038                Some(ProxyFilter {
1039                    id: Some(NonEmptyString::from_static("1316455915")),
1040                    country: Some(vec![StringFilter::new("BE")]),
1041                    mobile: Some(true),
1042                    residential: Some(true),
1043                    ..Default::default()
1044                }),
1045                None,
1046                Request::builder()
1047                    .version(Version::HTTP_3)
1048                    .method("GET")
1049                    .uri("https://example.com")
1050                    .body(Body::empty())
1051                    .unwrap(),
1052            ),
1053        ] {
1054            let mut ctx = Context::default();
1055            ctx.maybe_insert(filter.clone());
1056
1057            let proxy_address_result = service.serve(ctx, req).await;
1058            match expected_address {
1059                Some(expected_address) => {
1060                    assert_eq!(
1061                        proxy_address_result.unwrap().authority,
1062                        Authority::try_from(expected_address).unwrap()
1063                    );
1064                }
1065                None => {
1066                    assert!(proxy_address_result.is_err());
1067                }
1068            }
1069        }
1070    }
1071
1072    #[tokio::test]
1073    async fn test_proxy_db_service_required_with_predicate() {
1074        let db = memproxydb().await;
1075
1076        let service = ProxyDBLayer::new(Arc::new(db))
1077            .filter_mode(ProxyFilterMode::Required)
1078            .select_predicate(|proxy: &Proxy| proxy.mobile)
1079            .layer(service_fn(|ctx: Context<()>, _: Request| async move {
1080                Ok::<_, Infallible>(ctx.get::<ProxyAddress>().unwrap().clone())
1081            }));
1082
1083        for (filter, expected, req) in [
1084            (
1085                None,
1086                None,
1087                Request::builder()
1088                    .version(Version::HTTP_11)
1089                    .method("GET")
1090                    .uri("http://example.com")
1091                    .body(Body::empty())
1092                    .unwrap(),
1093            ),
1094            (
1095                Some(ProxyFilter {
1096                    country: Some(vec![StringFilter::new("BE")]),
1097                    mobile: Some(true),
1098                    residential: Some(true),
1099                    ..Default::default()
1100                }),
1101                Some("140.249.154.18:5800"),
1102                Request::builder()
1103                    .version(Version::HTTP_3)
1104                    .method("GET")
1105                    .uri("https://example.com")
1106                    .body(Body::empty())
1107                    .unwrap(),
1108            ),
1109            (
1110                Some(ProxyFilter {
1111                    id: Some(NonEmptyString::from_static("FooBar")),
1112                    ..Default::default()
1113                }),
1114                None,
1115                Request::builder()
1116                    .version(Version::HTTP_3)
1117                    .method("GET")
1118                    .uri("https://example.com")
1119                    .body(Body::empty())
1120                    .unwrap(),
1121            ),
1122            (
1123                Some(ProxyFilter {
1124                    id: Some(NonEmptyString::from_static("1316455915")),
1125                    country: Some(vec![StringFilter::new("BE")]),
1126                    mobile: Some(true),
1127                    residential: Some(true),
1128                    ..Default::default()
1129                }),
1130                None,
1131                Request::builder()
1132                    .version(Version::HTTP_3)
1133                    .method("GET")
1134                    .uri("https://example.com")
1135                    .body(Body::empty())
1136                    .unwrap(),
1137            ),
1138            // match found, but due to custom predicate it won't check, given it is not mobile
1139            (
1140                Some(ProxyFilter {
1141                    id: Some(NonEmptyString::from_static("1316455915")),
1142                    ..Default::default()
1143                }),
1144                None,
1145                Request::builder()
1146                    .version(Version::HTTP_3)
1147                    .method("GET")
1148                    .uri("https://example.com")
1149                    .body(Body::empty())
1150                    .unwrap(),
1151            ),
1152        ] {
1153            let mut ctx = Context::default();
1154            ctx.maybe_insert(filter);
1155
1156            let proxy_result = service.serve(ctx, req).await;
1157            match expected {
1158                Some(expected_address) => {
1159                    assert_eq!(
1160                        proxy_result.unwrap().authority,
1161                        Authority::try_from(expected_address).unwrap()
1162                    );
1163                }
1164                None => {
1165                    assert!(proxy_result.is_err());
1166                }
1167            }
1168        }
1169    }
1170}