oembed_rs/
request.rs

1use std::collections::HashMap;
2
3use lazy_static::lazy_static;
4use reqwest::{header, IntoUrl};
5
6use crate::{error::Error, spec::EmbedResponse};
7
8lazy_static! {
9    static ref DEFAULT_CLIENT: reqwest::Client = reqwest::Client::new();
10}
11
12/// Request for fetching oEmbed data
13///
14/// See the [oembed specification](https://oembed.com/#section2.2) for more information
15#[derive(Default)]
16pub struct ConsumerRequest<'a> {
17    pub url: &'a str,
18    pub max_width: Option<i32>,
19    pub max_height: Option<i32>,
20    pub params: Option<HashMap<&'a str, &'a str>>,
21}
22
23/// oEmbed client
24#[derive(Clone)]
25pub struct Client(reqwest::Client);
26
27impl Client {
28    pub fn new(client: reqwest::Client) -> Self {
29        Self(client)
30    }
31
32    /// Fetch oEmbed data from the endpoint of a provider
33    pub async fn fetch(
34        &self,
35        endpoint: impl IntoUrl,
36        request: ConsumerRequest<'_>,
37    ) -> Result<EmbedResponse, Error> {
38        let mut url = endpoint.into_url()?;
39
40        {
41            let mut query = url.query_pairs_mut();
42
43            query.append_pair("url", request.url);
44
45            if let Some(max_width) = request.max_width {
46                query.append_pair("maxwidth", &max_width.to_string());
47            }
48
49            if let Some(max_height) = request.max_height {
50                query.append_pair("maxheight", &max_height.to_string());
51            }
52
53            if let Some(params) = request.params {
54                for (key, value) in params {
55                    query.append_pair(key, value);
56                }
57            }
58
59            query.finish();
60        }
61
62        Ok(self
63            .0
64            .get(url)
65            .header(header::USER_AGENT, "crates/oembed-rs")
66            .send()
67            .await?
68            .error_for_status()?
69            .json()
70            .await
71            .map(|mut response: EmbedResponse| {
72                // Remove the `type` field from the extra fields as we use #[serde(flatten)] twice
73                response.extra.remove("type");
74                response
75            })?)
76    }
77}
78
79/// Fetch oEmbed data from the endpoint of a provider
80pub async fn fetch(
81    endpoint: impl IntoUrl,
82    request: ConsumerRequest<'_>,
83) -> Result<EmbedResponse, Error> {
84    Client::new(DEFAULT_CLIENT.clone())
85        .fetch(endpoint, request)
86        .await
87}
88
89#[cfg(test)]
90mod tests {
91    use mockito::Server;
92
93    use super::*;
94
95    #[tokio::test]
96    async fn test_fetch_success() {
97        let mut server = Server::new_async().await;
98
99        let mock = server
100            .mock("GET", "/?url=https%3A%2F%2Fexample.com")
101            .with_status(200)
102            .with_body(r#"{"version": "1.0", "type": "link"}"#)
103            .with_header("content-type", "application/json")
104            .create_async()
105            .await;
106
107        let result = fetch(
108            server.url(),
109            ConsumerRequest {
110                url: "https://example.com",
111                ..ConsumerRequest::default()
112            },
113        )
114        .await;
115        assert_eq!(
116            result.ok(),
117            Some(EmbedResponse {
118                oembed_type: crate::EmbedType::Link,
119                version: "1.0".to_string(),
120                title: None,
121                author_name: None,
122                author_url: None,
123                provider_name: None,
124                provider_url: None,
125                cache_age: None,
126                thumbnail_url: None,
127                thumbnail_width: None,
128                thumbnail_height: None,
129                extra: HashMap::default(),
130            })
131        );
132
133        mock.assert_async().await;
134    }
135
136    #[tokio::test]
137    async fn test_fetch_error() {
138        let mut server = Server::new_async().await;
139
140        let mock = server
141            .mock("GET", "/?url=https%3A%2F%2Fexample.com")
142            .with_status(404)
143            .create_async()
144            .await;
145
146        let result = fetch(
147            server.url(),
148            ConsumerRequest {
149                url: "https://example.com",
150                ..ConsumerRequest::default()
151            },
152        )
153        .await;
154
155        if let Err(Error::Reqwest(err)) = result {
156            assert_eq!(err.status(), Some(reqwest::StatusCode::NOT_FOUND))
157        } else {
158            panic!("unexpected result: {:?}", result);
159        }
160
161        mock.assert_async().await;
162    }
163}