newsapi_rs/
client.rs

1use crate::constant::{
2    EVERYTHING_ENDPOINT, NEWS_API_CLIENT_USER_AGENT, NEWS_API_KEY_ENV, NEWS_API_URI,
3    SOURCES_ENDPOINT, TOP_HEADLINES_ENDPOINT,
4};
5use crate::error::{ApiClientError, ApiClientErrorCode, ApiClientErrorResponse};
6use crate::model::{
7    GetEverythingRequest, GetEverythingResponse, GetSourcesRequest, GetSourcesResponse,
8    GetTopHeadlinesRequest, TopHeadlinesResponse,
9};
10#[cfg(feature = "blocking")]
11use crate::retry::retry_blocking;
12use crate::retry::{retry, RetryStrategy};
13use reqwest::header::{HeaderMap, HeaderValue, AUTHORIZATION, USER_AGENT};
14use serde::{Deserialize, Serialize};
15use std::env;
16use url::Url;
17
18#[derive(Debug, Deserialize, Serialize)]
19struct NewsApiErrorResponse {
20    status: String,
21    code: Option<String>,
22    message: Option<String>,
23}
24
25#[derive(Clone, Debug)]
26pub struct NewsApiClient<T> {
27    client: T,
28    api_key: String,
29    base_url: Url,
30    retry_strategy: RetryStrategy,
31    max_retries: usize,
32}
33
34pub struct NewsApiClientBuilder {
35    api_key: Option<String>,
36    base_url: Option<Url>,
37    retry_strategy: RetryStrategy,
38    max_retries: usize,
39}
40
41impl Default for NewsApiClientBuilder {
42    fn default() -> Self {
43        Self {
44            api_key: None,
45            base_url: Some(Url::parse(NEWS_API_URI).unwrap()),
46            retry_strategy: RetryStrategy::default(),
47            max_retries: 0,
48        }
49    }
50}
51
52impl NewsApiClientBuilder {
53    pub fn new() -> Self {
54        Self::default()
55    }
56
57    pub fn api_key(mut self, api_key: impl Into<String>) -> Self {
58        self.api_key = Some(api_key.into());
59        self
60    }
61
62    pub fn base_url(mut self, url: impl AsRef<str>) -> Result<Self, url::ParseError> {
63        self.base_url = Some(Url::parse(url.as_ref())?);
64        Ok(self)
65    }
66
67    pub fn retry(mut self, strategy: RetryStrategy, max_retries: usize) -> Self {
68        self.retry_strategy = strategy;
69        self.max_retries = max_retries;
70        self
71    }
72
73    pub fn from_env() -> Self {
74        match env::var(NEWS_API_KEY_ENV) {
75            Ok(api_key) => Self::new().api_key(api_key),
76            Err(_) => panic!("{} is not set", NEWS_API_KEY_ENV),
77        }
78    }
79
80    pub fn build(self) -> Result<NewsApiClient<reqwest::Client>, String> {
81        let api_key = match self.api_key {
82            Some(key) => key,
83            None => match env::var(NEWS_API_KEY_ENV) {
84                Ok(key) => key,
85                Err(_) => {
86                    return Err(format!(
87                        "API key must be provided either explicitly or via {} environment variable",
88                        NEWS_API_KEY_ENV
89                    ))
90                }
91            },
92        };
93
94        let base_url = self
95            .base_url
96            .unwrap_or_else(|| Url::parse(NEWS_API_URI).unwrap());
97
98        Ok(NewsApiClient {
99            client: reqwest::Client::new(),
100            api_key,
101            base_url,
102            retry_strategy: self.retry_strategy,
103            max_retries: self.max_retries,
104        })
105    }
106}
107
108#[cfg(feature = "blocking")]
109pub struct BlockingNewsApiClientBuilder {
110    api_key: Option<String>,
111    base_url: Option<Url>,
112    retry_strategy: RetryStrategy,
113    max_retries: usize,
114}
115
116#[cfg(feature = "blocking")]
117impl Default for BlockingNewsApiClientBuilder {
118    fn default() -> Self {
119        Self {
120            api_key: None,
121            base_url: Some(Url::parse(NEWS_API_URI).unwrap()),
122            retry_strategy: RetryStrategy::default(),
123            max_retries: 0,
124        }
125    }
126}
127
128#[cfg(feature = "blocking")]
129impl BlockingNewsApiClientBuilder {
130    pub fn new() -> Self {
131        Self::default()
132    }
133
134    pub fn api_key(mut self, api_key: impl Into<String>) -> Self {
135        self.api_key = Some(api_key.into());
136        self
137    }
138
139    pub fn base_url(mut self, url: impl AsRef<str>) -> Result<Self, url::ParseError> {
140        self.base_url = Some(Url::parse(url.as_ref())?);
141        Ok(self)
142    }
143
144    pub fn retry(mut self, strategy: RetryStrategy, max_retries: usize) -> Self {
145        self.retry_strategy = strategy;
146        self.max_retries = max_retries;
147        self
148    }
149
150    pub fn from_env() -> Self {
151        match env::var(NEWS_API_KEY_ENV) {
152            Ok(api_key) => Self::new().api_key(api_key),
153            Err(_) => panic!("{} is not set", NEWS_API_KEY_ENV),
154        }
155    }
156
157    pub fn build(self) -> Result<NewsApiClient<reqwest::blocking::Client>, String> {
158        let api_key = match self.api_key {
159            Some(key) => key,
160            None => match env::var(NEWS_API_KEY_ENV) {
161                Ok(key) => key,
162                Err(_) => {
163                    return Err(format!(
164                        "API key must be provided either explicitly or via {} environment variable",
165                        NEWS_API_KEY_ENV
166                    ))
167                }
168            },
169        };
170
171        let base_url = self
172            .base_url
173            .unwrap_or_else(|| Url::parse(NEWS_API_URI).unwrap());
174
175        Ok(NewsApiClient {
176            client: reqwest::blocking::Client::new(),
177            api_key,
178            base_url,
179            retry_strategy: self.retry_strategy,
180            max_retries: self.max_retries,
181        })
182    }
183}
184
185#[cfg(feature = "blocking")]
186mod blocking {
187    use super::*;
188    use reqwest::blocking::Client as BlockingClient;
189
190    impl NewsApiClient<BlockingClient> {
191        pub fn new_blocking(api_key: &str) -> Self {
192            NewsApiClient {
193                client: BlockingClient::new(),
194                api_key: api_key.to_string(),
195                base_url: Url::parse(NEWS_API_URI).unwrap(),
196                retry_strategy: RetryStrategy::default(),
197                max_retries: 0,
198            }
199        }
200
201        pub fn builder_blocking() -> super::BlockingNewsApiClientBuilder {
202            super::BlockingNewsApiClientBuilder::new()
203        }
204
205        fn parse_error_response(&self, response_text: String, status_code: u16) -> ApiClientError {
206            NewsApiClient::<BlockingClient>::parse_error_response_internal(
207                response_text,
208                status_code,
209            )
210        }
211
212        pub fn get_everything(
213            self,
214            request: &GetEverythingRequest,
215        ) -> Result<GetEverythingResponse, ApiClientError> {
216            retry_blocking(self.retry_strategy, self.max_retries, || {
217                log::debug!("Request: {:?}", request);
218
219                let mut url = self.base_url.clone();
220                NewsApiClient::<BlockingClient>::get_endpoint_with_query_params_for_everything(
221                    &mut url, request,
222                );
223                log::debug!("Request URL: {}", url.as_str());
224
225                let headers = self.get_request_headers()?;
226                let response = self.client.get(url.as_str()).headers(headers).send()?;
227                let status = response.status();
228                log::debug!("Response status: {:?}", status);
229
230                if status.is_success() {
231                    let response_text = response.text()?;
232                    match serde_json::from_str::<GetEverythingResponse>(&response_text) {
233                        Ok(everything_response) => Ok(everything_response),
234                        Err(e) => Err(ApiClientError::InvalidRequest(format!("{}", e))),
235                    }
236                } else {
237                    let response_text = response.text()?;
238                    Err(self.parse_error_response(response_text, status.as_u16()))
239                }
240            })
241        }
242
243        pub fn get_top_headlines(
244            self,
245            request: &GetTopHeadlinesRequest,
246        ) -> Result<TopHeadlinesResponse, ApiClientError> {
247            retry_blocking(self.retry_strategy, self.max_retries, || {
248                log::debug!("Request: {:?}", request);
249                NewsApiClient::<BlockingClient>::top_headlines_validate_request(request)?;
250
251                let mut url = self.base_url.clone();
252                NewsApiClient::<BlockingClient>::get_endpoint_with_query_params_for_top_headlines(
253                    &mut url, request,
254                );
255                log::debug!("Request URL: {}", url.as_str());
256
257                let headers = self.get_request_headers()?;
258                let response = self.client.get(url.as_str()).headers(headers).send()?;
259                let status = response.status();
260                log::debug!("Response status: {:?}", status);
261
262                if status.is_success() {
263                    let response_text = response.text()?;
264                    match serde_json::from_str::<TopHeadlinesResponse>(&response_text) {
265                        Ok(headline_response) => Ok(headline_response),
266                        Err(e) => Err(ApiClientError::InvalidRequest(format!(
267                            "Failed to parse response: {}",
268                            e
269                        ))),
270                    }
271                } else {
272                    let response_text = response.text()?;
273                    Err(self.parse_error_response(response_text, status.as_u16()))
274                }
275            })
276        }
277
278        pub fn get_sources(
279            self,
280            request: &GetSourcesRequest,
281        ) -> Result<GetSourcesResponse, ApiClientError> {
282            retry_blocking(self.retry_strategy, self.max_retries, || {
283                log::debug!("Request: {:?}", request);
284
285                let mut url = self.base_url.clone();
286                NewsApiClient::<BlockingClient>::get_endpoint_with_query_params_for_sources(
287                    &mut url, request,
288                );
289                log::debug!("Request URL: {}", url.as_str());
290
291                let headers = self.get_request_headers()?;
292                let response = self.client.get(url.as_str()).headers(headers).send()?;
293                let status = response.status();
294                log::debug!("Response status: {:?}", status);
295
296                if status.is_success() {
297                    let response_text = response.text()?;
298                    match serde_json::from_str::<GetSourcesResponse>(&response_text) {
299                        Ok(sources_response) => Ok(sources_response),
300                        Err(e) => Err(ApiClientError::InvalidRequest(format!("{}", e))),
301                    }
302                } else {
303                    let response_text = response.text()?;
304                    Err(self.parse_error_response(response_text, status.as_u16()))
305                }
306            })
307        }
308
309        pub fn with_retry(mut self, strategy: RetryStrategy, max_retries: usize) -> Self {
310            self.retry_strategy = strategy;
311            self.max_retries = max_retries;
312            self
313        }
314    }
315}
316
317impl NewsApiClient<reqwest::Client> {
318    pub fn new(api_key: &str) -> Self {
319        NewsApiClient {
320            client: reqwest::Client::new(),
321            api_key: api_key.to_string(),
322            base_url: Url::parse(NEWS_API_URI).unwrap(),
323            retry_strategy: RetryStrategy::default(),
324            max_retries: 0,
325        }
326    }
327
328    pub fn builder() -> NewsApiClientBuilder {
329        NewsApiClientBuilder::new()
330    }
331
332    pub fn from_env() -> Self {
333        match env::var(NEWS_API_KEY_ENV) {
334            Ok(api_key) => NewsApiClient::new(&api_key),
335            Err(_) => panic!("{} is not set", NEWS_API_KEY_ENV),
336        }
337    }
338
339    fn parse_error_response(&self, response_text: String, status_code: u16) -> ApiClientError {
340        NewsApiClient::<reqwest::Client>::parse_error_response_internal(response_text, status_code)
341    }
342
343    pub async fn get_everything(
344        &self,
345        request: &GetEverythingRequest,
346    ) -> Result<GetEverythingResponse, ApiClientError> {
347        retry(self.retry_strategy, self.max_retries, || async {
348            log::debug!("Request: {:?}", request);
349
350            let mut url = self.base_url.clone();
351            Self::get_endpoint_with_query_params_for_everything(&mut url, request);
352            log::debug!("Request URL: {}", url.as_str());
353
354            let headers = self.get_request_headers()?;
355            let response = self
356                .client
357                .get(url.as_str())
358                .headers(headers)
359                .send()
360                .await?;
361            let status = response.status();
362            log::debug!("Response status: {:?}", status);
363
364            if status.is_success() {
365                let response_text = response.text().await?;
366                match serde_json::from_str::<GetEverythingResponse>(&response_text) {
367                    Ok(everything_response) => Ok(everything_response),
368                    Err(e) => Err(ApiClientError::InvalidRequest(format!("{}", e))),
369                }
370            } else {
371                let response_text = response.text().await?;
372                Err(self.parse_error_response(response_text, status.as_u16()))
373            }
374        })
375        .await
376    }
377
378    pub async fn get_top_headlines(
379        &self,
380        request: &GetTopHeadlinesRequest,
381    ) -> Result<TopHeadlinesResponse, ApiClientError> {
382        retry(self.retry_strategy, self.max_retries, || async {
383            log::debug!("Request: {:?}", request);
384            Self::top_headlines_validate_request(request)?;
385
386            let mut url = self.base_url.clone();
387            Self::get_endpoint_with_query_params_for_top_headlines(&mut url, request);
388            log::debug!("Request URL: {}", url.as_str());
389
390            let headers = self.get_request_headers()?;
391            let response = self
392                .client
393                .get(url.as_str())
394                .headers(headers)
395                .send()
396                .await?;
397            let status = response.status();
398            log::debug!("Response status: {:?}", status);
399
400            if status.is_success() {
401                let response_text = response.text().await?;
402                match serde_json::from_str::<TopHeadlinesResponse>(&response_text) {
403                    Ok(headline_response) => Ok(headline_response),
404                    Err(e) => Err(ApiClientError::InvalidRequest(format!(
405                        "Failed to parse response: {}",
406                        e
407                    ))),
408                }
409            } else {
410                let response_text = response.text().await?;
411                Err(self.parse_error_response(response_text, status.as_u16()))
412            }
413        })
414        .await
415    }
416
417    pub async fn get_sources(
418        &self,
419        request: &GetSourcesRequest,
420    ) -> Result<GetSourcesResponse, ApiClientError> {
421        retry(self.retry_strategy, self.max_retries, || async {
422            log::debug!("Request: {:?}", request);
423
424            let mut url = self.base_url.clone();
425            Self::get_endpoint_with_query_params_for_sources(&mut url, request);
426            log::debug!("Request URL: {}", url.as_str());
427
428            let headers = self.get_request_headers()?;
429            let response = self
430                .client
431                .get(url.as_str())
432                .headers(headers)
433                .send()
434                .await?;
435            let status = response.status();
436            log::debug!("Response status: {:?}", status);
437
438            if status.is_success() {
439                let response_text = response.text().await?;
440                match serde_json::from_str::<GetSourcesResponse>(&response_text) {
441                    Ok(sources_response) => Ok(sources_response),
442                    Err(e) => Err(ApiClientError::InvalidRequest(format!("{}", e))),
443                }
444            } else {
445                let response_text = response.text().await?;
446                Err(self.parse_error_response(response_text, status.as_u16()))
447            }
448        })
449        .await
450    }
451
452    pub fn with_retry(mut self, strategy: RetryStrategy, max_retries: usize) -> Self {
453        self.retry_strategy = strategy;
454        self.max_retries = max_retries;
455        self
456    }
457}
458
459#[cfg(feature = "blocking")]
460impl NewsApiClient<reqwest::blocking::Client> {
461    pub fn from_env_blocking() -> Self {
462        match env::var(NEWS_API_KEY_ENV) {
463            Ok(api_key) => Self::new_blocking(&api_key),
464            Err(_) => panic!("{} is not set", NEWS_API_KEY_ENV),
465        }
466    }
467}
468
469impl<T> NewsApiClient<T> {
470    fn parse_error_response_internal(response_text: String, status_code: u16) -> ApiClientError {
471        match serde_json::from_str::<NewsApiErrorResponse>(&response_text) {
472            Ok(error_response) => {
473                let error_code = match error_response.code.as_deref() {
474                    Some("apiKeyDisabled") => ApiClientErrorCode::ApiKeyDisabled,
475                    Some("apiKeyExhausted") => ApiClientErrorCode::ApiKeyExhausted,
476                    Some("apiKeyInvalid") => ApiClientErrorCode::ApiKeyInvalid,
477                    Some("apiKeyMissing") => ApiClientErrorCode::ApiKeyMissing,
478                    Some("parameterInvalid") => ApiClientErrorCode::ParameterInvalid,
479                    Some("parametersMissing") => ApiClientErrorCode::ParametersMissing,
480                    Some("rateLimited") => ApiClientErrorCode::RateLimited,
481                    Some("sourcesTooMany") => ApiClientErrorCode::SourcesTooMany,
482                    Some("sourceDoesNotExist") => ApiClientErrorCode::SourceDoesNotExist,
483                    _ => {
484                        // Check for rate limiting based on status code
485                        if status_code == 429 {
486                            ApiClientErrorCode::RateLimited
487                        } else {
488                            ApiClientErrorCode::UnexpectedError
489                        }
490                    }
491                };
492
493                ApiClientError::InvalidResponse(ApiClientErrorResponse {
494                    status: error_response.status,
495                    code: error_code,
496                    message: error_response
497                        .message
498                        .unwrap_or_else(|| "Unknown error".to_string()),
499                })
500            }
501            Err(_) => {
502                let error_code = if status_code == 429 {
503                    ApiClientErrorCode::RateLimited
504                } else {
505                    ApiClientErrorCode::UnexpectedError
506                };
507
508                ApiClientError::InvalidResponse(ApiClientErrorResponse {
509                    status: "error".to_string(),
510                    code: error_code,
511                    message: if response_text.contains("too many requests")
512                        || response_text.contains("rate limit")
513                    {
514                        "You have made too many requests. Rate limit exceeded.".to_string()
515                    } else {
516                        "Failed to parse error response".to_string()
517                    },
518                })
519            }
520        }
521    }
522
523    fn get_request_headers(&self) -> Result<HeaderMap, ApiClientError> {
524        let mut headers = HeaderMap::new();
525        headers.insert(
526            AUTHORIZATION,
527            HeaderValue::from_str(&format!("Bearer {}", self.api_key))?,
528        );
529        headers.insert(
530            USER_AGENT,
531            HeaderValue::from_static(NEWS_API_CLIENT_USER_AGENT),
532        );
533        Ok(headers)
534    }
535
536    fn top_headlines_validate_request(
537        request: &GetTopHeadlinesRequest,
538    ) -> Result<(), ApiClientError> {
539        log::debug!("Validating request");
540        if request.get_sources().is_some()
541            && (request.get_country().is_some() || request.get_category().is_some())
542        {
543            return Err(ApiClientError::InvalidRequest(
544                "Cannot specify sources with country or category".to_string(),
545            ));
546        }
547        Ok(())
548    }
549
550    fn get_endpoint_with_query_params_for_top_headlines(
551        url: &mut Url,
552        request: &GetTopHeadlinesRequest,
553    ) {
554        url.set_path(TOP_HEADLINES_ENDPOINT);
555        url.query_pairs_mut().clear();
556
557        for (key, value) in Self::get_top_headlines_query_params(request) {
558            url.query_pairs_mut().append_pair(&key, &value);
559        }
560
561        url.query_pairs_mut().finish();
562    }
563
564    fn get_top_headlines_query_params(request: &GetTopHeadlinesRequest) -> Vec<(String, String)> {
565        let mut query_params = Vec::new();
566
567        if let Some(country) = request.get_country() {
568            query_params.push(("country".to_string(), country.to_string()));
569        }
570
571        if let Some(category) = request.get_category() {
572            query_params.push(("category".to_string(), category.to_string()));
573        }
574
575        if let Some(sources) = request.get_sources() {
576            query_params.push(("sources".to_string(), sources.to_string()));
577        }
578
579        if !request.get_search_term().is_empty() {
580            query_params.push(("q".to_string(), request.get_search_term().to_string()));
581        }
582
583        if *request.get_page_size() > 1 {
584            query_params.push(("pageSize".to_string(), request.get_page_size().to_string()));
585        }
586
587        if *request.get_page() > 1 {
588            query_params.push(("page".to_string(), request.get_page().to_string()));
589        }
590
591        query_params
592    }
593
594    fn get_endpoint_with_query_params_for_everything(
595        url: &mut Url,
596        request: &GetEverythingRequest,
597    ) {
598        url.set_path(EVERYTHING_ENDPOINT);
599        url.query_pairs_mut().clear();
600
601        let query_params = Self::get_everything_query_params(request);
602        for (key, value) in query_params {
603            url.query_pairs_mut().append_pair(&key, &value);
604        }
605
606        url.query_pairs_mut().finish();
607    }
608
609    fn get_everything_query_params(request: &GetEverythingRequest) -> Vec<(String, String)> {
610        let mut query_params = Vec::new();
611
612        query_params.push(("q".to_string(), request.get_search_term().to_string()));
613
614        if let Some(language) = request.get_language() {
615            query_params.push(("language".to_string(), language.to_string().to_lowercase()));
616        }
617
618        if let Some(start_date) = request.get_start_date() {
619            query_params.push(("from".to_string(), start_date.to_rfc3339()));
620        }
621
622        if let Some(end_date) = request.get_end_date() {
623            query_params.push(("to".to_string(), end_date.to_rfc3339()));
624        }
625
626        if *request.get_page_size() > 0 {
627            query_params.push(("pageSize".to_string(), request.get_page_size().to_string()));
628        }
629
630        if *request.get_page() > 1 {
631            query_params.push(("page".to_string(), request.get_page().to_string()));
632        }
633
634        query_params
635    }
636
637    fn get_endpoint_with_query_params_for_sources(url: &mut Url, request: &GetSourcesRequest) {
638        url.set_path(SOURCES_ENDPOINT);
639        url.query_pairs_mut().clear();
640
641        let query_params = Self::get_sources_query_params(request);
642        for (key, value) in query_params {
643            url.query_pairs_mut().append_pair(&key, &value);
644        }
645
646        url.query_pairs_mut().finish();
647    }
648
649    fn get_sources_query_params(request: &GetSourcesRequest) -> Vec<(String, String)> {
650        let mut query_params = Vec::new();
651
652        if let Some(category) = request.get_category() {
653            query_params.push(("category".to_string(), category.to_string()));
654        }
655
656        if let Some(language) = request.get_language() {
657            query_params.push(("language".to_string(), language.to_string().to_lowercase()));
658        }
659
660        if let Some(country) = request.get_country() {
661            query_params.push(("country".to_string(), country.to_string()));
662        }
663
664        query_params
665    }
666}
667
668#[cfg(test)]
669mod tests {
670    use super::*;
671    use crate::model::{Country, Language, NewsCategory};
672    use chrono::{DateTime, Utc};
673    use mockito;
674    use std::collections::HashMap;
675    use std::str::FromStr;
676    use std::time::Duration;
677
678    fn create_test_client() -> NewsApiClient<reqwest::Client> {
679        let api_key = "test-api-key";
680        let mut client = NewsApiClient::new(api_key);
681        let server = mockito::Server::new();
682        let mock_url = server.url();
683        client.base_url = Url::parse(&format!("http://{}", mock_url)).unwrap();
684        client
685    }
686
687    #[test]
688    fn test_parse_error_response() {
689        let error_json =
690            r#"{"status":"error","code":"apiKeyInvalid","message":"Your API key is invalid"}"#;
691        let error = NewsApiClient::<reqwest::Client>::parse_error_response_internal(
692            error_json.to_string(),
693            400,
694        );
695
696        match error {
697            ApiClientError::InvalidResponse(response) => {
698                assert_eq!(response.status, "error");
699                assert_eq!(response.code, ApiClientErrorCode::ApiKeyInvalid);
700                assert_eq!(response.message, "Your API key is invalid");
701            }
702            _ => panic!("Expected InvalidResponse error"),
703        }
704
705        let error_json =
706            r#"{"status":"error","code":"parameterInvalid","message":"Invalid parameter"}"#;
707        let error = NewsApiClient::<reqwest::Client>::parse_error_response_internal(
708            error_json.to_string(),
709            400,
710        );
711
712        match error {
713            ApiClientError::InvalidResponse(response) => {
714                assert_eq!(response.code, ApiClientErrorCode::ParameterInvalid);
715            }
716            _ => panic!("Expected InvalidResponse error"),
717        }
718
719        let error_json = r#"invalid json"#;
720        let error = NewsApiClient::<reqwest::Client>::parse_error_response_internal(
721            error_json.to_string(),
722            400,
723        );
724
725        match error {
726            ApiClientError::InvalidResponse(response) => {
727                assert_eq!(response.code, ApiClientErrorCode::UnexpectedError);
728            }
729            _ => panic!("Expected InvalidResponse error"),
730        }
731    }
732
733    #[test]
734    fn test_get_request_headers() {
735        let client = create_test_client();
736        let headers = client.get_request_headers().unwrap();
737
738        assert_eq!(
739            headers.get(AUTHORIZATION).unwrap().to_str().unwrap(),
740            "Bearer test-api-key"
741        );
742        assert_eq!(
743            headers.get(USER_AGENT).unwrap().to_str().unwrap(),
744            NEWS_API_CLIENT_USER_AGENT
745        );
746    }
747
748    #[test]
749    fn test_top_headlines_validate_request_country_and_category() {
750        let request = GetTopHeadlinesRequest::builder()
751            .country(Country::US)
752            .category(NewsCategory::Business)
753            .search_term(String::new())
754            .page_size(20)
755            .page(1)
756            .build()
757            .unwrap();
758        assert!(NewsApiClient::<reqwest::Client>::top_headlines_validate_request(&request).is_ok());
759    }
760
761    #[test]
762    fn test_top_headlines_validate_request_sources_only() {
763        let request = GetTopHeadlinesRequest::builder()
764            .sources("bbc-news,cnn".to_string())
765            .search_term(String::new())
766            .page_size(20)
767            .page(1)
768            .build()
769            .unwrap();
770        assert!(NewsApiClient::<reqwest::Client>::top_headlines_validate_request(&request).is_ok());
771    }
772
773    #[test]
774    fn test_top_headlines_validate_request_sources_with_country() {
775        let request = GetTopHeadlinesRequest::builder()
776            .sources("bbc-news".to_string())
777            .country(Country::US)
778            .search_term(String::new())
779            .page_size(20)
780            .page(1)
781            .build();
782
783        assert!(request.is_err());
784    }
785
786    #[test]
787    fn test_top_headlines_validate_request_sources_with_category() {
788        let request = GetTopHeadlinesRequest::builder()
789            .sources("bbc-news".to_string())
790            .category(NewsCategory::Business)
791            .search_term(String::new())
792            .page_size(20)
793            .page(1)
794            .build();
795
796        assert!(request.is_err());
797    }
798
799    #[test]
800    fn test_get_top_headlines_query_params() {
801        let request = GetTopHeadlinesRequest::builder()
802            .country(Country::US)
803            .category(NewsCategory::Technology)
804            .search_term("ai".to_string())
805            .page_size(15)
806            .page(2)
807            .build()
808            .unwrap();
809
810        let params = NewsApiClient::<reqwest::Client>::get_top_headlines_query_params(&request);
811        let params_map: HashMap<_, _> = params.into_iter().collect();
812
813        assert_eq!(params_map.get("country").unwrap(), "us");
814        assert_eq!(params_map.get("category").unwrap(), "technology");
815        assert_eq!(params_map.get("q").unwrap(), "ai");
816        assert_eq!(params_map.get("page").unwrap(), "2");
817        assert_eq!(params_map.get("pageSize").unwrap(), "15");
818    }
819
820    #[test]
821    fn test_get_everything_query_params() {
822        let start_date = DateTime::<Utc>::from_str("2023-01-01T00:00:00Z").unwrap();
823        let end_date = DateTime::<Utc>::from_str("2023-01-31T23:59:59Z").unwrap();
824
825        let request = GetEverythingRequest::builder()
826            .search_term(format!("bitcoin"))
827            .language(Language::AR)
828            .start_date(start_date)
829            .end_date(end_date)
830            .page(3)
831            .page_size(20)
832            .build();
833
834        let params = NewsApiClient::<reqwest::Client>::get_everything_query_params(&request);
835        let params_map: HashMap<_, _> = params.into_iter().collect();
836
837        assert_eq!(params_map.get("q").unwrap(), "bitcoin");
838        assert_eq!(params_map.get("language").unwrap(), "ar"); // Fix expectation to "ar" instead of "en"
839        assert_eq!(params_map.get("from").unwrap(), "2023-01-01T00:00:00+00:00");
840        assert_eq!(params_map.get("to").unwrap(), "2023-01-31T23:59:59+00:00");
841        assert_eq!(params_map.get("page").unwrap(), "3");
842        assert_eq!(params_map.get("pageSize").unwrap(), "20");
843    }
844
845    #[tokio::test]
846    async fn test_get_everything_async() {
847        let mock_response = r#"{
848            "status": "ok",
849            "totalResults": 2,
850            "articles": [
851                {
852                    "source": {"id": "test-source", "name": "Test Source"},
853                    "author": "Test Author",
854                    "title": "Test Title",
855                    "description": "Test Description",
856                    "url": "https://example.com/article1",
857                    "urlToImage": "https://example.com/image1.jpg",
858                    "publishedAt": "2023-05-01T12:00:00Z",
859                    "content": "Test content"
860                },
861                {
862                    "source": {"id": "test-source-2", "name": "Test Source 2"},
863                    "author": "Test Author 2",
864                    "title": "Test Title 2",
865                    "description": "Test Description 2",
866                    "url": "https://example.com/article2",
867                    "urlToImage": "https://example.com/image2.jpg",
868                    "publishedAt": "2023-05-02T12:00:00Z",
869                    "content": "Test content 2"
870                }
871            ]
872        }"#;
873
874        let mut server = mockito::Server::new_async().await;
875
876        let _m = server
877            .mock("GET", "/v2/everything")
878            .match_query(mockito::Matcher::Any)
879            .with_status(200)
880            .with_header("content-type", "application/json")
881            .with_body(mock_response)
882            .create_async()
883            .await;
884
885        let mut client = NewsApiClient::new("test-api-key");
886        client.base_url = Url::parse(&format!("{}", server.url())).unwrap();
887
888        let request = GetEverythingRequest::builder()
889            .search_term(format!("test"))
890            .build();
891
892        let response = client.get_everything(&request).await.unwrap();
893
894        assert_eq!(response.get_status(), "ok");
895        assert_eq!(*response.get_total_results(), 2);
896        assert_eq!(response.get_articles().len(), 2);
897        assert_eq!(response.get_articles()[0].get_title(), "Test Title");
898        assert_eq!(response.get_articles()[1].get_title(), "Test Title 2");
899    }
900
901    #[tokio::test]
902    async fn test_get_top_headlines_async() {
903        let mock_response = r#"{
904            "status": "ok",
905            "totalResults": 1,
906            "articles": [
907                {
908                    "source": {"id": "test-source", "name": "Test Source"},
909                    "author": "Test Author",
910                    "title": "Breaking News",
911                    "description": "Test Description",
912                    "url": "https://example.com/article1",
913                    "urlToImage": "https://example.com/image1.jpg",
914                    "publishedAt": "2023-05-01T12:00:00Z",
915                    "content": "Test content"
916                }
917            ]
918        }"#;
919
920        let mut server = mockito::Server::new_async().await;
921        let _m = server
922            .mock("GET", "/v2/top-headlines")
923            .match_query(mockito::Matcher::Any)
924            .with_status(200)
925            .with_header("content-type", "application/json")
926            .with_body(mock_response)
927            .create_async()
928            .await;
929        let mut client = NewsApiClient::new("test-api-key");
930        client.base_url = Url::parse(&format!("{}", server.url())).unwrap();
931
932        let request = GetTopHeadlinesRequest::builder()
933            .country(Country::US)
934            .search_term(String::new())
935            .page_size(20)
936            .page(1)
937            .build()
938            .unwrap();
939
940        let response = client.get_top_headlines(&request).await.unwrap();
941
942        assert_eq!(response.get_status(), "ok");
943        assert_eq!(*response.get_total_results(), 1);
944        assert_eq!(response.get_articles().len(), 1);
945        assert_eq!(response.get_articles()[0].get_title(), "Breaking News");
946    }
947
948    #[tokio::test]
949    async fn test_error_responses_async() {
950        let error_response = r#"{
951            "status": "error",
952            "code": "apiKeyInvalid",
953            "message": "Your API key is invalid or incorrect"
954        }"#;
955
956        let mut server = mockito::Server::new_async().await;
957        let _m = server
958            .mock("GET", "/v2/everything")
959            .match_query(mockito::Matcher::Any)
960            .with_status(401)
961            .with_body(error_response)
962            .create_async()
963            .await;
964
965        let mut client = NewsApiClient::new("test-api-key");
966        client.base_url = Url::parse(&format!("{}", server.url())).unwrap();
967
968        let request = GetEverythingRequest::builder()
969            .search_term(format!("test"))
970            .build();
971
972        let result = client.get_everything(&request).await;
973        assert!(result.is_err());
974
975        match result.unwrap_err() {
976            ApiClientError::InvalidResponse(response) => {
977                assert_eq!(response.code, ApiClientErrorCode::ApiKeyInvalid);
978            }
979            _ => panic!("Expected InvalidResponse error"),
980        }
981    }
982
983    #[cfg(feature = "blocking")]
984    mod blocking_tests {
985        use super::*;
986        use mockito::Mock;
987
988        #[test]
989        fn test_get_everything_blocking() {
990            let mock_response = r#"{
991                "status": "ok",
992                "totalResults": 1,
993                "articles": [
994                    {
995                        "source": {"id": "test-source", "name": "Test Source"},
996                        "author": "Test Author",
997                        "title": "Test Title Blocking",
998                        "description": "Test Description",
999                        "url": "https://example.com/article1",
1000                        "urlToImage": "https://example.com/image1.jpg",
1001                        "publishedAt": "2023-05-01T12:00:00Z",
1002                        "content": "Test content"
1003                    }
1004                ]
1005            }"#;
1006            let mut server = mockito::Server::new();
1007            let _m: Mock = server
1008                .mock("GET", "/v2/everything")
1009                .match_query(mockito::Matcher::Any)
1010                .with_status(200)
1011                .with_header("content-type", "application/json")
1012                .with_body(mock_response)
1013                .create();
1014
1015            let mut client = NewsApiClient::new_blocking("test-api-key");
1016            client.base_url = Url::parse(&format!("{}", server.url())).unwrap();
1017            let request = GetEverythingRequest::builder()
1018                .search_term("test".to_string())
1019                .build();
1020            let response = client.get_everything(&request).unwrap();
1021
1022            assert_eq!(response.get_status(), "ok");
1023            assert_eq!(*response.get_total_results(), 1);
1024            assert_eq!(
1025                response.get_articles()[0].get_title(),
1026                "Test Title Blocking"
1027            );
1028        }
1029    }
1030
1031    #[test]
1032    fn test_builder_pattern() {
1033        let client = NewsApiClient::<reqwest::Client>::builder()
1034            .api_key("test-api-key")
1035            .retry(RetryStrategy::Exponential(Duration::from_millis(100)), 3)
1036            .build()
1037            .unwrap();
1038
1039        assert_eq!(client.api_key, "test-api-key");
1040        assert_eq!(client.max_retries, 3);
1041    }
1042
1043    #[test]
1044    fn test_builder_failure() {
1045        let old_value = std::env::var(NEWS_API_KEY_ENV).ok();
1046        std::env::remove_var(NEWS_API_KEY_ENV);
1047        let result = NewsApiClient::builder().build();
1048
1049        if let Some(val) = old_value {
1050            std::env::set_var(NEWS_API_KEY_ENV, val);
1051        }
1052        assert!(result.is_err());
1053        assert_eq!(
1054            result.unwrap_err(),
1055            format!(
1056                "API key must be provided either explicitly or via {} environment variable",
1057                NEWS_API_KEY_ENV
1058            )
1059        );
1060    }
1061
1062    #[test]
1063    fn test_builder_from_env() {
1064        std::env::set_var(NEWS_API_KEY_ENV, "env-api-key");
1065
1066        let client = NewsApiClientBuilder::from_env().build().unwrap();
1067
1068        assert_eq!(client.api_key, "env-api-key");
1069        std::env::remove_var(NEWS_API_KEY_ENV);
1070    }
1071
1072    #[cfg(feature = "blocking")]
1073    #[test]
1074    fn test_blocking_builder_pattern() {
1075        let client = BlockingNewsApiClientBuilder::new()
1076            .api_key("test-api-key")
1077            .retry(RetryStrategy::Constant(Duration::from_secs(1)), 2)
1078            .build()
1079            .unwrap();
1080
1081        assert_eq!(client.api_key, "test-api-key");
1082        assert_eq!(client.max_retries, 2);
1083    }
1084}