1use crate::blocking::BlockingRequestHandler;
2use crate::internal::{
3 BodyRead, Client, GraphClientConfiguration, HttpResponseBuilderExt, ODataNextLink, ODataQuery,
4 RequestComponents,
5};
6use async_stream::try_stream;
7use futures::Stream;
8use graph_error::{AuthExecutionResult, ErrorMessage, GraphFailure, GraphResult};
9use reqwest::header::{HeaderMap, HeaderName, HeaderValue, CONTENT_TYPE};
10use reqwest::{Request, Response};
11use serde::de::DeserializeOwned;
12use std::collections::VecDeque;
13use std::fmt::Debug;
14use std::time::Duration;
15use tower::util::BoxCloneService;
16use tower::{Service, ServiceExt};
17use url::Url;
18
19pub struct RequestHandler {
20 pub(crate) inner: Client,
21 pub(crate) request_components: RequestComponents,
22 pub(crate) error: Option<GraphFailure>,
23 pub(crate) body: Option<BodyRead>,
24 pub(crate) client_builder: GraphClientConfiguration,
25 pub(crate) service:
26 BoxCloneService<Request, Response, Box<dyn std::error::Error + Send + Sync>>,
27}
28
29impl RequestHandler {
30 pub fn new(
31 inner: Client,
32 mut request_components: RequestComponents,
33 err: Option<GraphFailure>,
34 body: Option<BodyRead>,
35 ) -> RequestHandler {
36 let service = inner.builder.build_tower_service(&inner.inner);
37 let client_builder = inner.builder.clone();
38 let mut original_headers = inner.headers.clone();
39 original_headers.extend(request_components.headers.clone());
40 request_components.headers = original_headers;
41
42 let mut error = None;
43 if let Some(err) = err {
44 let message = err.to_string();
45 error = Some(GraphFailure::PreFlightError {
46 url: Some(request_components.url.clone()),
47 headers: Some(request_components.headers.clone()),
48 error: Some(Box::new(err)),
49 message,
50 });
51 }
52
53 RequestHandler {
54 inner,
55 request_components,
56 error,
57 body,
58 client_builder,
59 service,
60 }
61 }
62
63 pub fn into_blocking(self) -> BlockingRequestHandler {
64 BlockingRequestHandler::new(
65 self.client_builder.build_blocking(),
66 self.request_components,
67 self.error,
68 self.body,
69 )
70 }
71
72 pub fn is_err(&self) -> bool {
81 self.error.is_some()
82 }
83
84 pub fn err(&self) -> Option<&GraphFailure> {
93 self.error.as_ref()
94 }
95
96 #[inline]
97 pub fn url(&self) -> Url {
98 self.request_components.url.clone()
99 }
100
101 #[inline]
102 pub fn query<T: serde::Serialize + ?Sized>(mut self, query: &T) -> Self {
103 if let Err(err) = self.request_components.query(query) {
104 if self.error.is_none() {
105 self.error = Some(err);
106 }
107 }
108
109 if let Some("") = self.request_components.url.query() {
110 self.request_components.url.set_query(None);
111 }
112 self
113 }
114
115 #[inline]
116 pub fn append_query_pair<KV: AsRef<str>>(mut self, key: KV, value: KV) -> Self {
117 self.request_components
118 .url
119 .query_pairs_mut()
120 .append_pair(key.as_ref(), value.as_ref());
121 self
122 }
123
124 #[inline]
125 pub fn extend_path<I: AsRef<str>>(mut self, path: &[I]) -> Self {
126 if let Ok(mut p) = self.request_components.url.path_segments_mut() {
127 p.extend(path);
128 }
129 self
130 }
131
132 #[inline]
134 pub fn header<K: Into<HeaderName>, V: Into<HeaderValue>>(
135 mut self,
136 header_name: K,
137 header_value: V,
138 ) -> Self {
139 self.request_components
140 .headers
141 .insert(header_name.into(), header_value.into());
142 self
143 }
144
145 #[inline]
147 pub fn headers(mut self, header_map: HeaderMap) -> Self {
148 self.request_components.headers.extend(header_map);
149 self
150 }
151
152 #[inline]
154 pub fn headers_mut(&mut self) -> &mut HeaderMap {
155 self.request_components.as_mut()
156 }
157
158 pub fn paging(self) -> Paging {
159 Paging(self)
160 }
161
162 pub(crate) async fn default_request_builder_with_token(
163 &mut self,
164 ) -> AuthExecutionResult<(String, reqwest::RequestBuilder)> {
165 let access_token = self
166 .inner
167 .client_application
168 .get_token_silent_async()
169 .await?;
170
171 let request_builder = self
172 .inner
173 .inner
174 .request(
175 self.request_components.method.clone(),
176 self.request_components.url.clone(),
177 )
178 .bearer_auth(access_token.as_str())
179 .headers(self.request_components.headers.clone());
180
181 if let Some(body) = self.body.take() {
182 if body.has_byte_buf() {
183 self.request_components
184 .headers
185 .entry(CONTENT_TYPE)
186 .or_insert(HeaderValue::from_static("application/octet-stream"));
187 } else if body.has_string_buf() {
188 self.request_components
189 .headers
190 .entry(CONTENT_TYPE)
191 .or_insert(HeaderValue::from_static("application/json"));
192 }
193 return Ok((
194 access_token,
195 request_builder
196 .body::<reqwest::Body>(body.into())
197 .headers(self.request_components.headers.clone()),
198 ));
199 }
200 Ok((access_token, request_builder))
201 }
202
203 pub(crate) async fn default_request_builder(&mut self) -> GraphResult<reqwest::RequestBuilder> {
204 let access_token = self
205 .inner
206 .client_application
207 .get_token_silent_async()
208 .await?;
209
210 let request_builder = self
211 .inner
212 .inner
213 .request(
214 self.request_components.method.clone(),
215 self.request_components.url.clone(),
216 )
217 .bearer_auth(access_token.as_str())
218 .headers(self.request_components.headers.clone());
219
220 if let Some(body) = self.body.take() {
221 if body.has_byte_buf() {
222 self.request_components
223 .headers
224 .entry(CONTENT_TYPE)
225 .or_insert(HeaderValue::from_static("application/octet-stream"));
226 } else if body.has_string_buf() {
227 self.request_components
228 .headers
229 .entry(CONTENT_TYPE)
230 .or_insert(HeaderValue::from_static("application/json"));
231 }
232 return Ok(request_builder
233 .body::<reqwest::Body>(body.into())
234 .headers(self.request_components.headers.clone()));
235 }
236
237 Ok(request_builder)
238 }
239
240 #[inline]
242 pub async fn build(mut self) -> GraphResult<reqwest::RequestBuilder> {
243 if let Some(err) = self.error {
244 return Err(err);
245 }
246 self.default_request_builder().await
247 }
248
249 #[inline]
250 pub async fn send(self) -> GraphResult<reqwest::Response> {
251 let mut service = self.service.clone();
252 let request_builder = self.build().await?;
253 let request = request_builder.build()?;
254 service
255 .ready()
256 .await
257 .map_err(GraphFailure::from)?
258 .call(request)
259 .await
260 .map_err(GraphFailure::from)
261 }
262}
263
264impl ODataQuery for RequestHandler {
265 fn append_query_pair<KV: AsRef<str>>(self, key: KV, value: KV) -> Self {
266 self.append_query_pair(key.as_ref(), value.as_ref())
267 }
268}
269
270impl AsRef<Url> for RequestHandler {
271 fn as_ref(&self) -> &Url {
272 self.request_components.as_ref()
273 }
274}
275
276impl AsMut<Url> for RequestHandler {
277 fn as_mut(&mut self) -> &mut Url {
278 self.request_components.as_mut()
279 }
280}
281
282pub type PagingResponse<T> = http::Response<Result<T, ErrorMessage>>;
283pub type PagingResult<T> = GraphResult<PagingResponse<T>>;
284
285pub struct Paging(RequestHandler);
286
287impl Paging {
288 async fn http_response<T: DeserializeOwned>(
289 response: reqwest::Response,
290 ) -> GraphResult<(Option<String>, PagingResponse<T>)> {
291 let status = response.status();
292 let url = response.url().clone();
293 let headers = response.headers().clone();
294 let version = response.version();
295
296 let body: serde_json::Value = response.json().await?;
297 let next_link = body.odata_next_link();
298 let json = body.clone();
299 let body_result: Result<T, ErrorMessage> = serde_json::from_value(body)
300 .map_err(|_| serde_json::from_value(json.clone()).unwrap_or(ErrorMessage::default()));
301
302 let mut builder = http::Response::builder()
303 .url(url)
304 .json(&json)
305 .status(http::StatusCode::from(&status))
306 .version(version);
307
308 for builder_header in builder.headers_mut().iter_mut() {
309 builder_header.extend(headers.clone());
310 }
311
312 Ok((next_link, builder.body(body_result)?))
313 }
314
315 pub async fn json<T: DeserializeOwned>(mut self) -> GraphResult<VecDeque<PagingResponse<T>>> {
359 if let Some(err) = self.0.error {
360 return Err(err);
361 }
362
363 let (access_token, request) = self.0.default_request_builder_with_token().await?;
364 let response = request.send().await?;
365
366 let (next, http_response) = Paging::http_response(response).await?;
367 let mut next_link = next;
368 let mut vec = VecDeque::new();
369 vec.push_back(http_response);
370
371 let client = self.0.inner.inner.clone();
372 while let Some(next) = next_link {
373 let response = client
374 .get(next)
375 .bearer_auth(access_token.as_str())
376 .send()
377 .await?;
378
379 let (next, http_response) = Paging::http_response(response).await?;
380
381 next_link = next;
382 vec.push_back(http_response);
383 }
384
385 Ok(vec)
386 }
387
388 fn try_stream<'a, T: DeserializeOwned + 'a>(
389 mut self,
390 ) -> impl Stream<Item = PagingResult<T>> + 'a {
391 try_stream! {
392 let (access_token, request) = self.0.default_request_builder_with_token().await?;
393 let response = request.send().await?;
394 let (next, http_response) = Paging::http_response(response).await?;
395 let mut next_link = next;
396 yield http_response;
397
398 while let Some(url) = next_link {
399 let response = self.0
400 .inner
401 .inner
402 .get(url)
403 .bearer_auth(access_token.as_str())
404 .send()
405 .await?;
406 let (next, http_response) = Paging::http_response(response).await?;
407 next_link = next;
408 yield http_response;
409 }
410 }
411 }
412
413 pub fn stream<'a, T: DeserializeOwned + 'a>(
430 mut self,
431 ) -> GraphResult<impl Stream<Item = PagingResult<T>> + 'a> {
432 if let Some(err) = self.0.error.take() {
433 return Err(err);
434 }
435
436 Ok(Box::pin(self.try_stream()))
437 }
438
439 pub async fn channel<T: DeserializeOwned + Debug + Send + 'static>(
464 self,
465 ) -> GraphResult<tokio::sync::mpsc::Receiver<PagingResult<T>>> {
466 self.channel_buffer_timeout(100, Duration::from_secs(60))
467 .await
468 }
469
470 pub async fn channel_timeout<T: DeserializeOwned + Debug + Send + 'static>(
497 self,
498 timeout: Duration,
499 ) -> GraphResult<tokio::sync::mpsc::Receiver<PagingResult<T>>> {
500 self.channel_buffer_timeout(100, timeout).await
501 }
502
503 async fn send_channel_request<T: DeserializeOwned>(
504 client: &reqwest::Client,
505 url: &str,
506 access_token: &str,
507 ) -> GraphResult<(Option<String>, PagingResponse<T>)> {
508 let response = client.get(url).bearer_auth(access_token).send().await?;
509
510 Paging::http_response(response).await
511 }
512
513 #[allow(unused_assignments)] pub async fn channel_buffer_timeout<T: DeserializeOwned + Debug + Send + 'static>(
541 mut self,
542 buffer: usize,
543 timeout: Duration,
544 ) -> GraphResult<tokio::sync::mpsc::Receiver<PagingResult<T>>> {
545 let (sender, receiver) = tokio::sync::mpsc::channel(buffer);
546
547 let (access_token, request) = self.0.default_request_builder_with_token().await?;
548 let response = request.send().await?;
549 let (next, http_response) = Paging::http_response(response).await?;
550 let mut next_link = next;
551 sender
552 .send_timeout(Ok(http_response), timeout)
553 .await
554 .unwrap();
555
556 let client = self.0.inner.inner.clone();
557 tokio::spawn(async move {
558 while let Some(next) = next_link {
559 let result =
560 Paging::send_channel_request(&client, next.as_str(), access_token.as_str())
561 .await;
562
563 match result {
564 Ok((next, response)) => {
565 next_link = next;
566 sender.send_timeout(Ok(response), timeout).await.unwrap();
567 }
568 Err(err) => {
569 sender.send_timeout(Err(err), timeout).await.unwrap();
570 next_link = None;
571 break;
572 }
573 }
574 }
575 });
576
577 Ok(receiver)
578 }
579}