1use std::collections::HashMap;
2
3use lazy_static::lazy_static;
4use reqwest::{header, IntoUrl};
5
6use crate::{error::Error, spec::EmbedResponse};
7
8lazy_static! {
9 static ref DEFAULT_CLIENT: reqwest::Client = reqwest::Client::new();
10}
11
12#[derive(Default)]
16pub struct ConsumerRequest<'a> {
17 pub url: &'a str,
18 pub max_width: Option<i32>,
19 pub max_height: Option<i32>,
20 pub params: Option<HashMap<&'a str, &'a str>>,
21}
22
23#[derive(Clone)]
25pub struct Client(reqwest::Client);
26
27impl Client {
28 pub fn new(client: reqwest::Client) -> Self {
29 Self(client)
30 }
31
32 pub async fn fetch(
34 &self,
35 endpoint: impl IntoUrl,
36 request: ConsumerRequest<'_>,
37 ) -> Result<EmbedResponse, Error> {
38 let mut url = endpoint.into_url()?;
39
40 {
41 let mut query = url.query_pairs_mut();
42
43 query.append_pair("url", request.url);
44
45 if let Some(max_width) = request.max_width {
46 query.append_pair("maxwidth", &max_width.to_string());
47 }
48
49 if let Some(max_height) = request.max_height {
50 query.append_pair("maxheight", &max_height.to_string());
51 }
52
53 if let Some(params) = request.params {
54 for (key, value) in params {
55 query.append_pair(key, value);
56 }
57 }
58
59 query.finish();
60 }
61
62 Ok(self
63 .0
64 .get(url)
65 .header(header::USER_AGENT, "crates/oembed-rs")
66 .send()
67 .await?
68 .error_for_status()?
69 .json()
70 .await
71 .map(|mut response: EmbedResponse| {
72 response.extra.remove("type");
74 response
75 })?)
76 }
77}
78
79pub async fn fetch(
81 endpoint: impl IntoUrl,
82 request: ConsumerRequest<'_>,
83) -> Result<EmbedResponse, Error> {
84 Client::new(DEFAULT_CLIENT.clone())
85 .fetch(endpoint, request)
86 .await
87}
88
89#[cfg(test)]
90mod tests {
91 use mockito::Server;
92
93 use super::*;
94
95 #[tokio::test]
96 async fn test_fetch_success() {
97 let mut server = Server::new_async().await;
98
99 let mock = server
100 .mock("GET", "/?url=https%3A%2F%2Fexample.com")
101 .with_status(200)
102 .with_body(r#"{"version": "1.0", "type": "link"}"#)
103 .with_header("content-type", "application/json")
104 .create_async()
105 .await;
106
107 let result = fetch(
108 server.url(),
109 ConsumerRequest {
110 url: "https://example.com",
111 ..ConsumerRequest::default()
112 },
113 )
114 .await;
115 assert_eq!(
116 result.ok(),
117 Some(EmbedResponse {
118 oembed_type: crate::EmbedType::Link,
119 version: "1.0".to_string(),
120 title: None,
121 author_name: None,
122 author_url: None,
123 provider_name: None,
124 provider_url: None,
125 cache_age: None,
126 thumbnail_url: None,
127 thumbnail_width: None,
128 thumbnail_height: None,
129 extra: HashMap::default(),
130 })
131 );
132
133 mock.assert_async().await;
134 }
135
136 #[tokio::test]
137 async fn test_fetch_error() {
138 let mut server = Server::new_async().await;
139
140 let mock = server
141 .mock("GET", "/?url=https%3A%2F%2Fexample.com")
142 .with_status(404)
143 .create_async()
144 .await;
145
146 let result = fetch(
147 server.url(),
148 ConsumerRequest {
149 url: "https://example.com",
150 ..ConsumerRequest::default()
151 },
152 )
153 .await;
154
155 if let Err(Error::Reqwest(err)) = result {
156 assert_eq!(err.status(), Some(reqwest::StatusCode::NOT_FOUND))
157 } else {
158 panic!("unexpected result: {:?}", result);
159 }
160
161 mock.assert_async().await;
162 }
163}