Skip to main content

ralertsinua_http/
client.rs

1//! The client implementation for the reqwest HTTP client, which is async
2//! @borrows https://github.com/ramsayleung/rspotify/blob/master/rspotify-http/src/reqwest.rs
3
4use async_trait::async_trait;
5use bytes::Bytes;
6use ralertsinua_models::*;
7use reqwest::{
8    header::{HeaderMap, HeaderValue},
9    Client, ClientBuilder, Method, RequestBuilder, Response, StatusCode,
10};
11use serde::Deserialize;
12use std::fmt;
13use std::{collections::HashMap, sync::Arc};
14
15#[cfg(feature = "cache")]
16use crate::cache::*;
17use crate::error::*;
18
19type Query<'a> = HashMap<&'a str, &'a str>;
20type Result<T> = miette::Result<T, ApiError>;
21
22pub const API_BASE_URL: &str = "https://api.alerts.in.ua";
23pub const API_VERSION: &str = "/v1";
24pub const API_CACHE_SIZE: usize = 1000;
25
26pub struct AlertsInUaClient {
27    base_url: String,
28    token: String,
29    client: Client,
30    #[cfg(feature = "cache")]
31    cache_manager: Arc<dyn CacheManagerSync>,
32}
33
34impl std::fmt::Debug for AlertsInUaClient {
35    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
36        write!(f, "AlertsInUaClient {{ base_url: {}, token: {}, client: {:?}, cache_manager: {:?} }}", self.base_url, self.token, self.client, "CACacheManager")
37    }
38}
39
40impl AlertsInUaClient {
41    const APP_USER_AGENT: &'static str =
42        concat!(env!("CARGO_PKG_NAME"), "/", env!("CARGO_PKG_VERSION"),);
43
44    pub fn new(base_url: &str, token: &str) -> Self {
45        let base_url = base_url.into();
46        let token = token.into();
47        let client = ClientBuilder::new()
48            .timeout(std::time::Duration::from_secs(10))
49            .user_agent(Self::APP_USER_AGENT)
50            .build()
51            // building with these options cannot fail
52            .unwrap();
53
54        let cache_manager = Arc::new(CacheManagerQuick::new(API_CACHE_SIZE));
55
56        Self {
57            base_url,
58            token,
59            client,
60            #[cfg(feature = "cache")]
61            cache_manager,
62        }
63    }
64}
65
66impl AlertsInUaClient {
67    fn get_api_url(&self, url: &str) -> String {
68        format!("{}{}{}", self.base_url, API_VERSION, url)
69    }
70
71    async fn request<R, D>(&self, method: Method, url: &str, add_data: D) -> Result<R>
72    where
73        R: for<'de> Deserialize<'de>,
74        D: Fn(RequestBuilder) -> RequestBuilder,
75    {
76        let mut last_modified = String::new();
77        let mut cached_data: Bytes = Bytes::new();
78        // Build full URL
79        let url = self.get_api_url(url);
80        let mut req = self.client.request(method.clone(), &url);
81        // Enable HTTP bearer authentication.
82        req = req.bearer_auth(&self.token);
83        // Get last_modified from cache
84        let mut headers = HeaderMap::new();
85        // Set the headers
86        headers.insert("Accept", HeaderValue::from_static("application/json"));
87
88        if cfg!(feature = "cache") {
89            if let Some(CacheEntry(bytes, lm)) = self.cache_manager.get(&url)? {
90                last_modified = lm;
91                cached_data = bytes;
92            }
93            // Here we set the If-Modified-Since header from the last_modified
94            headers.insert(
95                "If-Modified-Since",
96                last_modified.parse().map_err(http::Error::from)?,
97            );
98        }
99
100        req = req.headers(headers);
101        // Configuring the request for the specific type (get/post/put/delete)
102        req = add_data(req);
103        // Finally performing the request and handling the response
104        log::trace!(target: env!("CARGO_PKG_NAME"), "Request {:?}", req);
105        let res: Response = req.send().await.inspect_err(|e| {
106            log::error!(target: env!("CARGO_PKG_NAME"),  "Error making request: {:?}", e);
107        })?;
108        log::trace!(target: env!("CARGO_PKG_NAME"), "Response {:?}", res);
109        // Making sure that the status code is OK
110        if let Err(err) = res.error_for_status_ref() {
111            let err = match err.status() {
112                Some(StatusCode::BAD_REQUEST) => Err(ApiError::InvalidParameterException),
113                Some(StatusCode::UNAUTHORIZED) => Err(ApiError::UnauthorizedError(err)),
114                Some(StatusCode::FORBIDDEN) => Err(ApiError::InvalidParameterException),
115                Some(StatusCode::METHOD_NOT_ALLOWED) | Some(StatusCode::NOT_FOUND) => {
116                    Err(ApiError::InvalidURL(err))
117                }
118                Some(StatusCode::TOO_MANY_REQUESTS) => Err(ApiError::RateLimitError),
119                Some(StatusCode::INTERNAL_SERVER_ERROR) => {
120                    Err(ApiError::InternalServerError)
121                }
122                _ => Err(ApiError::Unknown(err)),
123            };
124
125            return err;
126        }
127
128        last_modified = format!("{:?}", res.headers().get("Last-Modified").unwrap());
129        // -------------------------------------------------------------
130        let data: Bytes = match res.status() {
131            #[cfg(feature = "cache")]
132            StatusCode::NOT_MODIFIED => {
133                log::trace!(target: env!("CARGO_PKG_NAME"), "Response status '304 Not Modified', return cached data");
134                cached_data
135            }
136            _ => {
137                let bytes = res.bytes().await?;
138                if cfg!(feature = "cache") {
139                    // Save the data to the cache
140                    self.cache_manager
141                        .put(&url, &last_modified, bytes.clone())
142                        .inspect_err(|e| {
143                            log::error!("Error writing to cache: {:?}", e);
144                        })?;
145                }
146
147                bytes
148            }
149        };
150
151        // Return deserialized data
152        Ok(serde_json::from_slice(&data)?)
153    }
154}
155
156/// This trait represents the interface to be implemented for an HTTP client,
157/// which is kept separate from the Spotify client for cleaner code. Thus, it
158/// also requires other basic traits that are needed for the Spotify client.
159///
160/// When a request doesn't need to pass parameters, the empty or default value
161/// of the payload type should be passed, like `json!({})` or `Query::new()`.
162/// This avoids using `Option<T>` because `Value` itself may be null in other
163/// different ways (`Value::Null`, an empty `Value::Object`...), so this removes
164/// redundancy and edge cases (a `Some(Value::Null), for example, doesn't make
165/// much sense).
166pub trait BaseHttpClient: Send + fmt::Debug {
167    // This internal function should always be given an object value in JSON.
168    #[allow(async_fn_in_trait)]
169    async fn get<R>(&self, url: &str, payload: &Query) -> Result<R>
170    where
171        R: for<'de> Deserialize<'de>;
172}
173
174impl BaseHttpClient for AlertsInUaClient {
175    #[inline]
176    async fn get<R>(&self, url: &str, _payload: &Query<'_>) -> Result<R>
177    where
178        R: for<'de> Deserialize<'de>,
179    {
180        self.request(Method::GET, url, |r| r).await
181    }
182}
183
184/// The API for the AlertsInUaClient
185#[async_trait]
186pub trait AlertsInUaApi: fmt::Debug {
187    async fn get_active_alerts(&self) -> Result<Alerts>;
188
189    async fn get_alerts_history(&self, location_aid: &i8, period: &str) -> Result<Alerts>;
190
191    async fn get_air_raid_alert_status(&self, location_aid: &i8) -> Result<String>;
192
193    async fn get_air_raid_alert_statuses_by_location(
194        &self,
195    ) -> Result<AirRaidAlertOblastStatuses>;
196}
197
198#[async_trait]
199impl AlertsInUaApi for AlertsInUaClient {
200    async fn get_active_alerts(&self) -> Result<Alerts> {
201        let url = "/alerts/active.json";
202        self.get(url, &Query::default()).await
203    }
204
205    async fn get_alerts_history(&self, location_aid: &i8, period: &str) -> Result<Alerts> {
206        let url = format!("/locations/{}/alerts/{}.json", location_aid, period);
207        self.get(&url, &Query::default()).await
208    }
209
210    async fn get_air_raid_alert_status(&self, location_aid: &i8) -> Result<String> {
211        let url = format!("/iot/active_air_raid_alerts/{}.json", location_aid);
212        self.get(&url, &Query::default()).await
213    }
214
215    async fn get_air_raid_alert_statuses_by_location(
216        &self,
217    ) -> Result<AirRaidAlertOblastStatuses> {
218        let url = "/iot/active_air_raid_alerts_by_oblast.json";
219        let data: String = self.get(url, &Query::default()).await?;
220        let result = AirRaidAlertOblastStatuses::new(data, Some(true));
221        Ok(result)
222    }
223}
224
225// The existence of this function makes the compiler catch if the Buf
226// trait is "object-safe" or not.
227fn _assert_trait_object(_: &dyn AlertsInUaApi) {}
228
229#[cfg(test)]
230mod tests {
231
232    use super::*;
233    #[allow(unused_imports)]
234    use mockall::predicate::*;
235    use mockito::Server as MockServer;
236    use serde_json::json;
237    use std::sync::Arc;
238
239    #[test]
240    fn test_trait() {
241        let api_client: Arc<dyn AlertsInUaApi> = Arc::new(AlertsInUaClient::new("", ""));
242        println!("{:?}", api_client);
243    }
244
245    /* #[tokio::test]
246    async fn test_get_last_modified() {
247        let client = AlertsInUaClient::new("https://api.alerts.in.ua", "token");
248        let result = client.get_last_modified().await;
249        assert!(result.is_ok());
250    } */
251
252    #[test]
253    fn test_get_api_url() {
254        let client = AlertsInUaClient::new("https://api.alerts.in.ua", "token");
255        let url = client.get_api_url("/alerts/active.json");
256        assert_eq!(url, "https://api.alerts.in.ua/v1/alerts/active.json");
257    }
258
259    #[tokio::test]
260    async fn test_get_active_alerts() -> Result<()> {
261        let mut server = MockServer::new_async().await;
262        let client = AlertsInUaClient::new(server.url().as_str(), "token");
263        let mock = server
264            .mock(
265                "GET",
266                mockito::Matcher::Any, /* API_ALERTS_ACTIVE_BY_REGION_STRING */
267            )
268            .with_header("Last-Modified", "Tue, 14 May 2024 18:18:18 GMT")
269            .with_body(r#"{"alerts":[],"disclaimer":"","meta":{"last_updated_at":"2024/05/06 10:02:45 +0000"}}"#)
270            .create_async()
271            .await;
272        let expected_response: Alerts =
273            serde_json::from_value(json!({"alerts":[],"disclaimer":"","meta":{"last_updated_at":"2024/05/06 10:02:45 +0000"}})).unwrap();
274
275        let result = client.get_active_alerts().await?;
276
277        mock.assert();
278        assert_eq!(result, expected_response);
279
280        Ok(())
281    }
282
283    #[tokio::test]
284    async fn test_get_air_raid_alert_statuses_by_location() -> Result<()> {
285        let mut server = MockServer::new_async().await;
286        let client = AlertsInUaClient::new(server.url().as_str(), "token");
287        let mock = server
288            .mock(
289                "GET",
290                mockito::Matcher::Any, /* API_ALERTS_ACTIVE_BY_REGION_STRING */
291            )
292            .with_header("Last-Modified", "Tue, 14 May 2024 18:18:18 GMT")
293            .with_body(r#""ANNAANNANNNPANANANNNNAANNNN""#)
294            .create_async()
295            .await;
296
297        let _result = client.get_air_raid_alert_statuses_by_location().await?;
298
299        mock.assert();
300        // FIXME:
301        // assert_eq!(&*result, "ANNAANNANNNPANANANNNNAANNNN");
302
303        Ok(())
304    }
305}