use std::{collections::HashMap, fmt::Debug, sync::Arc};
use async_trait::async_trait;
use bytes::Bytes;
use crate::{
api::{
ClientInfo,
auth::{LoginInfo, sasl::SASLState},
},
error::{PgWireError, PgWireResult},
messages::startup::{Authentication, PasswordMessageFamily},
};
#[cfg(feature = "simple-oidc-validator")]
pub use crate::api::auth::simple_oidc_validator::SimpleOidcValidator;
#[derive(Debug)]
pub struct Oauth {
pub issuer: String,
pub scope: String,
pub validator: Arc<dyn OauthValidator>,
pub skip_usermap: bool,
}
const BEARER_SCHEME: &str = "Bearer ";
const AUTH_KEY: &str = "auth";
const KVSEP: u8 = 0x01;
#[derive(Debug, Clone)]
pub struct ValidatorModuleResult {
pub authorized: bool,
pub authn_id: Option<String>,
pub metadata: Option<HashMap<String, String>>,
}
#[async_trait]
pub trait OauthValidator: Send + Sync + Debug {
async fn validate(
&self,
token: &str,
username: &str,
issuer: &str,
required_scopes: &str,
) -> PgWireResult<ValidatorModuleResult>;
}
impl Oauth {
pub fn new(issuer: String, scope: String, validator: Arc<dyn OauthValidator>) -> Self {
Self {
issuer,
scope,
validator,
skip_usermap: false,
}
}
pub fn with_skip_usermapping(mut self, skip: bool) -> Self {
self.skip_usermap = skip;
self
}
fn generate_error_response(&self) -> String {
let config = if self.issuer.contains("/.well-known/") {
self.issuer.clone()
} else {
format!("{}/.well-known/openid-configuration", self.issuer)
};
let config = config.replace('\\', "\\\\").replace('"', "\\\"");
let scope = self.scope.replace('\\', "\\\\").replace('"', "\\\"");
let error = serde_json::json!({
"status": "invalid_token",
"openid-configuration": config,
"scope": scope
});
error.to_string()
}
fn parse_client_initial_response(&self, data: &[u8]) -> PgWireResult<Option<String>> {
if data.is_empty() {
return Ok(None);
}
let s = String::from_utf8_lossy(data);
let s = s.as_ref();
let mut chars = s.chars();
let cbind_flag = chars
.next()
.ok_or_else(|| PgWireError::InvalidOauthMessage("Empty message".to_string()))?;
match cbind_flag {
'n' | 'y' => {
if chars.next() != Some(',') {
return Err(PgWireError::InvalidOauthMessage(
"Expected comma after channel binding flag".to_string(),
));
}
}
'p' => {
return Err(PgWireError::InvalidOauthMessage(
"Channel binding not supported for oauth".to_string(),
));
}
_ => {
return Err(PgWireError::InvalidOauthMessage(format!(
"Invalid channel binding flag: {cbind_flag}"
)));
}
}
if chars.next() != Some(',') {
return Err(PgWireError::InvalidOauthMessage(
"authzid not supported".to_string(),
));
}
if chars.next() != Some('\x01') {
return Err(PgWireError::InvalidOauthMessage(
"Expected kvsep after GS2 header".to_string(),
));
}
let remnant = chars.as_str();
self.parse_kvpairs(remnant)
}
fn parse_kvpairs(&self, data: &str) -> PgWireResult<Option<String>> {
let mut auth = None;
for kv in data.split('\x01') {
if kv.is_empty() {
break;
}
let parts: Vec<&str> = kv.splitn(2, '=').collect();
if parts.len() != 2 {
return Err(PgWireError::InvalidOauthMessage(
"Malformed key-value pair".to_owned(),
));
}
let key = parts[0];
let value = parts[1];
if !key.chars().all(|c| c.is_ascii_alphabetic()) {
return Err(PgWireError::InvalidOauthMessage(
"Invalid key name".to_owned(),
));
}
if !value
.chars()
.all(|c| matches!(c, '\x21'..='\x7E' | ' ' | '\t' | '\r' | '\n'))
{
return Err(PgWireError::InvalidOauthMessage(
"Invalid value characters".to_owned(),
));
}
if key == AUTH_KEY {
if auth.is_some() {
return Err(PgWireError::InvalidOauthMessage(
"Multiple oauth values".to_string(),
));
}
if value.is_empty() {
auth = None;
} else {
auth = Some(value.to_string())
}
}
}
Ok(auth)
}
fn validate_token_format<'a>(&self, value: &'a str) -> Option<&'a str> {
if value.is_empty() {
return None;
}
if !value
.to_ascii_lowercase()
.starts_with(&BEARER_SCHEME.to_ascii_lowercase())
{
return None;
}
let token = value[BEARER_SCHEME.len()..].trim_start();
if token.is_empty() {
return None;
}
let valid_chars = token.chars().all(|c| {
c.is_ascii_alphanumeric()
|| c == '-'
|| c == '.'
|| c == '_'
|| c == '~'
|| c == '+'
|| c == '/'
|| c == '='
});
if !valid_chars {
return None;
}
Some(token)
}
pub async fn process_oauth_message<C>(
&self,
client: &C,
msg: PasswordMessageFamily,
state: &SASLState,
) -> PgWireResult<(Option<Authentication>, SASLState)>
where
C: ClientInfo + Unpin + Send,
{
match state {
SASLState::OauthStateInit => {
let res = msg.into_sasl_initial_response()?;
let data = match res.data.as_deref() {
None => {
return Ok((
Some(Authentication::SASLContinue(Bytes::from(""))),
SASLState::OauthStateInit,
));
}
Some(d) => d,
};
let auth = match self.parse_client_initial_response(data) {
Ok(Some(auth)) => auth,
Ok(None) => {
let err = self.generate_error_response();
return Ok((
Some(Authentication::SASLContinue(Bytes::from(err))),
SASLState::OauthStateError,
));
}
Err(err) => return Err(err),
};
if auth.is_empty() {
return Err(PgWireError::OAuthAuthenticationFailed(
"validation of OAuth token requested without a auth header".to_string(),
));
}
let token = match self.validate_token_format(&auth) {
Some(t) => t,
None => {
return Err(PgWireError::OAuthAuthenticationFailed(
"malformed OAuth bearer token".to_string(),
));
}
};
let login_info = LoginInfo::from_client_info(client);
let username = login_info
.user()
.ok_or_else(|| PgWireError::UserNameRequired)?;
let validation_result = self
.validator
.validate(token, username, &self.issuer, &self.scope)
.await?;
if !validation_result.authorized {
return Err(PgWireError::OAuthAuthenticationFailed(format!(
"OAuth bearer authentication failed for user: {}",
username
)));
}
Ok((None, SASLState::Finished))
}
SASLState::OauthStateError => {
let res = msg.into_sasl_response()?;
if res.data.len() != 1 || res.data[0] != KVSEP {
return Err(PgWireError::InvalidOauthMessage(
"Expected single kvsep byte in error response".to_string(),
));
}
Err(PgWireError::OAuthAuthenticationFailed(
"OAuth authentication failed".to_string(),
))
}
_ => Err(PgWireError::InvalidSASLState),
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[derive(Debug)]
struct MockValidator;
#[async_trait]
impl OauthValidator for MockValidator {
async fn validate(
&self,
_token: &str,
_username: &str,
_issuer: &str,
_required_scopes: &str,
) -> PgWireResult<ValidatorModuleResult> {
Ok(ValidatorModuleResult {
authorized: true,
authn_id: Some("test@example.com".to_string()),
metadata: None,
})
}
}
#[test]
fn test_parse_kvpairs() {
let oauth = Oauth::new(
"https://example.com".to_string(),
"openid".to_string(),
Arc::new(MockValidator),
);
let data = "auth=Bearer token123\x01\x01";
let result = oauth.parse_kvpairs(data).unwrap();
assert_eq!(result, Some("Bearer token123".to_string()));
let data = "host=localhost\x01auth=Bearer token123\x01port=5432\x01\x01";
let result = oauth.parse_kvpairs(data).unwrap();
assert_eq!(result, Some("Bearer token123".to_string()));
}
#[test]
fn test_validate_token_format() {
let oauth = Oauth::new(
"https://example.com".to_string(),
"openid".to_string(),
Arc::new(MockValidator),
);
assert!(oauth.validate_token_format("Bearer abc123").is_some());
assert!(
oauth
.validate_token_format("Bearer abc.123_def-ghi+jkl/mno===")
.is_some()
);
assert!(oauth.validate_token_format("").is_none());
assert!(oauth.validate_token_format("Bearer ").is_none());
assert!(oauth.validate_token_format("Basic abc123").is_none());
}
}