use pyo3::exceptions::PyRuntimeError;
use pyo3::prelude::*;
use crate::{AdminError, Database, DatabaseAdmin, TenantConfig, TenantCredentials};
#[pyclass(name = "TenantConfig")]
pub struct PyTenantConfig {
pub inner: TenantConfig,
}
#[pymethods]
impl PyTenantConfig {
#[new]
pub fn new(schema_name: String, username: String, password: Option<String>) -> Self {
Self {
inner: TenantConfig {
schema_name,
username,
password: password.unwrap_or_default(),
},
}
}
#[getter]
pub fn schema_name(&self) -> String {
self.inner.schema_name.clone()
}
#[getter]
pub fn username(&self) -> String {
self.inner.username.clone()
}
#[getter]
pub fn password(&self) -> String {
self.inner.password.clone()
}
pub fn __repr__(&self) -> String {
format!(
"TenantConfig(schema_name='{}', username='{}', password='***')",
self.inner.schema_name, self.inner.username
)
}
}
#[pyclass(name = "TenantCredentials")]
pub struct PyTenantCredentials {
pub inner: TenantCredentials,
}
#[pymethods]
impl PyTenantCredentials {
#[getter]
pub fn username(&self) -> String {
self.inner.username.clone()
}
#[getter]
pub fn password(&self) -> String {
self.inner.password.clone()
}
#[getter]
pub fn schema_name(&self) -> String {
self.inner.schema_name.clone()
}
#[getter]
pub fn connection_string(&self) -> String {
self.inner.connection_string.clone()
}
pub fn __repr__(&self) -> String {
format!(
"TenantCredentials(username='{}', schema_name='{}', password='***', connection_string='***')",
self.inner.username, self.inner.schema_name
)
}
}
fn is_postgres_url(url: &str) -> bool {
url.starts_with("postgres://") || url.starts_with("postgresql://")
}
#[pyclass(name = "DatabaseAdmin")]
pub struct PyDatabaseAdmin {
pub inner: DatabaseAdmin,
}
#[pymethods]
impl PyDatabaseAdmin {
#[new]
pub fn new(database_url: String) -> PyResult<Self> {
if !is_postgres_url(&database_url) {
return Err(PyRuntimeError::new_err(
"DatabaseAdmin requires a PostgreSQL connection. \
SQLite does not support database schemas or user management. \
Use a PostgreSQL URL like 'postgres://user:pass@host/db'",
));
}
let url = url::Url::parse(&database_url)
.map_err(|e| PyRuntimeError::new_err(format!("Invalid database URL: {}", e)))?;
let database_name = url.path().trim_start_matches('/');
if database_name.is_empty() {
return Err(PyRuntimeError::new_err(
"Database name is required in URL path",
));
}
let username = url.username();
let password = url.password().unwrap_or("");
let host = url.host_str().unwrap_or("localhost");
let port = url.port().unwrap_or(5432);
let connection_string = if password.is_empty() {
format!("{}://{}@{}:{}", url.scheme(), username, host, port)
} else {
format!(
"{}://{}:{}@{}:{}",
url.scheme(),
username,
password,
host,
port
)
};
let database = Database::new(&connection_string, database_name, 10);
let admin = DatabaseAdmin::new(database);
Ok(Self { inner: admin })
}
pub fn create_tenant(&self, config: &PyTenantConfig) -> PyResult<PyTenantCredentials> {
let rt = tokio::runtime::Runtime::new()
.map_err(|e| PyRuntimeError::new_err(format!("Failed to create runtime: {}", e)))?;
let tenant_config = TenantConfig {
schema_name: config.inner.schema_name.clone(),
username: config.inner.username.clone(),
password: config.inner.password.clone(),
};
let credentials = rt
.block_on(async { self.inner.create_tenant(tenant_config).await })
.map_err(|e: AdminError| {
PyRuntimeError::new_err(format!("Failed to create tenant: {}", e))
})?;
Ok(PyTenantCredentials { inner: credentials })
}
pub fn remove_tenant(&self, schema_name: String, username: String) -> PyResult<()> {
let rt = tokio::runtime::Runtime::new()
.map_err(|e| PyRuntimeError::new_err(format!("Failed to create runtime: {}", e)))?;
rt.block_on(async { self.inner.remove_tenant(&schema_name, &username).await })
.map_err(|e: AdminError| {
PyRuntimeError::new_err(format!("Failed to remove tenant: {}", e))
})?;
Ok(())
}
pub fn __repr__(&self) -> String {
"DatabaseAdmin()".to_string()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_tenant_config_new() {
let config = PyTenantConfig::new(
"test_schema".to_string(),
"test_user".to_string(),
Some("test_pass".to_string()),
);
assert_eq!(config.schema_name(), "test_schema");
assert_eq!(config.username(), "test_user");
assert_eq!(config.password(), "test_pass");
}
#[test]
fn test_tenant_config_default_password() {
let config = PyTenantConfig::new("schema".to_string(), "user".to_string(), None);
assert_eq!(config.password(), "");
}
#[test]
fn test_tenant_config_repr() {
let config = PyTenantConfig::new(
"my_schema".to_string(),
"my_user".to_string(),
Some("secret".to_string()),
);
let repr = config.__repr__();
assert!(repr.contains("my_schema"));
assert!(repr.contains("my_user"));
assert!(repr.contains("***")); assert!(!repr.contains("secret")); }
#[test]
fn test_is_postgres_url() {
assert!(is_postgres_url("postgres://user:pass@localhost/db"));
assert!(is_postgres_url("postgresql://user:pass@localhost/db"));
assert!(!is_postgres_url("sqlite:///tmp/test.db"));
assert!(!is_postgres_url("mysql://localhost/db"));
}
#[test]
fn test_database_admin_rejects_sqlite() {
let result = PyDatabaseAdmin::new("sqlite:///tmp/test.db".to_string());
assert!(result.is_err());
}
#[test]
fn test_database_admin_rejects_invalid_url() {
let result = PyDatabaseAdmin::new("not a url".to_string());
assert!(result.is_err());
}
#[test]
fn test_database_admin_rejects_missing_db_name() {
let result = PyDatabaseAdmin::new("postgres://user:pass@localhost/".to_string());
assert!(result.is_err());
}
}