actix_web_location/
extractors.rs

1use std::sync::Arc;
2
3use crate::{domain::Location, error::Error, providers::Provider};
4use anyhow::anyhow;
5use futures::{future::LocalBoxFuture, FutureExt};
6use lazy_static::lazy_static;
7
8#[cfg(feature = "actix-web-v3")]
9use actix_web_3::{dev, web, FromRequest, HttpRequest};
10
11#[cfg(feature = "actix-web-v4")]
12use actix_web_4::{dev, web, FromRequest, HttpRequest};
13
14impl FromRequest for Location {
15    #[cfg(feature = "actix-web-v3")]
16    type Config = LocationConfig;
17
18    type Error = Error;
19
20    type Future = LocalBoxFuture<'static, Result<Self, Self::Error>>;
21
22    fn from_request(req: &HttpRequest, _payload: &mut dev::Payload) -> Self::Future {
23        let req = req.clone();
24        async move {
25            let config = LocationConfig::from_req(&req).clone();
26            let mut result: Option<Result<Self, Self::Error>> = None;
27            for provider in config.providers {
28                if let Ok(Some(location)) = provider.get_location(&req).await {
29                    #[cfg(feature = "cadence")]
30                    {
31                        if let Some(metrics) = config.metrics.as_ref() {
32                            if provider.expect_city() && location.city.is_none() {
33                                metrics
34                                    .incr_with_tags("location.unknown.city")
35                                    .with_tag("provider", provider.name())
36                                    .try_send()
37                                    .ok();
38                            }
39                            if provider.expect_region() && location.region.is_none() {
40                                metrics
41                                    .incr_with_tags("location.unknown.region")
42                                    .with_tag("provider", provider.name())
43                                    .try_send()
44                                    .ok();
45                            }
46                            if provider.expect_country() && location.country.is_none() {
47                                metrics
48                                    .incr_with_tags("location.unknown.country")
49                                    .with_tag("provider", provider.name())
50                                    .try_send()
51                                    .ok();
52                            }
53                        }
54                    }
55
56                    result = Some(Ok(location));
57
58                    break;
59                }
60            }
61
62            #[cfg(feature = "cadence")]
63            let metrics = config.metrics.as_ref();
64
65            result.unwrap_or_else(|| {
66                #[cfg(feature = "cadence")]
67                {
68                    if let Some(metrics) = metrics {
69                        metrics
70                            .incr_with_tags("location.unknown.city")
71                            .with_tag("provider", "none")
72                            .try_send()
73                            .ok();
74                        metrics
75                            .incr_with_tags("location.unknown.region")
76                            .with_tag("provider", "none")
77                            .try_send()
78                            .ok();
79                        metrics
80                            .incr_with_tags("location.unknown.country")
81                            .with_tag("provider", "none")
82                            .try_send()
83                            .ok();
84                    }
85                }
86
87                Location::build()
88                    .provider("none".to_string())
89                    .finish()
90                    .map_err(|_| Error::Http(anyhow!("Bug when processing default result")))
91            })
92        }
93        .boxed_local()
94    }
95}
96
97/// Configuration for how to determine location from a request.
98#[derive(Clone, Default)]
99pub struct LocationConfig {
100    /// The provider to request location information from.
101    providers: Vec<Arc<Box<dyn Provider>>>,
102
103    /// An optional sink to send metrics to.
104    #[cfg(feature = "cadence")]
105    metrics: Option<Arc<dyn cadence::CountedExt + Send + Sync>>,
106}
107
108lazy_static! {
109    static ref DEFAULT_LOCATION_CONFIG: LocationConfig = LocationConfig::default();
110}
111
112impl LocationConfig {
113    /// Add a provider to this configuration. It will be wrapped into an `Arc<Box<T>>`.
114    pub fn with_provider<P: Provider + 'static>(mut self, provider: P) -> Self {
115        self.providers.push(Arc::new(Box::new(provider)));
116        self
117    }
118
119    /// Add a metrics sink to this configuration. It will be wrapped into an `Arc<Option<Box<T>>>`.
120    #[cfg(feature = "cadence")]
121    pub fn with_metrics<M: cadence::CountedExt + Send + Sync + 'static>(
122        mut self,
123        metrics: Arc<M>,
124    ) -> Self {
125        self.metrics = Some(metrics);
126        self
127    }
128
129    fn from_req(req: &HttpRequest) -> &Self {
130        req.app_data::<Self>()
131            .or_else(|| req.app_data::<web::Data<Self>>().map(|d| d.as_ref()))
132            .unwrap_or(&DEFAULT_LOCATION_CONFIG)
133    }
134}
135
136#[cfg(test)]
137mod tests {
138    use crate::{providers::FallbackProvider, Location, LocationConfig};
139
140    #[cfg(not(feature = "actix-web-v4"))]
141    use actix_web_3::{dev::Payload, test::TestRequest, FromRequest};
142    #[cfg(feature = "actix-web-v4")]
143    use actix_web_4::{dev::Payload, test::TestRequest, FromRequest};
144
145    #[actix_rt::test]
146    async fn default_config() {
147        let req = TestRequest::default()
148            .app_data(LocationConfig::default())
149            .to_http_request();
150        let location = Location::from_request(&req, &mut Payload::None)
151            .await
152            .expect("error getting request");
153        assert_eq!(
154            location,
155            Location {
156                country: None,
157                region: None,
158                city: None,
159                dma: None,
160                provider: "none".to_string()
161            }
162        );
163    }
164
165    #[actix_rt::test]
166    async fn with_provider() {
167        let provider = FallbackProvider::new(
168            Location::build()
169                .country("CA".to_string())
170                .region("ON".to_string())
171                .city("Toronto".to_string()),
172        );
173        let config = LocationConfig::default().with_provider(provider);
174        let req = TestRequest::default().app_data(config).to_http_request();
175        let location = Location::from_request(&req, &mut Payload::None)
176            .await
177            .expect("error getting request");
178        assert_eq!(
179            location,
180            Location {
181                country: Some("CA".to_string()),
182                region: Some("ON".to_string()),
183                city: Some("Toronto".to_string()),
184                dma: None,
185                provider: "fallback".to_string()
186            }
187        );
188    }
189
190    // TODO test metrics
191}