1#![warn(missing_docs)]
2
3mod error;
54mod event;
55mod program;
56mod report;
57mod resource;
58mod target;
59mod timeline;
60mod ven;
61
62use axum::async_trait;
63use openleadr_wire::{event::EventId, Event, Ven};
64use std::{
65 fmt::Debug,
66 future::Future,
67 sync::Arc,
68 time::{Duration, Instant},
69};
70use tokio::sync::RwLock;
71
72use reqwest::{Method, RequestBuilder, Response};
73use url::Url;
74
75pub use error::*;
76pub use event::*;
77pub use program::*;
78pub use report::*;
79pub use resource::*;
80pub use target::*;
81pub use timeline::*;
82pub use ven::*;
83
84use crate::error::Result;
85use openleadr_wire::ven::{VenContent, VenId};
86pub(crate) use openleadr_wire::{
87 event::EventContent,
88 program::{ProgramContent, ProgramId},
89 target::TargetType,
90 Program,
91};
92
93#[async_trait]
94pub trait HttpClient: Debug {
99 #[allow(missing_docs)]
100 fn request_builder(&self, method: Method, url: Url) -> RequestBuilder;
101 #[allow(missing_docs)]
102 async fn send(&self, req: RequestBuilder) -> reqwest::Result<Response>;
103}
104
105#[derive(Debug, Clone)]
112pub struct Client {
113 client_ref: Arc<ClientRef>,
114}
115
116pub struct ClientCredentials {
118 #[allow(missing_docs)]
119 pub client_id: String,
120 client_secret: String,
121 pub refresh_margin: Duration,
127 pub default_credential_expires_in: Duration,
131}
132
133impl Debug for ClientCredentials {
134 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
135 f.debug_struct(std::any::type_name::<Self>())
136 .field("client_id", &self.client_id)
137 .field("refresh_margin", &self.refresh_margin)
138 .field(
139 "default_credential_expires_in",
140 &self.default_credential_expires_in,
141 )
142 .finish_non_exhaustive()
143 }
144}
145
146impl ClientCredentials {
147 pub fn new(client_id: String, client_secret: String) -> Self {
152 Self {
153 client_id,
154 client_secret,
155 refresh_margin: Duration::from_secs(60),
156 default_credential_expires_in: Duration::from_secs(3600),
157 }
158 }
159}
160
161struct AuthToken {
162 token: String,
163 expires_in: Duration,
164 since: Instant,
165}
166
167impl Debug for AuthToken {
168 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
169 f.debug_struct(std::any::type_name::<Self>())
170 .field("expires_in", &self.expires_in)
171 .field("since", &self.since)
172 .finish_non_exhaustive()
173 }
174}
175
176#[derive(Debug)]
177struct ClientRef {
178 client: Box<dyn HttpClient + Send + Sync>,
179 vtn_base_url: Url,
180 oauth_base_url: Url,
181 default_page_size: usize,
182 auth_data: Option<ClientCredentials>,
183 auth_token: RwLock<Option<AuthToken>>,
184}
185
186impl ClientRef {
187 async fn ensure_auth(&self) -> Result<()> {
193 let Some(auth_data) = &self.auth_data else {
195 return Ok(());
196 };
197
198 if let Some(token) = self.auth_token.read().await.as_ref() {
200 if token.since.elapsed() < token.expires_in - auth_data.refresh_margin {
201 return Ok(());
202 }
203 }
204
205 #[derive(serde::Serialize)]
206 struct AccessTokenRequest {
207 grant_type: &'static str,
208 #[serde(skip_serializing_if = "Option::is_none")]
209 scope: Option<String>,
210 #[serde(skip_serializing_if = "Option::is_none")]
211 client_id: Option<String>,
212 #[serde(skip_serializing_if = "Option::is_none")]
213 client_secret: Option<String>,
214 }
215
216 let request = self
218 .client
219 .request_builder(Method::POST, self.oauth_base_url.clone())
220 .form(&AccessTokenRequest {
221 grant_type: "client_credentials",
222 scope: None,
223 client_id: None,
224 client_secret: None,
225 });
226 let request = request.basic_auth(&auth_data.client_id, Some(&auth_data.client_secret));
227 let request = request.header("Accept", "application/json");
228 let since = Instant::now();
229 let res = self.client.send(request).await?;
230 if !res.status().is_success() {
231 let problem = res.json::<openleadr_wire::oauth::OAuthError>().await?;
232 return Err(Error::AuthProblem(problem));
233 }
234
235 #[derive(Debug, serde::Deserialize)]
236 struct AuthResult {
237 access_token: String,
238 token_type: String,
239 #[serde(default)]
240 expires_in: Option<u64>,
241 }
249
250 let auth_result = res.json::<AuthResult>().await?;
251 if auth_result.token_type.to_lowercase() != "bearer" {
252 return Err(Error::OAuthTokenNotBearer);
253 }
254 let token = AuthToken {
255 token: auth_result.access_token,
256 expires_in: auth_result
257 .expires_in
258 .map(Duration::from_secs)
259 .unwrap_or(auth_data.default_credential_expires_in),
260 since,
261 };
262
263 *self.auth_token.write().await = Some(token);
264 Ok(())
265 }
266
267 async fn request<T: serde::de::DeserializeOwned>(
268 &self,
269 mut request: RequestBuilder,
270 query: &[(&str, &str)],
271 ) -> Result<T> {
272 self.ensure_auth().await?;
273 request = request.header("Accept", "application/json");
274 if !query.is_empty() {
275 request = request.query(&query);
276 }
277
278 {
280 let token = self.auth_token.read().await;
281 if let Some(token) = token.as_ref() {
282 request = request.bearer_auth(&token.token);
283 }
284 }
285 let res = self.client.send(request).await?;
286
287 if !res.status().is_success() {
289 let problem = res.json::<openleadr_wire::problem::Problem>().await?;
290 return Err(crate::error::Error::from(problem));
291 }
292
293 Ok(res.json().await?)
294 }
295
296 async fn get<T: serde::de::DeserializeOwned>(
297 &self,
298 path: &str,
299 query: &[(&str, &str)],
300 ) -> Result<T> {
301 let url = self.vtn_base_url.join(path)?;
302 let request = self.client.request_builder(Method::GET, url);
303 self.request(request, query).await
304 }
305
306 async fn post<S, T>(&self, path: &str, body: &S) -> Result<T>
307 where
308 S: serde::ser::Serialize + Sync,
309 T: serde::de::DeserializeOwned,
310 {
311 let url = self.vtn_base_url.join(path)?;
312 let request = self.client.request_builder(Method::POST, url).json(body);
313 self.request(request, &[]).await
314 }
315
316 async fn put<S, T>(&self, path: &str, body: &S) -> Result<T>
317 where
318 S: serde::ser::Serialize + Sync,
319 T: serde::de::DeserializeOwned,
320 {
321 let url = self.vtn_base_url.join(path)?;
322 let request = self.client.request_builder(Method::PUT, url).json(body);
323 self.request(request, &[]).await
324 }
325
326 async fn delete<T>(&self, path: &str) -> Result<T>
327 where
328 T: serde::de::DeserializeOwned,
329 {
330 let url = self.vtn_base_url.join(path)?;
331 let request = self.client.request_builder(Method::DELETE, url);
332 self.request(request, &[]).await
333 }
334
335 fn default_page_size(&self) -> usize {
336 self.default_page_size
337 }
338
339 async fn iterate_pages<T, Fut>(
340 &self,
341 single_page_req: impl Fn(usize, usize) -> Fut,
342 ) -> Result<Vec<T>>
343 where
344 Fut: Future<Output = Result<Vec<T>>>,
345 {
346 let page_size = self.default_page_size();
347 let mut items = vec![];
348 let mut page = 0;
349 loop {
351 let received = single_page_req(page * page_size, page_size).await?;
352 let received_all = received.len() < page_size;
353 for item in received {
354 items.push(item);
355 }
356
357 if received_all {
358 break;
359 } else {
360 page += 1;
361 }
362 }
363
364 Ok(items)
365 }
366}
367
368#[derive(Debug)]
369struct ReqwestClientRef {
370 client: reqwest::Client,
371}
372
373#[async_trait]
374impl HttpClient for ReqwestClientRef {
375 fn request_builder(&self, method: Method, url: Url) -> RequestBuilder {
376 self.client.request(method, url)
377 }
378
379 async fn send(&self, req: RequestBuilder) -> std::result::Result<Response, reqwest::Error> {
380 req.send().await
381 }
382}
383
384pub struct PaginationOptions {
390 #[allow(missing_docs)]
391 pub skip: usize,
392 #[allow(missing_docs)]
393 pub limit: usize,
394}
395
396#[derive(Debug, Clone)]
410pub enum Filter<'a, S: AsRef<str>> {
411 None,
413 By(TargetType, &'a [S]),
418}
419
420impl<'a> Filter<'a, &'static str> {
421 pub const fn none() -> Filter<'a, &'static str> {
423 Filter::None
424 }
425}
426
427impl<'a, S: AsRef<str>> Filter<'a, S> {
428 pub(crate) fn to_query_params(&'a self) -> Vec<(&'a str, &'a str)> {
429 let mut query = vec![];
430 if let Filter::By(ref target_label, target_values) = self {
431 query.push(("targetType", target_label.as_str()));
432
433 for target_value in *target_values {
434 query.push(("targetValues", target_value.as_ref()));
435 }
436 }
437 query
438 }
439}
440
441impl Client {
442 pub fn with_url(base_url: Url, auth: Option<ClientCredentials>) -> Self {
448 let client = reqwest::Client::new();
449 let oauth_base_url = base_url.join("/auth/token").unwrap();
450 Self::with_details(base_url, oauth_base_url, client, auth)
451 }
452
453 pub fn with_details(
461 vtn_base_url: Url,
462 oauth_base_url: Url,
463 client: reqwest::Client,
464 auth: Option<ClientCredentials>,
465 ) -> Self {
466 Self::with_http_client(
467 vtn_base_url,
468 oauth_base_url,
469 Box::new(ReqwestClientRef { client }),
470 auth,
471 )
472 }
473
474 pub fn with_http_client(
480 vtn_base_url: Url,
481 oauth_base_url: Url,
482 client: Box<dyn HttpClient + Send + Sync>,
483 auth: Option<ClientCredentials>,
484 ) -> Self {
485 let client_ref = ClientRef {
486 client,
487 vtn_base_url,
488 oauth_base_url,
489 default_page_size: 50,
490 auth_data: auth,
491 auth_token: RwLock::new(None),
492 };
493 Self::new(client_ref)
494 }
495
496 fn new(client_ref: ClientRef) -> Self {
497 Client {
498 client_ref: Arc::new(client_ref),
499 }
500 }
501
502 pub async fn create_program(&self, program_content: ProgramContent) -> Result<ProgramClient> {
504 let program = self.client_ref.post("programs", &program_content).await?;
505 Ok(ProgramClient::from_program(self.clone(), program))
506 }
507
508 pub async fn get_programs(
510 &self,
511 filter: Filter<'_, impl AsRef<str>>,
512 pagination: PaginationOptions,
513 ) -> Result<Vec<ProgramClient>> {
514 let skip_str = pagination.skip.to_string();
516 let limit_str = pagination.limit.to_string();
517 let mut query: Vec<(&str, &str)> = vec![("skip", &skip_str), ("limit", &limit_str)];
519
520 query.extend_from_slice(filter.to_query_params().as_slice());
521
522 let programs: Vec<Program> = self.client_ref.get("programs", &query).await?;
524 Ok(programs
525 .into_iter()
526 .map(|program| ProgramClient::from_program(self.clone(), program))
527 .collect())
528 }
529
530 pub async fn get_program_list(
534 &self,
535 filter: Filter<'_, impl AsRef<str> + Clone>,
536 ) -> Result<Vec<ProgramClient>> {
537 self.client_ref
538 .iterate_pages(|skip, limit| {
539 self.get_programs(filter.clone(), PaginationOptions { skip, limit })
540 })
541 .await
542 }
543
544 pub async fn get_program_by_id(&self, id: &ProgramId) -> Result<ProgramClient> {
546 let program = self
547 .client_ref
548 .get(&format!("programs/{}", id.as_str()), &[])
549 .await?;
550
551 Ok(ProgramClient::from_program(self.clone(), program))
552 }
553
554 pub async fn get_events(
558 &self,
559 program_id: Option<&ProgramId>,
560 filter: Filter<'_, impl AsRef<str>>,
561 pagination: PaginationOptions,
562 ) -> Result<Vec<EventClient>> {
563 let skip_str = pagination.skip.to_string();
565 let limit_str = pagination.limit.to_string();
566 let mut query: Vec<(&str, &str)> = vec![("skip", &skip_str), ("limit", &limit_str)];
568
569 query.extend_from_slice(filter.to_query_params().as_slice());
570
571 if let Some(program_id) = program_id {
572 query.push(("programID", program_id.as_str()));
573 }
574
575 let events: Vec<Event> = self.client_ref.get("events", &query).await?;
577 Ok(events
578 .into_iter()
579 .map(|event| EventClient::from_event(self.client_ref.clone(), event))
580 .collect())
581 }
582
583 pub async fn get_event_list(
587 &self,
588 program_id: Option<&ProgramId>,
589 filter: Filter<'_, impl AsRef<str> + Clone>,
590 ) -> Result<Vec<EventClient>> {
591 self.client_ref
592 .iterate_pages(|skip, limit| {
593 self.get_events(
594 program_id,
595 filter.clone(),
596 PaginationOptions { skip, limit },
597 )
598 })
599 .await
600 }
601
602 pub async fn get_event_by_id(&self, id: &EventId) -> Result<EventClient> {
604 let event = self
605 .client_ref
606 .get(&format!("events/{}", id.as_str()), &[])
607 .await?;
608
609 Ok(EventClient::from_event(self.client_ref.clone(), event))
610 }
611
612 pub async fn create_ven(&self, ven: VenContent) -> Result<VenClient> {
614 let ven = self.client_ref.post("vens", &ven).await?;
615 Ok(VenClient::from_ven(self.client_ref.clone(), ven))
616 }
617
618 async fn get_vens(
619 &self,
620 skip: usize,
621 limit: usize,
622 filter: Filter<'_, impl AsRef<str>>,
623 ) -> Result<Vec<VenClient>> {
624 let skip_str = skip.to_string();
625 let limit_str = limit.to_string();
626 let mut query: Vec<(&str, &str)> = vec![("skip", &skip_str), ("limit", &limit_str)];
627
628 query.extend_from_slice(filter.to_query_params().as_slice());
629
630 let vens: Vec<Ven> = self.client_ref.get("vens", &query).await?;
632 Ok(vens
633 .into_iter()
634 .map(|ven| VenClient::from_ven(self.client_ref.clone(), ven))
635 .collect())
636 }
637
638 pub async fn get_ven_list(
642 &self,
643 filter: Filter<'_, impl AsRef<str> + Clone>,
644 ) -> Result<Vec<VenClient>> {
645 self.client_ref
646 .iterate_pages(|skip, limit| self.get_vens(skip, limit, filter.clone()))
647 .await
648 }
649
650 pub async fn get_ven_by_id(&self, id: &VenId) -> Result<VenClient> {
652 let ven = self
653 .client_ref
654 .get(&format!("vens/{}", id.as_str()), &[])
655 .await?;
656 Ok(VenClient::from_ven(self.client_ref.clone(), ven))
657 }
658
659 pub async fn get_ven_by_name(&self, name: &str) -> Result<VenClient> {
662 let mut vens: Vec<Ven> = self.client_ref.get("vens", &[("venName", name)]).await?;
663 match vens[..] {
664 [] => Err(Error::ObjectNotFound),
665 [_] => Ok(VenClient::from_ven(self.client_ref.clone(), vens.remove(0))),
666 [..] => Err(Error::DuplicateObject),
667 }
668 }
669}