use std::str::from_utf8;
use memchr::memchr;
use sqlx_core::bytes::{Buf, Bytes};
use crate::error::Error;
use crate::io::ProtocolDecode;
use crate::message::{BackendMessage, BackendMessageFormat};
use base64::prelude::{Engine as _, BASE64_STANDARD};
#[derive(Debug)]
pub enum Authentication {
Ok,
CleartextPassword,
Md5Password(AuthenticationMd5Password),
Sasl(AuthenticationSasl),
SaslContinue(AuthenticationSaslContinue),
SaslFinal(AuthenticationSaslFinal),
}
impl BackendMessage for Authentication {
const FORMAT: BackendMessageFormat = BackendMessageFormat::Authentication;
fn decode_body(mut buf: Bytes) -> Result<Self, Error> {
Ok(match buf.get_u32() {
0 => Authentication::Ok,
3 => Authentication::CleartextPassword,
5 => {
let mut salt = [0; 4];
buf.copy_to_slice(&mut salt);
Authentication::Md5Password(AuthenticationMd5Password { salt })
}
10 => Authentication::Sasl(AuthenticationSasl(buf)),
11 => Authentication::SaslContinue(AuthenticationSaslContinue::decode(buf)?),
12 => Authentication::SaslFinal(AuthenticationSaslFinal::decode(buf)?),
ty => {
return Err(err_protocol!("unknown authentication method: {}", ty));
}
})
}
}
#[derive(Debug)]
pub struct AuthenticationMd5Password {
pub salt: [u8; 4],
}
#[derive(Debug)]
pub struct AuthenticationSasl(Bytes);
impl AuthenticationSasl {
#[inline]
pub fn mechanisms(&self) -> SaslMechanisms<'_> {
SaslMechanisms(&self.0)
}
}
pub struct SaslMechanisms<'a>(&'a [u8]);
impl<'a> Iterator for SaslMechanisms<'a> {
type Item = &'a str;
fn next(&mut self) -> Option<Self::Item> {
if !self.0.is_empty() && self.0[0] == b'\0' {
return None;
}
let mechanism = memchr(b'\0', self.0).and_then(|nul| from_utf8(&self.0[..nul]).ok())?;
self.0 = &self.0[(mechanism.len() + 1)..];
Some(mechanism)
}
}
#[derive(Debug)]
pub struct AuthenticationSaslContinue {
pub salt: Vec<u8>,
pub iterations: u32,
pub nonce: String,
pub message: String,
}
impl ProtocolDecode<'_> for AuthenticationSaslContinue {
fn decode_with(buf: Bytes, _: ()) -> Result<Self, Error> {
let mut iterations: u32 = 4096;
let mut salt = Vec::new();
let mut nonce = Bytes::new();
for item in buf.split(|b| *b == b',') {
let key = item[0];
let value = &item[2..];
match key {
b'r' => {
nonce = buf.slice_ref(value);
}
b'i' => {
iterations = atoi::atoi(value).unwrap_or(4096);
}
b's' => {
salt = BASE64_STANDARD.decode(value).map_err(Error::protocol)?;
}
_ => {}
}
}
Ok(Self {
iterations,
salt,
nonce: from_utf8(&nonce).map_err(Error::protocol)?.to_owned(),
message: from_utf8(&buf).map_err(Error::protocol)?.to_owned(),
})
}
}
#[derive(Debug)]
pub struct AuthenticationSaslFinal {
pub verifier: Vec<u8>,
}
impl ProtocolDecode<'_> for AuthenticationSaslFinal {
fn decode_with(buf: Bytes, _: ()) -> Result<Self, Error> {
let mut verifier = Vec::new();
for item in buf.split(|b| *b == b',') {
let key = item[0];
let value = &item[2..];
if let b'v' = key {
verifier = BASE64_STANDARD.decode(value).map_err(Error::protocol)?;
}
}
Ok(Self { verifier })
}
}