meilisearch_sdk/
reqwest.rs

1use std::{
2    pin::Pin,
3    task::{Context, Poll},
4};
5
6use async_trait::async_trait;
7use bytes::{Bytes, BytesMut};
8use futures_core::Stream;
9use futures_io::AsyncRead;
10use pin_project_lite::pin_project;
11use serde::{de::DeserializeOwned, Serialize};
12
13use crate::{
14    errors::Error,
15    request::{parse_response, HttpClient, Method},
16};
17
18#[derive(Debug, Clone, Default)]
19pub struct ReqwestClient {
20    client: reqwest::Client,
21}
22
23impl ReqwestClient {
24    pub fn new(api_key: Option<&str>) -> Result<Self, Error> {
25        use reqwest::{header, ClientBuilder};
26
27        let builder = ClientBuilder::new();
28        let mut headers = header::HeaderMap::new();
29        #[cfg(not(target_arch = "wasm32"))]
30        headers.insert(
31            header::USER_AGENT,
32            header::HeaderValue::from_str(&qualified_version()).unwrap(),
33        );
34        #[cfg(target_arch = "wasm32")]
35        headers.insert(
36            header::HeaderName::from_static("x-meilisearch-client"),
37            header::HeaderValue::from_str(&qualified_version()).unwrap(),
38        );
39
40        if let Some(api_key) = api_key {
41            headers.insert(
42                header::AUTHORIZATION,
43                header::HeaderValue::from_str(&format!("Bearer {api_key}")).unwrap(),
44            );
45        }
46
47        let builder = builder.default_headers(headers);
48        let client = builder.build()?;
49
50        Ok(ReqwestClient { client })
51    }
52}
53
54#[cfg_attr(feature = "futures-unsend", async_trait(?Send))]
55#[cfg_attr(not(feature = "futures-unsend"), async_trait)]
56impl HttpClient for ReqwestClient {
57    async fn stream_request<
58        Query: Serialize + Send + Sync,
59        Body: futures_io::AsyncRead + Send + Sync + 'static,
60        Output: DeserializeOwned + 'static,
61    >(
62        &self,
63        url: &str,
64        method: Method<Query, Body>,
65        content_type: &str,
66        expected_status_code: u16,
67    ) -> Result<Output, Error> {
68        use reqwest::header;
69
70        let query = method.query();
71        let query = yaup::to_string(query)?;
72
73        let url = if query.is_empty() {
74            url.to_string()
75        } else {
76            format!("{url}{query}")
77        };
78
79        let mut request = self.client.request(verb(&method), &url);
80
81        if let Some(body) = method.into_body() {
82            // TODO: Currently reqwest doesn't support streaming data in wasm so we need to collect everything in RAM
83            #[cfg(not(target_arch = "wasm32"))]
84            {
85                let stream = ReaderStream::new(body);
86                let body = reqwest::Body::wrap_stream(stream);
87
88                request = request
89                    .header(header::CONTENT_TYPE, content_type)
90                    .body(body);
91            }
92            #[cfg(target_arch = "wasm32")]
93            {
94                use futures_util::AsyncReadExt;
95
96                let mut buf = Vec::new();
97                let mut body = std::pin::pin!(body);
98                body.read_to_end(&mut buf)
99                    .await
100                    .map_err(|err| Error::Other(Box::new(err)))?;
101                request = request.header(header::CONTENT_TYPE, content_type).body(buf);
102            }
103        }
104
105        let response = self.client.execute(request.build()?).await?;
106        let status = response.status().as_u16();
107        let mut body = response.text().await?;
108
109        if body.is_empty() {
110            body = "null".to_string();
111        }
112
113        parse_response(status, expected_status_code, &body, url.to_string())
114    }
115
116    fn is_tokio(&self) -> bool {
117        true
118    }
119}
120
121fn verb<Q, B>(method: &Method<Q, B>) -> reqwest::Method {
122    match method {
123        Method::Get { .. } => reqwest::Method::GET,
124        Method::Delete { .. } => reqwest::Method::DELETE,
125        Method::Post { .. } => reqwest::Method::POST,
126        Method::Put { .. } => reqwest::Method::PUT,
127        Method::Patch { .. } => reqwest::Method::PATCH,
128    }
129}
130
131pub fn qualified_version() -> String {
132    const VERSION: Option<&str> = option_env!("CARGO_PKG_VERSION");
133
134    format!("Meilisearch Rust (v{})", VERSION.unwrap_or("unknown"))
135}
136
137pin_project! {
138    #[derive(Debug)]
139    pub struct ReaderStream<R: AsyncRead> {
140        #[pin]
141        reader: R,
142        buf: BytesMut,
143        capacity: usize,
144    }
145}
146
147impl<R: AsyncRead> ReaderStream<R> {
148    pub fn new(reader: R) -> Self {
149        Self {
150            reader,
151            buf: BytesMut::new(),
152            // 8KiB of capacity, the default capacity used by `BufReader` in the std
153            capacity: 8 * 1024 * 1024,
154        }
155    }
156}
157
158impl<R: AsyncRead> Stream for ReaderStream<R> {
159    type Item = std::io::Result<Bytes>;
160
161    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
162        let this = self.as_mut().project();
163
164        if this.buf.capacity() == 0 {
165            this.buf.resize(*this.capacity, 0);
166        }
167
168        match AsyncRead::poll_read(this.reader, cx, this.buf) {
169            Poll::Pending => Poll::Pending,
170            Poll::Ready(Err(err)) => Poll::Ready(Some(Err(err))),
171            Poll::Ready(Ok(0)) => Poll::Ready(None),
172            Poll::Ready(Ok(i)) => {
173                let chunk = this.buf.split_to(i);
174                Poll::Ready(Some(Ok(chunk.freeze())))
175            }
176        }
177    }
178}