1use crate::constant::{
2 EVERYTHING_ENDPOINT, NEWS_API_CLIENT_USER_AGENT, NEWS_API_KEY_ENV, NEWS_API_URI,
3 SOURCES_ENDPOINT, TOP_HEADLINES_ENDPOINT,
4};
5use crate::error::{ApiClientError, ApiClientErrorCode, ApiClientErrorResponse};
6use crate::model::{
7 GetEverythingRequest, GetEverythingResponse, GetSourcesRequest, GetSourcesResponse,
8 GetTopHeadlinesRequest, TopHeadlinesResponse,
9};
10#[cfg(feature = "blocking")]
11use crate::retry::retry_blocking;
12use crate::retry::{retry, RetryStrategy};
13use reqwest::header::{HeaderMap, HeaderValue, AUTHORIZATION, USER_AGENT};
14use serde::{Deserialize, Serialize};
15use std::env;
16use url::Url;
17
18#[derive(Debug, Deserialize, Serialize)]
19struct NewsApiErrorResponse {
20 status: String,
21 code: Option<String>,
22 message: Option<String>,
23}
24
25#[derive(Clone, Debug)]
26pub struct NewsApiClient<T> {
27 client: T,
28 api_key: String,
29 base_url: Url,
30 retry_strategy: RetryStrategy,
31 max_retries: usize,
32}
33
34pub struct NewsApiClientBuilder {
35 api_key: Option<String>,
36 base_url: Option<Url>,
37 retry_strategy: RetryStrategy,
38 max_retries: usize,
39}
40
41impl Default for NewsApiClientBuilder {
42 fn default() -> Self {
43 Self {
44 api_key: None,
45 base_url: Some(Url::parse(NEWS_API_URI).unwrap()),
46 retry_strategy: RetryStrategy::default(),
47 max_retries: 0,
48 }
49 }
50}
51
52impl NewsApiClientBuilder {
53 pub fn new() -> Self {
54 Self::default()
55 }
56
57 pub fn api_key(mut self, api_key: impl Into<String>) -> Self {
58 self.api_key = Some(api_key.into());
59 self
60 }
61
62 pub fn base_url(mut self, url: impl AsRef<str>) -> Result<Self, url::ParseError> {
63 self.base_url = Some(Url::parse(url.as_ref())?);
64 Ok(self)
65 }
66
67 pub fn retry(mut self, strategy: RetryStrategy, max_retries: usize) -> Self {
68 self.retry_strategy = strategy;
69 self.max_retries = max_retries;
70 self
71 }
72
73 pub fn from_env() -> Self {
74 match env::var(NEWS_API_KEY_ENV) {
75 Ok(api_key) => Self::new().api_key(api_key),
76 Err(_) => panic!("{} is not set", NEWS_API_KEY_ENV),
77 }
78 }
79
80 pub fn build(self) -> Result<NewsApiClient<reqwest::Client>, String> {
81 let api_key = match self.api_key {
82 Some(key) => key,
83 None => match env::var(NEWS_API_KEY_ENV) {
84 Ok(key) => key,
85 Err(_) => {
86 return Err(format!(
87 "API key must be provided either explicitly or via {} environment variable",
88 NEWS_API_KEY_ENV
89 ))
90 }
91 },
92 };
93
94 let base_url = self
95 .base_url
96 .unwrap_or_else(|| Url::parse(NEWS_API_URI).unwrap());
97
98 Ok(NewsApiClient {
99 client: reqwest::Client::new(),
100 api_key,
101 base_url,
102 retry_strategy: self.retry_strategy,
103 max_retries: self.max_retries,
104 })
105 }
106}
107
108#[cfg(feature = "blocking")]
109pub struct BlockingNewsApiClientBuilder {
110 api_key: Option<String>,
111 base_url: Option<Url>,
112 retry_strategy: RetryStrategy,
113 max_retries: usize,
114}
115
116#[cfg(feature = "blocking")]
117impl Default for BlockingNewsApiClientBuilder {
118 fn default() -> Self {
119 Self {
120 api_key: None,
121 base_url: Some(Url::parse(NEWS_API_URI).unwrap()),
122 retry_strategy: RetryStrategy::default(),
123 max_retries: 0,
124 }
125 }
126}
127
128#[cfg(feature = "blocking")]
129impl BlockingNewsApiClientBuilder {
130 pub fn new() -> Self {
131 Self::default()
132 }
133
134 pub fn api_key(mut self, api_key: impl Into<String>) -> Self {
135 self.api_key = Some(api_key.into());
136 self
137 }
138
139 pub fn base_url(mut self, url: impl AsRef<str>) -> Result<Self, url::ParseError> {
140 self.base_url = Some(Url::parse(url.as_ref())?);
141 Ok(self)
142 }
143
144 pub fn retry(mut self, strategy: RetryStrategy, max_retries: usize) -> Self {
145 self.retry_strategy = strategy;
146 self.max_retries = max_retries;
147 self
148 }
149
150 pub fn from_env() -> Self {
151 match env::var(NEWS_API_KEY_ENV) {
152 Ok(api_key) => Self::new().api_key(api_key),
153 Err(_) => panic!("{} is not set", NEWS_API_KEY_ENV),
154 }
155 }
156
157 pub fn build(self) -> Result<NewsApiClient<reqwest::blocking::Client>, String> {
158 let api_key = match self.api_key {
159 Some(key) => key,
160 None => match env::var(NEWS_API_KEY_ENV) {
161 Ok(key) => key,
162 Err(_) => {
163 return Err(format!(
164 "API key must be provided either explicitly or via {} environment variable",
165 NEWS_API_KEY_ENV
166 ))
167 }
168 },
169 };
170
171 let base_url = self
172 .base_url
173 .unwrap_or_else(|| Url::parse(NEWS_API_URI).unwrap());
174
175 Ok(NewsApiClient {
176 client: reqwest::blocking::Client::new(),
177 api_key,
178 base_url,
179 retry_strategy: self.retry_strategy,
180 max_retries: self.max_retries,
181 })
182 }
183}
184
185#[cfg(feature = "blocking")]
186mod blocking {
187 use super::*;
188 use reqwest::blocking::Client as BlockingClient;
189
190 impl NewsApiClient<BlockingClient> {
191 pub fn new_blocking(api_key: &str) -> Self {
192 NewsApiClient {
193 client: BlockingClient::new(),
194 api_key: api_key.to_string(),
195 base_url: Url::parse(NEWS_API_URI).unwrap(),
196 retry_strategy: RetryStrategy::default(),
197 max_retries: 0,
198 }
199 }
200
201 pub fn builder_blocking() -> super::BlockingNewsApiClientBuilder {
202 super::BlockingNewsApiClientBuilder::new()
203 }
204
205 fn parse_error_response(&self, response_text: String, status_code: u16) -> ApiClientError {
206 NewsApiClient::<BlockingClient>::parse_error_response_internal(
207 response_text,
208 status_code,
209 )
210 }
211
212 pub fn get_everything(
213 self,
214 request: &GetEverythingRequest,
215 ) -> Result<GetEverythingResponse, ApiClientError> {
216 retry_blocking(self.retry_strategy, self.max_retries, || {
217 log::debug!("Request: {:?}", request);
218
219 let mut url = self.base_url.clone();
220 NewsApiClient::<BlockingClient>::get_endpoint_with_query_params_for_everything(
221 &mut url, request,
222 );
223 log::debug!("Request URL: {}", url.as_str());
224
225 let headers = self.get_request_headers()?;
226 let response = self.client.get(url.as_str()).headers(headers).send()?;
227 let status = response.status();
228 log::debug!("Response status: {:?}", status);
229
230 if status.is_success() {
231 let response_text = response.text()?;
232 match serde_json::from_str::<GetEverythingResponse>(&response_text) {
233 Ok(everything_response) => Ok(everything_response),
234 Err(e) => Err(ApiClientError::InvalidRequest(format!("{}", e))),
235 }
236 } else {
237 let response_text = response.text()?;
238 Err(self.parse_error_response(response_text, status.as_u16()))
239 }
240 })
241 }
242
243 pub fn get_top_headlines(
244 self,
245 request: &GetTopHeadlinesRequest,
246 ) -> Result<TopHeadlinesResponse, ApiClientError> {
247 retry_blocking(self.retry_strategy, self.max_retries, || {
248 log::debug!("Request: {:?}", request);
249 NewsApiClient::<BlockingClient>::top_headlines_validate_request(request)?;
250
251 let mut url = self.base_url.clone();
252 NewsApiClient::<BlockingClient>::get_endpoint_with_query_params_for_top_headlines(
253 &mut url, request,
254 );
255 log::debug!("Request URL: {}", url.as_str());
256
257 let headers = self.get_request_headers()?;
258 let response = self.client.get(url.as_str()).headers(headers).send()?;
259 let status = response.status();
260 log::debug!("Response status: {:?}", status);
261
262 if status.is_success() {
263 let response_text = response.text()?;
264 match serde_json::from_str::<TopHeadlinesResponse>(&response_text) {
265 Ok(headline_response) => Ok(headline_response),
266 Err(e) => Err(ApiClientError::InvalidRequest(format!(
267 "Failed to parse response: {}",
268 e
269 ))),
270 }
271 } else {
272 let response_text = response.text()?;
273 Err(self.parse_error_response(response_text, status.as_u16()))
274 }
275 })
276 }
277
278 pub fn get_sources(
279 self,
280 request: &GetSourcesRequest,
281 ) -> Result<GetSourcesResponse, ApiClientError> {
282 retry_blocking(self.retry_strategy, self.max_retries, || {
283 log::debug!("Request: {:?}", request);
284
285 let mut url = self.base_url.clone();
286 NewsApiClient::<BlockingClient>::get_endpoint_with_query_params_for_sources(
287 &mut url, request,
288 );
289 log::debug!("Request URL: {}", url.as_str());
290
291 let headers = self.get_request_headers()?;
292 let response = self.client.get(url.as_str()).headers(headers).send()?;
293 let status = response.status();
294 log::debug!("Response status: {:?}", status);
295
296 if status.is_success() {
297 let response_text = response.text()?;
298 match serde_json::from_str::<GetSourcesResponse>(&response_text) {
299 Ok(sources_response) => Ok(sources_response),
300 Err(e) => Err(ApiClientError::InvalidRequest(format!("{}", e))),
301 }
302 } else {
303 let response_text = response.text()?;
304 Err(self.parse_error_response(response_text, status.as_u16()))
305 }
306 })
307 }
308
309 pub fn with_retry(mut self, strategy: RetryStrategy, max_retries: usize) -> Self {
310 self.retry_strategy = strategy;
311 self.max_retries = max_retries;
312 self
313 }
314 }
315}
316
317impl NewsApiClient<reqwest::Client> {
318 pub fn new(api_key: &str) -> Self {
319 NewsApiClient {
320 client: reqwest::Client::new(),
321 api_key: api_key.to_string(),
322 base_url: Url::parse(NEWS_API_URI).unwrap(),
323 retry_strategy: RetryStrategy::default(),
324 max_retries: 0,
325 }
326 }
327
328 pub fn builder() -> NewsApiClientBuilder {
329 NewsApiClientBuilder::new()
330 }
331
332 pub fn from_env() -> Self {
333 match env::var(NEWS_API_KEY_ENV) {
334 Ok(api_key) => NewsApiClient::new(&api_key),
335 Err(_) => panic!("{} is not set", NEWS_API_KEY_ENV),
336 }
337 }
338
339 fn parse_error_response(&self, response_text: String, status_code: u16) -> ApiClientError {
340 NewsApiClient::<reqwest::Client>::parse_error_response_internal(response_text, status_code)
341 }
342
343 pub async fn get_everything(
344 &self,
345 request: &GetEverythingRequest,
346 ) -> Result<GetEverythingResponse, ApiClientError> {
347 retry(self.retry_strategy, self.max_retries, || async {
348 log::debug!("Request: {:?}", request);
349
350 let mut url = self.base_url.clone();
351 Self::get_endpoint_with_query_params_for_everything(&mut url, request);
352 log::debug!("Request URL: {}", url.as_str());
353
354 let headers = self.get_request_headers()?;
355 let response = self
356 .client
357 .get(url.as_str())
358 .headers(headers)
359 .send()
360 .await?;
361 let status = response.status();
362 log::debug!("Response status: {:?}", status);
363
364 if status.is_success() {
365 let response_text = response.text().await?;
366 match serde_json::from_str::<GetEverythingResponse>(&response_text) {
367 Ok(everything_response) => Ok(everything_response),
368 Err(e) => Err(ApiClientError::InvalidRequest(format!("{}", e))),
369 }
370 } else {
371 let response_text = response.text().await?;
372 Err(self.parse_error_response(response_text, status.as_u16()))
373 }
374 })
375 .await
376 }
377
378 pub async fn get_top_headlines(
379 &self,
380 request: &GetTopHeadlinesRequest,
381 ) -> Result<TopHeadlinesResponse, ApiClientError> {
382 retry(self.retry_strategy, self.max_retries, || async {
383 log::debug!("Request: {:?}", request);
384 Self::top_headlines_validate_request(request)?;
385
386 let mut url = self.base_url.clone();
387 Self::get_endpoint_with_query_params_for_top_headlines(&mut url, request);
388 log::debug!("Request URL: {}", url.as_str());
389
390 let headers = self.get_request_headers()?;
391 let response = self
392 .client
393 .get(url.as_str())
394 .headers(headers)
395 .send()
396 .await?;
397 let status = response.status();
398 log::debug!("Response status: {:?}", status);
399
400 if status.is_success() {
401 let response_text = response.text().await?;
402 match serde_json::from_str::<TopHeadlinesResponse>(&response_text) {
403 Ok(headline_response) => Ok(headline_response),
404 Err(e) => Err(ApiClientError::InvalidRequest(format!(
405 "Failed to parse response: {}",
406 e
407 ))),
408 }
409 } else {
410 let response_text = response.text().await?;
411 Err(self.parse_error_response(response_text, status.as_u16()))
412 }
413 })
414 .await
415 }
416
417 pub async fn get_sources(
418 &self,
419 request: &GetSourcesRequest,
420 ) -> Result<GetSourcesResponse, ApiClientError> {
421 retry(self.retry_strategy, self.max_retries, || async {
422 log::debug!("Request: {:?}", request);
423
424 let mut url = self.base_url.clone();
425 Self::get_endpoint_with_query_params_for_sources(&mut url, request);
426 log::debug!("Request URL: {}", url.as_str());
427
428 let headers = self.get_request_headers()?;
429 let response = self
430 .client
431 .get(url.as_str())
432 .headers(headers)
433 .send()
434 .await?;
435 let status = response.status();
436 log::debug!("Response status: {:?}", status);
437
438 if status.is_success() {
439 let response_text = response.text().await?;
440 match serde_json::from_str::<GetSourcesResponse>(&response_text) {
441 Ok(sources_response) => Ok(sources_response),
442 Err(e) => Err(ApiClientError::InvalidRequest(format!("{}", e))),
443 }
444 } else {
445 let response_text = response.text().await?;
446 Err(self.parse_error_response(response_text, status.as_u16()))
447 }
448 })
449 .await
450 }
451
452 pub fn with_retry(mut self, strategy: RetryStrategy, max_retries: usize) -> Self {
453 self.retry_strategy = strategy;
454 self.max_retries = max_retries;
455 self
456 }
457}
458
459#[cfg(feature = "blocking")]
460impl NewsApiClient<reqwest::blocking::Client> {
461 pub fn from_env_blocking() -> Self {
462 match env::var(NEWS_API_KEY_ENV) {
463 Ok(api_key) => Self::new_blocking(&api_key),
464 Err(_) => panic!("{} is not set", NEWS_API_KEY_ENV),
465 }
466 }
467}
468
469impl<T> NewsApiClient<T> {
470 fn parse_error_response_internal(response_text: String, status_code: u16) -> ApiClientError {
471 match serde_json::from_str::<NewsApiErrorResponse>(&response_text) {
472 Ok(error_response) => {
473 let error_code = match error_response.code.as_deref() {
474 Some("apiKeyDisabled") => ApiClientErrorCode::ApiKeyDisabled,
475 Some("apiKeyExhausted") => ApiClientErrorCode::ApiKeyExhausted,
476 Some("apiKeyInvalid") => ApiClientErrorCode::ApiKeyInvalid,
477 Some("apiKeyMissing") => ApiClientErrorCode::ApiKeyMissing,
478 Some("parameterInvalid") => ApiClientErrorCode::ParameterInvalid,
479 Some("parametersMissing") => ApiClientErrorCode::ParametersMissing,
480 Some("rateLimited") => ApiClientErrorCode::RateLimited,
481 Some("sourcesTooMany") => ApiClientErrorCode::SourcesTooMany,
482 Some("sourceDoesNotExist") => ApiClientErrorCode::SourceDoesNotExist,
483 _ => {
484 if status_code == 429 {
486 ApiClientErrorCode::RateLimited
487 } else {
488 ApiClientErrorCode::UnexpectedError
489 }
490 }
491 };
492
493 ApiClientError::InvalidResponse(ApiClientErrorResponse {
494 status: error_response.status,
495 code: error_code,
496 message: error_response
497 .message
498 .unwrap_or_else(|| "Unknown error".to_string()),
499 })
500 }
501 Err(_) => {
502 let error_code = if status_code == 429 {
503 ApiClientErrorCode::RateLimited
504 } else {
505 ApiClientErrorCode::UnexpectedError
506 };
507
508 ApiClientError::InvalidResponse(ApiClientErrorResponse {
509 status: "error".to_string(),
510 code: error_code,
511 message: if response_text.contains("too many requests")
512 || response_text.contains("rate limit")
513 {
514 "You have made too many requests. Rate limit exceeded.".to_string()
515 } else {
516 "Failed to parse error response".to_string()
517 },
518 })
519 }
520 }
521 }
522
523 fn get_request_headers(&self) -> Result<HeaderMap, ApiClientError> {
524 let mut headers = HeaderMap::new();
525 headers.insert(
526 AUTHORIZATION,
527 HeaderValue::from_str(&format!("Bearer {}", self.api_key))?,
528 );
529 headers.insert(
530 USER_AGENT,
531 HeaderValue::from_static(NEWS_API_CLIENT_USER_AGENT),
532 );
533 Ok(headers)
534 }
535
536 fn top_headlines_validate_request(
537 request: &GetTopHeadlinesRequest,
538 ) -> Result<(), ApiClientError> {
539 log::debug!("Validating request");
540 if request.get_sources().is_some()
541 && (request.get_country().is_some() || request.get_category().is_some())
542 {
543 return Err(ApiClientError::InvalidRequest(
544 "Cannot specify sources with country or category".to_string(),
545 ));
546 }
547 Ok(())
548 }
549
550 fn get_endpoint_with_query_params_for_top_headlines(
551 url: &mut Url,
552 request: &GetTopHeadlinesRequest,
553 ) {
554 url.set_path(TOP_HEADLINES_ENDPOINT);
555 url.query_pairs_mut().clear();
556
557 for (key, value) in Self::get_top_headlines_query_params(request) {
558 url.query_pairs_mut().append_pair(&key, &value);
559 }
560
561 url.query_pairs_mut().finish();
562 }
563
564 fn get_top_headlines_query_params(request: &GetTopHeadlinesRequest) -> Vec<(String, String)> {
565 let mut query_params = Vec::new();
566
567 if let Some(country) = request.get_country() {
568 query_params.push(("country".to_string(), country.to_string()));
569 }
570
571 if let Some(category) = request.get_category() {
572 query_params.push(("category".to_string(), category.to_string()));
573 }
574
575 if let Some(sources) = request.get_sources() {
576 query_params.push(("sources".to_string(), sources.to_string()));
577 }
578
579 if !request.get_search_term().is_empty() {
580 query_params.push(("q".to_string(), request.get_search_term().to_string()));
581 }
582
583 if *request.get_page_size() > 1 {
584 query_params.push(("pageSize".to_string(), request.get_page_size().to_string()));
585 }
586
587 if *request.get_page() > 1 {
588 query_params.push(("page".to_string(), request.get_page().to_string()));
589 }
590
591 query_params
592 }
593
594 fn get_endpoint_with_query_params_for_everything(
595 url: &mut Url,
596 request: &GetEverythingRequest,
597 ) {
598 url.set_path(EVERYTHING_ENDPOINT);
599 url.query_pairs_mut().clear();
600
601 let query_params = Self::get_everything_query_params(request);
602 for (key, value) in query_params {
603 url.query_pairs_mut().append_pair(&key, &value);
604 }
605
606 url.query_pairs_mut().finish();
607 }
608
609 fn get_everything_query_params(request: &GetEverythingRequest) -> Vec<(String, String)> {
610 let mut query_params = Vec::new();
611
612 query_params.push(("q".to_string(), request.get_search_term().to_string()));
613
614 if let Some(language) = request.get_language() {
615 query_params.push(("language".to_string(), language.to_string().to_lowercase()));
616 }
617
618 if let Some(start_date) = request.get_start_date() {
619 query_params.push(("from".to_string(), start_date.to_rfc3339()));
620 }
621
622 if let Some(end_date) = request.get_end_date() {
623 query_params.push(("to".to_string(), end_date.to_rfc3339()));
624 }
625
626 if *request.get_page_size() > 0 {
627 query_params.push(("pageSize".to_string(), request.get_page_size().to_string()));
628 }
629
630 if *request.get_page() > 1 {
631 query_params.push(("page".to_string(), request.get_page().to_string()));
632 }
633
634 query_params
635 }
636
637 fn get_endpoint_with_query_params_for_sources(url: &mut Url, request: &GetSourcesRequest) {
638 url.set_path(SOURCES_ENDPOINT);
639 url.query_pairs_mut().clear();
640
641 let query_params = Self::get_sources_query_params(request);
642 for (key, value) in query_params {
643 url.query_pairs_mut().append_pair(&key, &value);
644 }
645
646 url.query_pairs_mut().finish();
647 }
648
649 fn get_sources_query_params(request: &GetSourcesRequest) -> Vec<(String, String)> {
650 let mut query_params = Vec::new();
651
652 if let Some(category) = request.get_category() {
653 query_params.push(("category".to_string(), category.to_string()));
654 }
655
656 if let Some(language) = request.get_language() {
657 query_params.push(("language".to_string(), language.to_string().to_lowercase()));
658 }
659
660 if let Some(country) = request.get_country() {
661 query_params.push(("country".to_string(), country.to_string()));
662 }
663
664 query_params
665 }
666}
667
668#[cfg(test)]
669mod tests {
670 use super::*;
671 use crate::model::{Country, Language, NewsCategory};
672 use chrono::{DateTime, Utc};
673 use mockito;
674 use std::collections::HashMap;
675 use std::str::FromStr;
676 use std::time::Duration;
677
678 fn create_test_client() -> NewsApiClient<reqwest::Client> {
679 let api_key = "test-api-key";
680 let mut client = NewsApiClient::new(api_key);
681 let server = mockito::Server::new();
682 let mock_url = server.url();
683 client.base_url = Url::parse(&format!("http://{}", mock_url)).unwrap();
684 client
685 }
686
687 #[test]
688 fn test_parse_error_response() {
689 let error_json =
690 r#"{"status":"error","code":"apiKeyInvalid","message":"Your API key is invalid"}"#;
691 let error = NewsApiClient::<reqwest::Client>::parse_error_response_internal(
692 error_json.to_string(),
693 400,
694 );
695
696 match error {
697 ApiClientError::InvalidResponse(response) => {
698 assert_eq!(response.status, "error");
699 assert_eq!(response.code, ApiClientErrorCode::ApiKeyInvalid);
700 assert_eq!(response.message, "Your API key is invalid");
701 }
702 _ => panic!("Expected InvalidResponse error"),
703 }
704
705 let error_json =
706 r#"{"status":"error","code":"parameterInvalid","message":"Invalid parameter"}"#;
707 let error = NewsApiClient::<reqwest::Client>::parse_error_response_internal(
708 error_json.to_string(),
709 400,
710 );
711
712 match error {
713 ApiClientError::InvalidResponse(response) => {
714 assert_eq!(response.code, ApiClientErrorCode::ParameterInvalid);
715 }
716 _ => panic!("Expected InvalidResponse error"),
717 }
718
719 let error_json = r#"invalid json"#;
720 let error = NewsApiClient::<reqwest::Client>::parse_error_response_internal(
721 error_json.to_string(),
722 400,
723 );
724
725 match error {
726 ApiClientError::InvalidResponse(response) => {
727 assert_eq!(response.code, ApiClientErrorCode::UnexpectedError);
728 }
729 _ => panic!("Expected InvalidResponse error"),
730 }
731 }
732
733 #[test]
734 fn test_get_request_headers() {
735 let client = create_test_client();
736 let headers = client.get_request_headers().unwrap();
737
738 assert_eq!(
739 headers.get(AUTHORIZATION).unwrap().to_str().unwrap(),
740 "Bearer test-api-key"
741 );
742 assert_eq!(
743 headers.get(USER_AGENT).unwrap().to_str().unwrap(),
744 NEWS_API_CLIENT_USER_AGENT
745 );
746 }
747
748 #[test]
749 fn test_top_headlines_validate_request_country_and_category() {
750 let request = GetTopHeadlinesRequest::builder()
751 .country(Country::US)
752 .category(NewsCategory::Business)
753 .search_term(String::new())
754 .page_size(20)
755 .page(1)
756 .build()
757 .unwrap();
758 assert!(NewsApiClient::<reqwest::Client>::top_headlines_validate_request(&request).is_ok());
759 }
760
761 #[test]
762 fn test_top_headlines_validate_request_sources_only() {
763 let request = GetTopHeadlinesRequest::builder()
764 .sources("bbc-news,cnn".to_string())
765 .search_term(String::new())
766 .page_size(20)
767 .page(1)
768 .build()
769 .unwrap();
770 assert!(NewsApiClient::<reqwest::Client>::top_headlines_validate_request(&request).is_ok());
771 }
772
773 #[test]
774 fn test_top_headlines_validate_request_sources_with_country() {
775 let request = GetTopHeadlinesRequest::builder()
776 .sources("bbc-news".to_string())
777 .country(Country::US)
778 .search_term(String::new())
779 .page_size(20)
780 .page(1)
781 .build();
782
783 assert!(request.is_err());
784 }
785
786 #[test]
787 fn test_top_headlines_validate_request_sources_with_category() {
788 let request = GetTopHeadlinesRequest::builder()
789 .sources("bbc-news".to_string())
790 .category(NewsCategory::Business)
791 .search_term(String::new())
792 .page_size(20)
793 .page(1)
794 .build();
795
796 assert!(request.is_err());
797 }
798
799 #[test]
800 fn test_get_top_headlines_query_params() {
801 let request = GetTopHeadlinesRequest::builder()
802 .country(Country::US)
803 .category(NewsCategory::Technology)
804 .search_term("ai".to_string())
805 .page_size(15)
806 .page(2)
807 .build()
808 .unwrap();
809
810 let params = NewsApiClient::<reqwest::Client>::get_top_headlines_query_params(&request);
811 let params_map: HashMap<_, _> = params.into_iter().collect();
812
813 assert_eq!(params_map.get("country").unwrap(), "us");
814 assert_eq!(params_map.get("category").unwrap(), "technology");
815 assert_eq!(params_map.get("q").unwrap(), "ai");
816 assert_eq!(params_map.get("page").unwrap(), "2");
817 assert_eq!(params_map.get("pageSize").unwrap(), "15");
818 }
819
820 #[test]
821 fn test_get_everything_query_params() {
822 let start_date = DateTime::<Utc>::from_str("2023-01-01T00:00:00Z").unwrap();
823 let end_date = DateTime::<Utc>::from_str("2023-01-31T23:59:59Z").unwrap();
824
825 let request = GetEverythingRequest::builder()
826 .search_term(format!("bitcoin"))
827 .language(Language::AR)
828 .start_date(start_date)
829 .end_date(end_date)
830 .page(3)
831 .page_size(20)
832 .build();
833
834 let params = NewsApiClient::<reqwest::Client>::get_everything_query_params(&request);
835 let params_map: HashMap<_, _> = params.into_iter().collect();
836
837 assert_eq!(params_map.get("q").unwrap(), "bitcoin");
838 assert_eq!(params_map.get("language").unwrap(), "ar"); assert_eq!(params_map.get("from").unwrap(), "2023-01-01T00:00:00+00:00");
840 assert_eq!(params_map.get("to").unwrap(), "2023-01-31T23:59:59+00:00");
841 assert_eq!(params_map.get("page").unwrap(), "3");
842 assert_eq!(params_map.get("pageSize").unwrap(), "20");
843 }
844
845 #[tokio::test]
846 async fn test_get_everything_async() {
847 let mock_response = r#"{
848 "status": "ok",
849 "totalResults": 2,
850 "articles": [
851 {
852 "source": {"id": "test-source", "name": "Test Source"},
853 "author": "Test Author",
854 "title": "Test Title",
855 "description": "Test Description",
856 "url": "https://example.com/article1",
857 "urlToImage": "https://example.com/image1.jpg",
858 "publishedAt": "2023-05-01T12:00:00Z",
859 "content": "Test content"
860 },
861 {
862 "source": {"id": "test-source-2", "name": "Test Source 2"},
863 "author": "Test Author 2",
864 "title": "Test Title 2",
865 "description": "Test Description 2",
866 "url": "https://example.com/article2",
867 "urlToImage": "https://example.com/image2.jpg",
868 "publishedAt": "2023-05-02T12:00:00Z",
869 "content": "Test content 2"
870 }
871 ]
872 }"#;
873
874 let mut server = mockito::Server::new_async().await;
875
876 let _m = server
877 .mock("GET", "/v2/everything")
878 .match_query(mockito::Matcher::Any)
879 .with_status(200)
880 .with_header("content-type", "application/json")
881 .with_body(mock_response)
882 .create_async()
883 .await;
884
885 let mut client = NewsApiClient::new("test-api-key");
886 client.base_url = Url::parse(&format!("{}", server.url())).unwrap();
887
888 let request = GetEverythingRequest::builder()
889 .search_term(format!("test"))
890 .build();
891
892 let response = client.get_everything(&request).await.unwrap();
893
894 assert_eq!(response.get_status(), "ok");
895 assert_eq!(*response.get_total_results(), 2);
896 assert_eq!(response.get_articles().len(), 2);
897 assert_eq!(response.get_articles()[0].get_title(), "Test Title");
898 assert_eq!(response.get_articles()[1].get_title(), "Test Title 2");
899 }
900
901 #[tokio::test]
902 async fn test_get_top_headlines_async() {
903 let mock_response = r#"{
904 "status": "ok",
905 "totalResults": 1,
906 "articles": [
907 {
908 "source": {"id": "test-source", "name": "Test Source"},
909 "author": "Test Author",
910 "title": "Breaking News",
911 "description": "Test Description",
912 "url": "https://example.com/article1",
913 "urlToImage": "https://example.com/image1.jpg",
914 "publishedAt": "2023-05-01T12:00:00Z",
915 "content": "Test content"
916 }
917 ]
918 }"#;
919
920 let mut server = mockito::Server::new_async().await;
921 let _m = server
922 .mock("GET", "/v2/top-headlines")
923 .match_query(mockito::Matcher::Any)
924 .with_status(200)
925 .with_header("content-type", "application/json")
926 .with_body(mock_response)
927 .create_async()
928 .await;
929 let mut client = NewsApiClient::new("test-api-key");
930 client.base_url = Url::parse(&format!("{}", server.url())).unwrap();
931
932 let request = GetTopHeadlinesRequest::builder()
933 .country(Country::US)
934 .search_term(String::new())
935 .page_size(20)
936 .page(1)
937 .build()
938 .unwrap();
939
940 let response = client.get_top_headlines(&request).await.unwrap();
941
942 assert_eq!(response.get_status(), "ok");
943 assert_eq!(*response.get_total_results(), 1);
944 assert_eq!(response.get_articles().len(), 1);
945 assert_eq!(response.get_articles()[0].get_title(), "Breaking News");
946 }
947
948 #[tokio::test]
949 async fn test_error_responses_async() {
950 let error_response = r#"{
951 "status": "error",
952 "code": "apiKeyInvalid",
953 "message": "Your API key is invalid or incorrect"
954 }"#;
955
956 let mut server = mockito::Server::new_async().await;
957 let _m = server
958 .mock("GET", "/v2/everything")
959 .match_query(mockito::Matcher::Any)
960 .with_status(401)
961 .with_body(error_response)
962 .create_async()
963 .await;
964
965 let mut client = NewsApiClient::new("test-api-key");
966 client.base_url = Url::parse(&format!("{}", server.url())).unwrap();
967
968 let request = GetEverythingRequest::builder()
969 .search_term(format!("test"))
970 .build();
971
972 let result = client.get_everything(&request).await;
973 assert!(result.is_err());
974
975 match result.unwrap_err() {
976 ApiClientError::InvalidResponse(response) => {
977 assert_eq!(response.code, ApiClientErrorCode::ApiKeyInvalid);
978 }
979 _ => panic!("Expected InvalidResponse error"),
980 }
981 }
982
983 #[cfg(feature = "blocking")]
984 mod blocking_tests {
985 use super::*;
986 use mockito::Mock;
987
988 #[test]
989 fn test_get_everything_blocking() {
990 let mock_response = r#"{
991 "status": "ok",
992 "totalResults": 1,
993 "articles": [
994 {
995 "source": {"id": "test-source", "name": "Test Source"},
996 "author": "Test Author",
997 "title": "Test Title Blocking",
998 "description": "Test Description",
999 "url": "https://example.com/article1",
1000 "urlToImage": "https://example.com/image1.jpg",
1001 "publishedAt": "2023-05-01T12:00:00Z",
1002 "content": "Test content"
1003 }
1004 ]
1005 }"#;
1006 let mut server = mockito::Server::new();
1007 let _m: Mock = server
1008 .mock("GET", "/v2/everything")
1009 .match_query(mockito::Matcher::Any)
1010 .with_status(200)
1011 .with_header("content-type", "application/json")
1012 .with_body(mock_response)
1013 .create();
1014
1015 let mut client = NewsApiClient::new_blocking("test-api-key");
1016 client.base_url = Url::parse(&format!("{}", server.url())).unwrap();
1017 let request = GetEverythingRequest::builder()
1018 .search_term("test".to_string())
1019 .build();
1020 let response = client.get_everything(&request).unwrap();
1021
1022 assert_eq!(response.get_status(), "ok");
1023 assert_eq!(*response.get_total_results(), 1);
1024 assert_eq!(
1025 response.get_articles()[0].get_title(),
1026 "Test Title Blocking"
1027 );
1028 }
1029 }
1030
1031 #[test]
1032 fn test_builder_pattern() {
1033 let client = NewsApiClient::<reqwest::Client>::builder()
1034 .api_key("test-api-key")
1035 .retry(RetryStrategy::Exponential(Duration::from_millis(100)), 3)
1036 .build()
1037 .unwrap();
1038
1039 assert_eq!(client.api_key, "test-api-key");
1040 assert_eq!(client.max_retries, 3);
1041 }
1042
1043 #[test]
1044 fn test_builder_failure() {
1045 let old_value = std::env::var(NEWS_API_KEY_ENV).ok();
1046 std::env::remove_var(NEWS_API_KEY_ENV);
1047 let result = NewsApiClient::builder().build();
1048
1049 if let Some(val) = old_value {
1050 std::env::set_var(NEWS_API_KEY_ENV, val);
1051 }
1052 assert!(result.is_err());
1053 assert_eq!(
1054 result.unwrap_err(),
1055 format!(
1056 "API key must be provided either explicitly or via {} environment variable",
1057 NEWS_API_KEY_ENV
1058 )
1059 );
1060 }
1061
1062 #[test]
1063 fn test_builder_from_env() {
1064 std::env::set_var(NEWS_API_KEY_ENV, "env-api-key");
1065
1066 let client = NewsApiClientBuilder::from_env().build().unwrap();
1067
1068 assert_eq!(client.api_key, "env-api-key");
1069 std::env::remove_var(NEWS_API_KEY_ENV);
1070 }
1071
1072 #[cfg(feature = "blocking")]
1073 #[test]
1074 fn test_blocking_builder_pattern() {
1075 let client = BlockingNewsApiClientBuilder::new()
1076 .api_key("test-api-key")
1077 .retry(RetryStrategy::Constant(Duration::from_secs(1)), 2)
1078 .build()
1079 .unwrap();
1080
1081 assert_eq!(client.api_key, "test-api-key");
1082 assert_eq!(client.max_retries, 2);
1083 }
1084}