use async_trait::async_trait;
use ocpi::{
types::{self, CredentialsToken, CsString},
Party, PartyStore, Result, Store,
};
use std::{
collections::HashMap,
sync::{
atomic::{AtomicU64, Ordering},
Arc, Mutex,
},
};
use url::Url;
#[derive(Clone)]
pub struct TestParty {
pub id: u64,
pub name: types::CsString<100>,
pub url: Url,
pub token_they_use: types::CsString<64>,
pub token_we_use: types::CsString<64>,
pub version_details: types::VersionDetails,
pub roles: Vec<types::CredentialsRole>,
}
impl Party for TestParty {
type Id = u64;
fn id(&self) -> Self::Id {
self.id
}
fn token_we_use(&self) -> types::CsString<64> {
self.token_we_use.clone()
}
fn token_they_use(&self) -> types::CsString<64> {
self.token_they_use.clone()
}
}
#[derive(Default, Clone)]
pub struct TestStore {
pub id_counter: Arc<AtomicU64>,
pub temp_token_gen: Arc<AtomicU64>,
pub token_gen: Arc<AtomicU64>,
pub reg_tokens: Arc<Mutex<HashMap<String, bool>>>,
pub parties: Arc<Mutex<HashMap<u64, TestParty>>>,
}
impl TestStore {
pub fn create_reg_token(&self) -> CredentialsToken {
let token = format!(
"TEMPTOKEN-{}",
self.temp_token_gen.fetch_add(1, Ordering::Relaxed)
);
let mut lock = self.reg_tokens.lock().expect("Locking");
lock.insert(token.clone(), false);
CredentialsToken::try_from(token).unwrap()
}
pub fn is_reg_token_used(&self, s: impl AsRef<str>) -> Option<bool> {
let lock = self.reg_tokens.lock().expect("Locking");
lock.get(s.as_ref()).copied()
}
pub fn by_token_we_use(&self, token: impl AsRef<str>) -> Option<TestParty> {
let lock = self.parties.lock().expect("Locking Parties");
lock.values()
.find(|tp| tp.token_we_use() == token.as_ref())
.cloned()
}
pub fn by_token_they_use(&self, token: impl AsRef<str>) -> Option<TestParty> {
let lock = self.parties.lock().expect("Locking Parties");
lock.values()
.find(|tp| tp.token_they_use() == token.as_ref())
.cloned()
}
}
#[async_trait]
impl Store for TestStore {
type PartyModel = TestParty;
type RegistrationModel = String;
async fn get_authorized(
&self,
token: types::CredentialsToken,
) -> Result<ocpi::Authorized<Self::PartyModel, Self::RegistrationModel>> {
let mut lock = self.reg_tokens.lock().expect("Locking");
if let Some(b) = lock.get_mut(token.as_ref()) {
*b = true;
return Ok(ocpi::Authorized::Registration(token.to_string()));
}
drop(lock);
if let Some(party) = self.by_token_they_use(&token) {
return Ok(ocpi::Authorized::Party(party));
}
Err(ocpi::Error::unauthorized("Invalid token"))
}
}
#[async_trait]
impl PartyStore for TestStore {
async fn delete_party(&self, party_id: <Self::PartyModel as Party>::Id) -> Result<()> {
let mut lock = self.parties.lock().expect("locking");
lock.remove(&party_id);
Ok(())
}
async fn save_new_party(
&self,
_temporary_model: Self::RegistrationModel,
credentials: types::Credential,
version_details: types::VersionDetails,
) -> Result<Self::PartyModel> {
let id = self.id_counter.fetch_add(1, Ordering::Relaxed);
let party = TestParty {
id,
name: credentials
.roles
.get(0)
.expect("At least one role must be present")
.business_details
.name
.clone(),
url: credentials.url,
token_we_use: credentials.token,
token_they_use: self.generate_token(),
roles: credentials.roles,
version_details,
};
let mut lock = self.parties.lock().expect("Locking parties");
lock.insert(id, party.clone());
Ok(party)
}
async fn update_party(
&self,
model: Self::PartyModel,
credentials: types::Credential,
details: types::VersionDetails,
) -> Result<Self::PartyModel> {
let mut lock = self.parties.lock().expect("Locking parties");
let existing = lock
.get_mut(&model.id)
.ok_or_else(|| ocpi::Error::client_generic("Party not found"))?;
existing.token_we_use = credentials.token;
existing.name = credentials.roles[0].business_details.name.clone();
existing.url = credentials.url;
existing.roles = credentials.roles;
existing.version_details = details;
existing.token_they_use = self.generate_token();
Ok(existing.clone())
}
async fn get_our_roles(&self) -> Result<Vec<types::CredentialsRole>> {
Ok(vec![types::CredentialsRole {
role: types::Role::Cpo,
business_details: types::BusinessDetails {
name: "TestStore details".parse().expect("Parsing name"),
website: None,
logo: None,
},
party_id: "EXA".parse().expect("PartyId"),
country_code: "se".parse().expect("CountryCode"),
}])
}
}
impl TestStore {
fn generate_token(&self) -> CsString<64> {
format!("TOKEN-{}", self.token_gen.fetch_add(1, Ordering::Relaxed))
.parse::<types::CsString<64>>()
.expect("Token")
}
}