use super::authentication::authenticate;
use super::authorization::authorize;
use super::credentials::{ClientId, ClientSecret};
use super::error::UpdateError;
use super::error::{AuthenticationError, AuthorizationError};
use crate::api::tracker::AttemptTracker;
use crate::api::ClientData;
use crate::implementation::constants::{
ACCEPT_HASH_ALGORITHM_VALUE, CONTRACTS_CACHE_EXPIRATION, CONTRACTS_FETCHING_INTERVAL,
LOCK_TIMEOUT_DURATION, PRIMARY_BACKUP_INTERVAL, PRIMARY_INACTIVE_TIMEOUT,
};
use crate::implementation::model::contracts_storage::{ContractsLocalStorage, ContractsStorage};
use crate::implementation::model::distributed_storage::ContractsCache;
use crate::implementation::model::session_storage::{SessionSharedDataStorage, SessionStorage};
use crate::implementation::platform::client::HttpPlatformClient;
use crate::implementation::platform::responses::{
contract_from_event, ContractsResponse, LoginResponse,
};
use crate::implementation::platform::shared::{AccessToken, ContractsRequestParams};
use data_storage_lib::DataStorageBuilder;
use lock_lib::{Lock, LockBuilder, TryLock};
use pdk_core::classy::extract::context::ConfigureContext;
use pdk_core::classy::extract::{Extract, FromContext};
use pdk_core::classy::hl::HttpClientError;
use pdk_core::classy::proxy_wasm::types::Status;
use pdk_core::classy::{Clock, SharedData};
use pdk_core::logger;
use pdk_core::logger::debug;
use pdk_core::policy_context::api::Metadata;
use std::rc::Rc;
use std::time::Duration;
use thiserror::Error;
pub struct ContractValidator {
api_id: String,
client: HttpPlatformClient,
session_lock: TryLock,
api_lock: TryLock,
session_storage: SessionSharedDataStorage,
contract_storage: Rc<ContractsLocalStorage>,
contracts_cache: ContractsCache,
clock: Rc<dyn Clock>,
tracker: AttemptTracker,
}
#[derive(Error, Debug)]
enum InternalUpdateError {
#[error("Http client error: {0}")]
HttpClientError(#[from] HttpClientError),
#[error("Parsing error: {0}")]
Serde(#[from] serde_json::error::Error),
#[error("Lost the lock while executing async function.")]
LostLock,
#[error("Upstream returned unexpected status code {0}")]
UnexpectedResponse(u32),
}
enum PollContractsResponse {
Continue,
Renegotiate,
Finish,
}
pub(crate) enum PollerType {
Primary,
Secondary,
}
pub(crate) enum PollerError {
LostLock,
DataStorageError,
}
impl ContractValidator {
fn new(
client: HttpPlatformClient,
api_id: String,
clock: Rc<dyn Clock>,
shared_data: Rc<dyn SharedData>,
lock_builder: LockBuilder,
data_storage_builder: DataStorageBuilder,
) -> Self {
let session_storage =
SessionSharedDataStorage::new(Rc::clone(&clock), Rc::clone(&shared_data));
let session_lock = lock_builder
.new(session_storage.session_lock_key())
.expiration(LOCK_TIMEOUT_DURATION)
.shared()
.build();
let contract_storage = Rc::new(ContractsLocalStorage::new(
&api_id,
Rc::clone(&clock),
shared_data,
));
let api_lock = lock_builder
.new(contract_storage.api_lock_key())
.expiration(LOCK_TIMEOUT_DURATION)
.shared()
.build();
let contracts_cache = ContractsCache::new(
Rc::clone(&clock),
data_storage_builder
.shared()
.remote(format!("{api_id}-CONTRACTS"), CONTRACTS_CACHE_EXPIRATION),
Rc::clone(&contract_storage),
);
Self {
api_id,
client,
session_lock,
api_lock,
session_storage,
contract_storage,
clock: Rc::clone(&clock),
contracts_cache,
tracker: AttemptTracker::new(clock, CONTRACTS_FETCHING_INTERVAL),
}
}
pub const INITIALIZATION_PERIOD: Duration = Duration::from_millis(100);
pub const UPDATE_PERIOD: Duration = CONTRACTS_FETCHING_INTERVAL;
pub fn authorize(&self, client_id: &ClientId) -> Result<ClientData, AuthorizationError> {
authorize(self.contract_storage.as_ref(), client_id)
}
pub fn authenticate(
&self,
client_id: &ClientId,
client_secret: &ClientSecret,
) -> Result<ClientData, AuthenticationError> {
authenticate(self.contract_storage.as_ref(), client_id, client_secret)
}
pub fn is_ready(&self) -> bool {
self.contract_storage.last_update().is_some()
}
pub async fn update_contracts(&self) -> Result<(), UpdateError> {
if self.tracker.expired() {
self.tracker.track();
} else {
return Ok(());
}
if !self.should_update() {
debug!("Contracts update skipped since update period hasn't elapsed.");
return Ok(());
}
debug!("Fetching contracts for API {}", self.api_id);
let Some(api_lock) = self.api_lock.try_lock() else {
debug!(
"Other worker has the lock for API {}. Skipping update.",
self.api_id
);
return Ok(());
};
if self.first_cycle() {
let _ = self.cache_contracts_poll(&api_lock).await;
if !api_lock.refresh_lock() {
return Ok(());
}
}
if !self.should_update() {
debug!("Contracts update skipped since update period hasn't elapsed.");
return Ok(());
}
let result = self.platform_contracts_poll(&api_lock).await;
if self.contract_storage.last_update().is_none() {
debug!("No successfully poll registered will not try to backup contracts");
return result.map(|_| ());
}
match self.poller_type(&api_lock).await {
Ok(PollerType::Primary) => {
self.backup_contracts(&api_lock, result.as_ref().map(|r| *r).unwrap_or_default())
.await;
}
Ok(PollerType::Secondary) => {
debug!("No update backup since we are a secondary node.");
}
Err(PollerError::LostLock) => {
debug!("Lost the api_lock while trying to become primary, skipping update.");
}
Err(PollerError::DataStorageError) => {
debug!("Unexpected error communicating with the data storage.");
}
};
result.map(|_| ())
}
async fn platform_contracts_poll(&self, api_lock: &'_ Lock<'_>) -> Result<bool, UpdateError> {
let mut updates = false;
debug!(
"Fetching contracts for API {} from contracts service",
self.api_id
);
let Some(token) = self.session_token().await else {
return Ok(updates);
};
if !api_lock.refresh_lock() {
debug!("Lost the api lock while fetching session token.");
return Ok(updates);
}
let mut token_data = token;
loop {
match self.poll_contracts(&token_data, api_lock).await {
Ok(PollContractsResponse::Continue) => {
debug!("Contract polling request successful. Chaining next request");
updates = true;
}
Ok(PollContractsResponse::Renegotiate) => {
let Some(token) = self.renegotiate_token(token_data).await else {
return Ok(updates);
};
token_data = token;
if !api_lock.refresh_lock() {
debug!("Lost the api lock while refreshing the session lock.");
return Ok(updates);
}
}
Ok(PollContractsResponse::Finish) => {
return Ok(updates);
}
Err(error) => {
debug!("Error while polling contracts: {error}");
return Ok(updates);
}
}
}
}
async fn cache_contracts_poll(&self, api_lock: &'_ Lock<'_>) -> Result<(), UpdateError> {
debug!("Fetching contracts for API {} from cache.", self.api_id);
let result = self.contracts_cache.get_state().await;
if !api_lock.refresh_lock() {
debug!("Lost the api lock while recovering state from remote storage.");
return Ok(());
}
result
.into_iter()
.for_each(|state| self.contract_storage.set_state(state));
Ok(())
}
async fn poller_type(&self, api_lock: &'_ Lock<'_>) -> Result<PollerType, PollerError> {
let Some(primary) = self.contract_storage.is_primary() else {
debug!("No information regarding primary node.");
return self.contracts_cache.try_primary(api_lock).await;
};
let primary_expired = self
.contract_storage
.last_primary_update()
.map(|last| last + PRIMARY_INACTIVE_TIMEOUT < self.clock.get_current_time())
.unwrap_or(true);
if primary_expired && !primary {
debug!("Secondary node trying to become primary due to timeout.");
return self.contracts_cache.try_primary(api_lock).await;
} else if primary_expired {
debug!("We lost the primary status. We'll become secondary for at least one polling cycle.");
self.contract_storage.set_primary(false);
return Ok(PollerType::Secondary);
}
match primary {
true => Ok(PollerType::Primary),
false => Ok(PollerType::Secondary),
}
}
fn first_cycle(&self) -> bool {
self.contract_storage.last_update().is_none()
}
fn should_update(&self) -> bool {
self.contract_storage
.last_update()
.map(|last| last + CONTRACTS_FETCHING_INTERVAL < self.clock.get_current_time())
.unwrap_or(true)
}
async fn backup_contracts(&self, api_lock: &'_ Lock<'_>, has_updates: bool) {
if !self.should_update_backup(has_updates) {
return;
}
let time = self.clock.get_current_time();
let mut update = self.contract_storage.get_state();
update.update_primary(time);
if self.contracts_cache.save_state(update).await {
self.contract_storage.set_primary_update(time);
}
if !api_lock.refresh_lock() {
debug!("Lost the api lock while backing data to cache.");
}
}
fn should_update_backup(&self, has_updates: bool) -> bool {
let Some(last_update) = self.contract_storage.last_update() else {
debug!("Skipping cache backup since no data to backup.");
return false;
};
if has_updates {
debug!("Will backup contracts since new updates are available.");
return true;
}
let Some(last_primary_update) = self.contract_storage.last_primary_update() else {
debug!("No local records of a primary node.");
return true;
};
if last_update < last_primary_update {
debug!("Skipping cache backup since no updates since last save.");
return false; }
let result = last_primary_update + PRIMARY_BACKUP_INTERVAL < self.clock.get_current_time();
if !result {
debug!("Skipping cache backup since the elapsed time is less than the refresh rate.");
}
result
}
async fn session_token(&self) -> Option<AccessToken> {
match self.session_storage.get_token() {
Some(token) => Some(token),
None => {
let Some(session_lock) = self.session_lock.try_lock() else {
debug!("Other worker has the session lock. Skipping update.");
return None;
};
if let Some(token) = self.session_storage.get_token() {
return Some(token);
}
self.fetch_session_token(&session_lock).await
}
}
}
async fn renegotiate_token(&self, old_token: AccessToken) -> Option<AccessToken> {
if let Some(token) = self.session_storage.get_token() {
if token != old_token {
return Some(token);
}
};
let Some(session_lock) = self.session_lock.try_lock() else {
debug!("Other worker has the session lock. Aborting token renegotiation");
return None;
};
if let Some(token) = self.session_storage.get_token() {
if token != old_token {
return Some(token);
}
};
self.fetch_session_token(&session_lock).await
}
async fn fetch_session_token(&self, session_lock: &'_ Lock<'_>) -> Option<AccessToken> {
match self.perform_login_request().await {
Ok(login) => {
if !session_lock.refresh_lock() {
debug!("Lost the session lock. Aborting update.");
return None;
}
let token = login.get_token();
let token_data = AccessToken::new(token.to_string(), login.get_type().to_string());
debug!("Obtained the session token.");
self.session_storage.save_token(token_data.clone());
Some(token_data)
}
Err(e) => {
logger::warn!(
"Unexpected error while performing login request {e}. Skipping update."
);
None
}
}
}
async fn perform_login_request(&self) -> Result<LoginResponse, InternalUpdateError> {
debug!("Getting platform token...");
match self.client.login().await? {
r if r.status_code() == 200 => Ok(serde_json::from_slice::<LoginResponse>(r.body())?),
r => {
debug!(
"Fetching contracts failed with status code: {} and body:\n {}",
r.status_code(),
String::from_utf8_lossy(r.body())
);
Err(InternalUpdateError::UnexpectedResponse(r.status_code()))
}
}
}
async fn poll_contracts(
&self,
access_token: &AccessToken,
api_lock: &'_ Lock<'_>,
) -> Result<PollContractsResponse, InternalUpdateError> {
let token = access_token.get_access_token();
let response = self
.client
.contracts(
token,
self.api_id.as_str(),
ACCEPT_HASH_ALGORITHM_VALUE,
self.next_url(),
)
.await?;
if !api_lock.refresh_lock() {
return Err(InternalUpdateError::LostLock);
}
match response.status_code() {
200 => {
let contracts: ContractsResponse = serde_json::from_slice(response.body())
.map_err(|_| HttpClientError::Status(Status::InternalFailure))?;
if self.no_updates(&contracts) {
self.finish_polling();
Ok(PollContractsResponse::Finish)
} else {
self.log_invalid_contracts(&contracts);
self.update_data(&contracts);
self.update_links(&contracts);
Ok(PollContractsResponse::Continue)
}
}
401 => Ok(PollContractsResponse::Renegotiate),
n => {
debug!(
"Fetching contracts failed with status code: {} and body:\n {}",
n,
String::from_utf8_lossy(response.body())
);
Err(InternalUpdateError::UnexpectedResponse(n))
}
}
}
fn log_invalid_contracts(&self, response: &ContractsResponse) {
for invalid_contract_error_msg in response.verify_contracts().err().unwrap_or_default() {
logger::warn!("{invalid_contract_error_msg}")
}
}
fn no_updates(&self, response: &ContractsResponse) -> bool {
let links = response.get_links();
links.self_link() == links.next_link()
}
fn finish_polling(&self) {
self.contract_storage.update_last();
debug!(
"No more contracts updates for API {}, polling in next tick.",
self.api_id
);
}
fn update_data(&self, response: &ContractsResponse) {
let data = response.get_data();
for contract_event in data {
match contract_event.removed.unwrap_or(false) {
true => self
.contract_storage
.remove_contract(&contract_event.client_id),
false => self
.contract_storage
.save_contract(contract_from_event(contract_event)),
}
}
self.contract_storage.update_last();
debug!(
"{} contract events processed for API {}",
data.len(),
self.api_id
);
}
fn update_links(&self, response: &ContractsResponse) {
let links = response.get_links();
let params = ContractsRequestParams::new(
Some(links.next_link().to_string()),
ACCEPT_HASH_ALGORITHM_VALUE.to_string(),
);
self.contract_storage.save_contracts_request_params(params);
}
fn next_url(&self) -> Option<String> {
self.contract_storage
.get_contracts_request_params()
.and_then(|x| x.next_url)
}
}
#[derive(thiserror::Error, Debug)]
pub enum ExtractionError {
#[error("Api metadata is unavailable.")]
ApiMetadata,
#[error("Environment Context is unavailable.")]
EnvironmentContext,
#[error("Anypoint Context is unavailable.")]
AnypointContext,
}
impl FromContext<ConfigureContext> for ContractValidator {
type Error = ExtractionError;
fn from_context(context: &ConfigureContext) -> Result<Self, Self::Error> {
let metadata: Metadata = context.extract_always();
let api_id = metadata
.api_metadata
.id
.ok_or(ExtractionError::ApiMetadata)?;
let client = context.extract()?;
let clock = context.extract_always();
let shared_data = context.extract_always();
let lock_builder = context.extract_always();
let storage_builder: DataStorageBuilder = context
.extract()
.map_err(|_| ExtractionError::EnvironmentContext)?;
Ok(Self::new(
client,
api_id,
clock,
shared_data,
lock_builder,
storage_builder,
))
}
}