use std::sync::atomic::{AtomicU64, Ordering};
use std::sync::Arc;
use std::time::{Duration, Instant};
use arc_swap::ArcSwapOption;
use futures::lock::Mutex;
use secrecy::{ExposeSecret, SecretString};
#[cfg(feature = "cert-auth")]
use snowflake_jwt::generate_jwt_token;
use thiserror::Error;
use crate::connection;
use crate::connection::{Connection, QueryType};
#[cfg(feature = "browser-auth")]
use crate::requests::{
AuthenticatorRequest, AuthenticatorRequestData, BrowserLoginRequest, BrowserRequestData,
};
#[cfg(feature = "cert-auth")]
use crate::requests::{CertLoginRequest, CertRequestData};
use crate::requests::{
ClientEnvironment, LoginRequest, LoginRequestCommon, PasswordLoginRequest, PasswordRequestData,
RenewSessionRequest, SessionParameters,
};
use crate::responses::{AuthResponse, BaseRestResponse};
#[derive(Error, Debug)]
pub enum AuthError {
#[error(transparent)]
#[cfg(feature = "cert-auth")]
JwtError(#[from] snowflake_jwt::JwtError),
#[error(transparent)]
RequestError(#[from] connection::ConnectionError),
#[error("Environment variable `{0}` is required, but were not set")]
MissingEnvArgument(String),
#[error("Password auth was requested, but password wasn't provided")]
MissingPassword,
#[error("Certificate auth was requested, but certificate wasn't provided")]
MissingCertificate,
#[error("Unexpected API response")]
UnexpectedResponse,
#[error("Failed to authenticate. Error code: {0}. Message: {1}")]
AuthFailed(String, String),
#[error("Can not renew closed session token")]
OutOfOrderRenew,
#[error("Failed to exchange or request a new token")]
TokenFetchFailed,
#[error("Login timed out after {0:?}")]
LoginTimeout(Duration),
#[error("Enable the cert-auth feature to use certificate authentication")]
CertAuthNotEnabled,
#[error("Enable the browser-auth feature to use external browser authentication")]
BrowserAuthNotEnabled,
#[cfg(feature = "browser-auth")]
#[error(transparent)]
BrowserAuthError(#[from] crate::browser::BrowserAuthError),
}
#[derive(Debug)]
struct AuthState {
session_token: AuthToken,
master_token: AuthToken,
auth_header: String,
}
impl AuthState {
fn new(session_token: AuthToken, master_token: AuthToken) -> Self {
let auth_header = session_token.auth_header();
Self {
session_token,
master_token,
auth_header,
}
}
fn is_fresh(&self) -> bool {
!self.master_token.is_expired() && !self.session_token.is_expired()
}
}
#[derive(Clone)]
struct AuthToken {
token: SecretString,
valid_for: Duration,
issued_on: Instant,
}
impl std::fmt::Debug for AuthToken {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("AuthToken")
.field("token", &"[REDACTED]")
.field("valid_for", &self.valid_for)
.field("issued_on", &self.issued_on)
.finish()
}
}
#[derive(Debug, Clone)]
pub struct AuthParts {
pub session_token_auth_header: String,
pub sequence_id: u64,
}
impl AuthToken {
pub fn new(token: &str, validity_in_seconds: i64) -> Self {
let token = SecretString::from(token);
let valid_for = if validity_in_seconds < 0 {
Duration::from_secs(u64::MAX)
} else {
Duration::from_secs(u64::try_from(validity_in_seconds).unwrap_or(u64::MAX))
};
let issued_on = Instant::now();
Self {
token,
valid_for,
issued_on,
}
}
pub fn is_expired(&self) -> bool {
Instant::now().duration_since(self.issued_on) >= self.valid_for
}
pub fn auth_header(&self) -> String {
format!("Snowflake Token=\"{}\"", self.token.expose_secret())
}
}
enum AuthType {
Certificate,
Password,
#[cfg(feature = "browser-auth")]
Browser,
}
pub struct Session {
connection: Arc<Connection>,
auth_state: ArcSwapOption<AuthState>,
sequence_id: AtomicU64,
refresh_lock: Mutex<()>,
auth_type: AuthType,
account_identifier: String,
warehouse: Option<String>,
database: Option<String>,
schema: Option<String>,
username: String,
role: Option<String>,
#[allow(dead_code)]
private_key_pem: Option<SecretString>,
password: Option<SecretString>,
login_timeout: Duration,
}
const DEFAULT_LOGIN_TIMEOUT: Duration = Duration::from_secs(300);
impl Session {
#[allow(clippy::too_many_arguments)]
pub fn cert_auth(
connection: Arc<Connection>,
account_identifier: &str,
warehouse: Option<&str>,
database: Option<&str>,
schema: Option<&str>,
username: &str,
role: Option<&str>,
private_key_pem: SecretString,
) -> Self {
let account_identifier = account_identifier.to_uppercase();
let database = database.map(str::to_uppercase);
let schema = schema.map(str::to_uppercase);
let username = username.to_uppercase();
let role = role.map(str::to_uppercase);
Self {
connection,
auth_state: ArcSwapOption::empty(),
sequence_id: AtomicU64::new(0),
refresh_lock: Mutex::new(()),
auth_type: AuthType::Certificate,
private_key_pem: Some(private_key_pem),
account_identifier,
warehouse: warehouse.map(str::to_uppercase),
database,
username,
role,
schema,
password: None,
login_timeout: DEFAULT_LOGIN_TIMEOUT,
}
}
#[allow(clippy::too_many_arguments)]
pub fn password_auth(
connection: Arc<Connection>,
account_identifier: &str,
warehouse: Option<&str>,
database: Option<&str>,
schema: Option<&str>,
username: &str,
role: Option<&str>,
password: SecretString,
) -> Self {
let account_identifier = account_identifier.to_uppercase();
let database = database.map(str::to_uppercase);
let schema = schema.map(str::to_uppercase);
let username = username.to_uppercase();
let role = role.map(str::to_uppercase);
Self {
connection,
auth_state: ArcSwapOption::empty(),
sequence_id: AtomicU64::new(0),
refresh_lock: Mutex::new(()),
auth_type: AuthType::Password,
account_identifier,
warehouse: warehouse.map(str::to_uppercase),
database,
username,
role,
password: Some(password),
schema,
private_key_pem: None,
login_timeout: DEFAULT_LOGIN_TIMEOUT,
}
}
#[cfg(feature = "browser-auth")]
#[allow(clippy::too_many_arguments)]
pub fn browser_auth(
connection: Arc<Connection>,
account_identifier: &str,
warehouse: Option<&str>,
database: Option<&str>,
schema: Option<&str>,
username: &str,
role: Option<&str>,
) -> Self {
let account_identifier = account_identifier.to_uppercase();
let database = database.map(str::to_uppercase);
let schema = schema.map(str::to_uppercase);
let username = username.to_uppercase();
let role = role.map(str::to_uppercase);
Self {
connection,
auth_state: ArcSwapOption::empty(),
sequence_id: AtomicU64::new(0),
refresh_lock: Mutex::new(()),
auth_type: AuthType::Browser,
account_identifier,
warehouse: warehouse.map(str::to_uppercase),
database,
username,
role,
password: None,
schema,
private_key_pem: None,
login_timeout: DEFAULT_LOGIN_TIMEOUT,
}
}
pub fn set_login_timeout(&mut self, timeout: Duration) {
self.login_timeout = timeout;
}
pub async fn get_token(&self) -> Result<AuthParts, AuthError> {
if let Some(state) = self.auth_state.load_full() {
if state.is_fresh() {
return Ok(self.build_parts(&state));
}
}
let _refresh_guard = self.refresh_lock.lock().await;
if let Some(state) = self.auth_state.load_full() {
if state.is_fresh() {
return Ok(self.build_parts(&state));
}
}
let current = self.auth_state.load_full();
let need_full_create = current
.as_deref()
.is_none_or(|s| s.master_token.is_expired());
let new_state = if need_full_create {
let tokens = match self.auth_type {
AuthType::Certificate => {
log::info!("Starting session with certificate authentication");
if cfg!(feature = "cert-auth") {
self.create(self.cert_request_body()?).await?
} else {
return Err(AuthError::MissingCertificate);
}
}
AuthType::Password => {
log::info!("Starting session with password authentication");
self.create(self.passwd_request_body()?).await?
}
#[cfg(feature = "browser-auth")]
AuthType::Browser => {
log::info!("Starting session with external browser authentication");
self.create_browser_session().await?
}
};
self.sequence_id.store(0, Ordering::Relaxed);
tokens
} else {
match current {
Some(state) => self.renew(&state).await?,
None => return Err(AuthError::OutOfOrderRenew),
}
};
let new_state = Arc::new(new_state);
self.auth_state.store(Some(Arc::clone(&new_state)));
Ok(self.build_parts(&new_state))
}
fn build_parts(&self, state: &AuthState) -> AuthParts {
let sequence_id = self.sequence_id.fetch_add(1, Ordering::Relaxed) + 1;
AuthParts {
session_token_auth_header: state.auth_header.clone(),
sequence_id,
}
}
pub async fn force_renew(&self) -> Result<AuthParts, AuthError> {
let _refresh_guard = self.refresh_lock.lock().await;
let current = self.auth_state.load_full();
let new_state = match current.as_deref() {
Some(s) if !s.master_token.is_expired() => self.renew(s).await?,
_ => {
let tokens = match self.auth_type {
AuthType::Certificate => {
log::info!("Re-creating session (certificate auth)");
if cfg!(feature = "cert-auth") {
self.create(self.cert_request_body()?).await?
} else {
return Err(AuthError::MissingCertificate);
}
}
AuthType::Password => {
log::info!("Re-creating session (password auth)");
self.create(self.passwd_request_body()?).await?
}
#[cfg(feature = "browser-auth")]
AuthType::Browser => {
log::info!("Re-creating session (browser auth)");
self.create_browser_session().await?
}
};
self.sequence_id.store(0, Ordering::Relaxed);
tokens
}
};
let new_state = Arc::new(new_state);
self.auth_state.store(Some(Arc::clone(&new_state)));
Ok(self.build_parts(&new_state))
}
pub async fn close(&self) -> Result<(), AuthError> {
let Some(state) = self.auth_state.swap(None) else {
return Ok(());
};
log::debug!("Closing sessions");
let resp = self
.connection
.request::<AuthResponse>(
QueryType::CloseSession,
&self.account_identifier,
&[("delete", "true")],
Some(&state.auth_header),
serde_json::Value::default(),
None,
)
.await?;
match resp {
AuthResponse::Close(_) => Ok(()),
AuthResponse::Error(e) => Err(AuthError::AuthFailed(
e.code.unwrap_or_default(),
e.message.unwrap_or_default(),
)),
_ => Err(AuthError::UnexpectedResponse),
}
}
#[cfg(feature = "cert-auth")]
fn cert_request_body(&self) -> Result<CertLoginRequest, AuthError> {
let full_identifier = format!("{}.{}", &self.account_identifier, &self.username);
let private_key_pem = self
.private_key_pem
.as_ref()
.ok_or(AuthError::MissingCertificate)?;
let jwt_token = generate_jwt_token(private_key_pem.expose_secret(), &full_identifier)?;
Ok(CertLoginRequest {
data: CertRequestData {
login_request_common: self.login_request_common(),
authenticator: "SNOWFLAKE_JWT".to_string(),
token: jwt_token,
},
})
}
fn passwd_request_body(&self) -> Result<PasswordLoginRequest, AuthError> {
let password = self.password.as_ref().ok_or(AuthError::MissingPassword)?;
Ok(PasswordLoginRequest {
data: PasswordRequestData {
login_request_common: self.login_request_common(),
password: password.expose_secret().to_string(),
},
})
}
async fn create<T: serde::ser::Serialize>(
&self,
body: LoginRequest<T>,
) -> Result<AuthState, AuthError> {
let timeout = self.login_timeout;
tokio::time::timeout(timeout, self.create_inner(body))
.await
.map_err(|_| AuthError::LoginTimeout(timeout))?
}
async fn create_inner<T: serde::ser::Serialize>(
&self,
body: LoginRequest<T>,
) -> Result<AuthState, AuthError> {
let mut get_params = Vec::new();
if let Some(warehouse) = &self.warehouse {
get_params.push(("warehouse", warehouse.as_str()));
}
if let Some(database) = &self.database {
get_params.push(("databaseName", database.as_str()));
}
if let Some(schema) = &self.schema {
get_params.push(("schemaName", schema.as_str()));
}
if let Some(role) = &self.role {
get_params.push(("roleName", role.as_str()));
}
let resp = self
.connection
.request::<AuthResponse>(
QueryType::LoginRequest,
&self.account_identifier,
&get_params,
None,
body,
None,
)
.await?;
log::debug!("Auth response: {resp:?}");
match resp {
AuthResponse::Login(lr) => {
let session_token = AuthToken::new(&lr.data.token, lr.data.validity_in_seconds);
let master_token =
AuthToken::new(&lr.data.master_token, lr.data.master_validity_in_seconds);
Ok(AuthState::new(session_token, master_token))
}
AuthResponse::Error(e) => Err(AuthError::AuthFailed(
e.code.unwrap_or_default(),
e.message.unwrap_or_default(),
)),
_ => Err(AuthError::UnexpectedResponse),
}
}
fn login_request_common(&self) -> LoginRequestCommon {
LoginRequestCommon {
client_app_id: "Go".to_string(),
client_app_version: "1.6.22".to_string(),
svn_revision: String::new(),
account_name: self.account_identifier.clone(),
login_name: self.username.clone(),
session_parameters: SessionParameters {
client_validate_default_parameters: true,
},
client_environment: ClientEnvironment {
application: "Rust".to_string(),
os: match std::env::consts::OS {
"macos" => "darwin".to_owned(),
other => other.to_owned(),
},
os_version: std::env::consts::ARCH.to_owned(),
ocsp_mode: "FAIL_OPEN".to_string(),
},
}
}
#[cfg(feature = "browser-auth")]
async fn create_browser_session(&self) -> Result<AuthState, AuthError> {
use crate::browser::{
create_local_listener, generate_proof_key, open_browser, wait_for_token,
};
let (listener, port) = create_local_listener()?;
let proof_key = generate_proof_key();
let auth_request = AuthenticatorRequest {
data: AuthenticatorRequestData {
client_app_id: "Go".to_string(),
client_app_version: "1.6.22".to_string(),
svn_revision: String::new(),
account_name: self.account_identifier.clone(),
login_name: self.username.clone(),
authenticator: "EXTERNALBROWSER".to_string(),
browser_mode_redirect_port: port.to_string(),
proof_key: proof_key.clone(),
client_environment: crate::requests::AuthenticatorClientEnvironment {
application: "Rust".to_string(),
os: std::env::consts::OS.to_string(),
os_version: "unknown".to_string(),
},
},
};
let resp = self
.connection
.request::<AuthResponse>(
QueryType::AuthenticatorRequest,
&self.account_identifier,
&[],
None,
auth_request,
None,
)
.await?;
let (sso_url, server_proof_key) = match resp {
AuthResponse::Auth(auth_resp) => (auth_resp.data.sso_url, auth_resp.data.proof_key),
AuthResponse::Error(e) => {
return Err(AuthError::AuthFailed(
e.code.unwrap_or_default(),
e.message.unwrap_or_default(),
));
}
_ => return Err(AuthError::UnexpectedResponse),
};
let final_proof_key = if server_proof_key.is_empty() {
proof_key
} else {
server_proof_key
};
open_browser(&sso_url)?;
let token = tokio::task::spawn_blocking(move || wait_for_token(&listener))
.await
.map_err(|_| AuthError::TokenFetchFailed)??;
let login_request = BrowserLoginRequest {
data: BrowserRequestData {
login_request_common: self.login_request_common(),
authenticator: "EXTERNALBROWSER".to_string(),
token,
proof_key: final_proof_key,
},
};
self.create(login_request).await
}
async fn renew(&self, old: &AuthState) -> Result<AuthState, AuthError> {
log::debug!("Renewing the token");
let auth = old.master_token.auth_header();
let body = RenewSessionRequest {
old_session_token: old.session_token.token.expose_secret().to_string(),
request_type: "RENEW".to_string(),
};
let resp = self
.connection
.request(
QueryType::TokenRequest,
&self.account_identifier,
&[],
Some(&auth),
body,
None,
)
.await?;
match resp {
AuthResponse::Renew(rs) => {
let session_token =
AuthToken::new(&rs.data.session_token, rs.data.validity_in_seconds_s_t);
let master_token =
AuthToken::new(&rs.data.master_token, rs.data.validity_in_seconds_m_t);
Ok(AuthState::new(session_token, master_token))
}
AuthResponse::Error(e) => Err(AuthError::AuthFailed(
e.code.unwrap_or_default(),
e.message.unwrap_or_default(),
)),
_ => Err(AuthError::UnexpectedResponse),
}
}
pub async fn heartbeat(&self) -> Result<(), AuthError> {
if self.auth_state.load().is_none() {
return Ok(());
}
let parts = self.get_token().await?;
let resp = self
.send_heartbeat(&parts.session_token_auth_header)
.await?;
if resp.success {
return Ok(());
}
if resp.code.as_deref() == Some("390112") {
log::debug!("Heartbeat saw 390112; renewing and retrying once");
let parts = self.force_renew().await?;
let resp = self
.send_heartbeat(&parts.session_token_auth_header)
.await?;
if resp.success {
return Ok(());
}
return Err(AuthError::AuthFailed(
resp.code.unwrap_or_default(),
resp.message.unwrap_or_default(),
));
}
Err(AuthError::AuthFailed(
resp.code.unwrap_or_default(),
resp.message.unwrap_or_default(),
))
}
async fn send_heartbeat(
&self,
auth_header: &str,
) -> Result<BaseRestResponse<serde_json::Value>, AuthError> {
Ok(self
.connection
.request::<BaseRestResponse<serde_json::Value>>(
QueryType::Heartbeat,
&self.account_identifier,
&[],
Some(auth_header),
serde_json::Value::Null,
None,
)
.await?)
}
}