pdk_contracts_lib/api/
validator.rs

1// Copyright (c) 2025, Salesforce, Inc.,
2// All rights reserved.
3// For full license text, see the LICENSE.txt file
4
5use lock_lib::{Lock, LockBuilder, TryLock};
6use pdk_core::classy::extract::context::ConfigureContext;
7use pdk_core::classy::extract::{Extract as _, FromContext};
8use pdk_core::classy::hl::HttpClientError;
9use pdk_core::classy::proxy_wasm::types::Status;
10use pdk_core::classy::{Clock, SharedData};
11use pdk_core::logger;
12use pdk_core::policy_context::api::Metadata;
13use std::rc::Rc;
14use std::time::Duration;
15use thiserror::Error;
16
17use super::error::UpdateError;
18use super::error::{AuthenticationError, AuthorizationError};
19use crate::api::ClientData;
20use crate::implementation::constants::{
21    ACCEPT_HASH_ALGORITHM_VALUE, CONTRACTS_FETCHING_INTERVAL, LOCK_TIMEOUT_DURATION,
22};
23use crate::implementation::model::contracts_storage::{ContractsLocalStorage, ContractsStorage};
24use crate::implementation::model::session_storage::{SessionSharedDataStorage, SessionStorage};
25use crate::implementation::platform::client::HttpPlatformClient;
26use crate::implementation::platform::responses::{
27    contract_from_event, ContractsResponse, LoginResponse,
28};
29use crate::implementation::platform::shared::{AccessToken, ContractsRequestParams};
30
31use super::authentication::authenticate;
32use super::authorization::authorize;
33use super::credentials::{ClientId, ClientSecret};
34
35/// The object that will collect the contracts and provide the functionality to validate incoming
36/// requests.
37pub struct ContractValidator {
38    api_id: String,
39    client: HttpPlatformClient,
40    session_lock: TryLock,
41    api_lock: TryLock,
42    session_storage: SessionSharedDataStorage,
43    contract_storage: ContractsLocalStorage,
44    clock: Rc<dyn Clock>,
45}
46
47#[derive(Error, Debug)]
48enum InternalUpdateError {
49    #[error("Http client error: {0}")]
50    HttpClientError(#[from] HttpClientError),
51    #[error("Parsing error: {0}")]
52    Serde(#[from] serde_json::error::Error),
53    #[error("Lost the lock while executing async function.")]
54    LostLock,
55    #[error("Upstream returned unexpected status code {0}")]
56    UnexpectedResponse(u32),
57}
58
59enum PollContractsResponse {
60    Continue,
61    Renegotiate,
62    Finish,
63}
64
65impl ContractValidator {
66    fn new(
67        client: HttpPlatformClient,
68        api_id: String,
69        clock: Rc<dyn Clock>,
70        shared_data: Rc<dyn SharedData>,
71        lock_builder: LockBuilder,
72    ) -> Self {
73        // The session storage and the session lock only use two entries in total despite the amount
74        // of Apis/policies deployed. We can safely not remove them without fear of a leak.
75        let session_storage =
76            SessionSharedDataStorage::new(Rc::clone(&clock), Rc::clone(&shared_data));
77
78        let session_lock = lock_builder
79            .new(session_storage.session_lock_key())
80            .expiration(LOCK_TIMEOUT_DURATION)
81            .shared()
82            .build();
83
84        let contract_storage = ContractsLocalStorage::new(&api_id, Rc::clone(&clock), shared_data);
85        let api_lock = lock_builder
86            .new(contract_storage.api_lock_key())
87            .expiration(LOCK_TIMEOUT_DURATION)
88            .shared()
89            .build();
90
91        Self {
92            api_id,
93            client,
94            session_lock,
95            api_lock,
96            session_storage,
97            contract_storage,
98            clock,
99        }
100    }
101
102    /// Initial updating interval for calling [ContractValidator::update_contracts()].
103    pub const INITIALIZATION_PERIOD: Duration = Duration::from_millis(100);
104
105    /// Remaining updating interval for calling [ContractValidator::update_contracts()].
106    pub const UPDATE_PERIOD: Duration = CONTRACTS_FETCHING_INTERVAL;
107
108    /// Validates if `client_id` credential can be authorized in the current API contract.
109    /// If the validation is succesful, this method returns the [ClientData] of `client_id`.
110    /// Otherwise an [AuthorizationError] is returned.
111    pub fn authorize(&self, client_id: &ClientId) -> Result<ClientData, AuthorizationError> {
112        authorize(&self.contract_storage, client_id)
113    }
114
115    /// Validates if `client_id` and `client_secret` credentials can be authenticated
116    /// in the current API contract.
117    /// If the validation is succesful, this method returns the [ClientData] of `client_id`.
118    /// Otherwise an [AuthenticationError] is returned.
119    pub fn authenticate(
120        &self,
121        client_id: &ClientId,
122        client_secret: &ClientSecret,
123    ) -> Result<ClientData, AuthenticationError> {
124        authenticate(&self.contract_storage, client_id, client_secret)
125    }
126
127    /// Returns `true` if contracts have been pulled and are available locally.
128    /// This method can be used to determine if the [`ContractValidator`] is ready
129    /// for authorization and authentication operations.
130    pub fn is_ready(&self) -> bool {
131        self.contract_storage.last_update().is_some()
132    }
133
134    /// Updates local contracts database.
135    /// This method is intended to be called periodically in order to keep
136    /// the local contracts database up to date.
137    /// During initialization time, this method should be invoked in a period of
138    /// [ContractValidator::INITIALIZATION_PERIOD].
139    /// After initialization, this method should be invoked in a period of
140    /// [ContractValidator::UPDATE_PERIOD].
141    pub async fn update_contracts(&self) -> Result<(), UpdateError> {
142        if !self.should_update() {
143            logger::debug!("Contracts update skipped since update period hasn't elapsed.");
144            return Ok(());
145        }
146
147        logger::debug!("Fetching contracts for API {}", self.api_id);
148
149        let Some(token) = self.session_token().await else {
150            return Ok(());
151        };
152
153        let Some(api_lock) = self.api_lock.try_lock() else {
154            logger::debug!(
155                "Other worker has the lock for API {}. Skipping update.",
156                self.api_id
157            );
158            return Ok(());
159        };
160
161        let mut token_data = token;
162        loop {
163            match self.poll_contracts(&token_data, &api_lock).await {
164                Ok(PollContractsResponse::Continue) => {
165                    logger::debug!("Contract polling request successful. Chaining next request");
166                }
167                Ok(PollContractsResponse::Renegotiate) => {
168                    let Some(token) = self.renegotiate_token(token_data).await else {
169                        // Could not renegotiate the token
170                        return Ok(());
171                    };
172                    token_data = token;
173                    if !api_lock.refresh_lock() {
174                        logger::debug!("Lost the api lock while refreshing the session lock.");
175                        return Ok(());
176                    }
177                }
178                Ok(PollContractsResponse::Finish) => {
179                    return Ok(());
180                }
181                Err(error) => {
182                    logger::debug!("Error while polling contracts: {error}");
183                    return Ok(());
184                }
185            }
186        }
187    }
188
189    fn should_update(&self) -> bool {
190        self.contract_storage
191            .last_update()
192            .map(|last| last + CONTRACTS_FETCHING_INTERVAL < self.clock.get_current_time())
193            .unwrap_or(true)
194    }
195
196    async fn session_token(&self) -> Option<AccessToken> {
197        match self.session_storage.get_token() {
198            Some(token) => Some(token),
199            None => {
200                let Some(session_lock) = self.session_lock.try_lock() else {
201                    logger::debug!("Other worker has the session lock. Skipping update.");
202                    return None;
203                };
204
205                // Re check if the token was already set before obtaining the lock.
206                // Since the lock is "not blocking" the changes are quite slim, but not null.
207                if let Some(token) = self.session_storage.get_token() {
208                    return Some(token);
209                }
210
211                self.fetch_session_token(&session_lock).await
212            }
213        }
214    }
215
216    async fn renegotiate_token(&self, old_token: AccessToken) -> Option<AccessToken> {
217        // Validate that the token was not renegotiated by someone else.
218        if let Some(token) = self.session_storage.get_token() {
219            if token != old_token {
220                return Some(token);
221            }
222        };
223
224        // Acquire the lock
225        let Some(session_lock) = self.session_lock.try_lock() else {
226            logger::debug!("Other worker has the session lock. Aborting token renegotiation");
227            return None;
228        };
229
230        // Validate Again that the token was not renegotiated by someone else.
231        // Since the lock is "not blocking" the changes are quite slim, but not null.
232        if let Some(token) = self.session_storage.get_token() {
233            if token != old_token {
234                return Some(token);
235            }
236        };
237
238        self.fetch_session_token(&session_lock).await
239    }
240
241    async fn fetch_session_token(&self, session_lock: &'_ Lock<'_>) -> Option<AccessToken> {
242        match self.perform_login_request().await {
243            Ok(login) => {
244                if !session_lock.refresh_lock() {
245                    logger::debug!("Lost the session lock. Aborting update.");
246                    return None;
247                }
248                let token = login.get_token();
249                let token_data = AccessToken::new(token.to_string(), login.get_type().to_string());
250                logger::debug!("Obtained the session token.");
251                self.session_storage.save_token(token_data.clone());
252                Some(token_data)
253            }
254            Err(e) => {
255                logger::warn!(
256                    "Unexpected error while performing login request {e}. Skipping update."
257                );
258                None
259            }
260        }
261    }
262
263    async fn perform_login_request(&self) -> Result<LoginResponse, InternalUpdateError> {
264        logger::debug!("Getting platform token...");
265        match self.client.login().await? {
266            r if r.status_code() == 200 => Ok(serde_json::from_slice::<LoginResponse>(r.body())?),
267            r => {
268                logger::debug!(
269                    "Fetching contracts failed with status code: {} and body:\n {}",
270                    r.status_code(),
271                    String::from_utf8_lossy(r.body())
272                );
273                Err(InternalUpdateError::UnexpectedResponse(r.status_code()))
274            }
275        }
276    }
277
278    async fn poll_contracts(
279        &self,
280        access_token: &AccessToken,
281        api_lock: &'_ Lock<'_>,
282    ) -> Result<PollContractsResponse, InternalUpdateError> {
283        let token = access_token.get_access_token();
284        let response = self
285            .client
286            .contracts(
287                token,
288                self.api_id.as_str(),
289                ACCEPT_HASH_ALGORITHM_VALUE,
290                self.next_url(),
291            )
292            .await?;
293
294        if !api_lock.refresh_lock() {
295            return Err(InternalUpdateError::LostLock);
296        }
297
298        match response.status_code() {
299            200 => {
300                let contracts: ContractsResponse = serde_json::from_slice(response.body())
301                    .map_err(|_| HttpClientError::Status(Status::InternalFailure))?;
302
303                if self.no_updates(&contracts) {
304                    self.finish_polling();
305                    Ok(PollContractsResponse::Finish)
306                } else {
307                    self.log_invalid_contracts(&contracts);
308                    self.update_data(&contracts);
309                    self.update_links(&contracts);
310                    Ok(PollContractsResponse::Continue)
311                }
312            }
313            401 => Ok(PollContractsResponse::Renegotiate),
314            n => {
315                logger::debug!(
316                    "Fetching contracts failed with status code: {} and body:\n {}",
317                    n,
318                    String::from_utf8_lossy(response.body())
319                );
320                Err(InternalUpdateError::UnexpectedResponse(n))
321            }
322        }
323    }
324
325    fn log_invalid_contracts(&self, response: &ContractsResponse) {
326        for invalid_contract_error_msg in response.verify_contracts().err().unwrap_or_default() {
327            logger::warn!("{invalid_contract_error_msg}")
328        }
329    }
330
331    fn no_updates(&self, response: &ContractsResponse) -> bool {
332        let links = response.get_links();
333        links.self_link() == links.next_link()
334    }
335
336    fn finish_polling(&self) {
337        self.contract_storage.update_last();
338        logger::debug!(
339            "No more contracts updates for API {}, polling in next tick.",
340            self.api_id
341        );
342    }
343
344    fn update_data(&self, response: &ContractsResponse) {
345        let data = response.get_data();
346        for contract_event in data {
347            match contract_event.removed.unwrap_or(false) {
348                true => self
349                    .contract_storage
350                    .remove_contract(&contract_event.client_id),
351                false => self
352                    .contract_storage
353                    .save_contract(contract_from_event(contract_event)),
354            }
355        }
356        logger::debug!(
357            "{} contract events processed for API {}",
358            data.len(),
359            self.api_id
360        );
361    }
362
363    fn update_links(&self, response: &ContractsResponse) {
364        let links = response.get_links();
365        let params = ContractsRequestParams::new(
366            Some(links.next_link().to_string()),
367            ACCEPT_HASH_ALGORITHM_VALUE.to_string(),
368        );
369        self.contract_storage.save_contracts_request_params(params);
370    }
371
372    fn next_url(&self) -> Option<String> {
373        self.contract_storage
374            .get_contracts_request_params()
375            .and_then(|x| x.next_url)
376    }
377}
378
379#[derive(thiserror::Error, Debug)]
380pub enum ExtractionError {
381    #[error("Api metadata is unavailable.")]
382    ApiMetadata,
383
384    #[error("Environment Context is unavailable.")]
385    EnvironmentContext,
386
387    #[error("Anypoint Context is unavailable.")]
388    AnypointContext,
389}
390
391impl FromContext<ConfigureContext> for ContractValidator {
392    type Error = ExtractionError;
393
394    fn from_context(context: &ConfigureContext) -> Result<Self, Self::Error> {
395        let metadata: Metadata = context.extract_always();
396        let api_id = metadata
397            .api_metadata
398            .id
399            .ok_or(ExtractionError::ApiMetadata)?;
400        let client = context.extract()?;
401        let clock = context.extract_always();
402        let shared_data = context.extract_always();
403        let lock_builder = context.extract_always();
404
405        Ok(Self::new(client, api_id, clock, shared_data, lock_builder))
406    }
407}