Skip to main content

openleadr_client/
lib.rs

1#![warn(missing_docs)]
2#![deny(rustdoc::broken_intra_doc_links)]
3#![deny(rustdoc::private_intra_doc_links)]
4
5//! # OpenADR 3.1 VEN client
6//!
7//! This is a client library to interact with an OpenADR 3.1 complaint VTN server.
8//! It mainly wraps the HTTP REST interface into an easy-to-use Rust API.
9//!
10//! Basic usage
11//! ```no_run
12//! # use openleadr_client::{Client, ClientCredentials, BusinessLogic};
13//! # use openleadr_wire::event::{EventInterval, EventType, EventValuesMap, Priority};
14//! # use openleadr_wire::program::ProgramRequest;
15//! # use openleadr_wire::values_map::Value;
16//! # tokio_test::block_on(async {
17//! let credentials =
18//!     ClientCredentials::new("client_id".to_string(), "client_secret".to_string());
19//! let client = Client::<BusinessLogic>::with_url(
20//!     "https://your-vtn.com".parse().unwrap(),
21//!     Some(credentials),
22//! );
23//! let new_program = ProgramRequest::new("example-program-name".to_string());
24//! let example_program = client.create_program(new_program).await.unwrap();
25//! let mut new_event = example_program.new_event(vec![EventInterval {
26//!     id: 0,
27//!     interval_period: None,
28//!     payloads: vec![EventValuesMap {
29//!         value_type: EventType::Price,
30//!         values: vec![Value::Number(1.23)],
31//!     }],
32//! }]);
33//! new_event.priority = Priority::new(10);
34//! new_event.event_name = Some("Some descriptive name".to_string());
35//! example_program.create_event(new_event).await.unwrap();
36//! # })
37//! ```
38//!
39//! If you want to use a separate OAuth provider that is not built into the VTN server,
40//! you can do so as well.
41//! ```no_run
42//! # use openleadr_client::{Client, ClientCredentials, BusinessLogic};
43//! # let credentials =
44//! #    ClientCredentials::new("client_id".to_string(), "client_secret".to_string());
45//! // optionally, you can build special configuration into the reqwest client here as well
46//! let reqwest_client = reqwest::Client::new();
47//! let client = Client::<BusinessLogic>::with_details(
48//!     "https://your-vtn.com".parse().unwrap(),
49//!     "https://your-oauth-provider.com".parse().unwrap(),
50//!     reqwest_client,
51//!     Some(credentials),
52//! );
53//! ```
54
55mod error;
56mod event;
57#[cfg(feature = "mdns")]
58mod mdns;
59mod program;
60mod report;
61mod resource;
62mod timeline;
63mod ven;
64
65use async_trait::async_trait;
66use openleadr_wire::{event::EventId, Event, Ven};
67use std::{
68    fmt::Debug,
69    future::Future,
70    marker::PhantomData,
71    sync::Arc,
72    time::{Duration, Instant},
73};
74use tokio::sync::RwLock;
75
76use reqwest::{Method, RequestBuilder, Response};
77use url::Url;
78
79pub use error::*;
80pub use event::*;
81#[cfg(feature = "mdns")]
82pub use mdns::*;
83pub use program::*;
84pub use report::*;
85pub use resource::*;
86pub use timeline::*;
87pub use ven::*;
88
89use crate::error::Result;
90use openleadr_wire::{
91    program::{ProgramId, ProgramRequest},
92    ven::{BlVenRequest, VenId, VenRequest, VenVenRequest},
93    Program,
94};
95
96#[async_trait]
97/// Abstracts the implementation used for actual requests.
98///
99/// This is used for testing purposes such that we don't need
100/// to run an actual server instance but instead directly call into the axum router
101pub trait HttpClient: Debug {
102    #[allow(missing_docs)]
103    fn request_builder(&self, method: Method, url: Url) -> RequestBuilder;
104    #[allow(missing_docs)]
105    async fn send(&self, req: RequestBuilder) -> reqwest::Result<Response>;
106}
107
108/// Client for managing top-level entities on a VTN, i.e., programs and VENs.
109///
110/// Can be used to implement both, the VEN and the business logic.
111///
112/// If using the VTN of this project with the built-in OAuth authentication provider,
113/// the [`Client`] also allows managing the users.
114#[derive(Debug, Clone)]
115pub struct Client<K> {
116    client_ref: Arc<ClientRef<K>>,
117}
118
119/// Credentials necessary for authentication at the VTN
120pub struct ClientCredentials {
121    #[allow(missing_docs)]
122    pub client_id: String,
123    client_secret: String,
124    /// Margin to refresh the authentication token with the client_id and client_secret before it expired
125    /// This is helpful to prevent an "unauthorized"
126    /// due to small differences in client/server times and network latency
127    ///
128    /// **Default:** 60 sec
129    pub refresh_margin: Duration,
130    /// Time the authorization token is typically valid for.
131    ///
132    /// **Default:** 3600 sec, i.e., one hour
133    pub default_credential_expires_in: Duration,
134}
135
136impl Debug for ClientCredentials {
137    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
138        f.debug_struct(std::any::type_name::<Self>())
139            .field("client_id", &self.client_id)
140            .field("refresh_margin", &self.refresh_margin)
141            .field(
142                "default_credential_expires_in",
143                &self.default_credential_expires_in,
144            )
145            .finish_non_exhaustive()
146    }
147}
148
149impl ClientCredentials {
150    /// Creates new [`ClientCredentials`] with default values for
151    /// [`refresh_margin`](ClientCredentials::refresh_margin) and
152    /// [`default_credential_expires_in`](ClientCredentials::default_credential_expires_in)
153    /// (60 and 3600 sec, respectively)
154    pub fn new(client_id: String, client_secret: String) -> Self {
155        Self {
156            client_id,
157            client_secret,
158            refresh_margin: Duration::from_secs(60),
159            default_credential_expires_in: Duration::from_secs(3600),
160        }
161    }
162}
163
164struct AuthToken {
165    token: String,
166    expires_in: Duration,
167    since: Instant,
168}
169
170impl Debug for AuthToken {
171    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
172        f.debug_struct(std::any::type_name::<Self>())
173            .field("expires_in", &self.expires_in)
174            .field("since", &self.since)
175            .finish_non_exhaustive()
176    }
177}
178
179/// This trait is used to mark the kind of [`Client`] to use,
180/// either a [`BusinessLogic`], or a [`VirtualEndNode`]
181pub trait ClientKind: Copy {}
182#[derive(Copy, Clone, Debug)]
183/// Used to mark the kind of [`Client`], see also [`ClientKind`]
184pub struct BusinessLogic;
185impl ClientKind for BusinessLogic {}
186#[derive(Copy, Clone, Debug)]
187/// Used to mark the kind of [`Client`], see also [`ClientKind`]
188pub struct VirtualEndNode;
189impl ClientKind for VirtualEndNode {}
190
191#[derive(Debug)]
192struct ClientRef<K> {
193    client: Box<dyn HttpClient + Send + Sync>,
194    vtn_base_url: Url,
195    oauth_base_url: Url,
196    default_page_size: usize,
197    auth_data: Option<ClientCredentials>,
198    auth_token: RwLock<Option<AuthToken>>,
199    phantom: PhantomData<K>,
200}
201
202impl<K: ClientKind> ClientRef<K> {
203    /// This ensures the client is authenticated.
204    ///
205    /// We follow the process according to RFC 6749, section 4.4 (client
206    /// credentials grant). The client id and secret are by default sent via
207    /// HTTP Basic Auth.
208    async fn ensure_auth(&self) -> Result<()> {
209        // if there is no auth data, we don't do any authentication
210        let Some(auth_data) = &self.auth_data else {
211            return Ok(());
212        };
213
214        // if there is a token, and it is valid long enough, we don't have to do anything
215        if let Some(token) = self.auth_token.read().await.as_ref() {
216            if token.since.elapsed() < token.expires_in - auth_data.refresh_margin {
217                return Ok(());
218            }
219        }
220
221        #[derive(serde::Serialize)]
222        struct AccessTokenRequest {
223            grant_type: &'static str,
224            #[serde(skip_serializing_if = "Option::is_none")]
225            scope: Option<String>,
226            #[serde(skip_serializing_if = "Option::is_none")]
227            client_id: Option<String>,
228            #[serde(skip_serializing_if = "Option::is_none")]
229            client_secret: Option<String>,
230        }
231
232        // we should authenticate
233        let request = self
234            .client
235            .request_builder(Method::POST, self.oauth_base_url.clone())
236            .form(&AccessTokenRequest {
237                grant_type: "client_credentials",
238                scope: None,
239                client_id: None,
240                client_secret: None,
241            });
242        let request = request.basic_auth(&auth_data.client_id, Some(&auth_data.client_secret));
243        let request = request.header("Accept", "application/json");
244        let since = Instant::now();
245        let res = self.client.send(request).await?;
246        if !res.status().is_success() {
247            let problem = res.json::<openleadr_wire::oauth::OAuthError>().await?;
248            return Err(Error::AuthProblem(problem));
249        }
250
251        #[derive(Debug, serde::Deserialize)]
252        struct AuthResult {
253            access_token: String,
254            token_type: String,
255            #[serde(default)]
256            expires_in: Option<u64>,
257            // Refresh tokens aren't supported currently
258            // #[serde(default)]
259            // refresh_token: Option<String>,
260            // #[serde(default)]
261            // scope: Option<String>,
262            // #[serde(flatten)]
263            // other: std::collections::HashMap<String, serde_json::Value>,
264        }
265
266        let auth_result = res.json::<AuthResult>().await?;
267        if auth_result.token_type.to_lowercase() != "bearer" {
268            return Err(Error::OAuthTokenNotBearer);
269        }
270        let token = AuthToken {
271            token: auth_result.access_token,
272            expires_in: auth_result
273                .expires_in
274                .map(Duration::from_secs)
275                .unwrap_or(auth_data.default_credential_expires_in),
276            since,
277        };
278
279        *self.auth_token.write().await = Some(token);
280        Ok(())
281    }
282
283    async fn request<T: serde::de::DeserializeOwned>(
284        &self,
285        mut request: RequestBuilder,
286        query: &[(&str, &str)],
287    ) -> Result<T> {
288        self.ensure_auth().await?;
289        request = request.header("Accept", "application/json");
290        if !query.is_empty() {
291            request = request.query(&query);
292        }
293
294        // read token and insert in request if available
295        {
296            let token = self.auth_token.read().await;
297            if let Some(token) = token.as_ref() {
298                request = request.bearer_auth(&token.token);
299            }
300        }
301        let res = self.client.send(request).await?;
302
303        // handle any errors returned by the server
304        if !res.status().is_success() {
305            let problem = res.json::<openleadr_wire::problem::Problem>().await?;
306            return Err(crate::error::Error::from(problem));
307        }
308
309        Ok(res.json().await?)
310    }
311
312    async fn get<T: serde::de::DeserializeOwned>(
313        &self,
314        path: &str,
315        query: &[(&str, &str)],
316    ) -> Result<T> {
317        let url = self.vtn_base_url.join(path)?;
318        let request = self.client.request_builder(Method::GET, url);
319        self.request(request, query).await
320    }
321
322    async fn post<S, T>(&self, path: &str, body: &S) -> Result<T>
323    where
324        S: serde::ser::Serialize + Sync,
325        T: serde::de::DeserializeOwned,
326    {
327        let url = self.vtn_base_url.join(path)?;
328        let request = self.client.request_builder(Method::POST, url).json(body);
329        self.request(request, &[]).await
330    }
331
332    async fn put<S, T>(&self, path: &str, body: &S) -> Result<T>
333    where
334        S: serde::ser::Serialize + Sync,
335        T: serde::de::DeserializeOwned,
336    {
337        let url = self.vtn_base_url.join(path)?;
338        let request = self.client.request_builder(Method::PUT, url).json(body);
339        self.request(request, &[]).await
340    }
341
342    async fn delete<T>(&self, path: &str) -> Result<T>
343    where
344        T: serde::de::DeserializeOwned,
345    {
346        let url = self.vtn_base_url.join(path)?;
347        let request = self.client.request_builder(Method::DELETE, url);
348        self.request(request, &[]).await
349    }
350
351    fn default_page_size(&self) -> usize {
352        self.default_page_size
353    }
354
355    async fn iterate_pages<T, Fut>(
356        &self,
357        single_page_req: impl Fn(usize, usize) -> Fut,
358    ) -> Result<Vec<T>>
359    where
360        Fut: Future<Output = Result<Vec<T>>>,
361    {
362        let page_size = self.default_page_size();
363        let mut items = vec![];
364        let mut page = 0;
365        // TODO: pagination should depend on that the server indicated there are more results
366        loop {
367            let received = single_page_req(page * page_size, page_size).await?;
368            let received_all = received.len() < page_size;
369            for item in received {
370                items.push(item);
371            }
372
373            if received_all {
374                break;
375            } else {
376                page += 1;
377            }
378        }
379
380        Ok(items)
381    }
382}
383
384#[derive(Debug)]
385struct ReqwestClientRef {
386    client: reqwest::Client,
387}
388
389#[async_trait]
390impl HttpClient for ReqwestClientRef {
391    fn request_builder(&self, method: Method, url: Url) -> RequestBuilder {
392        self.client.request(method, url)
393    }
394
395    async fn send(&self, req: RequestBuilder) -> std::result::Result<Response, reqwest::Error> {
396        req.send().await
397    }
398}
399
400/// Allows setting specific `skip` and `limit` values for list queries.
401///
402/// In most cases, you should not need this functionality
403/// but use the `_list` functions
404/// that automatically try to iterate though all pages to retrieve all entities
405pub struct PaginationOptions {
406    #[allow(missing_docs)]
407    pub skip: usize,
408    #[allow(missing_docs)]
409    pub limit: usize,
410}
411
412/// Filter based on Target as specified for various items.
413///
414/// **Please note:** This does only filter based on what is stored in the `target` field of an item
415/// (e.g., [`ProgramRequest::targets`]) and should not get interpreted by the server.
416///
417/// Unfortunately, the specification is not very clear about this behavior,
418/// so some servers might interpret it differently.
419/// There has been some discussion with the authors of the standard in
420/// <https://github.com/oadr3-org/openadr3-vtn-reference-implementation/issues/83> (sadly not public).
421#[derive(Debug, Clone)]
422pub enum Filter<'a, S: AsRef<str>> {
423    /// Do not apply any filtering
424    None,
425    /// Filter a list of [`Target`](openleadr_wire::target::Target).
426    ///
427    /// It will be encoded to the request as query parameters,
428    /// e.g., `/programs?targets=group-1&targets=group-2`.
429    By(&'a [S]),
430}
431
432impl<'a> Filter<'a, &'static str> {
433    /// Create a new filter that does not apply any filtering.
434    pub const fn none() -> Filter<'a, &'static str> {
435        Filter::None
436    }
437}
438
439impl<'a, S: AsRef<str>> Filter<'a, S> {
440    pub(crate) fn to_query_params(&'a self) -> Vec<(&'a str, &'a str)> {
441        let mut query = vec![];
442        if let Filter::By(target_values) = self {
443            for target_value in *target_values {
444                query.push(("targets", target_value.as_ref()));
445            }
446        }
447        query
448    }
449}
450
451impl Client<VirtualEndNode> {
452    /// Create a new VEN entity at the VTN.
453    pub async fn create_ven(&self, ven: VenVenRequest) -> Result<VenClient<VirtualEndNode>> {
454        let ven = self
455            .client_ref
456            .post("vens", &VenRequest::VenVenRequest(ven))
457            .await?;
458        Ok(VenClient::from_ven(self.client_ref.clone(), ven))
459    }
460}
461
462impl Client<BusinessLogic> {
463    /// Create a new VEN entity at the VTN.
464    pub async fn create_ven(&self, ven: BlVenRequest) -> Result<VenClient<BusinessLogic>> {
465        let ven = self
466            .client_ref
467            .post("vens", &VenRequest::BlVenRequest(ven))
468            .await?;
469        Ok(VenClient::from_ven(self.client_ref.clone(), ven))
470    }
471}
472
473impl<K: ClientKind> Client<K> {
474    /// Create a new client for a VTN located at the specified URL.
475    ///
476    /// This assumes that the VTN also works as an OAuth provider
477    /// and exposes an API endpoint at `<base_url>/auth/token` to retrieve a token.
478    /// If you want to use another OAuth provider, please use [`Self::with_details`].
479    pub fn with_url(base_url: Url, auth: Option<ClientCredentials>) -> Self {
480        let client = reqwest::Client::new();
481        let oauth_base_url = base_url.join("/auth/token").unwrap();
482        Self::with_details(base_url, oauth_base_url, client, auth)
483    }
484
485    /// Create a new client with more details than [`Self::with_url`].
486    ///
487    /// It allows specifying a [`reqwest::Client`] instead of the default one.
488    /// This allows configuring proxy settings, timeouts, etc.
489    ///
490    /// Additionally, it allows for a separate `oauth_base_url`,
491    /// which is relevant if you don't want or cannot rely on an OAuth provider at the default URL.
492    pub fn with_details(
493        vtn_base_url: Url,
494        oauth_base_url: Url,
495        client: reqwest::Client,
496        auth: Option<ClientCredentials>,
497    ) -> Self {
498        Self::with_http_client(
499            vtn_base_url,
500            oauth_base_url,
501            Box::new(ReqwestClientRef { client }),
502            auth,
503        )
504    }
505
506    /// Create a new client with anything that implements the [`HttpClient`] trait.
507    ///
508    /// This is mainly helpful for the integration tests
509    /// and should most likely not be used for other purposes.
510    /// Please use [`Client::with_details`] for detailed HTTP client configuration.
511    pub fn with_http_client(
512        vtn_base_url: Url,
513        oauth_base_url: Url,
514        client: Box<dyn HttpClient + Send + Sync>,
515        auth: Option<ClientCredentials>,
516    ) -> Self {
517        let client_ref = ClientRef {
518            client,
519            vtn_base_url,
520            oauth_base_url,
521            default_page_size: 50,
522            auth_data: auth,
523            auth_token: RwLock::new(None),
524            phantom: PhantomData::<K>,
525        };
526        Self::new(client_ref)
527    }
528
529    fn new(client_ref: ClientRef<K>) -> Self {
530        Client {
531            client_ref: Arc::new(client_ref),
532        }
533    }
534
535    /// Create a new program on the VTN.
536    pub async fn create_program(
537        &self,
538        program_content: ProgramRequest,
539    ) -> Result<ProgramClient<K>> {
540        let program = self.client_ref.post("programs", &program_content).await?;
541        Ok(ProgramClient::from_program(self.clone(), program))
542    }
543
544    /// Lowlevel operation that gets a list of programs from the VTN with the given query parameters
545    pub async fn get_programs(
546        &self,
547        filter: Filter<'_, impl AsRef<str>>,
548        pagination: PaginationOptions,
549    ) -> Result<Vec<ProgramClient<K>>> {
550        // convert query params
551        let skip_str = pagination.skip.to_string();
552        let limit_str = pagination.limit.to_string();
553        // insert into query params
554        let mut query: Vec<(&str, &str)> = vec![("skip", &skip_str), ("limit", &limit_str)];
555
556        query.extend_from_slice(filter.to_query_params().as_slice());
557
558        // send request and return response
559        let programs: Vec<Program> = self.client_ref.get("programs", &query).await?;
560        Ok(programs
561            .into_iter()
562            .map(|program| ProgramClient::from_program(self.clone(), program))
563            .collect())
564    }
565
566    /// Get all programs from the VTN with the given query parameters
567    ///
568    /// It automatically tries to iterate pages where necessary.
569    pub async fn get_program_list(
570        &self,
571        filter: Filter<'_, impl AsRef<str> + Clone>,
572    ) -> Result<Vec<ProgramClient<K>>> {
573        self.client_ref
574            .iterate_pages(|skip, limit| {
575                self.get_programs(filter.clone(), PaginationOptions { skip, limit })
576            })
577            .await
578    }
579
580    /// Get a program by id
581    pub async fn get_program_by_id(&self, id: &ProgramId) -> Result<ProgramClient<K>> {
582        let program = self
583            .client_ref
584            .get(&format!("programs/{}", id.as_str()), &[])
585            .await?;
586
587        Ok(ProgramClient::from_program(self.clone(), program))
588    }
589
590    /// Low-level operation that gets a list of events from the VTN with the given query parameters
591    ///
592    /// To automatically iterate pages, use [`self.get_event_list`](Self::get_event_list)
593    pub async fn get_events(
594        &self,
595        program_id: Option<&ProgramId>,
596        filter: Filter<'_, impl AsRef<str>>,
597        pagination: PaginationOptions,
598    ) -> Result<Vec<EventClient<K>>> {
599        // convert query params
600        let skip_str = pagination.skip.to_string();
601        let limit_str = pagination.limit.to_string();
602        // insert into query params
603        let mut query: Vec<(&str, &str)> = vec![("skip", &skip_str), ("limit", &limit_str)];
604
605        query.extend_from_slice(filter.to_query_params().as_slice());
606
607        if let Some(program_id) = program_id {
608            query.push(("programID", program_id.as_str()));
609        }
610
611        // send request and return response
612        let events: Vec<Event> = self.client_ref.get("events", &query).await?;
613        Ok(events
614            .into_iter()
615            .map(|event| EventClient::from_event(self.client_ref.clone(), event))
616            .collect())
617    }
618
619    /// Get all events from the VTN with the given query parameters.
620    ///
621    /// It automatically tries to iterate pages where necessary.
622    pub async fn get_event_list(
623        &self,
624        program_id: Option<&ProgramId>,
625        filter: Filter<'_, impl AsRef<str> + Clone>,
626    ) -> Result<Vec<EventClient<K>>> {
627        self.client_ref
628            .iterate_pages(|skip, limit| {
629                self.get_events(
630                    program_id,
631                    filter.clone(),
632                    PaginationOptions { skip, limit },
633                )
634            })
635            .await
636    }
637
638    /// Get an event by id
639    pub async fn get_event_by_id(&self, id: &EventId) -> Result<EventClient<K>> {
640        let event = self
641            .client_ref
642            .get(&format!("events/{}", id.as_str()), &[])
643            .await?;
644
645        Ok(EventClient::from_event(self.client_ref.clone(), event))
646    }
647
648    async fn get_vens(
649        &self,
650        skip: usize,
651        limit: usize,
652        filter: Filter<'_, impl AsRef<str>>,
653    ) -> Result<Vec<VenClient<K>>> {
654        let skip_str = skip.to_string();
655        let limit_str = limit.to_string();
656        let mut query: Vec<(&str, &str)> = vec![("skip", &skip_str), ("limit", &limit_str)];
657
658        query.extend_from_slice(filter.to_query_params().as_slice());
659
660        // send request and return response
661        let vens: Vec<Ven> = self.client_ref.get("vens", &query).await?;
662        Ok(vens
663            .into_iter()
664            .map(|ven| VenClient::from_ven(self.client_ref.clone(), ven))
665            .collect())
666    }
667
668    /// Get all VENs from the VTN with the given query parameters.
669    ///
670    /// The client automatically tries to iterate pages where necessary.
671    pub async fn get_ven_list(
672        &self,
673        filter: Filter<'_, impl AsRef<str> + Clone>,
674    ) -> Result<Vec<VenClient<K>>> {
675        self.client_ref
676            .iterate_pages(|skip, limit| self.get_vens(skip, limit, filter.clone()))
677            .await
678    }
679
680    /// Get VEN by id from VTN
681    pub async fn get_ven_by_id(&self, id: &VenId) -> Result<VenClient<K>> {
682        let ven = self
683            .client_ref
684            .get(&format!("vens/{}", id.as_str()), &[])
685            .await?;
686        Ok(VenClient::from_ven(self.client_ref.clone(), ven))
687    }
688
689    /// Get VEN by name from VTN.
690    /// According to the spec, a [`ven_name`](BlVenRequest::ven_name) must be unique for the whole VTN instance.
691    pub async fn get_ven_by_name(&self, name: &str) -> Result<VenClient<K>> {
692        let mut vens: Vec<Ven> = self.client_ref.get("vens", &[("venName", name)]).await?;
693        match vens[..] {
694            [] => Err(Error::ObjectNotFound),
695            [_] => Ok(VenClient::from_ven(self.client_ref.clone(), vens.remove(0))),
696            [..] => Err(Error::DuplicateObject),
697        }
698    }
699}