#![doc(html_favicon_url = "https://cipherstash.com/favicon.ico")]
#![doc = include_str!("../README.md")]
#![deny(unsafe_code)]
#![warn(clippy::unwrap_used)]
#![warn(clippy::expect_used)]
#![warn(clippy::panic)]
#![warn(clippy::mem_forget)]
#![warn(clippy::print_stdout)]
#![warn(clippy::print_stderr)]
#![warn(clippy::dbg_macro)]
#![warn(unreachable_pub)]
#![warn(unused_results)]
#![warn(clippy::todo)]
#![warn(clippy::unimplemented)]
#![cfg_attr(test, allow(clippy::unwrap_used))]
#![cfg_attr(test, allow(clippy::expect_used))]
#![cfg_attr(test, allow(clippy::panic))]
#![cfg_attr(test, allow(unused_results))]
use std::convert::Infallible;
use std::future::Future;
#[cfg(all(not(any(test, feature = "test-utils")), not(target_arch = "wasm32")))]
use std::time::Duration;
use vitaminc::protected::OpaqueDebug;
use zeroize::ZeroizeOnDrop;
mod access_key;
mod access_key_refresher;
mod access_key_strategy;
mod auto_refresh;
mod auto_strategy;
mod oauth_refresher;
mod oauth_strategy;
mod refresher;
mod service_token;
mod token;
#[cfg(not(target_arch = "wasm32"))]
mod device_client;
#[cfg(not(target_arch = "wasm32"))]
mod device_code;
#[cfg(any(test, feature = "test-utils"))]
mod static_token_strategy;
pub use access_key::{AccessKey, InvalidAccessKey};
pub use access_key_strategy::{AccessKeyStrategy, AccessKeyStrategyBuilder};
pub use auto_strategy::{AutoStrategy, AutoStrategyBuilder};
pub use oauth_strategy::{OAuthStrategy, OAuthStrategyBuilder};
pub use service_token::ServiceToken;
#[cfg(any(test, feature = "test-utils"))]
pub use static_token_strategy::StaticTokenStrategy;
pub use token::Token;
#[cfg(not(target_arch = "wasm32"))]
pub use device_client::{bind_client_device, DeviceClientError};
#[cfg(not(target_arch = "wasm32"))]
pub use device_code::{DeviceCodeStrategy, DeviceCodeStrategyBuilder, PendingDeviceCode};
#[cfg(not(target_arch = "wasm32"))]
pub use stack_profile::DeviceIdentity;
#[cfg_attr(doc, aquamarine::aquamarine)]
#[cfg(not(target_arch = "wasm32"))]
pub trait AuthStrategy: Send {
fn get_token(self) -> impl Future<Output = Result<ServiceToken, AuthError>> + Send;
}
#[cfg(target_arch = "wasm32")]
pub trait AuthStrategy {
fn get_token(self) -> impl Future<Output = Result<ServiceToken, AuthError>>;
}
#[derive(Clone, OpaqueDebug, ZeroizeOnDrop, serde::Deserialize, serde::Serialize)]
#[serde(transparent)]
pub struct SecretToken(String);
impl SecretToken {
pub fn new(value: impl Into<String>) -> Self {
Self(value.into())
}
pub fn as_str(&self) -> &str {
&self.0
}
}
#[derive(Debug, thiserror::Error, miette::Diagnostic)]
#[non_exhaustive]
pub enum AuthError {
#[error("HTTP request failed: {0}")]
Request(#[from] reqwest::Error),
#[error("Authorization was denied")]
AccessDenied,
#[error("Invalid grant")]
InvalidGrant,
#[error("Invalid client")]
InvalidClient,
#[error("Invalid URL: {0}")]
InvalidUrl(#[from] url::ParseError),
#[error("Unsupported region: {0}")]
Region(#[from] cts_common::RegionError),
#[error("Invalid workspace CRN: {0}")]
InvalidCrn(cts_common::InvalidCrn),
#[error("Workspace CRN is required when using an access key — set CS_WORKSPACE_CRN or call AutoStrategyBuilder::with_workspace_crn")]
MissingWorkspaceCrn,
#[error("Not authenticated")]
NotAuthenticated,
#[error("Token expired")]
TokenExpired,
#[error("Invalid access key: {0}")]
InvalidAccessKey(#[from] access_key::InvalidAccessKey),
#[error("Invalid token: {0}")]
InvalidToken(String),
#[error("Server error: {0}")]
Server(String),
#[cfg(not(target_arch = "wasm32"))]
#[error("Token store error: {0}")]
Store(#[from] stack_profile::ProfileError),
}
impl AuthError {
pub fn error_code(&self) -> &'static str {
match self {
Self::Request(_) => "REQUEST_ERROR",
Self::AccessDenied => "ACCESS_DENIED",
Self::TokenExpired => "EXPIRED_TOKEN",
Self::InvalidGrant => "INVALID_GRANT",
Self::InvalidClient => "INVALID_CLIENT",
Self::InvalidUrl(_) => "INVALID_URL",
Self::Region(_) => "INVALID_REGION",
Self::InvalidToken(_) => "INVALID_TOKEN",
Self::Server(_) => "SERVER_ERROR",
Self::NotAuthenticated => "NOT_AUTHENTICATED",
Self::MissingWorkspaceCrn => "MISSING_WORKSPACE_CRN",
Self::InvalidAccessKey(_) => "INVALID_ACCESS_KEY",
Self::InvalidCrn(_) => "INVALID_CRN",
#[cfg(not(target_arch = "wasm32"))]
Self::Store(_) => "STORE_ERROR",
}
}
}
impl From<Infallible> for AuthError {
fn from(never: Infallible) -> Self {
match never {}
}
}
pub(crate) fn cts_base_url_from_env() -> Result<Option<url::Url>, AuthError> {
match std::env::var("CS_CTS_HOST") {
Ok(val) if !val.is_empty() => Ok(Some(val.parse()?)),
_ => Ok(None),
}
}
pub(crate) fn ensure_trailing_slash(mut url: url::Url) -> url::Url {
if !url.path().ends_with('/') {
url.set_path(&format!("{}/", url.path()));
}
url
}
#[cfg(target_arch = "wasm32")]
pub(crate) fn decode_jwt_payload_wasm<C>(token: &str) -> Result<C, AuthError>
where
C: serde::de::DeserializeOwned,
{
use base64::Engine;
let segments: Vec<&str> = token.split('.').collect();
if segments.len() != 3 {
return Err(AuthError::InvalidToken(
"JWT must have three segments".to_string(),
));
}
let payload = base64::engine::general_purpose::URL_SAFE_NO_PAD
.decode(segments[1])
.map_err(|e| AuthError::InvalidToken(format!("base64 decode failed: {e}")))?;
serde_json::from_slice(&payload)
.map_err(|e| AuthError::InvalidToken(format!("failed to decode JWT claims: {e}")))
}
#[cfg(any(test, feature = "test-utils"))]
pub(crate) fn http_client() -> reqwest::Client {
reqwest::Client::builder()
.build()
.unwrap_or_else(|_| reqwest::Client::new())
}
#[cfg(all(not(any(test, feature = "test-utils")), not(target_arch = "wasm32")))]
pub(crate) fn http_client() -> reqwest::Client {
reqwest::Client::builder()
.connect_timeout(Duration::from_secs(10))
.timeout(Duration::from_secs(30))
.pool_idle_timeout(Duration::from_secs(5))
.pool_max_idle_per_host(10)
.build()
.unwrap_or_else(|_| reqwest::Client::new())
}
#[cfg(all(not(any(test, feature = "test-utils")), target_arch = "wasm32"))]
pub(crate) fn http_client() -> reqwest::Client {
reqwest::Client::builder()
.build()
.unwrap_or_else(|_| reqwest::Client::new())
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn auth_error_code_known_variants() {
assert_eq!(AuthError::AccessDenied.error_code(), "ACCESS_DENIED");
assert_eq!(AuthError::TokenExpired.error_code(), "EXPIRED_TOKEN");
assert_eq!(AuthError::InvalidGrant.error_code(), "INVALID_GRANT");
assert_eq!(AuthError::InvalidClient.error_code(), "INVALID_CLIENT");
assert_eq!(
AuthError::NotAuthenticated.error_code(),
"NOT_AUTHENTICATED"
);
assert_eq!(
AuthError::MissingWorkspaceCrn.error_code(),
"MISSING_WORKSPACE_CRN"
);
assert_eq!(AuthError::Server("x".into()).error_code(), "SERVER_ERROR");
assert_eq!(
AuthError::InvalidToken("malformed".into()).error_code(),
"INVALID_TOKEN"
);
}
}