1use crate::blocking::BlockingClient;
2use graph_core::identity::{ClientApplication, ForceTokenRefresh};
3use reqwest::header::{HeaderMap, HeaderValue, ACCEPT, USER_AGENT};
4use reqwest::redirect::Policy;
5use reqwest::tls::Version;
6use reqwest::Proxy;
7use reqwest::{Request, Response};
8use std::env::VarError;
9use std::ffi::OsStr;
10use std::fmt::{Debug, Formatter};
11use std::time::Duration;
12use tower::limit::ConcurrencyLimitLayer;
13use tower::retry::RetryLayer;
14use tower::util::BoxCloneService;
15use tower::ServiceExt;
16
17fn user_agent_header_from_env() -> Option<HeaderValue> {
18 let header = std::option_env!("GRAPH_CLIENT_USER_AGENT")?;
19 HeaderValue::from_str(header).ok()
20}
21
22#[derive(Default, Clone)]
23struct ServiceLayersConfiguration {
24 concurrency_limit: Option<usize>,
25 retry: Option<usize>,
26 wait_for_retry_after_headers: Option<()>,
27}
28
29#[derive(Clone)]
30struct ClientConfiguration {
31 client_application: Option<Box<dyn ClientApplication>>,
32 headers: HeaderMap,
33 referer: bool,
34 timeout: Option<Duration>,
35 connect_timeout: Option<Duration>,
36 connection_verbose: bool,
37 https_only: bool,
38 min_tls_version: Version,
41 service_layers_configuration: ServiceLayersConfiguration,
42 proxy: Option<Proxy>,
43}
44
45impl ClientConfiguration {
46 pub fn new() -> ClientConfiguration {
47 let mut headers: HeaderMap<HeaderValue> = HeaderMap::with_capacity(2);
48 headers.insert(ACCEPT, HeaderValue::from_static("*/*"));
49
50 if let Some(user_agent) = user_agent_header_from_env() {
51 headers.insert(USER_AGENT, user_agent);
52 }
53
54 ClientConfiguration {
55 client_application: None,
56 headers,
57 referer: true,
58 timeout: None,
59 connect_timeout: None,
60 connection_verbose: false,
61 https_only: true,
62 min_tls_version: Version::TLS_1_2,
63 service_layers_configuration: ServiceLayersConfiguration::default(),
64 proxy: None,
65 }
66 }
67}
68
69impl Debug for ClientConfiguration {
70 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
71 f.debug_struct("ClientConfiguration")
72 .field("headers", &self.headers)
73 .field("referer", &self.referer)
74 .field("timeout", &self.timeout)
75 .field("connect_timeout", &self.connect_timeout)
76 .field("https_only", &self.https_only)
77 .field("min_tls_version", &self.min_tls_version)
78 .field("proxy", &self.proxy)
79 .finish()
80 }
81}
82
83#[derive(Clone, Debug)]
84pub struct GraphClientConfiguration {
85 config: ClientConfiguration,
86}
87
88impl GraphClientConfiguration {
89 pub fn new() -> GraphClientConfiguration {
90 GraphClientConfiguration {
91 config: ClientConfiguration::new(),
92 }
93 }
94
95 pub fn access_token<AT: ToString>(mut self, access_token: AT) -> GraphClientConfiguration {
96 self.config.client_application = Some(Box::new(access_token.to_string()));
97 self
98 }
99
100 pub fn client_application<CA: ClientApplication + 'static>(mut self, client_app: CA) -> Self {
101 self.config.client_application = Some(Box::new(client_app));
102 self
103 }
104
105 pub fn default_headers(mut self, headers: HeaderMap) -> GraphClientConfiguration {
106 for (key, value) in headers.iter() {
107 self.config.headers.insert(key, value.clone());
108 }
109 self
110 }
111
112 pub fn referer(mut self, enable: bool) -> GraphClientConfiguration {
116 self.config.referer = enable;
117 self
118 }
119
120 pub fn timeout(mut self, timeout: Duration) -> GraphClientConfiguration {
127 self.config.timeout = Some(timeout);
128 self
129 }
130
131 pub fn connect_timeout(mut self, timeout: Duration) -> GraphClientConfiguration {
140 self.config.connect_timeout = Some(timeout);
141 self
142 }
143
144 pub fn connection_verbose(mut self, verbose: bool) -> GraphClientConfiguration {
151 self.config.connection_verbose = verbose;
152 self
153 }
154
155 pub fn user_agent(mut self, value: HeaderValue) -> GraphClientConfiguration {
156 self.config.headers.insert(USER_AGENT, value);
157 self
158 }
159
160 pub fn min_tls_version(mut self, version: Version) -> GraphClientConfiguration {
163 self.config.min_tls_version = version;
164 self
165 }
166
167 pub fn proxy(mut self, proxy: Proxy) -> GraphClientConfiguration {
171 self.config.proxy = Some(proxy);
172 self
173 }
174
175 #[cfg(feature = "test-util")]
176 pub fn https_only(mut self, https_only: bool) -> GraphClientConfiguration {
177 self.config.https_only = https_only;
178 self
179 }
180
181 pub fn retry(mut self, retry: Option<usize>) -> GraphClientConfiguration {
189 self.config.service_layers_configuration.retry = retry;
190 self
191 }
192
193 pub fn wait_for_retry_after_headers(mut self, retry: bool) -> GraphClientConfiguration {
207 self.config
208 .service_layers_configuration
209 .wait_for_retry_after_headers = match retry {
210 true => Some(()),
211 false => None,
212 };
213 self
214 }
215
216 pub fn concurrency_limit(
223 mut self,
224 concurrency_limit: Option<usize>,
225 ) -> GraphClientConfiguration {
226 self.config.service_layers_configuration.concurrency_limit = concurrency_limit;
227 self
228 }
229
230 pub(crate) fn build_tower_service(
231 &self,
232 client: &reqwest::Client,
233 ) -> BoxCloneService<Request, Response, Box<dyn std::error::Error + Send + Sync>> {
234 tower::ServiceBuilder::new()
235 .option_layer(
236 self.config
237 .service_layers_configuration
238 .retry
239 .map(|num| RetryLayer::new(crate::tower_services::Attempts(num))),
240 )
241 .option_layer(
242 self.config
243 .service_layers_configuration
244 .wait_for_retry_after_headers
245 .map(|_| RetryLayer::new(crate::tower_services::WaitFor())),
246 )
247 .option_layer(
248 self.config
249 .service_layers_configuration
250 .concurrency_limit
251 .map(ConcurrencyLimitLayer::new),
252 )
253 .service(client.clone())
254 .boxed_clone()
255 }
256
257 pub fn build(self) -> Client {
258 let config = self.clone();
259 let headers = self.config.headers.clone();
260 let mut builder = reqwest::ClientBuilder::new()
261 .referer(self.config.referer)
262 .connection_verbose(self.config.connection_verbose)
263 .https_only(self.config.https_only)
264 .min_tls_version(self.config.min_tls_version)
265 .redirect(Policy::limited(2))
266 .default_headers(self.config.headers);
267
268 if let Some(timeout) = self.config.timeout {
269 builder = builder.timeout(timeout);
270 }
271
272 if let Some(connect_timeout) = self.config.connect_timeout {
273 builder = builder.connect_timeout(connect_timeout);
274 }
275
276 if let Some(proxy) = self.config.proxy {
277 builder = builder.proxy(proxy);
278 }
279
280 let client = builder.build().unwrap();
281
282 if let Some(client_application) = self.config.client_application {
283 Client {
284 client_application,
285 inner: client,
286 headers,
287 builder: config,
288 }
289 } else {
290 Client {
291 client_application: Box::<String>::default(),
292 inner: client,
293 headers,
294 builder: config,
295 }
296 }
297 }
298
299 pub(crate) fn build_blocking(self) -> BlockingClient {
300 let headers = self.config.headers.clone();
301 let mut builder = reqwest::blocking::ClientBuilder::new()
302 .referer(self.config.referer)
303 .connection_verbose(self.config.connection_verbose)
304 .https_only(self.config.https_only)
305 .min_tls_version(self.config.min_tls_version)
306 .redirect(Policy::limited(2))
307 .default_headers(self.config.headers);
308
309 if let Some(timeout) = self.config.timeout {
310 builder = builder.timeout(timeout);
311 }
312
313 if let Some(connect_timeout) = self.config.connect_timeout {
314 builder = builder.connect_timeout(connect_timeout);
315 }
316
317 if let Some(proxy) = self.config.proxy {
318 builder = builder.proxy(proxy);
319 }
320
321 let client = builder.build().unwrap();
322
323 if let Some(client_application) = self.config.client_application {
324 BlockingClient {
325 client_application,
326 inner: client,
327 headers,
328 }
329 } else {
330 BlockingClient {
331 client_application: Box::<String>::default(),
332 inner: client,
333 headers,
334 }
335 }
336 }
337}
338
339impl Default for GraphClientConfiguration {
340 fn default() -> Self {
341 GraphClientConfiguration::new()
342 }
343}
344
345#[derive(Clone)]
346pub struct Client {
347 pub(crate) client_application: Box<dyn ClientApplication>,
348 pub(crate) inner: reqwest::Client,
349 pub(crate) headers: HeaderMap,
350 pub(crate) builder: GraphClientConfiguration,
351}
352
353impl Client {
354 pub fn new<CA: ClientApplication + 'static>(client_app: CA) -> Self {
355 GraphClientConfiguration::new()
356 .client_application(client_app)
357 .build()
358 }
359
360 pub fn from_access_token<T: AsRef<str>>(access_token: T) -> Self {
361 GraphClientConfiguration::new()
362 .access_token(access_token.as_ref())
363 .build()
364 }
365
366 pub fn new_env<K: AsRef<OsStr>>(env_var: K) -> Result<Client, VarError> {
369 Ok(GraphClientConfiguration::new()
370 .access_token(std::env::var(env_var)?)
371 .build())
372 }
373
374 pub fn builder() -> GraphClientConfiguration {
375 GraphClientConfiguration::new()
376 }
377
378 pub fn headers(&self) -> &HeaderMap {
379 &self.headers
380 }
381
382 pub fn with_force_token_refresh(&mut self, force_token_refresh: ForceTokenRefresh) {
383 self.client_application
384 .with_force_token_refresh(force_token_refresh);
385 }
386}
387
388impl Default for Client {
389 fn default() -> Self {
390 GraphClientConfiguration::new().build()
391 }
392}
393
394impl Debug for Client {
395 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
396 f.debug_struct("Client")
397 .field("inner", &self.inner)
398 .field("headers", &self.headers)
399 .field("builder", &self.builder)
400 .finish()
401 }
402}
403
404impl From<GraphClientConfiguration> for Client {
405 fn from(value: GraphClientConfiguration) -> Self {
406 value.build()
407 }
408}
409
410#[cfg(test)]
411mod test {
412 use super::*;
413
414 #[test]
415 fn compile_time_user_agent_header() {
416 let client = GraphClientConfiguration::new()
417 .access_token("access_token")
418 .build();
419
420 assert!(client.builder.config.headers.contains_key(USER_AGENT));
421 }
422
423 #[test]
424 fn update_user_agent_header() {
425 let client = GraphClientConfiguration::new()
426 .access_token("access_token")
427 .user_agent(HeaderValue::from_static("user_agent"))
428 .build();
429
430 assert!(client.builder.config.headers.contains_key(USER_AGENT));
431 let user_agent_header = client.builder.config.headers.get(USER_AGENT).unwrap();
432 assert_eq!("user_agent", user_agent_header.to_str().unwrap());
433 }
434}