use core::fmt;
use std::collections::BTreeMap;
use reqwest::{Method, StatusCode};
use secrecy::{ExposeSecret, SecretString};
use serde::{Deserialize, Deserializer, Serialize, Serializer, de::Visitor, ser::SerializeMap};
use crate::{
Authenticated, Client, Error, Result,
path::{validate_mount_path, validate_secret_path},
response::{
Empty, ResponseEnvelope, deserialize_bounded_string_map_or_default,
deserialize_bounded_string_vec,
},
};
#[derive(Debug)]
pub struct Database<'a> {
client: &'a Client<Authenticated>,
mount: Vec<String>,
}
#[derive(Clone, Default)]
pub struct DatabaseConnectionConfig {
pub plugin_name: String,
pub plugin_version: Option<String>,
pub verify_connection: Option<bool>,
pub allowed_roles: Vec<String>,
pub root_rotation_statements: Vec<String>,
pub password_policy: Option<String>,
pub connection_url: Option<SecretString>,
pub username: Option<String>,
pub password: Option<SecretString>,
pub disable_escaping: Option<bool>,
pub extra: BTreeMap<String, String>,
}
impl DatabaseConnectionConfig {
pub fn new(plugin_name: impl Into<String>) -> Self {
Self {
plugin_name: plugin_name.into(),
..Self::default()
}
}
}
impl fmt::Debug for DatabaseConnectionConfig {
fn fmt(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
formatter
.debug_struct("DatabaseConnectionConfig")
.field("plugin_name", &self.plugin_name)
.field("plugin_version", &self.plugin_version)
.field("verify_connection", &self.verify_connection)
.field("allowed_roles", &self.allowed_roles)
.field("root_rotation_statements", &self.root_rotation_statements)
.field("password_policy", &self.password_policy)
.field("has_connection_url", &self.connection_url.is_some())
.field("username", &self.username)
.field("password", &self.password.as_ref().map(|_| "<redacted>"))
.field("disable_escaping", &self.disable_escaping)
.field("extra", &self.extra)
.finish()
}
}
#[derive(Clone, Debug, Default, Deserialize)]
pub struct DatabaseConnectionInfo {
#[serde(default, deserialize_with = "deserialize_bounded_string_vec")]
pub allowed_roles: Vec<String>,
#[serde(
default,
deserialize_with = "deserialize_bounded_string_map_or_default"
)]
pub connection_details: BTreeMap<String, String>,
#[serde(default)]
pub password_policy: Option<String>,
pub plugin_name: String,
#[serde(default)]
pub plugin_version: Option<String>,
#[serde(
default,
alias = "root_rotation_statements",
deserialize_with = "deserialize_bounded_string_vec"
)]
pub root_credentials_rotate_statements: Vec<String>,
}
#[derive(Clone, Debug, Default, Deserialize)]
pub struct DatabaseList {
#[serde(default, deserialize_with = "deserialize_bounded_string_vec")]
pub keys: Vec<String>,
}
#[derive(Clone, Debug, Default, Deserialize, Serialize)]
pub struct DatabaseRole {
#[serde(skip_serializing_if = "String::is_empty")]
pub db_name: String,
#[serde(default, deserialize_with = "deserialize_bounded_string_or_vec")]
#[serde(skip_serializing_if = "Vec::is_empty")]
pub creation_statements: Vec<String>,
#[serde(
default,
deserialize_with = "deserialize_optional_string_or_u64",
skip_serializing_if = "Option::is_none"
)]
pub default_ttl: Option<String>,
#[serde(
default,
deserialize_with = "deserialize_optional_string_or_u64",
skip_serializing_if = "Option::is_none"
)]
pub max_ttl: Option<String>,
#[serde(default, deserialize_with = "deserialize_bounded_string_or_vec")]
#[serde(skip_serializing_if = "Vec::is_empty")]
pub revocation_statements: Vec<String>,
#[serde(default, deserialize_with = "deserialize_bounded_string_or_vec")]
#[serde(skip_serializing_if = "Vec::is_empty")]
pub rollback_statements: Vec<String>,
#[serde(default, deserialize_with = "deserialize_bounded_string_or_vec")]
#[serde(skip_serializing_if = "Vec::is_empty")]
pub renew_statements: Vec<String>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub credential_type: Option<String>,
#[serde(
default,
deserialize_with = "deserialize_bounded_string_map_or_default"
)]
#[serde(skip_serializing_if = "BTreeMap::is_empty")]
pub credential_config: BTreeMap<String, String>,
}
impl DatabaseRole {
pub fn new(db_name: impl Into<String>) -> Self {
Self {
db_name: db_name.into(),
..Self::default()
}
}
#[must_use]
pub fn with_creation_statement(mut self, statement: impl Into<String>) -> Self {
self.creation_statements.push(statement.into());
self
}
}
#[derive(Clone, Debug, Default, Deserialize, Serialize)]
pub struct DatabaseStaticRole {
#[serde(skip_serializing_if = "String::is_empty")]
pub db_name: String,
#[serde(skip_serializing_if = "String::is_empty")]
pub username: String,
#[serde(
default,
deserialize_with = "deserialize_optional_string_or_u64",
skip_serializing_if = "Option::is_none"
)]
pub rotation_period: Option<String>,
#[serde(default, deserialize_with = "deserialize_bounded_string_or_vec")]
#[serde(skip_serializing_if = "Vec::is_empty")]
pub rotation_statements: Vec<String>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub credential_type: Option<String>,
#[serde(
default,
deserialize_with = "deserialize_bounded_string_map_or_default"
)]
#[serde(skip_serializing_if = "BTreeMap::is_empty")]
pub credential_config: BTreeMap<String, String>,
}
impl DatabaseStaticRole {
pub fn new(db_name: impl Into<String>, username: impl Into<String>) -> Self {
Self {
db_name: db_name.into(),
username: username.into(),
..Self::default()
}
}
}
#[derive(Clone)]
pub struct DatabaseCredentials {
pub username: String,
pub password: Option<SecretString>,
pub private_key: Option<SecretString>,
pub certificate: Option<String>,
pub issuing_ca: Option<String>,
pub ca_chain: Vec<String>,
pub lease_id: SecretString,
pub lease_duration: u64,
pub renewable: bool,
}
impl fmt::Debug for DatabaseCredentials {
fn fmt(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
formatter
.debug_struct("DatabaseCredentials")
.field("username", &self.username)
.field("password", &self.password.as_ref().map(|_| "<redacted>"))
.field(
"private_key",
&self.private_key.as_ref().map(|_| "<redacted>"),
)
.field("certificate", &self.certificate)
.field("issuing_ca", &self.issuing_ca)
.field("ca_chain", &self.ca_chain)
.field("lease_id", &"<redacted>")
.field("lease_duration", &self.lease_duration)
.field("renewable", &self.renewable)
.finish()
}
}
#[derive(Clone, Deserialize)]
pub struct DatabaseStaticCredentials {
pub username: String,
pub password: SecretString,
#[serde(default)]
pub last_openbao_rotation: Option<String>,
#[serde(default)]
pub rotation_period: Option<u64>,
#[serde(default)]
pub ttl: Option<u64>,
}
impl fmt::Debug for DatabaseStaticCredentials {
fn fmt(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
formatter
.debug_struct("DatabaseStaticCredentials")
.field("username", &self.username)
.field("password", &"<redacted>")
.field("last_openbao_rotation", &self.last_openbao_rotation)
.field("rotation_period", &self.rotation_period)
.field("ttl", &self.ttl)
.finish()
}
}
#[derive(Serialize)]
struct DatabaseConnectionPayload<'a> {
plugin_name: &'a str,
#[serde(skip_serializing_if = "Option::is_none")]
plugin_version: Option<&'a str>,
#[serde(skip_serializing_if = "Option::is_none")]
verify_connection: Option<bool>,
#[serde(skip_serializing_if = "Vec::is_empty")]
allowed_roles: Vec<&'a str>,
#[serde(skip_serializing_if = "Vec::is_empty")]
root_rotation_statements: Vec<&'a str>,
#[serde(skip_serializing_if = "Option::is_none")]
password_policy: Option<&'a str>,
#[serde(skip_serializing_if = "Option::is_none")]
connection_url: Option<&'a str>,
#[serde(skip_serializing_if = "Option::is_none")]
username: Option<&'a str>,
#[serde(skip_serializing_if = "Option::is_none")]
password: Option<&'a str>,
#[serde(skip_serializing_if = "Option::is_none")]
disable_escaping: Option<bool>,
#[serde(flatten)]
extra: &'a BTreeMap<String, String>,
}
#[derive(Deserialize)]
struct DatabaseCredentialData {
username: String,
#[serde(default)]
password: Option<SecretString>,
#[serde(default)]
private_key: Option<SecretString>,
#[serde(default)]
certificate: Option<String>,
#[serde(default)]
issuing_ca: Option<String>,
#[serde(default, deserialize_with = "deserialize_bounded_string_vec")]
ca_chain: Vec<String>,
}
impl Client<Authenticated> {
pub fn database(&self, mount: impl Into<String>) -> Result<Database<'_>> {
let mount = mount.into();
Ok(Database {
client: self,
mount: validate_mount_path(&mount)?,
})
}
}
impl Database<'_> {
pub async fn configure_connection(
&self,
name: &str,
config: &DatabaseConnectionConfig,
) -> Result<Empty> {
let payload = DatabaseConnectionPayload {
plugin_name: &config.plugin_name,
plugin_version: config.plugin_version.as_deref(),
verify_connection: config.verify_connection,
allowed_roles: config.allowed_roles.iter().map(String::as_str).collect(),
root_rotation_statements: config
.root_rotation_statements
.iter()
.map(String::as_str)
.collect(),
password_policy: config.password_policy.as_deref(),
connection_url: config
.connection_url
.as_ref()
.map(SecretString::expose_secret),
username: config.username.as_deref(),
password: config.password.as_ref().map(SecretString::expose_secret),
disable_escaping: config.disable_escaping,
extra: &config.extra,
};
self.client
.request_json(Method::POST, &self.path(&["config", name])?, Some(&payload))
.await
}
pub async fn read_connection(&self, name: &str) -> Result<DatabaseConnectionInfo> {
let envelope: ResponseEnvelope<DatabaseConnectionInfo> = self
.client
.request_json(
Method::GET,
&self.path(&["config", name])?,
Option::<&Empty>::None,
)
.await?;
Ok(envelope.data)
}
pub async fn list_connections(&self) -> Result<DatabaseList> {
self.list_at("config", None, None).await
}
pub async fn delete_connection(&self, name: &str) -> Result<Empty> {
self.delete_at("config", name).await
}
pub async fn reset_connection(&self, name: &str) -> Result<Empty> {
self.client
.request_json(Method::POST, &self.path(&["reset", name])?, Some(&Empty {}))
.await
}
pub async fn rotate_root(&self, name: &str) -> Result<Empty> {
self.client
.request_json(
Method::POST,
&self.path(&["rotate-root", name])?,
Some(&Empty {}),
)
.await
}
pub async fn write_role(&self, name: &str, role: &DatabaseRole) -> Result<Empty> {
self.client
.request_json(Method::POST, &self.path(&["roles", name])?, Some(role))
.await
}
pub async fn read_role(&self, name: &str) -> Result<DatabaseRole> {
let envelope: ResponseEnvelope<DatabaseRole> = self
.client
.request_json(
Method::GET,
&self.path(&["roles", name])?,
Option::<&Empty>::None,
)
.await?;
Ok(envelope.data)
}
pub async fn list_roles(&self) -> Result<DatabaseList> {
self.list_roles_after(None, None).await
}
pub async fn list_roles_after(
&self,
after: Option<&str>,
limit: Option<u64>,
) -> Result<DatabaseList> {
self.list_at("roles", after, limit).await
}
pub async fn delete_role(&self, name: &str) -> Result<Empty> {
self.delete_at("roles", name).await
}
pub async fn credentials(&self, name: &str) -> Result<DatabaseCredentials> {
let envelope: ResponseEnvelope<DatabaseCredentialData> = self
.client
.request_json(
Method::GET,
&self.path(&["creds", name])?,
Option::<&Empty>::None,
)
.await?;
Ok(database_credentials_from_envelope(envelope))
}
pub async fn write_static_role(&self, name: &str, role: &DatabaseStaticRole) -> Result<Empty> {
self.client
.request_json(
Method::POST,
&self.path(&["static-roles", name])?,
Some(role),
)
.await
}
pub async fn read_static_role(&self, name: &str) -> Result<DatabaseStaticRole> {
let envelope: ResponseEnvelope<DatabaseStaticRole> = self
.client
.request_json(
Method::GET,
&self.path(&["static-roles", name])?,
Option::<&Empty>::None,
)
.await?;
Ok(envelope.data)
}
pub async fn list_static_roles(&self) -> Result<DatabaseList> {
self.list_static_roles_after(None, None).await
}
pub async fn list_static_roles_after(
&self,
after: Option<&str>,
limit: Option<u64>,
) -> Result<DatabaseList> {
self.list_at("static-roles", after, limit).await
}
pub async fn delete_static_role(&self, name: &str) -> Result<Empty> {
self.delete_at("static-roles", name).await
}
pub async fn static_credentials(&self, name: &str) -> Result<DatabaseStaticCredentials> {
let envelope: ResponseEnvelope<DatabaseStaticCredentials> = self
.client
.request_json(
Method::GET,
&self.path(&["static-creds", name])?,
Option::<&Empty>::None,
)
.await?;
Ok(envelope.data)
}
pub async fn rotate_static_role(&self, name: &str) -> Result<Empty> {
self.client
.request_json(
Method::POST,
&self.path(&["rotate-role", name])?,
Some(&Empty {}),
)
.await
}
async fn list_at(
&self,
segment: &'static str,
after: Option<&str>,
limit: Option<u64>,
) -> Result<DatabaseList> {
let method =
Method::from_bytes(b"LIST").map_err(|error| Error::InvalidHeader(error.to_string()))?;
let mut query = Vec::new();
if let Some(after) = after {
query.push(("after", validate_mount_path(after)?.join("/")));
}
if let Some(limit) = limit {
query.push(("limit", limit.to_string()));
}
let envelope: ResponseEnvelope<DatabaseList> = self
.client
.request_json_query_accepting(
method,
&self.path(&[segment])?,
&query,
Option::<&Empty>::None,
&[StatusCode::OK],
)
.await?;
Ok(envelope.data)
}
async fn delete_at(&self, segment: &'static str, name: &str) -> Result<Empty> {
self.client
.request_json_accepting(
Method::DELETE,
&self.path(&[segment, name])?,
Option::<&Empty>::None,
&[StatusCode::OK, StatusCode::NO_CONTENT],
)
.await
}
fn path(&self, tail: &[&str]) -> Result<String> {
let mut segments = self.mount.clone();
for segment in tail {
segments.extend(validate_secret_path(segment)?);
}
Ok(segments.join("/"))
}
}
fn database_credentials_from_envelope(
envelope: ResponseEnvelope<DatabaseCredentialData>,
) -> DatabaseCredentials {
DatabaseCredentials {
username: envelope.data.username,
password: envelope.data.password,
private_key: envelope.data.private_key,
certificate: envelope.data.certificate,
issuing_ca: envelope.data.issuing_ca,
ca_chain: envelope.data.ca_chain,
lease_id: envelope.lease_id,
lease_duration: envelope.lease_duration,
renewable: envelope.renewable,
}
}
impl Serialize for DatabaseConnectionConfig {
fn serialize<S>(&self, serializer: S) -> core::result::Result<S::Ok, S::Error>
where
S: Serializer,
{
let mut count = 1 + self.extra.len();
count += usize::from(self.plugin_version.is_some());
count += usize::from(self.verify_connection.is_some());
count += usize::from(!self.allowed_roles.is_empty());
count += usize::from(!self.root_rotation_statements.is_empty());
count += usize::from(self.password_policy.is_some());
count += usize::from(self.connection_url.is_some());
count += usize::from(self.username.is_some());
count += usize::from(self.password.is_some());
count += usize::from(self.disable_escaping.is_some());
let mut map = serializer.serialize_map(Some(count))?;
map.serialize_entry("plugin_name", &self.plugin_name)?;
if let Some(plugin_version) = self.plugin_version.as_ref() {
map.serialize_entry("plugin_version", plugin_version)?;
}
if let Some(verify_connection) = self.verify_connection {
map.serialize_entry("verify_connection", &verify_connection)?;
}
if !self.allowed_roles.is_empty() {
map.serialize_entry("allowed_roles", &self.allowed_roles)?;
}
if !self.root_rotation_statements.is_empty() {
map.serialize_entry("root_rotation_statements", &self.root_rotation_statements)?;
}
if let Some(password_policy) = self.password_policy.as_ref() {
map.serialize_entry("password_policy", password_policy)?;
}
if let Some(connection_url) = self.connection_url.as_ref() {
map.serialize_entry("connection_url", connection_url.expose_secret())?;
}
if let Some(username) = self.username.as_ref() {
map.serialize_entry("username", username)?;
}
if let Some(password) = self.password.as_ref() {
map.serialize_entry("password", password.expose_secret())?;
}
if let Some(disable_escaping) = self.disable_escaping {
map.serialize_entry("disable_escaping", &disable_escaping)?;
}
for (key, value) in &self.extra {
map.serialize_entry(key, value)?;
}
map.end()
}
}
fn deserialize_bounded_string_or_vec<'de, D>(
deserializer: D,
) -> core::result::Result<Vec<String>, D::Error>
where
D: Deserializer<'de>,
{
deserializer.deserialize_any(StringOrListVisitor::<{ crate::response::MAX_RESPONSE_STRINGS }>)
}
struct StringOrListVisitor<const MAX: usize>;
impl<'de, const MAX: usize> Visitor<'de> for StringOrListVisitor<MAX> {
type Value = Vec<String>;
fn expecting(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(
formatter,
"a comma-separated string or a list of at most {MAX} strings"
)
}
fn visit_unit<E>(self) -> core::result::Result<Self::Value, E>
where
E: serde::de::Error,
{
Ok(Vec::new())
}
fn visit_str<E>(self, value: &str) -> core::result::Result<Self::Value, E>
where
E: serde::de::Error,
{
if value.trim().is_empty() {
return Ok(Vec::new());
}
let values: Vec<String> = value
.split(',')
.map(str::trim)
.filter(|part| !part.is_empty())
.map(str::to_owned)
.collect();
if values.len() > MAX {
return Err(E::custom("OpenBao string list exceeds item limit"));
}
Ok(values)
}
fn visit_seq<A>(self, mut seq: A) -> core::result::Result<Self::Value, A::Error>
where
A: serde::de::SeqAccess<'de>,
{
let mut values = Vec::new();
while values.len() < MAX {
let Some(value) = seq.next_element::<String>()? else {
return Ok(values);
};
values.push(value);
}
if seq.next_element::<serde::de::IgnoredAny>()?.is_some() {
return Err(serde::de::Error::custom(
"OpenBao string list exceeds item limit",
));
}
Ok(values)
}
}
fn deserialize_optional_string_or_u64<'de, D>(
deserializer: D,
) -> core::result::Result<Option<String>, D::Error>
where
D: Deserializer<'de>,
{
deserializer.deserialize_any(OptionalStringOrU64Visitor)
}
struct OptionalStringOrU64Visitor;
impl<'de> Visitor<'de> for OptionalStringOrU64Visitor {
type Value = Option<String>;
fn expecting(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
formatter.write_str("a string, integer, null, or omitted value")
}
fn visit_none<E>(self) -> core::result::Result<Self::Value, E> {
Ok(None)
}
fn visit_unit<E>(self) -> core::result::Result<Self::Value, E> {
Ok(None)
}
fn visit_some<D>(self, deserializer: D) -> core::result::Result<Self::Value, D::Error>
where
D: Deserializer<'de>,
{
deserializer.deserialize_any(self)
}
fn visit_str<E>(self, value: &str) -> core::result::Result<Self::Value, E>
where
E: serde::de::Error,
{
Ok(Some(value.to_owned()))
}
fn visit_string<E>(self, value: String) -> core::result::Result<Self::Value, E> {
Ok(Some(value))
}
fn visit_u64<E>(self, value: u64) -> core::result::Result<Self::Value, E> {
Ok(Some(value.to_string()))
}
}
#[cfg(test)]
mod tests {
#![allow(clippy::panic)]
use secrecy::{ExposeSecret, SecretString};
use super::{DatabaseCredentials, DatabaseList, DatabaseRole, DatabaseStaticCredentials};
#[test]
fn database_credentials_debug_redacts_password_and_lease() {
let credentials = DatabaseCredentials {
username: "app".to_owned(),
password: Some(SecretString::from("db-password")),
private_key: Some(SecretString::from("private-key")),
certificate: None,
issuing_ca: None,
ca_chain: Vec::new(),
lease_id: SecretString::from("database/creds/app/lease"),
lease_duration: 3600,
renewable: true,
};
let debug = format!("{credentials:?}");
assert!(debug.contains("<redacted>"));
assert!(!debug.contains("db-password"));
assert!(!debug.contains("private-key"));
assert!(!debug.contains("database/creds/app/lease"));
}
#[test]
fn database_static_credentials_debug_redacts_password() {
let credentials = DatabaseStaticCredentials {
username: "static-user".to_owned(),
password: SecretString::from("static-password"),
last_openbao_rotation: None,
rotation_period: Some(3600),
ttl: Some(300),
};
let debug = format!("{credentials:?}");
assert!(debug.contains("<redacted>"));
assert!(!debug.contains("static-password"));
}
#[test]
fn database_list_is_bounded() {
let mut keys = Vec::new();
for index in 0..=crate::response::MAX_RESPONSE_STRINGS {
keys.push(format!("role-{index}"));
}
let value = serde_json::json!({ "keys": keys });
let error = match serde_json::from_value::<DatabaseList>(value) {
Ok(_) => panic!("oversized database list unexpectedly decoded"),
Err(error) => error,
};
assert!(error.to_string().contains("exceeds item limit"));
}
#[test]
fn database_role_accepts_integer_ttls_and_string_statements() {
let role: DatabaseRole = serde_json::from_str(
r#"{"db_name":"postgres","creation_statements":"CREATE ROLE {{name}}","default_ttl":3600,"max_ttl":"24h"}"#,
)
.unwrap_or_else(|error| panic!("{error}"));
assert_eq!(role.creation_statements, ["CREATE ROLE {{name}}"]);
assert_eq!(role.default_ttl.as_deref(), Some("3600"));
assert_eq!(role.max_ttl.as_deref(), Some("24h"));
}
#[test]
fn database_static_password_deserializes_secret() {
let credentials: DatabaseStaticCredentials =
serde_json::from_str(r#"{"username":"static","password":"secret"}"#)
.unwrap_or_else(|error| panic!("{error}"));
assert_eq!(credentials.password.expose_secret(), "secret");
}
}