use std::path::PathBuf;
use faucet_core::FaucetError;
use schemars::JsonSchema;
use serde::{Deserialize, Serialize};
pub const PARAM_LIMIT: usize = 2100;
#[derive(Clone, Serialize, Deserialize, JsonSchema, Default)]
pub struct MssqlConnectionConfig {
#[serde(default, skip_serializing_if = "Option::is_none")]
pub connection_url: Option<String>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub connection_string: Option<String>,
#[serde(default)]
pub tls: MssqlTls,
}
#[derive(Clone, Debug, Serialize, Deserialize, JsonSchema, Default, PartialEq, Eq)]
pub struct MssqlTls {
#[serde(rename = "type", default)]
pub mode: MssqlTlsMode,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub ca_cert_path: Option<PathBuf>,
}
#[derive(Clone, Copy, Debug, Serialize, Deserialize, JsonSchema, Default, PartialEq, Eq)]
#[serde(rename_all = "snake_case")]
pub enum MssqlTlsMode {
#[default]
Prefer,
Require,
TrustServerCertificate,
Disable,
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub(crate) struct ConnectionParts {
pub host: String,
pub port: u16,
pub database: Option<String>,
pub username: String,
pub password: String,
}
impl MssqlConnectionConfig {
pub fn validate(&self) -> Result<(), FaucetError> {
match (&self.connection_url, &self.connection_string) {
(Some(_), Some(_)) => Err(FaucetError::Config(
"MSSQL config sets both `connection_url` and `connection_string`; set exactly one"
.into(),
)),
(None, None) => Err(FaucetError::Config(
"MSSQL config requires either `connection_url` or `connection_string`".into(),
)),
_ => Ok(()),
}
}
}
pub(crate) fn parse_connection_url(raw: &str) -> Result<ConnectionParts, FaucetError> {
let url = url::Url::parse(raw)
.map_err(|e| FaucetError::Config(format!("invalid MSSQL connection_url: {e}")))?;
if url.scheme() != "mssql" && url.scheme() != "sqlserver" {
return Err(FaucetError::Config(format!(
"MSSQL connection_url scheme must be `mssql://`, got `{}://`",
url.scheme()
)));
}
let host = url
.host_str()
.filter(|h| !h.is_empty())
.ok_or_else(|| FaucetError::Config("MSSQL connection_url is missing a host".into()))?
.to_string();
let port = url.port().unwrap_or(1433);
let database = {
let seg = url.path().trim_start_matches('/');
if seg.is_empty() {
None
} else {
Some(
percent_decode(seg)
.map_err(|e| FaucetError::Config(format!("invalid database in URL: {e}")))?,
)
}
};
let username = percent_decode(url.username())
.map_err(|e| FaucetError::Config(format!("invalid username in URL: {e}")))?;
let password = percent_decode(url.password().unwrap_or(""))
.map_err(|e| FaucetError::Config(format!("invalid password in URL: {e}")))?;
Ok(ConnectionParts {
host,
port,
database,
username,
password,
})
}
fn percent_decode(s: &str) -> Result<String, std::str::Utf8Error> {
percent_encoding::percent_decode_str(s)
.decode_utf8()
.map(|c| c.into_owned())
}
pub fn quote_ident_mssql(name: &str) -> Result<String, FaucetError> {
if name.contains('\0') {
return Err(FaucetError::Config(format!(
"invalid MSSQL identifier (contains NUL): {name:?}"
)));
}
Ok(format!("[{}]", name.replace(']', "]]")))
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn validate_accepts_exactly_one() {
let url_only = MssqlConnectionConfig {
connection_url: Some("mssql://sa:pw@host/db".into()),
..Default::default()
};
assert!(url_only.validate().is_ok());
let str_only = MssqlConnectionConfig {
connection_string: Some("Server=host;Database=db".into()),
..Default::default()
};
assert!(str_only.validate().is_ok());
}
#[test]
fn validate_rejects_both_and_neither() {
let both = MssqlConnectionConfig {
connection_url: Some("mssql://sa:pw@host/db".into()),
connection_string: Some("Server=host".into()),
..Default::default()
};
assert!(both.validate().is_err());
let neither = MssqlConnectionConfig::default();
assert!(neither.validate().is_err());
}
#[test]
fn parse_url_extracts_all_parts() {
let parts = parse_connection_url("mssql://sa:s3cret@db.example.com:1433/sales").unwrap();
assert_eq!(parts.host, "db.example.com");
assert_eq!(parts.port, 1433);
assert_eq!(parts.database.as_deref(), Some("sales"));
assert_eq!(parts.username, "sa");
assert_eq!(parts.password, "s3cret");
}
#[test]
fn parse_url_defaults_port_and_optional_database() {
let parts = parse_connection_url("mssql://sa:pw@localhost").unwrap();
assert_eq!(parts.port, 1433);
assert_eq!(parts.database, None);
}
#[test]
fn parse_url_percent_decodes_credentials() {
let parts = parse_connection_url("mssql://us%65r:p%40ss%3Aw%2Frd@host/db").unwrap();
assert_eq!(parts.username, "user");
assert_eq!(parts.password, "p@ss:w/rd");
}
#[test]
fn parse_url_rejects_wrong_scheme_and_missing_host() {
assert!(parse_connection_url("postgres://sa:pw@host/db").is_err());
assert!(parse_connection_url("not a url").is_err());
}
#[test]
fn quote_ident_brackets_and_doubles_closing_bracket() {
assert_eq!(quote_ident_mssql("events").unwrap(), "[events]");
assert_eq!(quote_ident_mssql("dbo.events").unwrap(), "[dbo.events]");
assert_eq!(quote_ident_mssql("we[i]rd").unwrap(), "[we[i]]rd]");
assert!(quote_ident_mssql("bad\0name").is_err());
}
}