1#![warn(missing_docs)]
2#![deny(rustdoc::broken_intra_doc_links)]
3#![deny(rustdoc::private_intra_doc_links)]
4
5mod error;
56mod event;
57#[cfg(feature = "mdns")]
58mod mdns;
59mod program;
60mod report;
61mod resource;
62mod timeline;
63mod ven;
64
65use async_trait::async_trait;
66use openleadr_wire::{event::EventId, Event, Ven};
67use std::{
68 fmt::Debug,
69 future::Future,
70 marker::PhantomData,
71 sync::Arc,
72 time::{Duration, Instant},
73};
74use tokio::sync::RwLock;
75
76use reqwest::{Method, RequestBuilder, Response};
77use url::Url;
78
79pub use error::*;
80pub use event::*;
81#[cfg(feature = "mdns")]
82pub use mdns::*;
83pub use program::*;
84pub use report::*;
85pub use resource::*;
86pub use timeline::*;
87pub use ven::*;
88
89use crate::error::Result;
90use openleadr_wire::{
91 program::{ProgramId, ProgramRequest},
92 ven::{BlVenRequest, VenId, VenRequest, VenVenRequest},
93 Program,
94};
95
96#[async_trait]
97pub trait HttpClient: Debug {
102 #[allow(missing_docs)]
103 fn request_builder(&self, method: Method, url: Url) -> RequestBuilder;
104 #[allow(missing_docs)]
105 async fn send(&self, req: RequestBuilder) -> reqwest::Result<Response>;
106}
107
108#[derive(Debug, Clone)]
115pub struct Client<K> {
116 client_ref: Arc<ClientRef<K>>,
117}
118
119pub struct ClientCredentials {
121 #[allow(missing_docs)]
122 pub client_id: String,
123 client_secret: String,
124 pub refresh_margin: Duration,
130 pub default_credential_expires_in: Duration,
134}
135
136impl Debug for ClientCredentials {
137 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
138 f.debug_struct(std::any::type_name::<Self>())
139 .field("client_id", &self.client_id)
140 .field("refresh_margin", &self.refresh_margin)
141 .field(
142 "default_credential_expires_in",
143 &self.default_credential_expires_in,
144 )
145 .finish_non_exhaustive()
146 }
147}
148
149impl ClientCredentials {
150 pub fn new(client_id: String, client_secret: String) -> Self {
155 Self {
156 client_id,
157 client_secret,
158 refresh_margin: Duration::from_secs(60),
159 default_credential_expires_in: Duration::from_secs(3600),
160 }
161 }
162}
163
164struct AuthToken {
165 token: String,
166 expires_in: Duration,
167 since: Instant,
168}
169
170impl Debug for AuthToken {
171 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
172 f.debug_struct(std::any::type_name::<Self>())
173 .field("expires_in", &self.expires_in)
174 .field("since", &self.since)
175 .finish_non_exhaustive()
176 }
177}
178
179pub trait ClientKind: Copy {}
182#[derive(Copy, Clone, Debug)]
183pub struct BusinessLogic;
185impl ClientKind for BusinessLogic {}
186#[derive(Copy, Clone, Debug)]
187pub struct VirtualEndNode;
189impl ClientKind for VirtualEndNode {}
190
191#[derive(Debug)]
192struct ClientRef<K> {
193 client: Box<dyn HttpClient + Send + Sync>,
194 vtn_base_url: Url,
195 oauth_base_url: Url,
196 default_page_size: usize,
197 auth_data: Option<ClientCredentials>,
198 auth_token: RwLock<Option<AuthToken>>,
199 phantom: PhantomData<K>,
200}
201
202impl<K: ClientKind> ClientRef<K> {
203 async fn ensure_auth(&self) -> Result<()> {
209 let Some(auth_data) = &self.auth_data else {
211 return Ok(());
212 };
213
214 if let Some(token) = self.auth_token.read().await.as_ref() {
216 if token.since.elapsed() < token.expires_in - auth_data.refresh_margin {
217 return Ok(());
218 }
219 }
220
221 #[derive(serde::Serialize)]
222 struct AccessTokenRequest {
223 grant_type: &'static str,
224 #[serde(skip_serializing_if = "Option::is_none")]
225 scope: Option<String>,
226 #[serde(skip_serializing_if = "Option::is_none")]
227 client_id: Option<String>,
228 #[serde(skip_serializing_if = "Option::is_none")]
229 client_secret: Option<String>,
230 }
231
232 let request = self
234 .client
235 .request_builder(Method::POST, self.oauth_base_url.clone())
236 .form(&AccessTokenRequest {
237 grant_type: "client_credentials",
238 scope: None,
239 client_id: None,
240 client_secret: None,
241 });
242 let request = request.basic_auth(&auth_data.client_id, Some(&auth_data.client_secret));
243 let request = request.header("Accept", "application/json");
244 let since = Instant::now();
245 let res = self.client.send(request).await?;
246 if !res.status().is_success() {
247 let problem = res.json::<openleadr_wire::oauth::OAuthError>().await?;
248 return Err(Error::AuthProblem(problem));
249 }
250
251 #[derive(Debug, serde::Deserialize)]
252 struct AuthResult {
253 access_token: String,
254 token_type: String,
255 #[serde(default)]
256 expires_in: Option<u64>,
257 }
265
266 let auth_result = res.json::<AuthResult>().await?;
267 if auth_result.token_type.to_lowercase() != "bearer" {
268 return Err(Error::OAuthTokenNotBearer);
269 }
270 let token = AuthToken {
271 token: auth_result.access_token,
272 expires_in: auth_result
273 .expires_in
274 .map(Duration::from_secs)
275 .unwrap_or(auth_data.default_credential_expires_in),
276 since,
277 };
278
279 *self.auth_token.write().await = Some(token);
280 Ok(())
281 }
282
283 async fn request<T: serde::de::DeserializeOwned>(
284 &self,
285 mut request: RequestBuilder,
286 query: &[(&str, &str)],
287 ) -> Result<T> {
288 self.ensure_auth().await?;
289 request = request.header("Accept", "application/json");
290 if !query.is_empty() {
291 request = request.query(&query);
292 }
293
294 {
296 let token = self.auth_token.read().await;
297 if let Some(token) = token.as_ref() {
298 request = request.bearer_auth(&token.token);
299 }
300 }
301 let res = self.client.send(request).await?;
302
303 if !res.status().is_success() {
305 let problem = res.json::<openleadr_wire::problem::Problem>().await?;
306 return Err(crate::error::Error::from(problem));
307 }
308
309 Ok(res.json().await?)
310 }
311
312 async fn get<T: serde::de::DeserializeOwned>(
313 &self,
314 path: &str,
315 query: &[(&str, &str)],
316 ) -> Result<T> {
317 let url = self.vtn_base_url.join(path)?;
318 let request = self.client.request_builder(Method::GET, url);
319 self.request(request, query).await
320 }
321
322 async fn post<S, T>(&self, path: &str, body: &S) -> Result<T>
323 where
324 S: serde::ser::Serialize + Sync,
325 T: serde::de::DeserializeOwned,
326 {
327 let url = self.vtn_base_url.join(path)?;
328 let request = self.client.request_builder(Method::POST, url).json(body);
329 self.request(request, &[]).await
330 }
331
332 async fn put<S, T>(&self, path: &str, body: &S) -> Result<T>
333 where
334 S: serde::ser::Serialize + Sync,
335 T: serde::de::DeserializeOwned,
336 {
337 let url = self.vtn_base_url.join(path)?;
338 let request = self.client.request_builder(Method::PUT, url).json(body);
339 self.request(request, &[]).await
340 }
341
342 async fn delete<T>(&self, path: &str) -> Result<T>
343 where
344 T: serde::de::DeserializeOwned,
345 {
346 let url = self.vtn_base_url.join(path)?;
347 let request = self.client.request_builder(Method::DELETE, url);
348 self.request(request, &[]).await
349 }
350
351 fn default_page_size(&self) -> usize {
352 self.default_page_size
353 }
354
355 async fn iterate_pages<T, Fut>(
356 &self,
357 single_page_req: impl Fn(usize, usize) -> Fut,
358 ) -> Result<Vec<T>>
359 where
360 Fut: Future<Output = Result<Vec<T>>>,
361 {
362 let page_size = self.default_page_size();
363 let mut items = vec![];
364 let mut page = 0;
365 loop {
367 let received = single_page_req(page * page_size, page_size).await?;
368 let received_all = received.len() < page_size;
369 for item in received {
370 items.push(item);
371 }
372
373 if received_all {
374 break;
375 } else {
376 page += 1;
377 }
378 }
379
380 Ok(items)
381 }
382}
383
384#[derive(Debug)]
385struct ReqwestClientRef {
386 client: reqwest::Client,
387}
388
389#[async_trait]
390impl HttpClient for ReqwestClientRef {
391 fn request_builder(&self, method: Method, url: Url) -> RequestBuilder {
392 self.client.request(method, url)
393 }
394
395 async fn send(&self, req: RequestBuilder) -> std::result::Result<Response, reqwest::Error> {
396 req.send().await
397 }
398}
399
400pub struct PaginationOptions {
406 #[allow(missing_docs)]
407 pub skip: usize,
408 #[allow(missing_docs)]
409 pub limit: usize,
410}
411
412#[derive(Debug, Clone)]
422pub enum Filter<'a, S: AsRef<str>> {
423 None,
425 By(&'a [S]),
430}
431
432impl<'a> Filter<'a, &'static str> {
433 pub const fn none() -> Filter<'a, &'static str> {
435 Filter::None
436 }
437}
438
439impl<'a, S: AsRef<str>> Filter<'a, S> {
440 pub(crate) fn to_query_params(&'a self) -> Vec<(&'a str, &'a str)> {
441 let mut query = vec![];
442 if let Filter::By(target_values) = self {
443 for target_value in *target_values {
444 query.push(("targets", target_value.as_ref()));
445 }
446 }
447 query
448 }
449}
450
451impl Client<VirtualEndNode> {
452 pub async fn create_ven(&self, ven: VenVenRequest) -> Result<VenClient<VirtualEndNode>> {
454 let ven = self
455 .client_ref
456 .post("vens", &VenRequest::VenVenRequest(ven))
457 .await?;
458 Ok(VenClient::from_ven(self.client_ref.clone(), ven))
459 }
460}
461
462impl Client<BusinessLogic> {
463 pub async fn create_ven(&self, ven: BlVenRequest) -> Result<VenClient<BusinessLogic>> {
465 let ven = self
466 .client_ref
467 .post("vens", &VenRequest::BlVenRequest(ven))
468 .await?;
469 Ok(VenClient::from_ven(self.client_ref.clone(), ven))
470 }
471}
472
473impl<K: ClientKind> Client<K> {
474 pub fn with_url(base_url: Url, auth: Option<ClientCredentials>) -> Self {
480 let client = reqwest::Client::new();
481 let oauth_base_url = base_url.join("/auth/token").unwrap();
482 Self::with_details(base_url, oauth_base_url, client, auth)
483 }
484
485 pub fn with_details(
493 vtn_base_url: Url,
494 oauth_base_url: Url,
495 client: reqwest::Client,
496 auth: Option<ClientCredentials>,
497 ) -> Self {
498 Self::with_http_client(
499 vtn_base_url,
500 oauth_base_url,
501 Box::new(ReqwestClientRef { client }),
502 auth,
503 )
504 }
505
506 pub fn with_http_client(
512 vtn_base_url: Url,
513 oauth_base_url: Url,
514 client: Box<dyn HttpClient + Send + Sync>,
515 auth: Option<ClientCredentials>,
516 ) -> Self {
517 let client_ref = ClientRef {
518 client,
519 vtn_base_url,
520 oauth_base_url,
521 default_page_size: 50,
522 auth_data: auth,
523 auth_token: RwLock::new(None),
524 phantom: PhantomData::<K>,
525 };
526 Self::new(client_ref)
527 }
528
529 fn new(client_ref: ClientRef<K>) -> Self {
530 Client {
531 client_ref: Arc::new(client_ref),
532 }
533 }
534
535 pub async fn create_program(
537 &self,
538 program_content: ProgramRequest,
539 ) -> Result<ProgramClient<K>> {
540 let program = self.client_ref.post("programs", &program_content).await?;
541 Ok(ProgramClient::from_program(self.clone(), program))
542 }
543
544 pub async fn get_programs(
546 &self,
547 filter: Filter<'_, impl AsRef<str>>,
548 pagination: PaginationOptions,
549 ) -> Result<Vec<ProgramClient<K>>> {
550 let skip_str = pagination.skip.to_string();
552 let limit_str = pagination.limit.to_string();
553 let mut query: Vec<(&str, &str)> = vec![("skip", &skip_str), ("limit", &limit_str)];
555
556 query.extend_from_slice(filter.to_query_params().as_slice());
557
558 let programs: Vec<Program> = self.client_ref.get("programs", &query).await?;
560 Ok(programs
561 .into_iter()
562 .map(|program| ProgramClient::from_program(self.clone(), program))
563 .collect())
564 }
565
566 pub async fn get_program_list(
570 &self,
571 filter: Filter<'_, impl AsRef<str> + Clone>,
572 ) -> Result<Vec<ProgramClient<K>>> {
573 self.client_ref
574 .iterate_pages(|skip, limit| {
575 self.get_programs(filter.clone(), PaginationOptions { skip, limit })
576 })
577 .await
578 }
579
580 pub async fn get_program_by_id(&self, id: &ProgramId) -> Result<ProgramClient<K>> {
582 let program = self
583 .client_ref
584 .get(&format!("programs/{}", id.as_str()), &[])
585 .await?;
586
587 Ok(ProgramClient::from_program(self.clone(), program))
588 }
589
590 pub async fn get_events(
594 &self,
595 program_id: Option<&ProgramId>,
596 filter: Filter<'_, impl AsRef<str>>,
597 pagination: PaginationOptions,
598 ) -> Result<Vec<EventClient<K>>> {
599 let skip_str = pagination.skip.to_string();
601 let limit_str = pagination.limit.to_string();
602 let mut query: Vec<(&str, &str)> = vec![("skip", &skip_str), ("limit", &limit_str)];
604
605 query.extend_from_slice(filter.to_query_params().as_slice());
606
607 if let Some(program_id) = program_id {
608 query.push(("programID", program_id.as_str()));
609 }
610
611 let events: Vec<Event> = self.client_ref.get("events", &query).await?;
613 Ok(events
614 .into_iter()
615 .map(|event| EventClient::from_event(self.client_ref.clone(), event))
616 .collect())
617 }
618
619 pub async fn get_event_list(
623 &self,
624 program_id: Option<&ProgramId>,
625 filter: Filter<'_, impl AsRef<str> + Clone>,
626 ) -> Result<Vec<EventClient<K>>> {
627 self.client_ref
628 .iterate_pages(|skip, limit| {
629 self.get_events(
630 program_id,
631 filter.clone(),
632 PaginationOptions { skip, limit },
633 )
634 })
635 .await
636 }
637
638 pub async fn get_event_by_id(&self, id: &EventId) -> Result<EventClient<K>> {
640 let event = self
641 .client_ref
642 .get(&format!("events/{}", id.as_str()), &[])
643 .await?;
644
645 Ok(EventClient::from_event(self.client_ref.clone(), event))
646 }
647
648 async fn get_vens(
649 &self,
650 skip: usize,
651 limit: usize,
652 filter: Filter<'_, impl AsRef<str>>,
653 ) -> Result<Vec<VenClient<K>>> {
654 let skip_str = skip.to_string();
655 let limit_str = limit.to_string();
656 let mut query: Vec<(&str, &str)> = vec![("skip", &skip_str), ("limit", &limit_str)];
657
658 query.extend_from_slice(filter.to_query_params().as_slice());
659
660 let vens: Vec<Ven> = self.client_ref.get("vens", &query).await?;
662 Ok(vens
663 .into_iter()
664 .map(|ven| VenClient::from_ven(self.client_ref.clone(), ven))
665 .collect())
666 }
667
668 pub async fn get_ven_list(
672 &self,
673 filter: Filter<'_, impl AsRef<str> + Clone>,
674 ) -> Result<Vec<VenClient<K>>> {
675 self.client_ref
676 .iterate_pages(|skip, limit| self.get_vens(skip, limit, filter.clone()))
677 .await
678 }
679
680 pub async fn get_ven_by_id(&self, id: &VenId) -> Result<VenClient<K>> {
682 let ven = self
683 .client_ref
684 .get(&format!("vens/{}", id.as_str()), &[])
685 .await?;
686 Ok(VenClient::from_ven(self.client_ref.clone(), ven))
687 }
688
689 pub async fn get_ven_by_name(&self, name: &str) -> Result<VenClient<K>> {
692 let mut vens: Vec<Ven> = self.client_ref.get("vens", &[("venName", name)]).await?;
693 match vens[..] {
694 [] => Err(Error::ObjectNotFound),
695 [_] => Ok(VenClient::from_ven(self.client_ref.clone(), vens.remove(0))),
696 [..] => Err(Error::DuplicateObject),
697 }
698 }
699}