#[cfg(feature = "sqlx-storage")]
use bon::Builder;
#[cfg(feature = "sqlx-storage")]
use serde::{Deserialize, Serialize};
#[cfg(feature = "sqlx-storage")]
use std::collections::HashMap;
#[cfg(feature = "sqlx-storage")]
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum DatabaseType {
Sqlite,
Postgres,
Mysql,
}
#[cfg(feature = "sqlx-storage")]
impl DatabaseType {
pub fn from_url(url: &str) -> Option<Self> {
if url.starts_with("sqlite:") {
Some(Self::Sqlite)
} else if url.starts_with("postgres:") || url.starts_with("postgresql:") {
Some(Self::Postgres)
} else if url.starts_with("mysql:") || url.starts_with("mariadb:") {
Some(Self::Mysql)
} else {
None
}
}
pub fn is_feature_enabled(self) -> bool {
match self {
Self::Sqlite => cfg!(feature = "sqlite"),
Self::Postgres => cfg!(feature = "postgres"),
Self::Mysql => cfg!(feature = "mysql"),
}
}
pub fn feature_name(self) -> &'static str {
match self {
Self::Sqlite => "sqlite",
Self::Postgres => "postgres",
Self::Mysql => "mysql",
}
}
}
#[cfg(feature = "sqlx-storage")]
impl std::fmt::Display for DatabaseType {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::Sqlite => write!(f, "SQLite"),
Self::Postgres => write!(f, "PostgreSQL"),
Self::Mysql => write!(f, "MySQL"),
}
}
}
#[cfg(feature = "sqlx-storage")]
#[derive(Debug, Clone, Builder, Serialize, Deserialize)]
pub struct DatabaseConfig {
pub url: String,
#[builder(default = 10)]
pub max_connections: u32,
#[builder(default = 30)]
pub timeout_seconds: u64,
#[builder(default = false)]
pub enable_logging: bool,
}
#[cfg(feature = "sqlx-storage")]
impl DatabaseConfig {
pub fn examples() -> HashMap<&'static str, Self> {
[
(
"sqlite_memory",
Self::builder()
.url("sqlite::memory:".to_string())
.max_connections(1)
.enable_logging(true)
.build(),
),
(
"sqlite_file",
Self::builder()
.url("sqlite:a2a_tasks.db".to_string())
.max_connections(5)
.build(),
),
(
"postgres_dev",
Self::builder()
.url("postgres://user:password@localhost/a2a_dev".to_string())
.max_connections(10)
.timeout_seconds(10)
.build(),
),
(
"postgres_prod",
Self::builder()
.url("postgres://user:password@prod-db/a2a_prod".to_string())
.max_connections(50)
.timeout_seconds(5)
.enable_logging(false)
.build(),
),
(
"mysql_dev",
Self::builder()
.url("mysql://user:password@localhost/a2a_dev".to_string())
.max_connections(10)
.timeout_seconds(10)
.build(),
),
]
.into_iter()
.collect()
}
pub fn from_env() -> Result<Self, std::env::VarError> {
let url = std::env::var("DATABASE_URL")?;
let max_connections = std::env::var("DATABASE_MAX_CONNECTIONS")
.ok()
.and_then(|s| s.parse().ok())
.unwrap_or(10);
let timeout_seconds = std::env::var("DATABASE_TIMEOUT_SECONDS")
.ok()
.and_then(|s| s.parse().ok())
.unwrap_or(30);
let enable_logging = std::env::var("DATABASE_ENABLE_LOGGING")
.ok()
.and_then(|s| s.parse().ok())
.unwrap_or(false);
Ok(Self::builder()
.url(url)
.max_connections(max_connections)
.timeout_seconds(timeout_seconds)
.enable_logging(enable_logging)
.build())
}
pub fn validate(&self) -> Result<(), String> {
if self.url.is_empty() {
return Err("Database URL cannot be empty".to_string());
}
if self.max_connections == 0 {
return Err("Max connections must be greater than 0".to_string());
}
if self.timeout_seconds == 0 {
return Err("Timeout must be greater than 0".to_string());
}
if !self.url.contains("://") && !self.url.starts_with("sqlite:") {
return Err(
"Database URL must contain a protocol (e.g., sqlite://, postgres://, mysql://)"
.to_string(),
);
}
Ok(())
}
pub fn database_type(&self) -> Option<DatabaseType> {
DatabaseType::from_url(&self.url)
}
pub fn validate_database_support(&self) -> Result<DatabaseType, String> {
let db_type = self.database_type().ok_or_else(|| {
format!(
"Unrecognized database URL scheme in '{}'. Expected sqlite:, postgres:, or mysql:",
self.url
)
})?;
if !db_type.is_feature_enabled() {
return Err(format!(
"{} database detected from URL but the '{}' feature is not enabled. \
Add `features = [\"{}\"]` to your a2a-rs dependency.",
db_type,
db_type.feature_name(),
db_type.feature_name(),
));
}
Ok(db_type)
}
}
#[cfg(feature = "sqlx-storage")]
impl Default for DatabaseConfig {
fn default() -> Self {
Self::builder().url("sqlite::memory:".to_string()).build()
}
}
#[cfg(test)]
#[cfg(feature = "sqlx-storage")]
mod tests {
use super::*;
#[test]
fn test_database_config_validation() {
let config = DatabaseConfig::builder()
.url("sqlite:test.db".to_string())
.build();
assert!(config.validate().is_ok());
let config = DatabaseConfig::builder().url("".to_string()).build();
assert!(config.validate().is_err());
let config = DatabaseConfig::builder()
.url("sqlite:test.db".to_string())
.max_connections(0)
.build();
assert!(config.validate().is_err());
}
#[test]
fn test_database_type_detection() {
let sqlite_config = DatabaseConfig::builder()
.url("sqlite:test.db".to_string())
.build();
assert_eq!(sqlite_config.database_type(), Some(DatabaseType::Sqlite));
let postgres_config = DatabaseConfig::builder()
.url("postgres://localhost/test".to_string())
.build();
assert_eq!(
postgres_config.database_type(),
Some(DatabaseType::Postgres)
);
let postgresql_config = DatabaseConfig::builder()
.url("postgresql://localhost/test".to_string())
.build();
assert_eq!(
postgresql_config.database_type(),
Some(DatabaseType::Postgres)
);
let mysql_config = DatabaseConfig::builder()
.url("mysql://localhost/test".to_string())
.build();
assert_eq!(mysql_config.database_type(), Some(DatabaseType::Mysql));
let unknown_config = DatabaseConfig::builder()
.url("http://localhost".to_string())
.build();
assert_eq!(unknown_config.database_type(), None);
}
#[test]
fn test_database_type_from_url() {
assert_eq!(
DatabaseType::from_url("sqlite::memory:"),
Some(DatabaseType::Sqlite)
);
assert_eq!(
DatabaseType::from_url("sqlite:data.db"),
Some(DatabaseType::Sqlite)
);
assert_eq!(
DatabaseType::from_url("postgres://user:pass@host/db"),
Some(DatabaseType::Postgres)
);
assert_eq!(
DatabaseType::from_url("postgresql://user:pass@host/db"),
Some(DatabaseType::Postgres)
);
assert_eq!(
DatabaseType::from_url("mysql://user:pass@host/db"),
Some(DatabaseType::Mysql)
);
assert_eq!(
DatabaseType::from_url("mariadb://user:pass@host/db"),
Some(DatabaseType::Mysql)
);
assert_eq!(DatabaseType::from_url("ftp://something"), None);
}
#[test]
fn test_examples() {
let examples = DatabaseConfig::examples();
assert!(examples.contains_key("sqlite_memory"));
assert!(examples.contains_key("postgres_dev"));
for (name, config) in examples {
assert!(
config.validate().is_ok(),
"Example '{}' failed validation",
name
);
}
}
}