graph_http/
request_handler.rs

1use crate::blocking::BlockingRequestHandler;
2use crate::internal::{
3    BodyRead, Client, GraphClientConfiguration, HttpResponseBuilderExt, ODataNextLink, ODataQuery,
4    RequestComponents,
5};
6use async_stream::try_stream;
7use futures::Stream;
8use graph_error::{AuthExecutionResult, ErrorMessage, GraphFailure, GraphResult};
9use reqwest::header::{HeaderMap, HeaderName, HeaderValue, CONTENT_TYPE};
10use reqwest::{Request, Response};
11use serde::de::DeserializeOwned;
12use std::collections::VecDeque;
13use std::fmt::Debug;
14use std::time::Duration;
15use tower::util::BoxCloneService;
16use tower::{Service, ServiceExt};
17use url::Url;
18
19pub struct RequestHandler {
20    pub(crate) inner: Client,
21    pub(crate) request_components: RequestComponents,
22    pub(crate) error: Option<GraphFailure>,
23    pub(crate) body: Option<BodyRead>,
24    pub(crate) client_builder: GraphClientConfiguration,
25    pub(crate) service:
26        BoxCloneService<Request, Response, Box<dyn std::error::Error + Send + Sync>>,
27}
28
29impl RequestHandler {
30    pub fn new(
31        inner: Client,
32        mut request_components: RequestComponents,
33        err: Option<GraphFailure>,
34        body: Option<BodyRead>,
35    ) -> RequestHandler {
36        let service = inner.builder.build_tower_service(&inner.inner);
37        let client_builder = inner.builder.clone();
38        let mut original_headers = inner.headers.clone();
39        original_headers.extend(request_components.headers.clone());
40        request_components.headers = original_headers;
41
42        let mut error = None;
43        if let Some(err) = err {
44            let message = err.to_string();
45            error = Some(GraphFailure::PreFlightError {
46                url: Some(request_components.url.clone()),
47                headers: Some(request_components.headers.clone()),
48                error: Some(Box::new(err)),
49                message,
50            });
51        }
52
53        RequestHandler {
54            inner,
55            request_components,
56            error,
57            body,
58            client_builder,
59            service,
60        }
61    }
62
63    pub fn into_blocking(self) -> BlockingRequestHandler {
64        BlockingRequestHandler::new(
65            self.client_builder.build_blocking(),
66            self.request_components,
67            self.error,
68            self.body,
69        )
70    }
71
72    /// Returns true if any errors occurred prior to sending the request.
73    ///
74    /// # Example
75    /// ```rust,ignore
76    /// let client = Graph::new("ACCESS_TOKEN");
77    /// let request_handler = client.groups().list_group();
78    /// println!("{:#?}", request_handler.is_err());
79    /// ```
80    pub fn is_err(&self) -> bool {
81        self.error.is_some()
82    }
83
84    /// Returns any error wrapped in an Option that occurred prior to sending a request
85    ///
86    /// # Example
87    /// ```rust,ignore
88    /// let client = Graph::new("ACCESS_TOKEN");
89    /// let request_handler = client.groups().list_group();
90    /// println!("{:#?}", request_handler.err());
91    /// ```
92    pub fn err(&self) -> Option<&GraphFailure> {
93        self.error.as_ref()
94    }
95
96    #[inline]
97    pub fn url(&self) -> Url {
98        self.request_components.url.clone()
99    }
100
101    #[inline]
102    pub fn query<T: serde::Serialize + ?Sized>(mut self, query: &T) -> Self {
103        if let Err(err) = self.request_components.query(query) {
104            if self.error.is_none() {
105                self.error = Some(err);
106            }
107        }
108
109        if let Some("") = self.request_components.url.query() {
110            self.request_components.url.set_query(None);
111        }
112        self
113    }
114
115    #[inline]
116    pub fn append_query_pair<KV: AsRef<str>>(mut self, key: KV, value: KV) -> Self {
117        self.request_components
118            .url
119            .query_pairs_mut()
120            .append_pair(key.as_ref(), value.as_ref());
121        self
122    }
123
124    #[inline]
125    pub fn extend_path<I: AsRef<str>>(mut self, path: &[I]) -> Self {
126        if let Ok(mut p) = self.request_components.url.path_segments_mut() {
127            p.extend(path);
128        }
129        self
130    }
131
132    /// Insert a header for the request.
133    #[inline]
134    pub fn header<K: Into<HeaderName>, V: Into<HeaderValue>>(
135        mut self,
136        header_name: K,
137        header_value: V,
138    ) -> Self {
139        self.request_components
140            .headers
141            .insert(header_name.into(), header_value.into());
142        self
143    }
144
145    /// Set the headers for the request using reqwest::HeaderMap
146    #[inline]
147    pub fn headers(mut self, header_map: HeaderMap) -> Self {
148        self.request_components.headers.extend(header_map);
149        self
150    }
151
152    /// Get a mutable reference to the headers.
153    #[inline]
154    pub fn headers_mut(&mut self) -> &mut HeaderMap {
155        self.request_components.as_mut()
156    }
157
158    pub fn paging(self) -> Paging {
159        Paging(self)
160    }
161
162    pub(crate) async fn default_request_builder_with_token(
163        &mut self,
164    ) -> AuthExecutionResult<(String, reqwest::RequestBuilder)> {
165        let access_token = self
166            .inner
167            .client_application
168            .get_token_silent_async()
169            .await?;
170
171        let request_builder = self
172            .inner
173            .inner
174            .request(
175                self.request_components.method.clone(),
176                self.request_components.url.clone(),
177            )
178            .bearer_auth(access_token.as_str())
179            .headers(self.request_components.headers.clone());
180
181        if let Some(body) = self.body.take() {
182            if body.has_byte_buf() {
183                self.request_components
184                    .headers
185                    .entry(CONTENT_TYPE)
186                    .or_insert(HeaderValue::from_static("application/octet-stream"));
187            } else if body.has_string_buf() {
188                self.request_components
189                    .headers
190                    .entry(CONTENT_TYPE)
191                    .or_insert(HeaderValue::from_static("application/json"));
192            }
193            return Ok((
194                access_token,
195                request_builder
196                    .body::<reqwest::Body>(body.into())
197                    .headers(self.request_components.headers.clone()),
198            ));
199        }
200        Ok((access_token, request_builder))
201    }
202
203    pub(crate) async fn default_request_builder(&mut self) -> GraphResult<reqwest::RequestBuilder> {
204        let access_token = self
205            .inner
206            .client_application
207            .get_token_silent_async()
208            .await?;
209
210        let request_builder = self
211            .inner
212            .inner
213            .request(
214                self.request_components.method.clone(),
215                self.request_components.url.clone(),
216            )
217            .bearer_auth(access_token.as_str())
218            .headers(self.request_components.headers.clone());
219
220        if let Some(body) = self.body.take() {
221            if body.has_byte_buf() {
222                self.request_components
223                    .headers
224                    .entry(CONTENT_TYPE)
225                    .or_insert(HeaderValue::from_static("application/octet-stream"));
226            } else if body.has_string_buf() {
227                self.request_components
228                    .headers
229                    .entry(CONTENT_TYPE)
230                    .or_insert(HeaderValue::from_static("application/json"));
231            }
232            return Ok(request_builder
233                .body::<reqwest::Body>(body.into())
234                .headers(self.request_components.headers.clone()));
235        }
236
237        Ok(request_builder)
238    }
239
240    /// Builds the request and returns a [`reqwest::RequestBuilder`].
241    #[inline]
242    pub async fn build(mut self) -> GraphResult<reqwest::RequestBuilder> {
243        if let Some(err) = self.error {
244            return Err(err);
245        }
246        self.default_request_builder().await
247    }
248
249    #[inline]
250    pub async fn send(self) -> GraphResult<reqwest::Response> {
251        let mut service = self.service.clone();
252        let request_builder = self.build().await?;
253        let request = request_builder.build()?;
254        service
255            .ready()
256            .await
257            .map_err(GraphFailure::from)?
258            .call(request)
259            .await
260            .map_err(GraphFailure::from)
261    }
262}
263
264impl ODataQuery for RequestHandler {
265    fn append_query_pair<KV: AsRef<str>>(self, key: KV, value: KV) -> Self {
266        self.append_query_pair(key.as_ref(), value.as_ref())
267    }
268}
269
270impl AsRef<Url> for RequestHandler {
271    fn as_ref(&self) -> &Url {
272        self.request_components.as_ref()
273    }
274}
275
276impl AsMut<Url> for RequestHandler {
277    fn as_mut(&mut self) -> &mut Url {
278        self.request_components.as_mut()
279    }
280}
281
282pub type PagingResponse<T> = http::Response<Result<T, ErrorMessage>>;
283pub type PagingResult<T> = GraphResult<PagingResponse<T>>;
284
285pub struct Paging(RequestHandler);
286
287impl Paging {
288    async fn http_response<T: DeserializeOwned>(
289        response: reqwest::Response,
290    ) -> GraphResult<(Option<String>, PagingResponse<T>)> {
291        let status = response.status();
292        let url = response.url().clone();
293        let headers = response.headers().clone();
294        let version = response.version();
295
296        let body: serde_json::Value = response.json().await?;
297        let next_link = body.odata_next_link();
298        let json = body.clone();
299        let body_result: Result<T, ErrorMessage> = serde_json::from_value(body)
300            .map_err(|_| serde_json::from_value(json.clone()).unwrap_or(ErrorMessage::default()));
301
302        let mut builder = http::Response::builder()
303            .url(url)
304            .json(&json)
305            .status(http::StatusCode::from(&status))
306            .version(version);
307
308        for builder_header in builder.headers_mut().iter_mut() {
309            builder_header.extend(headers.clone());
310        }
311
312        Ok((next_link, builder.body(body_result)?))
313    }
314
315    /// Returns all next links as [`VecDeque<http::Response<T>>`]. This method may
316    /// cause significant delay in returning when there is a high volume of next links.
317    ///
318    /// This method is mainly provided for convenience in cases where the caller is sure
319    /// the requests will return successful without issue or where the caller is ok with
320    /// a return delay or does not care if errors occur. It is not recommended to use this
321    /// method in production environments.
322    ///
323    ///
324    /// # Example
325    /// ```rust,ignore
326    /// #[derive(Debug, Serialize, Deserialize)]
327    /// pub struct User {
328    ///     pub(crate) id: Option<String>,
329    ///     #[serde(rename = "userPrincipalName")]
330    ///     user_principal_name: Option<String>,
331    /// }
332    ///
333    /// #[derive(Debug, Serialize, Deserialize)]
334    /// pub struct Users {
335    ///     pub value: Vec<User>,
336    /// }
337    ///
338    /// #[tokio::main]
339    /// async fn main() -> GraphResult<()> {
340    ///     let client = GraphClient::new("ACCESS_TOKEN");
341    ///
342    ///     let deque = client
343    ///         .users()
344    ///         .list_user()
345    ///         .select(&["id", "userPrincipalName"])
346    ///         .paging()
347    ///         .json::<Users>()
348    ///         .await?;
349    ///
350    ///     for response in deque.iter() {
351    ///         let users = response.into_body()?;
352    ///         println!("{users:#?}");
353    ///     }
354    ///     Ok(())
355    /// }
356    ///
357    /// ```
358    pub async fn json<T: DeserializeOwned>(mut self) -> GraphResult<VecDeque<PagingResponse<T>>> {
359        if let Some(err) = self.0.error {
360            return Err(err);
361        }
362
363        let (access_token, request) = self.0.default_request_builder_with_token().await?;
364        let response = request.send().await?;
365
366        let (next, http_response) = Paging::http_response(response).await?;
367        let mut next_link = next;
368        let mut vec = VecDeque::new();
369        vec.push_back(http_response);
370
371        let client = self.0.inner.inner.clone();
372        while let Some(next) = next_link {
373            let response = client
374                .get(next)
375                .bearer_auth(access_token.as_str())
376                .send()
377                .await?;
378
379            let (next, http_response) = Paging::http_response(response).await?;
380
381            next_link = next;
382            vec.push_back(http_response);
383        }
384
385        Ok(vec)
386    }
387
388    fn try_stream<'a, T: DeserializeOwned + 'a>(
389        mut self,
390    ) -> impl Stream<Item = PagingResult<T>> + 'a {
391        try_stream! {
392            let (access_token, request) = self.0.default_request_builder_with_token().await?;
393            let response = request.send().await?;
394            let (next, http_response) = Paging::http_response(response).await?;
395            let mut next_link = next;
396            yield http_response;
397
398            while let Some(url) = next_link {
399                let response = self.0
400                    .inner
401                    .inner
402                    .get(url)
403                    .bearer_auth(access_token.as_str())
404                    .send()
405                    .await?;
406                let (next, http_response) = Paging::http_response(response).await?;
407                next_link = next;
408                yield http_response;
409            }
410        }
411    }
412
413    /// Stream the current request along with any next link requests from the response body.
414    /// Each stream.next() returns a [`GraphResult<http::Response<T>>`].
415    ///
416    /// # Example
417    /// ```rust,ignore
418    /// let mut stream = client
419    ///     .users()
420    ///     .delta()
421    ///     .paging()
422    ///     .stream::<serde_json::Value>()
423    ///     .unwrap();
424    ///
425    ///  while let Some(result) = stream.next().await {
426    ///     println!("{result:#?}");
427    ///  }
428    /// ```
429    pub fn stream<'a, T: DeserializeOwned + 'a>(
430        mut self,
431    ) -> GraphResult<impl Stream<Item = PagingResult<T>> + 'a> {
432        if let Some(err) = self.0.error.take() {
433            return Err(err);
434        }
435
436        Ok(Box::pin(self.try_stream()))
437    }
438
439    /// Get next link responses using a channel Receiver [`tokio::sync::mpsc::Receiver<Option<GraphResult<http::Response<T>>>>`].
440    ///
441    /// By default channels use [`tokio::sync::mpsc::Sender::send_timeout`] with a buffer of 100
442    /// and a timeout of 60. Use [`Paging::channel_timeout`] to set a custom timeout and use
443    /// [`Paging::channel_buffer_timeout`] to set both the buffer and timeout.
444    ///
445    /// # Example
446    ///
447    /// ```rust,ignore
448    /// let client = Graph::new("ACCESS_TOKEN");
449    ///
450    ///  let mut receiver = client
451    ///     .users()
452    ///     .list_user()
453    ///     .top("5")
454    ///     .paging()
455    ///     .channel::<serde_json::Value>()
456    ///     .await?;
457    ///
458    ///  while let Some(result) = receiver.recv().await {
459    ///     let response = result?;
460    ///     println!("{:#?}", response);
461    ///  }
462    /// ```
463    pub async fn channel<T: DeserializeOwned + Debug + Send + 'static>(
464        self,
465    ) -> GraphResult<tokio::sync::mpsc::Receiver<PagingResult<T>>> {
466        self.channel_buffer_timeout(100, Duration::from_secs(60))
467            .await
468    }
469
470    /// Get next link responses using a channel Receiver,
471    /// [`tokio::sync::mpsc::Receiver<Option<GraphResult<http::Response<T>>>>`].
472    /// using a custom timeout [`Duration`]
473    ///
474    /// By default channels use [`tokio::sync::mpsc::Sender::send_timeout`] with a buffer of 100
475    /// and a timeout of 60. Use [`Paging::channel_timeout`] to set a custom timeout and use
476    /// [`Paging::channel_buffer_timeout`] to set both the buffer and timeout.
477    ///
478    /// # Example
479    ///
480    /// ```rust,ignore
481    /// let client = Graph::new("ACCESS_TOKEN");
482    ///
483    ///  let mut receiver = client
484    ///     .users()
485    ///     .list_user()
486    ///     .top("5")
487    ///     .paging()
488    ///     .channel_timeout::<serde_json::Value>(Duration::from_secs(60))
489    ///     .await?;
490    ///
491    ///  while let Some(result) = receiver.recv().await {
492    ///     let response = result?;
493    ///     println!("{:#?}", response);
494    ///  }
495    /// ```
496    pub async fn channel_timeout<T: DeserializeOwned + Debug + Send + 'static>(
497        self,
498        timeout: Duration,
499    ) -> GraphResult<tokio::sync::mpsc::Receiver<PagingResult<T>>> {
500        self.channel_buffer_timeout(100, timeout).await
501    }
502
503    async fn send_channel_request<T: DeserializeOwned>(
504        client: &reqwest::Client,
505        url: &str,
506        access_token: &str,
507    ) -> GraphResult<(Option<String>, PagingResponse<T>)> {
508        let response = client.get(url).bearer_auth(access_token).send().await?;
509
510        Paging::http_response(response).await
511    }
512
513    /// Get next link responses using a channel Receiver,
514    /// [`tokio::sync::mpsc::Receiver<Option<GraphResult<http::Response<T>>>>`].
515    /// with a custom timeout [`Duration`] and buffer.
516    ///
517    /// By default channels use [`tokio::sync::mpsc::Sender::send_timeout`] with a buffer of 100
518    /// and a timeout of 60. Use [`Paging::channel_timeout`] to set a custom timeout and use
519    /// [`Paging::channel_buffer_timeout`] to set both the buffer and timeout.
520    ///
521    /// # Example
522    ///
523    /// ```rust,ignore
524    /// let client = Graph::new("ACCESS_TOKEN");
525    ///
526    ///  let mut receiver = client
527    ///     .users()
528    ///     .list_user()
529    ///     .top("5")
530    ///     .paging()
531    ///     .channel_buffer_timeout::<serde_json::Value>(100, Duration::from_secs(60))
532    ///     .await?;
533    ///
534    ///  while let Some(result) = receiver.recv().await {
535    ///     let response = result?;
536    ///     println!("{:#?}", response);
537    ///  }
538    /// ```
539    #[allow(unused_assignments)] // Issue with Clippy not seeing next_link get assigned.
540    pub async fn channel_buffer_timeout<T: DeserializeOwned + Debug + Send + 'static>(
541        mut self,
542        buffer: usize,
543        timeout: Duration,
544    ) -> GraphResult<tokio::sync::mpsc::Receiver<PagingResult<T>>> {
545        let (sender, receiver) = tokio::sync::mpsc::channel(buffer);
546
547        let (access_token, request) = self.0.default_request_builder_with_token().await?;
548        let response = request.send().await?;
549        let (next, http_response) = Paging::http_response(response).await?;
550        let mut next_link = next;
551        sender
552            .send_timeout(Ok(http_response), timeout)
553            .await
554            .unwrap();
555
556        let client = self.0.inner.inner.clone();
557        tokio::spawn(async move {
558            while let Some(next) = next_link {
559                let result =
560                    Paging::send_channel_request(&client, next.as_str(), access_token.as_str())
561                        .await;
562
563                match result {
564                    Ok((next, response)) => {
565                        next_link = next;
566                        sender.send_timeout(Ok(response), timeout).await.unwrap();
567                    }
568                    Err(err) => {
569                        sender.send_timeout(Err(err), timeout).await.unwrap();
570                        next_link = None;
571                        break;
572                    }
573                }
574            }
575        });
576
577        Ok(receiver)
578    }
579}