openleadr_client/
lib.rs

1#![warn(missing_docs)]
2
3//! # OpenADR 3.0 VEN client
4//!
5//! This is a client library to interact with an OpenADR 3.0 complaint VTN server.
6//! It mainly wraps the HTTP REST interface into an easy-to-use Rust API.
7//!
8//! Basic usage
9//! ```no_run
10//! # use openleadr_client::{Client, ClientCredentials};
11//! # use openleadr_wire::event::{EventInterval, EventType, EventValuesMap, Priority};
12//! # use openleadr_wire::program::ProgramContent;
13//! # use openleadr_wire::values_map::Value;
14//! # tokio_test::block_on(async {
15//! let credentials =
16//!     ClientCredentials::new("client_id".to_string(), "client_secret".to_string());
17//! let client = Client::with_url(
18//!     "https://your-vtn.com".parse().unwrap(),
19//!     Some(credentials),
20//! );
21//! let new_program = ProgramContent::new("example-program-name".to_string());
22//! let example_program = client.create_program(new_program).await.unwrap();
23//! let mut new_event = example_program.new_event(vec![EventInterval {
24//!     id: 0,
25//!     interval_period: None,
26//!     payloads: vec![EventValuesMap {
27//!         value_type: EventType::Price,
28//!         values: vec![Value::Number(1.23)],
29//!     }],
30//! }]);
31//! new_event.priority = Priority::new(10);
32//! new_event.event_name = Some("Some descriptive name".to_string());
33//! example_program.create_event(new_event).await.unwrap();
34//! # })
35//! ```
36//!
37//! If you want to use a separate OAuth provider that is not built into the VTN server,
38//! you can do so as well.
39//! ```no_run
40//! # use openleadr_client::{Client, ClientCredentials};
41//! # let credentials =
42//! #    ClientCredentials::new("client_id".to_string(), "client_secret".to_string());
43//! // optionally, you can build special configuration into the reqwest client here as well
44//! let reqwest_client = reqwest::Client::new();
45//! let client = Client::with_details(
46//!     "https://your-vtn.com".parse().unwrap(),
47//!     "https://your-oauth-provider.com".parse().unwrap(),
48//!     reqwest_client,
49//!     Some(credentials),
50//! );
51//! ```
52
53mod error;
54mod event;
55mod program;
56mod report;
57mod resource;
58mod target;
59mod timeline;
60mod ven;
61
62use axum::async_trait;
63use openleadr_wire::{event::EventId, Event, Ven};
64use std::{
65    fmt::Debug,
66    future::Future,
67    sync::Arc,
68    time::{Duration, Instant},
69};
70use tokio::sync::RwLock;
71
72use reqwest::{Method, RequestBuilder, Response};
73use url::Url;
74
75pub use error::*;
76pub use event::*;
77pub use program::*;
78pub use report::*;
79pub use resource::*;
80pub use target::*;
81pub use timeline::*;
82pub use ven::*;
83
84use crate::error::Result;
85use openleadr_wire::ven::{VenContent, VenId};
86pub(crate) use openleadr_wire::{
87    event::EventContent,
88    program::{ProgramContent, ProgramId},
89    target::TargetType,
90    Program,
91};
92
93#[async_trait]
94/// Abstracts the implementation used for actual requests.
95///
96/// This is used for testing purposes such that we don't need
97/// to run an actual server instance but instead directly call into the axum router
98pub trait HttpClient: Debug {
99    #[allow(missing_docs)]
100    fn request_builder(&self, method: Method, url: Url) -> RequestBuilder;
101    #[allow(missing_docs)]
102    async fn send(&self, req: RequestBuilder) -> reqwest::Result<Response>;
103}
104
105/// Client for managing top-level entities on a VTN, i.e., programs and VENs.
106///
107/// Can be used to implement both, the VEN and the business logic.
108///
109/// If using the VTN of this project with the built-in OAuth authentication provider,
110/// the [`Client`] also allows managing the users.
111#[derive(Debug, Clone)]
112pub struct Client {
113    client_ref: Arc<ClientRef>,
114}
115
116/// Credentials necessary for authentication at the VTN
117pub struct ClientCredentials {
118    #[allow(missing_docs)]
119    pub client_id: String,
120    client_secret: String,
121    /// Margin to refresh the authentication token with the client_id and client_secret before it expired
122    /// This is helpful to prevent an "unauthorized"
123    /// due to small differences in client/server times and network latency
124    ///
125    /// **Default:** 60 sec
126    pub refresh_margin: Duration,
127    /// Time the authorization token is typically valid for.
128    ///
129    /// **Default:** 3600 sec, i.e., one hour
130    pub default_credential_expires_in: Duration,
131}
132
133impl Debug for ClientCredentials {
134    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
135        f.debug_struct(std::any::type_name::<Self>())
136            .field("client_id", &self.client_id)
137            .field("refresh_margin", &self.refresh_margin)
138            .field(
139                "default_credential_expires_in",
140                &self.default_credential_expires_in,
141            )
142            .finish_non_exhaustive()
143    }
144}
145
146impl ClientCredentials {
147    /// Creates new [`ClientCredentials`] with default values for
148    /// [`refresh_margin`](ClientCredentials::refresh_margin) and
149    /// [`default_credential_expires_in`](ClientCredentials::default_credential_expires_in)
150    /// (60 and 3600 sec, respectively)
151    pub fn new(client_id: String, client_secret: String) -> Self {
152        Self {
153            client_id,
154            client_secret,
155            refresh_margin: Duration::from_secs(60),
156            default_credential_expires_in: Duration::from_secs(3600),
157        }
158    }
159}
160
161struct AuthToken {
162    token: String,
163    expires_in: Duration,
164    since: Instant,
165}
166
167impl Debug for AuthToken {
168    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
169        f.debug_struct(std::any::type_name::<Self>())
170            .field("expires_in", &self.expires_in)
171            .field("since", &self.since)
172            .finish_non_exhaustive()
173    }
174}
175
176#[derive(Debug)]
177struct ClientRef {
178    client: Box<dyn HttpClient + Send + Sync>,
179    vtn_base_url: Url,
180    oauth_base_url: Url,
181    default_page_size: usize,
182    auth_data: Option<ClientCredentials>,
183    auth_token: RwLock<Option<AuthToken>>,
184}
185
186impl ClientRef {
187    /// This ensures the client is authenticated.
188    ///
189    /// We follow the process according to RFC 6749, section 4.4 (client
190    /// credentials grant). The client id and secret are by default sent via
191    /// HTTP Basic Auth.
192    async fn ensure_auth(&self) -> Result<()> {
193        // if there is no auth data, we don't do any authentication
194        let Some(auth_data) = &self.auth_data else {
195            return Ok(());
196        };
197
198        // if there is a token, and it is valid long enough, we don't have to do anything
199        if let Some(token) = self.auth_token.read().await.as_ref() {
200            if token.since.elapsed() < token.expires_in - auth_data.refresh_margin {
201                return Ok(());
202            }
203        }
204
205        #[derive(serde::Serialize)]
206        struct AccessTokenRequest {
207            grant_type: &'static str,
208            #[serde(skip_serializing_if = "Option::is_none")]
209            scope: Option<String>,
210            #[serde(skip_serializing_if = "Option::is_none")]
211            client_id: Option<String>,
212            #[serde(skip_serializing_if = "Option::is_none")]
213            client_secret: Option<String>,
214        }
215
216        // we should authenticate
217        let request = self
218            .client
219            .request_builder(Method::POST, self.oauth_base_url.clone())
220            .form(&AccessTokenRequest {
221                grant_type: "client_credentials",
222                scope: None,
223                client_id: None,
224                client_secret: None,
225            });
226        let request = request.basic_auth(&auth_data.client_id, Some(&auth_data.client_secret));
227        let request = request.header("Accept", "application/json");
228        let since = Instant::now();
229        let res = self.client.send(request).await?;
230        if !res.status().is_success() {
231            let problem = res.json::<openleadr_wire::oauth::OAuthError>().await?;
232            return Err(Error::AuthProblem(problem));
233        }
234
235        #[derive(Debug, serde::Deserialize)]
236        struct AuthResult {
237            access_token: String,
238            token_type: String,
239            #[serde(default)]
240            expires_in: Option<u64>,
241            // Refresh tokens aren't supported currently
242            // #[serde(default)]
243            // refresh_token: Option<String>,
244            // #[serde(default)]
245            // scope: Option<String>,
246            // #[serde(flatten)]
247            // other: std::collections::HashMap<String, serde_json::Value>,
248        }
249
250        let auth_result = res.json::<AuthResult>().await?;
251        if auth_result.token_type.to_lowercase() != "bearer" {
252            return Err(Error::OAuthTokenNotBearer);
253        }
254        let token = AuthToken {
255            token: auth_result.access_token,
256            expires_in: auth_result
257                .expires_in
258                .map(Duration::from_secs)
259                .unwrap_or(auth_data.default_credential_expires_in),
260            since,
261        };
262
263        *self.auth_token.write().await = Some(token);
264        Ok(())
265    }
266
267    async fn request<T: serde::de::DeserializeOwned>(
268        &self,
269        mut request: RequestBuilder,
270        query: &[(&str, &str)],
271    ) -> Result<T> {
272        self.ensure_auth().await?;
273        request = request.header("Accept", "application/json");
274        if !query.is_empty() {
275            request = request.query(&query);
276        }
277
278        // read token and insert in request if available
279        {
280            let token = self.auth_token.read().await;
281            if let Some(token) = token.as_ref() {
282                request = request.bearer_auth(&token.token);
283            }
284        }
285        let res = self.client.send(request).await?;
286
287        // handle any errors returned by the server
288        if !res.status().is_success() {
289            let problem = res.json::<openleadr_wire::problem::Problem>().await?;
290            return Err(crate::error::Error::from(problem));
291        }
292
293        Ok(res.json().await?)
294    }
295
296    async fn get<T: serde::de::DeserializeOwned>(
297        &self,
298        path: &str,
299        query: &[(&str, &str)],
300    ) -> Result<T> {
301        let url = self.vtn_base_url.join(path)?;
302        let request = self.client.request_builder(Method::GET, url);
303        self.request(request, query).await
304    }
305
306    async fn post<S, T>(&self, path: &str, body: &S) -> Result<T>
307    where
308        S: serde::ser::Serialize + Sync,
309        T: serde::de::DeserializeOwned,
310    {
311        let url = self.vtn_base_url.join(path)?;
312        let request = self.client.request_builder(Method::POST, url).json(body);
313        self.request(request, &[]).await
314    }
315
316    async fn put<S, T>(&self, path: &str, body: &S) -> Result<T>
317    where
318        S: serde::ser::Serialize + Sync,
319        T: serde::de::DeserializeOwned,
320    {
321        let url = self.vtn_base_url.join(path)?;
322        let request = self.client.request_builder(Method::PUT, url).json(body);
323        self.request(request, &[]).await
324    }
325
326    async fn delete<T>(&self, path: &str) -> Result<T>
327    where
328        T: serde::de::DeserializeOwned,
329    {
330        let url = self.vtn_base_url.join(path)?;
331        let request = self.client.request_builder(Method::DELETE, url);
332        self.request(request, &[]).await
333    }
334
335    fn default_page_size(&self) -> usize {
336        self.default_page_size
337    }
338
339    async fn iterate_pages<T, Fut>(
340        &self,
341        single_page_req: impl Fn(usize, usize) -> Fut,
342    ) -> Result<Vec<T>>
343    where
344        Fut: Future<Output = Result<Vec<T>>>,
345    {
346        let page_size = self.default_page_size();
347        let mut items = vec![];
348        let mut page = 0;
349        // TODO: pagination should depend on that the server indicated there are more results
350        loop {
351            let received = single_page_req(page * page_size, page_size).await?;
352            let received_all = received.len() < page_size;
353            for item in received {
354                items.push(item);
355            }
356
357            if received_all {
358                break;
359            } else {
360                page += 1;
361            }
362        }
363
364        Ok(items)
365    }
366}
367
368#[derive(Debug)]
369struct ReqwestClientRef {
370    client: reqwest::Client,
371}
372
373#[async_trait]
374impl HttpClient for ReqwestClientRef {
375    fn request_builder(&self, method: Method, url: Url) -> RequestBuilder {
376        self.client.request(method, url)
377    }
378
379    async fn send(&self, req: RequestBuilder) -> std::result::Result<Response, reqwest::Error> {
380        req.send().await
381    }
382}
383
384/// Allows setting specific `skip` and `limit` values for list queries.
385///
386/// In most cases, you should not need this functionality
387/// but use the `_list` functions
388/// that automatically try to iterate though all pages to retrieve all entities
389pub struct PaginationOptions {
390    #[allow(missing_docs)]
391    pub skip: usize,
392    #[allow(missing_docs)]
393    pub limit: usize,
394}
395
396/// Filter based on TargetType and TargetValues as specified for various items.
397///
398/// **Please note:** This does only filter based on what is stored in the `target` field of an item
399/// (e.g., [`ProgramContent::targets`]) and should not get interpreted by the server.
400/// For example, setting the [`TargetType`] to [`ProgramName`](TargetType::ProgramName)
401/// will not filter based on the [`program_name`](ProgramContent::program_name)
402/// value but only consider what is stored in the [`targets`](`ProgramContent::targets`)
403/// of that program.
404///
405/// Unfortunately, the specification is not very clear about this behavior,
406/// so some servers might interpret it differently.
407/// There has been some discussion with the authors of the standard in
408/// <https://github.com/oadr3-org/openadr3-vtn-reference-implementation/issues/83> (sadly not public).
409#[derive(Debug, Clone)]
410pub enum Filter<'a, S: AsRef<str>> {
411    /// Do not apply any filtering
412    None,
413    /// Filter by [`TargetType`] and a list of values.
414    ///
415    /// It will be encoded to the request as query parameters,
416    /// e.g., `/programs?targetType=GROUP&targetValues=Group-1&targetValues=Group-2`.
417    By(TargetType, &'a [S]),
418}
419
420impl<'a> Filter<'a, &'static str> {
421    /// Create a new filter that does not apply any filtering.
422    pub const fn none() -> Filter<'a, &'static str> {
423        Filter::None
424    }
425}
426
427impl<'a, S: AsRef<str>> Filter<'a, S> {
428    pub(crate) fn to_query_params(&'a self) -> Vec<(&'a str, &'a str)> {
429        let mut query = vec![];
430        if let Filter::By(ref target_label, target_values) = self {
431            query.push(("targetType", target_label.as_str()));
432
433            for target_value in *target_values {
434                query.push(("targetValues", target_value.as_ref()));
435            }
436        }
437        query
438    }
439}
440
441impl Client {
442    /// Create a new client for a VTN located at the specified URL.
443    ///
444    /// This assumes that the VTN also works as an OAuth provider
445    /// and exposes an API endpoint at `<base_url>/auth/token` to retrieve a token.
446    /// If you want to use another OAuth provider, please use [`Self::with_details`].
447    pub fn with_url(base_url: Url, auth: Option<ClientCredentials>) -> Self {
448        let client = reqwest::Client::new();
449        let oauth_base_url = base_url.join("/auth/token").unwrap();
450        Self::with_details(base_url, oauth_base_url, client, auth)
451    }
452
453    /// Create a new client with more details than [`Self::with_url`].
454    ///
455    /// It allows specifying a [`reqwest::Client`] instead of the default one.
456    /// This allows configuring proxy settings, timeouts, etc.
457    ///
458    /// Additionally, it allows for a separate `oauth_base_url`,
459    /// which is relevant if you don't want or cannot rely on an OAuth provider at the default URL.
460    pub fn with_details(
461        vtn_base_url: Url,
462        oauth_base_url: Url,
463        client: reqwest::Client,
464        auth: Option<ClientCredentials>,
465    ) -> Self {
466        Self::with_http_client(
467            vtn_base_url,
468            oauth_base_url,
469            Box::new(ReqwestClientRef { client }),
470            auth,
471        )
472    }
473
474    /// Create a new client with anything that implements the [`HttpClient`] trait.
475    ///
476    /// This is mainly helpful for the integration tests
477    /// and should most likely not be used for other purposes.
478    /// Please use [`Client::with_details`] for detailed HTTP client configuration.
479    pub fn with_http_client(
480        vtn_base_url: Url,
481        oauth_base_url: Url,
482        client: Box<dyn HttpClient + Send + Sync>,
483        auth: Option<ClientCredentials>,
484    ) -> Self {
485        let client_ref = ClientRef {
486            client,
487            vtn_base_url,
488            oauth_base_url,
489            default_page_size: 50,
490            auth_data: auth,
491            auth_token: RwLock::new(None),
492        };
493        Self::new(client_ref)
494    }
495
496    fn new(client_ref: ClientRef) -> Self {
497        Client {
498            client_ref: Arc::new(client_ref),
499        }
500    }
501
502    /// Create a new program on the VTN.
503    pub async fn create_program(&self, program_content: ProgramContent) -> Result<ProgramClient> {
504        let program = self.client_ref.post("programs", &program_content).await?;
505        Ok(ProgramClient::from_program(self.clone(), program))
506    }
507
508    /// Lowlevel operation that gets a list of programs from the VTN with the given query parameters
509    pub async fn get_programs(
510        &self,
511        filter: Filter<'_, impl AsRef<str>>,
512        pagination: PaginationOptions,
513    ) -> Result<Vec<ProgramClient>> {
514        // convert query params
515        let skip_str = pagination.skip.to_string();
516        let limit_str = pagination.limit.to_string();
517        // insert into query params
518        let mut query: Vec<(&str, &str)> = vec![("skip", &skip_str), ("limit", &limit_str)];
519
520        query.extend_from_slice(filter.to_query_params().as_slice());
521
522        // send request and return response
523        let programs: Vec<Program> = self.client_ref.get("programs", &query).await?;
524        Ok(programs
525            .into_iter()
526            .map(|program| ProgramClient::from_program(self.clone(), program))
527            .collect())
528    }
529
530    /// Get all programs from the VTN with the given query parameters
531    ///
532    /// It automatically tries to iterate pages where necessary.
533    pub async fn get_program_list(
534        &self,
535        filter: Filter<'_, impl AsRef<str> + Clone>,
536    ) -> Result<Vec<ProgramClient>> {
537        self.client_ref
538            .iterate_pages(|skip, limit| {
539                self.get_programs(filter.clone(), PaginationOptions { skip, limit })
540            })
541            .await
542    }
543
544    /// Get a program by id
545    pub async fn get_program_by_id(&self, id: &ProgramId) -> Result<ProgramClient> {
546        let program = self
547            .client_ref
548            .get(&format!("programs/{}", id.as_str()), &[])
549            .await?;
550
551        Ok(ProgramClient::from_program(self.clone(), program))
552    }
553
554    /// Low-level operation that gets a list of events from the VTN with the given query parameters
555    ///
556    /// To automatically iterate pages, use [`self.get_event_list`]
557    pub async fn get_events(
558        &self,
559        program_id: Option<&ProgramId>,
560        filter: Filter<'_, impl AsRef<str>>,
561        pagination: PaginationOptions,
562    ) -> Result<Vec<EventClient>> {
563        // convert query params
564        let skip_str = pagination.skip.to_string();
565        let limit_str = pagination.limit.to_string();
566        // insert into query params
567        let mut query: Vec<(&str, &str)> = vec![("skip", &skip_str), ("limit", &limit_str)];
568
569        query.extend_from_slice(filter.to_query_params().as_slice());
570
571        if let Some(program_id) = program_id {
572            query.push(("programID", program_id.as_str()));
573        }
574
575        // send request and return response
576        let events: Vec<Event> = self.client_ref.get("events", &query).await?;
577        Ok(events
578            .into_iter()
579            .map(|event| EventClient::from_event(self.client_ref.clone(), event))
580            .collect())
581    }
582
583    /// Get all events from the VTN with the given query parameters.
584    ///
585    /// It automatically tries to iterate pages where necessary.
586    pub async fn get_event_list(
587        &self,
588        program_id: Option<&ProgramId>,
589        filter: Filter<'_, impl AsRef<str> + Clone>,
590    ) -> Result<Vec<EventClient>> {
591        self.client_ref
592            .iterate_pages(|skip, limit| {
593                self.get_events(
594                    program_id,
595                    filter.clone(),
596                    PaginationOptions { skip, limit },
597                )
598            })
599            .await
600    }
601
602    /// Get an event by id
603    pub async fn get_event_by_id(&self, id: &EventId) -> Result<EventClient> {
604        let event = self
605            .client_ref
606            .get(&format!("events/{}", id.as_str()), &[])
607            .await?;
608
609        Ok(EventClient::from_event(self.client_ref.clone(), event))
610    }
611
612    /// Create a new VEN entity at the VTN. The content should be created with [`VenContent::new`].
613    pub async fn create_ven(&self, ven: VenContent) -> Result<VenClient> {
614        let ven = self.client_ref.post("vens", &ven).await?;
615        Ok(VenClient::from_ven(self.client_ref.clone(), ven))
616    }
617
618    async fn get_vens(
619        &self,
620        skip: usize,
621        limit: usize,
622        filter: Filter<'_, impl AsRef<str>>,
623    ) -> Result<Vec<VenClient>> {
624        let skip_str = skip.to_string();
625        let limit_str = limit.to_string();
626        let mut query: Vec<(&str, &str)> = vec![("skip", &skip_str), ("limit", &limit_str)];
627
628        query.extend_from_slice(filter.to_query_params().as_slice());
629
630        // send request and return response
631        let vens: Vec<Ven> = self.client_ref.get("vens", &query).await?;
632        Ok(vens
633            .into_iter()
634            .map(|ven| VenClient::from_ven(self.client_ref.clone(), ven))
635            .collect())
636    }
637
638    /// Get all VENs from the VTN with the given query parameters.
639    ///
640    /// The client automatically tries to iterate pages where necessary.
641    pub async fn get_ven_list(
642        &self,
643        filter: Filter<'_, impl AsRef<str> + Clone>,
644    ) -> Result<Vec<VenClient>> {
645        self.client_ref
646            .iterate_pages(|skip, limit| self.get_vens(skip, limit, filter.clone()))
647            .await
648    }
649
650    /// Get VEN by id from VTN
651    pub async fn get_ven_by_id(&self, id: &VenId) -> Result<VenClient> {
652        let ven = self
653            .client_ref
654            .get(&format!("vens/{}", id.as_str()), &[])
655            .await?;
656        Ok(VenClient::from_ven(self.client_ref.clone(), ven))
657    }
658
659    /// Get VEN by name from VTN.
660    /// According to the spec, a [`ven_name`](VenContent::ven_name) must be unique for the whole VTN instance.
661    pub async fn get_ven_by_name(&self, name: &str) -> Result<VenClient> {
662        let mut vens: Vec<Ven> = self.client_ref.get("vens", &[("venName", name)]).await?;
663        match vens[..] {
664            [] => Err(Error::ObjectNotFound),
665            [_] => Ok(VenClient::from_ven(self.client_ref.clone(), vens.remove(0))),
666            [..] => Err(Error::DuplicateObject),
667        }
668    }
669}