use sqlx::{PgPool, Row};
use std::sync::Arc;
use std::time::Duration;
use crate::backends::error::{DatabaseError, Result};
#[non_exhaustive]
#[derive(Debug, Clone)]
pub struct CockroachDBConnectionConfig {
pub url: String,
pub max_connections: u32,
pub min_connections: u32,
pub connect_timeout: Duration,
pub idle_timeout: Duration,
pub application_name: Option<String>,
}
impl Default for CockroachDBConnectionConfig {
fn default() -> Self {
Self {
url: "postgresql://localhost:26257/defaultdb".to_string(),
max_connections: 10,
min_connections: 2,
connect_timeout: Duration::from_secs(30),
idle_timeout: Duration::from_secs(600),
application_name: None,
}
}
}
impl CockroachDBConnectionConfig {
pub fn new(url: impl Into<String>) -> Self {
Self {
url: url.into(),
..Default::default()
}
}
pub fn with_max_connections(mut self, max: u32) -> Self {
self.max_connections = max;
self
}
pub fn with_min_connections(mut self, min: u32) -> Self {
self.min_connections = min;
self
}
pub fn with_connect_timeout(mut self, timeout: Duration) -> Self {
self.connect_timeout = timeout;
self
}
pub fn with_idle_timeout(mut self, timeout: Duration) -> Self {
self.idle_timeout = timeout;
self
}
pub fn with_application_name(mut self, name: impl Into<String>) -> Self {
self.application_name = Some(name.into());
self
}
}
#[derive(Clone)]
pub struct CockroachDBConnection {
pool: Arc<PgPool>,
}
impl CockroachDBConnection {
pub async fn connect(config: CockroachDBConnectionConfig) -> Result<Self> {
let url = build_connection_url(&config.url, config.application_name.as_deref());
let pool = PgPool::connect(&url).await.map_err(DatabaseError::from)?;
Ok(Self {
pool: Arc::new(pool),
})
}
pub fn from_pool(pool: PgPool) -> Self {
Self {
pool: Arc::new(pool),
}
}
pub fn from_pool_arc(pool: Arc<PgPool>) -> Self {
Self { pool }
}
pub fn pool(&self) -> &PgPool {
&self.pool
}
pub fn pool_arc(&self) -> Arc<PgPool> {
Arc::clone(&self.pool)
}
pub async fn ping(&self) -> Result<()> {
sqlx::query("SELECT 1")
.execute(self.pool.as_ref())
.await
.map_err(DatabaseError::from)?;
Ok(())
}
pub async fn version(&self) -> Result<String> {
let row = sqlx::query("SELECT version()")
.fetch_one(self.pool.as_ref())
.await
.map_err(DatabaseError::from)?;
row.try_get(0).map_err(DatabaseError::from)
}
pub async fn current_database(&self) -> Result<String> {
let row = sqlx::query("SELECT current_database()")
.fetch_one(self.pool.as_ref())
.await
.map_err(DatabaseError::from)?;
row.try_get(0).map_err(DatabaseError::from)
}
pub async fn list_regions(&self) -> Result<Vec<String>> {
let rows = sqlx::query("SHOW REGIONS")
.fetch_all(self.pool.as_ref())
.await
.map_err(DatabaseError::from)?;
let mut regions = Vec::new();
for row in rows {
let region: String = row.try_get(0).map_err(DatabaseError::from)?;
regions.push(region);
}
Ok(regions)
}
pub async fn primary_region(&self) -> Result<Option<String>> {
let row = sqlx::query("SHOW PRIMARY REGION")
.fetch_optional(self.pool.as_ref())
.await
.map_err(DatabaseError::from)?;
if let Some(row) = row {
Ok(Some(row.try_get(0).map_err(DatabaseError::from)?))
} else {
Ok(None)
}
}
pub async fn close(&self) {
self.pool.close().await;
}
}
fn build_connection_url(base_url: &str, application_name: Option<&str>) -> String {
let mut url = base_url.to_string();
if let Some(app_name) = application_name {
let encoded: String = app_name
.chars()
.map(|c| match c {
' ' => "%20".to_string(),
'&' => "%26".to_string(),
'=' => "%3D".to_string(),
'?' => "%3F".to_string(),
'#' => "%23".to_string(),
'%' => "%25".to_string(),
'+' => "%2B".to_string(),
_ => c.to_string(),
})
.collect();
let separator = if url.contains('?') { '&' } else { '?' };
url = format!("{}{}application_name={}", url, separator, encoded);
}
url
}
#[cfg(test)]
mod tests {
use super::*;
use rstest::rstest;
#[rstest]
fn test_config_default() {
let config = CockroachDBConnectionConfig::default();
assert_eq!(config.url, "postgresql://localhost:26257/defaultdb");
assert_eq!(config.max_connections, 10);
assert_eq!(config.min_connections, 2);
}
#[rstest]
fn test_config_new() {
let config = CockroachDBConnectionConfig::new("postgresql://localhost:26257/mydb");
assert_eq!(config.url, "postgresql://localhost:26257/mydb");
}
#[rstest]
fn test_config_with_max_connections() {
let config = CockroachDBConnectionConfig::new("postgresql://localhost:26257/mydb")
.with_max_connections(20);
assert_eq!(config.max_connections, 20);
}
#[rstest]
fn test_config_with_min_connections() {
let config = CockroachDBConnectionConfig::new("postgresql://localhost:26257/mydb")
.with_min_connections(5);
assert_eq!(config.min_connections, 5);
}
#[rstest]
fn test_config_with_connect_timeout() {
let config = CockroachDBConnectionConfig::new("postgresql://localhost:26257/mydb")
.with_connect_timeout(Duration::from_secs(10));
assert_eq!(config.connect_timeout, Duration::from_secs(10));
}
#[rstest]
fn test_config_with_idle_timeout() {
let config = CockroachDBConnectionConfig::new("postgresql://localhost:26257/mydb")
.with_idle_timeout(Duration::from_secs(300));
assert_eq!(config.idle_timeout, Duration::from_secs(300));
}
#[rstest]
fn test_config_with_application_name() {
let config = CockroachDBConnectionConfig::new("postgresql://localhost:26257/mydb")
.with_application_name("my-app");
assert_eq!(config.application_name, Some("my-app".to_string()));
}
#[rstest]
fn test_config_chaining() {
let config = CockroachDBConnectionConfig::new("postgresql://localhost:26257/mydb")
.with_max_connections(20)
.with_min_connections(5)
.with_connect_timeout(Duration::from_secs(10))
.with_application_name("my-app");
assert_eq!(config.max_connections, 20);
assert_eq!(config.min_connections, 5);
assert_eq!(config.connect_timeout, Duration::from_secs(10));
assert_eq!(config.application_name, Some("my-app".to_string()));
}
#[tokio::test]
async fn test_connection_from_pool() {
let pool = PgPool::connect_lazy("postgresql://localhost:26257/testdb")
.expect("Failed to create lazy pool");
let conn = CockroachDBConnection::from_pool(pool);
assert!(Arc::strong_count(&conn.pool) >= 1);
}
#[tokio::test]
async fn test_connection_clone() {
let pool = Arc::new(
PgPool::connect_lazy("postgresql://localhost:26257/testdb")
.expect("Failed to create lazy pool"),
);
let conn1 = CockroachDBConnection::from_pool_arc(pool.clone());
let conn2 = conn1.clone();
assert!(Arc::ptr_eq(&conn1.pool, &conn2.pool));
}
#[rstest]
fn test_build_connection_url_no_app_name() {
let base_url = "postgresql://localhost:26257/mydb";
let result = build_connection_url(base_url, None);
assert_eq!(result, "postgresql://localhost:26257/mydb");
}
#[rstest]
fn test_build_connection_url_simple_app_name() {
let base_url = "postgresql://localhost:26257/mydb";
let result = build_connection_url(base_url, Some("my-app"));
assert_eq!(
result,
"postgresql://localhost:26257/mydb?application_name=my-app"
);
}
#[rstest]
fn test_build_connection_url_special_chars_encoded() {
let base_url = "postgresql://localhost:26257/mydb";
let result = build_connection_url(base_url, Some("my app&name=v1"));
assert_eq!(
result,
"postgresql://localhost:26257/mydb?application_name=my%20app%26name%3Dv1"
);
}
#[rstest]
fn test_build_connection_url_existing_query_params() {
let base_url = "postgresql://localhost:26257/mydb?sslmode=require";
let result = build_connection_url(base_url, Some("my-app"));
assert_eq!(
result,
"postgresql://localhost:26257/mydb?sslmode=require&application_name=my-app"
);
}
#[rstest]
fn test_build_connection_url_percent_in_name() {
let base_url = "postgresql://localhost:26257/mydb";
let result = build_connection_url(base_url, Some("100%done"));
assert_eq!(
result,
"postgresql://localhost:26257/mydb?application_name=100%25done"
);
}
#[rstest]
fn test_build_connection_url_hash_and_question_mark() {
let base_url = "postgresql://localhost:26257/mydb";
let result = build_connection_url(base_url, Some("app#1?v2"));
assert_eq!(
result,
"postgresql://localhost:26257/mydb?application_name=app%231%3Fv2"
);
}
}