Skip to main content

pdk_contracts_lib/api/
validator.rs

1// Copyright (c) 2026, Salesforce, Inc.,
2// All rights reserved.
3// For full license text, see the LICENSE.txt file
4
5use super::authentication::authenticate;
6use super::authorization::authorize;
7use super::credentials::{ClientId, ClientSecret};
8use super::error::UpdateError;
9use super::error::{AuthenticationError, AuthorizationError};
10use crate::api::ClientData;
11use crate::implementation::constants::{
12    ACCEPT_HASH_ALGORITHM_VALUE, CONTRACTS_CACHE_EXPIRATION, CONTRACTS_FETCHING_INTERVAL,
13    LOCK_TIMEOUT_DURATION, PRIMARY_BACKUP_INTERVAL, PRIMARY_INACTIVE_TIMEOUT,
14};
15use crate::implementation::model::contracts_storage::{ContractsLocalStorage, ContractsStorage};
16use crate::implementation::model::distributed_storage::ContractsCache;
17use crate::implementation::model::session_storage::{SessionSharedDataStorage, SessionStorage};
18use crate::implementation::platform::client::HttpPlatformClient;
19use crate::implementation::platform::responses::{
20    contract_from_event, ContractsResponse, LoginResponse,
21};
22use crate::implementation::platform::shared::{AccessToken, ContractsRequestParams};
23use data_storage_lib::DataStorageBuilder;
24use lock_lib::{Lock, LockBuilder, TryLock};
25use pdk_core::classy::extract::context::ConfigureContext;
26use pdk_core::classy::extract::{Extract, FromContext};
27use pdk_core::classy::hl::HttpClientError;
28use pdk_core::classy::proxy_wasm::types::Status;
29use pdk_core::classy::{Clock, SharedData};
30use pdk_core::logger;
31use pdk_core::logger::debug;
32use pdk_core::policy_context::api::Metadata;
33use std::rc::Rc;
34use std::time::Duration;
35use thiserror::Error;
36
37/// The object that will collect the contracts and provide the functionality to validate incoming
38/// requests.
39pub struct ContractValidator {
40    api_id: String,
41    client: HttpPlatformClient,
42    session_lock: TryLock,
43    api_lock: TryLock,
44    session_storage: SessionSharedDataStorage,
45    contract_storage: Rc<ContractsLocalStorage>,
46    contracts_cache: ContractsCache,
47    clock: Rc<dyn Clock>,
48}
49
50#[derive(Error, Debug)]
51enum InternalUpdateError {
52    #[error("Http client error: {0}")]
53    HttpClientError(#[from] HttpClientError),
54    #[error("Parsing error: {0}")]
55    Serde(#[from] serde_json::error::Error),
56    #[error("Lost the lock while executing async function.")]
57    LostLock,
58    #[error("Upstream returned unexpected status code {0}")]
59    UnexpectedResponse(u32),
60}
61
62enum PollContractsResponse {
63    Continue,
64    Renegotiate,
65    Finish,
66}
67
68pub(crate) enum PollerType {
69    Primary,
70    Secondary,
71}
72
73pub(crate) enum PollerError {
74    LostLock,
75    DataStorageError,
76}
77
78impl ContractValidator {
79    fn new(
80        client: HttpPlatformClient,
81        api_id: String,
82        clock: Rc<dyn Clock>,
83        shared_data: Rc<dyn SharedData>,
84        lock_builder: LockBuilder,
85        data_storage_builder: DataStorageBuilder,
86    ) -> Self {
87        // The session storage and the session lock only use two entries in total despite the amount
88        // of Apis/policies deployed. We can safely not remove them without fear of a leak.
89        let session_storage =
90            SessionSharedDataStorage::new(Rc::clone(&clock), Rc::clone(&shared_data));
91
92        let session_lock = lock_builder
93            .new(session_storage.session_lock_key())
94            .expiration(LOCK_TIMEOUT_DURATION)
95            .shared()
96            .build();
97
98        let contract_storage = Rc::new(ContractsLocalStorage::new(
99            &api_id,
100            Rc::clone(&clock),
101            shared_data,
102        ));
103        let api_lock = lock_builder
104            .new(contract_storage.api_lock_key())
105            .expiration(LOCK_TIMEOUT_DURATION)
106            .shared()
107            .build();
108
109        let contracts_cache = ContractsCache::new(
110            Rc::clone(&clock),
111            data_storage_builder
112                .shared()
113                .remote(format!("{api_id}-CONTRACTS"), CONTRACTS_CACHE_EXPIRATION),
114            Rc::clone(&contract_storage),
115        );
116
117        Self {
118            api_id,
119            client,
120            session_lock,
121            api_lock,
122            session_storage,
123            contract_storage,
124            clock,
125            contracts_cache,
126        }
127    }
128
129    /// Initial updating interval for calling [ContractValidator::update_contracts()].
130    pub const INITIALIZATION_PERIOD: Duration = Duration::from_millis(100);
131
132    /// Remaining updating interval for calling [ContractValidator::update_contracts()].
133    pub const UPDATE_PERIOD: Duration = CONTRACTS_FETCHING_INTERVAL;
134
135    /// Validates if `client_id` credential can be authorized in the current API contract.
136    /// If the validation is succesful, this method returns the [ClientData] of `client_id`.
137    /// Otherwise an [AuthorizationError] is returned.
138    pub fn authorize(&self, client_id: &ClientId) -> Result<ClientData, AuthorizationError> {
139        authorize(self.contract_storage.as_ref(), client_id)
140    }
141
142    /// Validates if `client_id` and `client_secret` credentials can be authenticated
143    /// in the current API contract.
144    /// If the validation is succesful, this method returns the [ClientData] of `client_id`.
145    /// Otherwise an [AuthenticationError] is returned.
146    pub fn authenticate(
147        &self,
148        client_id: &ClientId,
149        client_secret: &ClientSecret,
150    ) -> Result<ClientData, AuthenticationError> {
151        authenticate(self.contract_storage.as_ref(), client_id, client_secret)
152    }
153
154    /// Returns `true` if contracts have been pulled and are available locally.
155    /// This method can be used to determine if the [`ContractValidator`] is ready
156    /// for authorization and authentication operations.
157    pub fn is_ready(&self) -> bool {
158        self.contract_storage.last_update().is_some()
159    }
160
161    /// Updates local contracts database.
162    /// This method is intended to be called periodically in order to keep
163    /// the local contracts database up to date.
164    /// During initialization time, this method should be invoked in a period of
165    /// [ContractValidator::INITIALIZATION_PERIOD].
166    /// After initialization, this method should be invoked in a period of
167    /// [ContractValidator::UPDATE_PERIOD].
168    pub async fn update_contracts(&self) -> Result<(), UpdateError> {
169        if !self.should_update() {
170            debug!("Contracts update skipped since update period hasn't elapsed.");
171            return Ok(());
172        }
173
174        debug!("Fetching contracts for API {}", self.api_id);
175        let Some(api_lock) = self.api_lock.try_lock() else {
176            debug!(
177                "Other worker has the lock for API {}. Skipping update.",
178                self.api_id
179            );
180            return Ok(());
181        };
182
183        if self.first_cycle() {
184            let _ = self.cache_contracts_poll(&api_lock).await;
185            if !api_lock.refresh_lock() {
186                return Ok(());
187            }
188        }
189
190        // Re-check if the update period has elapsed after backup restore
191        if !self.should_update() {
192            debug!("Contracts update skipped since update period hasn't elapsed.");
193            return Ok(());
194        }
195
196        let result = self.platform_contracts_poll(&api_lock).await;
197
198        if self.contract_storage.last_update().is_none() {
199            debug!("No successfully poll registered will not try to backup contracts");
200            return result.map(|_| ());
201        }
202
203        match self.poller_type(&api_lock).await {
204            Ok(PollerType::Primary) => {
205                self.backup_contracts(&api_lock, result.as_ref().map(|r| *r).unwrap_or_default())
206                    .await;
207            }
208            Ok(PollerType::Secondary) => {
209                debug!("No update backup since we are a secondary node.");
210            }
211            Err(PollerError::LostLock) => {
212                debug!("Lost the api_lock while trying to become primary, skipping update.");
213            }
214            Err(PollerError::DataStorageError) => {
215                debug!("Unexpected error communicating with the data storage.");
216            }
217        };
218
219        result.map(|_| ())
220    }
221
222    async fn platform_contracts_poll(&self, api_lock: &'_ Lock<'_>) -> Result<bool, UpdateError> {
223        let mut updates = false;
224        debug!(
225            "Fetching contracts for API {} from contracts service",
226            self.api_id
227        );
228
229        let Some(token) = self.session_token().await else {
230            return Ok(updates);
231        };
232
233        if !api_lock.refresh_lock() {
234            debug!("Lost the api lock while fetching session token.");
235            return Ok(updates);
236        }
237
238        let mut token_data = token;
239        loop {
240            match self.poll_contracts(&token_data, api_lock).await {
241                Ok(PollContractsResponse::Continue) => {
242                    debug!("Contract polling request successful. Chaining next request");
243                    updates = true;
244                }
245                Ok(PollContractsResponse::Renegotiate) => {
246                    let Some(token) = self.renegotiate_token(token_data).await else {
247                        // Could not renegotiate the token
248                        return Ok(updates);
249                    };
250                    token_data = token;
251                    if !api_lock.refresh_lock() {
252                        debug!("Lost the api lock while refreshing the session lock.");
253                        return Ok(updates);
254                    }
255                }
256                Ok(PollContractsResponse::Finish) => {
257                    return Ok(updates);
258                }
259                Err(error) => {
260                    debug!("Error while polling contracts: {error}");
261                    return Ok(updates);
262                }
263            }
264        }
265    }
266
267    async fn cache_contracts_poll(&self, api_lock: &'_ Lock<'_>) -> Result<(), UpdateError> {
268        debug!("Fetching contracts for API {} from cache.", self.api_id);
269
270        let result = self.contracts_cache.get_state().await;
271
272        if !api_lock.refresh_lock() {
273            debug!("Lost the api lock while recovering state from remote storage.");
274            return Ok(());
275        }
276
277        result
278            .into_iter()
279            .for_each(|state| self.contract_storage.set_state(state));
280
281        Ok(())
282    }
283
284    async fn poller_type(&self, api_lock: &'_ Lock<'_>) -> Result<PollerType, PollerError> {
285        let Some(primary) = self.contract_storage.is_primary() else {
286            debug!("No information regarding primary node.");
287            return self.contracts_cache.try_primary(api_lock).await;
288        };
289
290        let primary_expired = self
291            .contract_storage
292            .last_primary_update()
293            .map(|last| last + PRIMARY_INACTIVE_TIMEOUT < self.clock.get_current_time())
294            .unwrap_or(true);
295
296        if primary_expired && !primary {
297            debug!("Secondary node trying to become primary due to timeout.");
298            return self.contracts_cache.try_primary(api_lock).await;
299        } else if primary_expired {
300            debug!("We lost the primary status. We'll become secondary for at least one polling cycle.");
301            self.contract_storage.set_primary(false);
302            return Ok(PollerType::Secondary);
303        }
304
305        match primary {
306            true => Ok(PollerType::Primary),
307            false => Ok(PollerType::Secondary),
308        }
309    }
310
311    fn first_cycle(&self) -> bool {
312        self.contract_storage.last_update().is_none()
313    }
314
315    fn should_update(&self) -> bool {
316        self.contract_storage
317            .last_update()
318            .map(|last| last + CONTRACTS_FETCHING_INTERVAL < self.clock.get_current_time())
319            .unwrap_or(true)
320    }
321
322    async fn backup_contracts(&self, api_lock: &'_ Lock<'_>, has_updates: bool) {
323        if !self.should_update_backup(has_updates) {
324            return;
325        }
326
327        let time = self.clock.get_current_time();
328        let mut update = self.contract_storage.get_state();
329        update.update_primary(time);
330        if self.contracts_cache.save_state(update).await {
331            self.contract_storage.set_primary_update(time);
332        }
333
334        if !api_lock.refresh_lock() {
335            debug!("Lost the api lock while backing data to cache.");
336        }
337    }
338
339    fn should_update_backup(&self, has_updates: bool) -> bool {
340        let Some(last_update) = self.contract_storage.last_update() else {
341            debug!("Skipping cache backup since no data to backup.");
342            return false;
343        };
344
345        if has_updates {
346            debug!("Will backup contracts since new updates are available.");
347            return true;
348        }
349
350        let Some(last_primary_update) = self.contract_storage.last_primary_update() else {
351            debug!("No local records of a primary node.");
352            return true;
353        };
354
355        if last_update < last_primary_update {
356            debug!("Skipping cache backup since no updates since last save.");
357            return false; // No data to back up. This can happen when connection to the platform fails.
358        }
359
360        let result = last_primary_update + PRIMARY_BACKUP_INTERVAL < self.clock.get_current_time();
361        if !result {
362            debug!("Skipping cache backup since the elapsed time is less than the refresh rate.");
363        }
364        result
365    }
366
367    async fn session_token(&self) -> Option<AccessToken> {
368        match self.session_storage.get_token() {
369            Some(token) => Some(token),
370            None => {
371                let Some(session_lock) = self.session_lock.try_lock() else {
372                    debug!("Other worker has the session lock. Skipping update.");
373                    return None;
374                };
375
376                // Re check if the token was already set before obtaining the lock.
377                // Since the lock is "not blocking" the changes are quite slim, but not null.
378                if let Some(token) = self.session_storage.get_token() {
379                    return Some(token);
380                }
381
382                self.fetch_session_token(&session_lock).await
383            }
384        }
385    }
386
387    async fn renegotiate_token(&self, old_token: AccessToken) -> Option<AccessToken> {
388        // Validate that the token was not renegotiated by someone else.
389        if let Some(token) = self.session_storage.get_token() {
390            if token != old_token {
391                return Some(token);
392            }
393        };
394
395        // Acquire the lock
396        let Some(session_lock) = self.session_lock.try_lock() else {
397            debug!("Other worker has the session lock. Aborting token renegotiation");
398            return None;
399        };
400
401        // Validate Again that the token was not renegotiated by someone else.
402        // Since the lock is "not blocking" the changes are quite slim, but not null.
403        if let Some(token) = self.session_storage.get_token() {
404            if token != old_token {
405                return Some(token);
406            }
407        };
408
409        self.fetch_session_token(&session_lock).await
410    }
411
412    async fn fetch_session_token(&self, session_lock: &'_ Lock<'_>) -> Option<AccessToken> {
413        match self.perform_login_request().await {
414            Ok(login) => {
415                if !session_lock.refresh_lock() {
416                    debug!("Lost the session lock. Aborting update.");
417                    return None;
418                }
419                let token = login.get_token();
420                let token_data = AccessToken::new(token.to_string(), login.get_type().to_string());
421                debug!("Obtained the session token.");
422                self.session_storage.save_token(token_data.clone());
423                Some(token_data)
424            }
425            Err(e) => {
426                logger::warn!(
427                    "Unexpected error while performing login request {e}. Skipping update."
428                );
429                None
430            }
431        }
432    }
433
434    async fn perform_login_request(&self) -> Result<LoginResponse, InternalUpdateError> {
435        debug!("Getting platform token...");
436        match self.client.login().await? {
437            r if r.status_code() == 200 => Ok(serde_json::from_slice::<LoginResponse>(r.body())?),
438            r => {
439                debug!(
440                    "Fetching contracts failed with status code: {} and body:\n {}",
441                    r.status_code(),
442                    String::from_utf8_lossy(r.body())
443                );
444                Err(InternalUpdateError::UnexpectedResponse(r.status_code()))
445            }
446        }
447    }
448
449    async fn poll_contracts(
450        &self,
451        access_token: &AccessToken,
452        api_lock: &'_ Lock<'_>,
453    ) -> Result<PollContractsResponse, InternalUpdateError> {
454        let token = access_token.get_access_token();
455        let response = self
456            .client
457            .contracts(
458                token,
459                self.api_id.as_str(),
460                ACCEPT_HASH_ALGORITHM_VALUE,
461                self.next_url(),
462            )
463            .await?;
464
465        if !api_lock.refresh_lock() {
466            return Err(InternalUpdateError::LostLock);
467        }
468
469        match response.status_code() {
470            200 => {
471                let contracts: ContractsResponse = serde_json::from_slice(response.body())
472                    .map_err(|_| HttpClientError::Status(Status::InternalFailure))?;
473
474                if self.no_updates(&contracts) {
475                    self.finish_polling();
476                    Ok(PollContractsResponse::Finish)
477                } else {
478                    self.log_invalid_contracts(&contracts);
479                    self.update_data(&contracts);
480                    self.update_links(&contracts);
481                    Ok(PollContractsResponse::Continue)
482                }
483            }
484            401 => Ok(PollContractsResponse::Renegotiate),
485            n => {
486                debug!(
487                    "Fetching contracts failed with status code: {} and body:\n {}",
488                    n,
489                    String::from_utf8_lossy(response.body())
490                );
491                Err(InternalUpdateError::UnexpectedResponse(n))
492            }
493        }
494    }
495
496    fn log_invalid_contracts(&self, response: &ContractsResponse) {
497        for invalid_contract_error_msg in response.verify_contracts().err().unwrap_or_default() {
498            logger::warn!("{invalid_contract_error_msg}")
499        }
500    }
501
502    fn no_updates(&self, response: &ContractsResponse) -> bool {
503        let links = response.get_links();
504        links.self_link() == links.next_link()
505    }
506
507    fn finish_polling(&self) {
508        self.contract_storage.update_last();
509        debug!(
510            "No more contracts updates for API {}, polling in next tick.",
511            self.api_id
512        );
513    }
514
515    fn update_data(&self, response: &ContractsResponse) {
516        let data = response.get_data();
517        for contract_event in data {
518            match contract_event.removed.unwrap_or(false) {
519                true => self
520                    .contract_storage
521                    .remove_contract(&contract_event.client_id),
522                false => self
523                    .contract_storage
524                    .save_contract(contract_from_event(contract_event)),
525            }
526        }
527        self.contract_storage.update_last();
528        debug!(
529            "{} contract events processed for API {}",
530            data.len(),
531            self.api_id
532        );
533    }
534
535    fn update_links(&self, response: &ContractsResponse) {
536        let links = response.get_links();
537        let params = ContractsRequestParams::new(
538            Some(links.next_link().to_string()),
539            ACCEPT_HASH_ALGORITHM_VALUE.to_string(),
540        );
541        self.contract_storage.save_contracts_request_params(params);
542    }
543
544    fn next_url(&self) -> Option<String> {
545        self.contract_storage
546            .get_contracts_request_params()
547            .and_then(|x| x.next_url)
548    }
549}
550
551#[derive(thiserror::Error, Debug)]
552pub enum ExtractionError {
553    #[error("Api metadata is unavailable.")]
554    ApiMetadata,
555
556    #[error("Environment Context is unavailable.")]
557    EnvironmentContext,
558
559    #[error("Anypoint Context is unavailable.")]
560    AnypointContext,
561}
562
563impl FromContext<ConfigureContext> for ContractValidator {
564    type Error = ExtractionError;
565
566    fn from_context(context: &ConfigureContext) -> Result<Self, Self::Error> {
567        let metadata: Metadata = context.extract_always();
568        let api_id = metadata
569            .api_metadata
570            .id
571            .ok_or(ExtractionError::ApiMetadata)?;
572        let client = context.extract()?;
573        let clock = context.extract_always();
574        let shared_data = context.extract_always();
575        let lock_builder = context.extract_always();
576        let storage_builder: DataStorageBuilder = context
577            .extract()
578            .map_err(|_| ExtractionError::EnvironmentContext)?;
579
580        Ok(Self::new(
581            client,
582            api_id,
583            clock,
584            shared_data,
585            lock_builder,
586            storage_builder,
587        ))
588    }
589}