#![allow(dead_code)]
#![allow(missing_docs)]
#![allow(clippy::too_many_arguments)]
use crate::error::{IoError, Result};
use crate::metadata::Metadata;
use scirs2_core::ndarray::{Array1, Array2, ArrayView2};
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
#[cfg(feature = "sqlite")]
pub mod sqlite;
#[cfg(feature = "postgres")]
pub mod postgres;
#[cfg(feature = "mysql")]
pub mod mysql;
pub mod pool;
pub mod bulk;
pub mod timeseries;
pub use self::pool::ConnectionPool;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
pub enum DatabaseType {
PostgreSQL,
MySQL,
SQLite,
MongoDB,
InfluxDB,
Redis,
Cassandra,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct DatabaseConfig {
pub db_type: DatabaseType,
pub host: Option<String>,
pub port: Option<u16>,
pub database: String,
pub username: Option<String>,
pub password: Option<String>,
pub options: HashMap<String, String>,
}
impl DatabaseConfig {
pub fn new(db_type: DatabaseType, database: impl Into<String>) -> Self {
Self {
db_type,
host: None,
port: None,
database: database.into(),
username: None,
password: None,
options: HashMap::new(),
}
}
pub fn host(mut self, host: impl Into<String>) -> Self {
self.host = Some(host.into());
self
}
pub fn port(mut self, port: u16) -> Self {
self.port = Some(port);
self
}
pub fn credentials(mut self, username: impl Into<String>, password: impl Into<String>) -> Self {
self.username = Some(username.into());
self.password = Some(password.into());
self
}
pub fn option(mut self, key: impl Into<String>, value: impl Into<String>) -> Self {
self.options.insert(key.into(), value.into());
self
}
pub fn connection_string(&self) -> String {
match self.db_type {
DatabaseType::PostgreSQL => {
let host = self.host.as_deref().unwrap_or("localhost");
let port = self.port.unwrap_or(5432);
let user = self.username.as_deref().unwrap_or("postgres");
format!(
"postgresql://{}:password@{}:{}/{}",
user, host, port, self.database
)
}
DatabaseType::MySQL => {
let host = self.host.as_deref().unwrap_or("localhost");
let port = self.port.unwrap_or(3306);
let user = self.username.as_deref().unwrap_or("root");
format!(
"mysql://{}:password@{}:{}/{}",
user, host, port, self.database
)
}
DatabaseType::SQLite => {
format!("sqlite://{}", self.database)
}
DatabaseType::MongoDB => {
let host = self.host.as_deref().unwrap_or("localhost");
let port = self.port.unwrap_or(27017);
format!("mongodb://{}:{}/{}", host, port, self.database)
}
_ => format!("{}://{}", self.db_type.as_str(), self.database),
}
}
}
impl DatabaseType {
fn as_str(&self) -> &'static str {
match self {
Self::PostgreSQL => "postgresql",
Self::MySQL => "mysql",
Self::SQLite => "sqlite",
Self::MongoDB => "mongodb",
Self::InfluxDB => "influxdb",
Self::Redis => "redis",
Self::Cassandra => "cassandra",
}
}
}
pub struct QueryBuilder {
pub(crate) query_type: QueryType,
pub(crate) table: String,
pub(crate) columns: Vec<String>,
pub(crate) conditions: Vec<String>,
pub(crate) values: Vec<serde_json::Value>,
pub(crate) order_by: Option<String>,
pub(crate) limit: Option<usize>,
pub(crate) offset: Option<usize>,
}
#[derive(Debug, Clone)]
#[allow(dead_code)]
pub(crate) enum QueryType {
Select,
Insert,
Update,
Delete,
CreateTable,
}
impl QueryBuilder {
pub fn select(table: impl Into<String>) -> Self {
Self {
query_type: QueryType::Select,
table: table.into(),
columns: vec!["*".to_string()],
conditions: Vec::new(),
values: Vec::new(),
order_by: None,
limit: None,
offset: None,
}
}
pub fn insert(table: impl Into<String>) -> Self {
Self {
query_type: QueryType::Insert,
table: table.into(),
columns: Vec::new(),
conditions: Vec::new(),
values: Vec::new(),
order_by: None,
limit: None,
offset: None,
}
}
pub fn columns(mut self, columns: Vec<impl Into<String>>) -> Self {
self.columns = columns.into_iter().map(|c| c.into()).collect();
self
}
pub fn where_clause(mut self, condition: impl Into<String>) -> Self {
self.conditions.push(condition.into());
self
}
pub fn values(mut self, values: Vec<serde_json::Value>) -> Self {
self.values = values;
self
}
pub fn order_by(mut self, column: impl Into<String>, desc: bool) -> Self {
self.order_by = Some(format!(
"{} {}",
column.into(),
if desc { "DESC" } else { "ASC" }
));
self
}
pub fn limit(mut self, limit: usize) -> Self {
self.limit = Some(limit);
self
}
pub fn offset(mut self, offset: usize) -> Self {
self.offset = Some(offset);
self
}
pub fn build_sql(&self) -> String {
match self.query_type {
QueryType::Select => {
let mut sql = format!("SELECT {} FROM {}", self.columns.join(", "), self.table);
if !self.conditions.is_empty() {
sql.push_str(&format!(" WHERE {}", self.conditions.join(" AND ")));
}
if let Some(order) = &self.order_by {
sql.push_str(&format!(" ORDER BY {order}"));
}
if let Some(limit) = self.limit {
sql.push_str(&format!(" LIMIT {limit}"));
}
if let Some(offset) = self.offset {
sql.push_str(&format!(" OFFSET {offset}"));
}
sql
}
QueryType::Insert => {
format!(
"INSERT INTO {} ({}) VALUES ({})",
self.table,
self.columns.join(", "),
self.values
.iter()
.map(|_| "?")
.collect::<Vec<_>>()
.join(", ")
)
}
_ => String::new(),
}
}
pub fn build_mongo(&self) -> serde_json::Value {
match self.query_type {
QueryType::Select => {
let mut query = serde_json::json!({});
for condition in &self.conditions {
if let Some((field, value)) = condition.split_once(" = ") {
query[field] = serde_json::json!(value.trim_matches('\''));
}
}
serde_json::json!({
"collection": self.table,
"filter": query,
"limit": self.limit,
"skip": self.offset,
})
}
_ => serde_json::json!({}),
}
}
}
#[derive(Debug, Clone)]
pub struct ResultSet {
pub columns: Vec<String>,
pub rows: Vec<Vec<serde_json::Value>>,
pub metadata: Metadata,
}
impl ResultSet {
pub fn new(columns: Vec<String>) -> Self {
Self {
columns,
rows: Vec::new(),
metadata: Metadata::new(),
}
}
pub fn add_row(&mut self, row: Vec<serde_json::Value>) {
self.rows.push(row);
}
pub fn row_count(&self) -> usize {
self.rows.len()
}
pub fn column_count(&self) -> usize {
self.columns.len()
}
pub fn to_array(&self) -> Result<Array2<f64>> {
let mut data = Vec::new();
for row in &self.rows {
for value in row {
let num = value.as_f64().ok_or_else(|| {
IoError::ConversionError("Non-numeric value in result set".to_string())
})?;
data.push(num);
}
}
Array2::from_shape_vec((self.row_count(), self.column_count()), data)
.map_err(|e| IoError::Other(e.to_string()))
}
pub fn get_column(&self, name: &str) -> Result<Array1<f64>> {
let col_idx = self
.columns
.iter()
.position(|c| c == name)
.ok_or_else(|| IoError::Other(format!("Column '{name}' not found")))?;
let mut data = Vec::new();
for row in &self.rows {
let num = row[col_idx].as_f64().ok_or_else(|| {
IoError::ConversionError("Non-numeric value in column".to_string())
})?;
data.push(num);
}
Ok(Array1::from_vec(data))
}
}
pub trait DatabaseConnection: Send + Sync {
fn query(&self, query: &QueryBuilder) -> Result<ResultSet>;
fn execute_sql(&self, sql: &str, params: &[serde_json::Value]) -> Result<ResultSet>;
fn insert_array(&self, table: &str, data: ArrayView2<f64>, columns: &[&str]) -> Result<usize>;
fn create_table(&self, table: &str, schema: &TableSchema) -> Result<()>;
fn table_exists(&self, table: &str) -> Result<bool>;
fn get_schema(&self, table: &str) -> Result<TableSchema>;
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct TableSchema {
pub name: String,
pub columns: Vec<ColumnDef>,
pub primary_key: Option<Vec<String>>,
pub indexes: Vec<Index>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ColumnDef {
pub name: String,
pub data_type: DataType,
pub nullable: bool,
pub default: Option<serde_json::Value>,
}
#[derive(Debug, Clone, Copy, Serialize, Deserialize)]
#[serde(rename_all = "lowercase")]
pub enum DataType {
Integer,
BigInt,
Float,
Double,
Decimal(u8, u8),
Varchar(usize),
Text,
Boolean,
Date,
Timestamp,
Json,
Binary,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Index {
pub name: String,
pub columns: Vec<String>,
pub unique: bool,
}
pub struct DatabaseConnector;
impl DatabaseConnector {
pub fn connect(config: &DatabaseConfig) -> Result<Box<dyn DatabaseConnection>> {
match config.db_type {
#[cfg(feature = "sqlite")]
DatabaseType::SQLite => Ok(Box::new(sqlite::SQLiteConnection::new(config)?)),
#[cfg(not(feature = "sqlite"))]
DatabaseType::SQLite => Err(IoError::UnsupportedFormat(
"SQLite support not enabled. Enable 'sqlite' feature.".to_string(),
)),
#[cfg(feature = "postgres")]
DatabaseType::PostgreSQL => Ok(Box::new(postgres::PostgreSQLConnection::new(config)?)),
#[cfg(not(feature = "postgres"))]
DatabaseType::PostgreSQL => Err(IoError::UnsupportedFormat(
"PostgreSQL support not enabled. Enable 'postgres' feature.".to_string(),
)),
#[cfg(feature = "mysql")]
DatabaseType::MySQL => Ok(Box::new(mysql::MySQLConnection::new(config)?)),
#[cfg(not(feature = "mysql"))]
DatabaseType::MySQL => Err(IoError::UnsupportedFormat(
"MySQL support not enabled. Enable 'mysql' feature.".to_string(),
)),
_ => Err(IoError::UnsupportedFormat(format!(
"Database type {:?} not yet implemented",
config.db_type
))),
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_database_config() {
let config = DatabaseConfig::new(DatabaseType::SQLite, "test.db");
assert_eq!(config.db_type, DatabaseType::SQLite);
assert_eq!(config.database, "test.db");
assert_eq!(config.connection_string(), "sqlite://test.db");
}
#[test]
fn test_query_builder() {
let query = QueryBuilder::select("users")
.columns(vec!["id", "name", "email"])
.where_clause("age > 21")
.limit(10);
let sql = query.build_sql();
assert!(sql.contains("SELECT id, name, email FROM users"));
assert!(sql.contains("WHERE age > 21"));
assert!(sql.contains("LIMIT 10"));
}
}