use std::path::PathBuf;
use std::sync::Arc;
use futures::future::try_join;
use jsonwebtoken::{decode, errors::Error as JWTError, Algorithm, DecodingKey, Validation};
use serde::{Deserialize, Serialize};
use tokio::sync::{Mutex, MutexGuard};
#[cfg(feature = "tracing")]
use urlpattern::UrlPatternMatchInput;
pub use builder::{BuildError, ClientConfigurationBuilder};
use secrets::Secrets;
pub use secrets::SECRETS_PATH_VAR;
use settings::Settings;
pub use settings::{AuthServer, SETTINGS_PATH_VAR};
use crate::configuration::LoadError::AuthServerNotFound;
#[cfg(feature = "tracing-config")]
use crate::tracing_configuration::{TracingConfiguration, TracingFilterError};
mod builder;
mod path;
mod secrets;
mod settings;
pub const DEFAULT_API_URL: &str = "https://api.qcs.rigetti.com";
pub const DEFAULT_GRPC_API_URL: &str = "https://grpc.qcs.rigetti.com";
pub const DEFAULT_QVM_URL: &str = "http://127.0.0.1:5000";
pub const DEFAULT_QUILC_URL: &str = "tcp://127.0.0.1:5555";
pub const DEFAULT_PROFILE_NAME: &str = "default";
pub const QUILC_URL_VAR: &str = "QCS_SETTINGS_APPLICATIONS_QUILC_URL";
pub const QVM_URL_VAR: &str = "QCS_SETTINGS_APPLICATIONS_QVM_URL";
pub const GRPC_API_URL_VAR: &str = "QCS_SETTINGS_APPLICATIONS_GRPC_URL";
#[derive(Clone, Debug, Default, PartialEq, Eq)]
pub struct Tokens {
pub bearer_access_token: Option<String>,
pub refresh_token: Option<String>,
}
#[derive(Clone, Debug)]
#[allow(clippy::module_name_repetitions)]
pub struct ClientConfiguration {
tokens: Arc<Mutex<Tokens>>,
api_url: String,
auth_server: AuthServer,
grpc_api_url: String,
quilc_url: String,
qvm_url: String,
#[cfg(feature = "tracing-config")]
tracing_configuration: Option<TracingConfiguration>,
}
impl ClientConfiguration {
#[must_use]
#[allow(clippy::missing_const_for_fn)]
pub fn builder() -> ClientConfigurationBuilder {
ClientConfigurationBuilder::default()
}
#[must_use]
pub fn api_url(&self) -> &str {
&self.api_url
}
#[must_use]
pub fn grpc_api_url(&self) -> &str {
&self.grpc_api_url
}
#[must_use]
pub fn quilc_url(&self) -> &str {
&self.quilc_url
}
#[must_use]
pub fn qvm_url(&self) -> &str {
&self.qvm_url
}
#[cfg(feature = "tracing-config")]
#[must_use]
#[allow(clippy::missing_const_for_fn)]
pub fn tracing_configuration(&self) -> Option<&TracingConfiguration> {
self.tracing_configuration.as_ref()
}
}
pub const PROFILE_NAME_VAR: &str = "QCS_PROFILE_NAME";
#[derive(Debug, thiserror::Error)]
pub enum RefreshError {
#[error("No refresh token is configured within selected QCS credential")]
NoRefreshToken,
#[error("Error fetching new token")]
FetchError(#[from] reqwest::Error),
#[error("Error validating existing token: {0}")]
ValidationError(#[from] JWTError),
}
#[derive(Debug, thiserror::Error)]
pub enum LoadError {
#[error("Expected profile {0} in settings.profiles but it didn't exist")]
ProfileNotFound(String),
#[error("Expected auth server {0} in settings.auth_servers but it didn't exist")]
AuthServerNotFound(String),
#[error("Failed to determine home directory. You can use an explicit path by setting the {env} environment variable")]
HomeDirError {
env: String,
},
#[error("Could not open file at {path}: {source}")]
FileOpenError {
path: PathBuf,
source: std::io::Error,
},
#[error("Could not parse TOML file at {path}: {source}")]
FileParseError {
path: PathBuf,
source: toml::de::Error,
},
#[cfg(feature = "tracing-config")]
#[error("Could not parse tracing filter: {0}")]
TracingFilterParseError(TracingFilterError),
}
impl ClientConfiguration {
pub async fn load_default() -> Result<Self, LoadError> {
Self::load(None).await
}
pub async fn load_profile(profile_name: String) -> Result<Self, LoadError> {
Self::load(Some(profile_name)).await
}
#[inline]
async fn load(profile_name: Option<String>) -> Result<Self, LoadError> {
#[cfg(feature = "tracing")]
#[allow(clippy::option_if_let_else)]
match profile_name.as_ref() {
None => tracing::debug!("loading default QCS profile"),
Some(profile) => tracing::debug!("loading QCS profile {:?}", profile),
}
let (settings, secrets) = try_join(settings::load(), secrets::load()).await?;
Self::new(settings, secrets, profile_name)
}
fn validated_bearer_access_token(lock: &MutexGuard<Tokens>) -> Option<String> {
#[allow(clippy::option_if_let_else)]
lock.bearer_access_token.as_ref().and_then(|token| {
let dummy_key = DecodingKey::from_secret(&[]);
let mut validation = Validation::new(Algorithm::RS256);
validation.validate_exp = true;
validation.leeway = 0;
validation.insecure_disable_signature_validation();
decode::<toml::Value>(token, &dummy_key, &validation)
.map(|_| token.clone())
.ok()
})
}
pub async fn get_bearer_access_token(&self) -> Result<String, RefreshError> {
let mut lock = self.tokens.lock().await;
let validation = Self::validated_bearer_access_token(&lock);
match validation {
Some(token) => Ok(token),
None => self.internal_refresh(&mut lock).await,
}
}
pub async fn refresh(&self) -> Result<String, RefreshError> {
let mut lock = self.tokens.lock().await;
self.internal_refresh(&mut lock).await
}
async fn internal_refresh<'a>(
&'a self,
lock: &mut MutexGuard<'a, Tokens>,
) -> Result<String, RefreshError> {
#[cfg(feature = "tracing")]
tracing::trace!("refreshing QCS access token");
let refresh_token = lock
.refresh_token
.as_deref()
.ok_or(RefreshError::NoRefreshToken)?;
let token_url = format!("{}/v1/token", &self.auth_server.issuer());
let data = TokenRequest::new(self.auth_server.client_id(), refresh_token);
let resp = reqwest::Client::builder()
.user_agent(format!(
"QCS API Client (Rust)/{}",
env!("CARGO_PKG_VERSION")
))
.timeout(std::time::Duration::from_secs(10))
.build()?
.post(token_url)
.form(&data)
.send()
.await?;
let response_data: TokenResponse = resp.error_for_status()?.json().await?;
lock.bearer_access_token = Some(response_data.access_token.clone());
lock.refresh_token = Some(response_data.refresh_token);
Ok(response_data.access_token)
}
fn new(
settings: Settings,
mut secrets: Secrets,
profile_name: Option<String>,
) -> Result<Self, LoadError> {
let Settings {
default_profile_name,
mut profiles,
mut auth_servers,
} = settings;
let profile_name = profile_name
.or_else(|| std::env::var(PROFILE_NAME_VAR).ok())
.unwrap_or(default_profile_name);
let profile = profiles
.remove(&profile_name)
.ok_or(LoadError::ProfileNotFound(profile_name))?;
let auth_server = auth_servers
.remove(&profile.auth_server_name)
.ok_or_else(|| AuthServerNotFound(profile.auth_server_name.clone()))?;
let credential = secrets.credentials.remove(&profile.credentials_name);
let (access_token, refresh_token) = match credential {
Some(secrets::Credential {
token_payload: Some(token_payload),
}) => (token_payload.access_token, token_payload.refresh_token),
_ => (None, None),
};
let quilc_url =
std::env::var(QUILC_URL_VAR).unwrap_or(profile.applications.pyquil.quilc_url);
let qvm_url = std::env::var(QVM_URL_VAR).unwrap_or(profile.applications.pyquil.qvm_url);
let grpc_api_url = std::env::var(GRPC_API_URL_VAR).unwrap_or(profile.grpc_api_url);
let tokens = Tokens {
bearer_access_token: access_token,
refresh_token,
};
#[cfg(feature = "tracing-config")]
let tracing_configuration =
TracingConfiguration::from_env().map_err(LoadError::TracingFilterParseError)?;
let mut builder = Self::builder();
builder = builder
.set_tokens(tokens)
.set_auth_server(auth_server)
.set_api_url(profile.api_url)
.set_quilc_url(quilc_url)
.set_qvm_url(qvm_url)
.set_grpc_api_url(grpc_api_url);
#[cfg(feature = "tracing-config")]
{
builder = builder.set_tracing_configuration(tracing_configuration);
};
Ok({
builder
.build()
.expect("curated build process should not fail")
})
}
}
#[derive(Debug, Serialize)]
struct TokenRequest<'a> {
grant_type: &'static str,
client_id: &'a str,
refresh_token: &'a str,
}
impl<'a> TokenRequest<'a> {
const fn new(client_id: &'a str, refresh_token: &'a str) -> TokenRequest<'a> {
Self {
grant_type: "refresh_token",
client_id,
refresh_token,
}
}
}
#[derive(Deserialize, Debug)]
struct TokenResponse {
refresh_token: String,
access_token: String,
}
impl Default for ClientConfiguration {
fn default() -> Self {
Self::builder()
.build()
.expect("a builder without anything set should build without error")
}
}
#[async_trait::async_trait]
pub trait TokenRefresher: Clone {
type Error;
async fn get_access_token(&self) -> Result<String, Self::Error>;
async fn refresh_access_token(&self) -> Result<String, Self::Error>;
#[cfg(feature = "tracing")]
fn base_url(&self) -> &str;
#[cfg(feature = "tracing-config")]
fn tracing_configuration(&self) -> Option<&TracingConfiguration>;
#[cfg(feature = "tracing")]
#[allow(clippy::needless_return)]
fn should_trace(&self, url: &UrlPatternMatchInput) -> bool {
#[cfg(not(feature = "tracing-config"))]
{
let _ = url;
return true;
}
#[cfg(feature = "tracing-config")]
self.tracing_configuration()
.map_or(true, |config| config.is_enabled(url))
}
}
#[async_trait::async_trait]
impl TokenRefresher for ClientConfiguration {
type Error = RefreshError;
async fn refresh_access_token(&self) -> Result<String, Self::Error> {
self.refresh().await
}
async fn get_access_token(&self) -> Result<String, Self::Error> {
self.get_bearer_access_token().await
}
#[cfg(feature = "tracing")]
fn base_url(&self) -> &str {
&self.grpc_api_url
}
#[cfg(feature = "tracing-config")]
fn tracing_configuration(&self) -> Option<&TracingConfiguration> {
self.tracing_configuration()
}
}
#[cfg(test)]
mod describe_client_configuration_load {
use serial_test::serial;
#[allow(clippy::wildcard_imports)]
use crate::configuration::*;
#[tokio::test]
#[serial]
async fn it_uses_env_var_overrides() {
let quilc_url = "tcp://quilc:5555";
let qvm_url = "http://qvm:5000";
let grpc_url = "http://grpc:80";
std::env::set_var(QUILC_URL_VAR, quilc_url);
std::env::set_var(QVM_URL_VAR, qvm_url);
std::env::set_var(GRPC_API_URL_VAR, grpc_url);
let config = ClientConfiguration::new(Settings::default(), Secrets::default(), None)
.expect("config should load successfully");
assert_eq!(config.quilc_url, quilc_url);
assert_eq!(config.qvm_url, qvm_url);
assert_eq!(config.grpc_api_url, grpc_url);
}
#[test]
#[serial]
fn test_default_uses_env_var_overrides() {
let quilc_url = "quilc_url";
let qvm_url = "qvm_url";
let grpc_url = "grpc_url";
std::env::set_var(QUILC_URL_VAR, quilc_url);
std::env::set_var(QVM_URL_VAR, qvm_url);
std::env::set_var(GRPC_API_URL_VAR, grpc_url);
let config = ClientConfiguration::default();
assert_eq!(config.quilc_url, quilc_url);
assert_eq!(config.qvm_url, qvm_url);
assert_eq!(config.grpc_api_url, grpc_url);
}
}