use async_trait::async_trait;
use itertools::Itertools;
use lazy_static::lazy_static;
use std::{
collections::HashMap,
env, fmt,
path::PathBuf,
time::{Duration, Instant},
};
use tokio::{fs, sync::Mutex};
use crate::common::*;
use crate::config::config_dir;
#[derive(Clone, Debug)]
pub(crate) struct Credentials {
data: HashMap<String, String>,
expires: Option<Instant>,
}
impl Credentials {
pub(crate) fn get_required<'a>(&'a self, key: &str) -> Result<&'a str> {
self.get_optional(key)
.ok_or_else(|| format_err!("no key {:?} in credential", key))
}
pub(crate) fn get_optional<'a>(&'a self, key: &str) -> Option<&'a str> {
self.data.get(key).map(|v| &v[..])
}
fn needs_refresh(&self) -> bool {
if let Some(expires) = self.expires {
let minimum_useful_expiration =
Instant::now() + Duration::from_secs(30 * 60);
expires < minimum_useful_expiration
} else {
false
}
}
}
lazy_static! {
static ref MANAGER: CredentialsManager = CredentialsManager::new().unwrap();
}
pub(crate) struct CredentialsManager {
sources: HashMap<String, Mutex<Box<dyn CredentialsSource>>>,
cache: Mutex<HashMap<String, Credentials>>,
}
impl CredentialsManager {
pub(crate) fn singleton() -> &'static CredentialsManager {
&MANAGER
}
fn new() -> Result<CredentialsManager> {
let mut sources = HashMap::new();
let config_dir = config_dir()?;
let aws = EnvCredentialsSource::new(vec![
EnvMapping::required("access_key_id", "AWS_ACCESS_KEY_ID"),
EnvMapping::required("secret_access_key", "AWS_SECRET_ACCESS_KEY"),
EnvMapping::optional("session_token", "AWS_SESSION_TOKEN"),
EnvMapping::required("default_region", "AWS_DEFAULT_REGION"),
]);
sources.insert("aws".to_owned(), Mutex::new(aws.boxed()));
let gcloud_service_account_key = CredentialsSources::new(vec![
EnvCredentialsSource::new(vec![EnvMapping::required(
"value",
"GCLOUD_SERVICE_ACCOUNT_KEY",
)])
.boxed(),
FileCredentialsSource::new(
"value",
config_dir.join("gcloud_service_account_key.json"),
)
.boxed(),
]);
sources.insert(
"gcloud_service_account_key".to_owned(),
Mutex::new(gcloud_service_account_key.boxed()),
);
let gcloud_client_secret = CredentialsSources::new(vec![
EnvCredentialsSource::new(vec![EnvMapping::required(
"value",
"GCLOUD_CLIENT_SECRET",
)])
.boxed(),
FileCredentialsSource::new(
"value",
config_dir.join("gcloud_client_secret.json"),
)
.boxed(),
]);
sources.insert(
"gcloud_client_secret".to_owned(),
Mutex::new(gcloud_client_secret.boxed()),
);
let shopify_secret = EnvCredentialsSource::new(vec![EnvMapping::required(
"auth_token",
"SHOPIFY_AUTH_TOKEN",
)]);
sources.insert("shopify".to_owned(), Mutex::new(shopify_secret.boxed()));
let cache = Mutex::new(HashMap::new());
Ok(CredentialsManager { sources, cache })
}
pub(crate) async fn get(&self, name: &str) -> Result<Credentials> {
let source = self
.sources
.get(name)
.ok_or_else(|| format_err!("unknown credential {:?}", name))?;
let source = source.lock().await;
let credentials: Option<Credentials> =
self.cache.lock().await.get(name).map(|c| c.to_owned());
match credentials {
Some(c) if !c.needs_refresh() => return Ok(c),
_ => {}
}
if let Some(c) = source.get_credentials().await? {
self.cache.lock().await.insert(name.to_owned(), c.clone());
Ok(c)
} else {
Err(format_err!(
"could not find credentials for {} in any of:\n{}",
name,
source,
))
}
}
}
#[async_trait]
trait CredentialsSource: fmt::Debug + fmt::Display + Send + Sync + 'static {
async fn get_credentials(&self) -> Result<Option<Credentials>>;
}
trait CredentialsSourceExt: CredentialsSource + Sized + 'static {
fn boxed(self) -> Box<dyn CredentialsSource> {
Box::new(self)
}
}
impl<CS: CredentialsSource> CredentialsSourceExt for CS {}
#[derive(Debug)]
struct EnvMapping {
key: &'static str,
var: &'static str,
optional: bool,
}
impl EnvMapping {
fn required(key: &'static str, var: &'static str) -> Self {
Self {
key,
var,
optional: false,
}
}
fn optional(key: &'static str, var: &'static str) -> Self {
Self {
key,
var,
optional: true,
}
}
}
impl fmt::Display for EnvMapping {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
if self.optional {
write!(f, "(optional) {}", self.var)
} else {
write!(f, "{}", self.var)
}
}
}
#[derive(Debug)]
struct EnvCredentialsSource {
mapping: Vec<EnvMapping>,
}
impl EnvCredentialsSource {
fn new(mapping: Vec<EnvMapping>) -> Self {
assert!(!mapping.is_empty());
assert!(!mapping[0].optional);
Self { mapping }
}
}
impl fmt::Display for EnvCredentialsSource {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
if self.mapping.len() == 1 {
writeln!(f, "- The environment variable {}", &self.mapping[0])
} else {
writeln!(
f,
"- The environment variables {}",
self.mapping.iter().join(", "),
)
}
}
}
#[async_trait]
impl CredentialsSource for EnvCredentialsSource {
async fn get_credentials(&self) -> Result<Option<Credentials>> {
if let Some(value) = try_var(self.mapping[0].var)? {
let mut data = HashMap::new();
data.insert(self.mapping[0].key.to_owned(), value);
for m in &self.mapping[1..] {
if m.optional {
if let Some(value) = try_var(m.var)? {
data.insert(m.key.to_owned(), value);
}
} else {
data.insert(m.key.to_owned(), var(m.var)?);
}
}
Ok(Some(Credentials {
data,
expires: None,
}))
} else {
Ok(None)
}
}
}
fn try_var(name: &str) -> Result<Option<String>> {
match env::var(name) {
Ok(value) => Ok(Some(value)),
Err(env::VarError::NotPresent) => Ok(None),
Err(env::VarError::NotUnicode(..)) => Err(format_err!(
"environment variable {} cannot be converted to UTF-8",
name,
)),
}
}
fn var(name: &str) -> Result<String> {
match try_var(name)? {
Some(value) => Ok(value),
None => Err(format_err!(
"expected environment variable {} to be set",
name,
)),
}
}
#[derive(Debug)]
struct FileCredentialsSource {
key: &'static str,
path: PathBuf,
}
impl FileCredentialsSource {
fn new(key: &'static str, path: PathBuf) -> Self {
Self { key, path }
}
}
impl fmt::Display for FileCredentialsSource {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
writeln!(f, "- The file {}", self.path.display())
}
}
#[async_trait]
impl CredentialsSource for FileCredentialsSource {
async fn get_credentials(&self) -> Result<Option<Credentials>> {
match fs::read_to_string(&self.path).await {
Ok(value) => {
let mut data = HashMap::new();
data.insert(self.key.to_owned(), value);
Ok(Some(Credentials {
data,
expires: None,
}))
}
Err(err) if err.kind() == io::ErrorKind::NotFound => Ok(None),
Err(err) => Err(format_err!(
"error reading {}: {}",
self.path.display(),
err
)),
}
}
}
#[derive(Debug)]
struct CredentialsSources {
sources: Vec<Box<dyn CredentialsSource>>,
}
impl CredentialsSources {
fn new(sources: Vec<Box<dyn CredentialsSource>>) -> Self {
Self { sources }
}
}
impl fmt::Display for CredentialsSources {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
for s in &self.sources {
write!(f, "{}", s)?;
}
Ok(())
}
}
#[async_trait]
impl CredentialsSource for CredentialsSources {
async fn get_credentials(&self) -> Result<Option<Credentials>> {
for source in &self.sources {
if let Some(credentials) = source.get_credentials().await? {
return Ok(Some(credentials));
}
}
Ok(None)
}
}