1use crate::UtilsError;
7use scirs2_core::ndarray::{Array1, Array2};
8use serde::{Deserialize, Serialize};
9use std::collections::HashMap;
10use std::fmt;
11use std::sync::{Arc, Mutex};
12use std::time::Duration;
13
14#[derive(Debug, Clone, Serialize, Deserialize)]
16pub struct DatabaseConfig {
17 pub host: String,
18 pub port: u16,
19 pub database: String,
20 pub username: String,
21 pub password: String,
22 pub pool_size: usize,
23 pub connection_timeout: Duration,
24 pub query_timeout: Duration,
25 pub ssl_mode: SslMode,
26 pub additional_params: HashMap<String, String>,
27}
28
29#[derive(Debug, Clone, Serialize, Deserialize)]
30pub enum SslMode {
31 Disable,
32 Prefer,
33 Require,
34}
35
36impl Default for DatabaseConfig {
37 fn default() -> Self {
38 Self {
39 host: "localhost".to_string(),
40 port: 5432,
41 database: "postgres".to_string(),
42 username: "postgres".to_string(),
43 password: String::new(),
44 pool_size: 10,
45 connection_timeout: Duration::from_secs(30),
46 query_timeout: Duration::from_secs(60),
47 ssl_mode: SslMode::Prefer,
48 additional_params: HashMap::new(),
49 }
50 }
51}
52
53impl DatabaseConfig {
54 pub fn new(host: String, database: String, username: String, password: String) -> Self {
55 Self {
56 host,
57 database,
58 username,
59 password,
60 ..Default::default()
61 }
62 }
63
64 pub fn with_port(mut self, port: u16) -> Self {
65 self.port = port;
66 self
67 }
68
69 pub fn with_pool_size(mut self, pool_size: usize) -> Self {
70 self.pool_size = pool_size;
71 self
72 }
73
74 pub fn with_timeout(mut self, timeout: Duration) -> Self {
75 self.connection_timeout = timeout;
76 self.query_timeout = timeout;
77 self
78 }
79
80 pub fn connection_string(&self) -> String {
81 let ssl_param = match self.ssl_mode {
82 SslMode::Disable => "sslmode=disable",
83 SslMode::Prefer => "sslmode=prefer",
84 SslMode::Require => "sslmode=require",
85 };
86
87 let mut params = vec![
88 format!("host={}", self.host),
89 format!("port={}", self.port),
90 format!("dbname={}", self.database),
91 format!("user={}", self.username),
92 ssl_param.to_string(),
93 ];
94
95 if !self.password.is_empty() {
96 params.push(format!("password={}", self.password));
97 }
98
99 for (key, value) in &self.additional_params {
100 params.push(format!("{key}={value}"));
101 }
102
103 params.join(" ")
104 }
105}
106
107#[derive(thiserror::Error, Debug, Clone)]
109pub enum DatabaseError {
110 #[error("Connection failed: {0}")]
111 ConnectionFailed(String),
112 #[error("Query execution failed: {0}")]
113 QueryFailed(String),
114 #[error("Transaction failed: {0}")]
115 TransactionFailed(String),
116 #[error("Data conversion failed: {0}")]
117 ConversionFailed(String),
118 #[error("Connection pool exhausted")]
119 PoolExhausted,
120 #[error("Invalid configuration: {0}")]
121 InvalidConfig(String),
122}
123
124impl From<DatabaseError> for UtilsError {
125 fn from(err: DatabaseError) -> Self {
126 UtilsError::InvalidParameter(err.to_string())
127 }
128}
129
130#[derive(Debug, Clone)]
132pub struct Row {
133 columns: HashMap<String, Value>,
134 column_order: Vec<String>,
135}
136
137impl Row {
138 pub fn new() -> Self {
139 Self {
140 columns: HashMap::new(),
141 column_order: Vec::new(),
142 }
143 }
144
145 pub fn insert<T: Into<Value>>(&mut self, column: String, value: T) {
146 if !self.columns.contains_key(&column) {
147 self.column_order.push(column.clone());
148 }
149 self.columns.insert(column, value.into());
150 }
151
152 pub fn get(&self, column: &str) -> Option<&Value> {
153 self.columns.get(column)
154 }
155
156 pub fn get_string(&self, column: &str) -> Option<String> {
157 self.get(column)?.as_string()
158 }
159
160 pub fn get_f64(&self, column: &str) -> Option<f64> {
161 self.get(column)?.as_f64()
162 }
163
164 pub fn get_i64(&self, column: &str) -> Option<i64> {
165 self.get(column)?.as_i64()
166 }
167
168 pub fn columns(&self) -> &[String] {
169 &self.column_order
170 }
171}
172
173impl Default for Row {
174 fn default() -> Self {
175 Self::new()
176 }
177}
178
179#[derive(Debug, Clone, PartialEq)]
181pub enum Value {
182 Null,
183 Bool(bool),
184 Int(i64),
185 Float(f64),
186 String(String),
187 Bytes(Vec<u8>),
188}
189
190impl Value {
191 pub fn as_string(&self) -> Option<String> {
192 match self {
193 Value::String(s) => Some(s.clone()),
194 Value::Int(i) => Some(i.to_string()),
195 Value::Float(f) => Some(f.to_string()),
196 Value::Bool(b) => Some(b.to_string()),
197 _ => None,
198 }
199 }
200
201 pub fn as_f64(&self) -> Option<f64> {
202 match self {
203 Value::Float(f) => Some(*f),
204 Value::Int(i) => Some(*i as f64),
205 Value::String(s) => s.parse().ok(),
206 _ => None,
207 }
208 }
209
210 pub fn as_i64(&self) -> Option<i64> {
211 match self {
212 Value::Int(i) => Some(*i),
213 Value::Float(f) => Some(*f as i64),
214 Value::String(s) => s.parse().ok(),
215 _ => None,
216 }
217 }
218
219 pub fn is_null(&self) -> bool {
220 matches!(self, Value::Null)
221 }
222}
223
224impl From<String> for Value {
225 fn from(s: String) -> Self {
226 Value::String(s)
227 }
228}
229
230impl From<&str> for Value {
231 fn from(s: &str) -> Self {
232 Value::String(s.to_string())
233 }
234}
235
236impl From<i64> for Value {
237 fn from(i: i64) -> Self {
238 Value::Int(i)
239 }
240}
241
242impl From<i32> for Value {
243 fn from(i: i32) -> Self {
244 Value::Int(i as i64)
245 }
246}
247
248impl From<f64> for Value {
249 fn from(f: f64) -> Self {
250 Value::Float(f)
251 }
252}
253
254impl From<f32> for Value {
255 fn from(f: f32) -> Self {
256 Value::Float(f as f64)
257 }
258}
259
260impl From<bool> for Value {
261 fn from(b: bool) -> Self {
262 Value::Bool(b)
263 }
264}
265
266pub trait Connection {
268 fn execute(&self, query: &Query) -> Result<QueryResult, DatabaseError>;
269 fn query(&self, query: &Query) -> Result<ResultSet, DatabaseError>;
270 fn begin_transaction(&self) -> Result<Transaction, DatabaseError>;
271 fn close(&self) -> Result<(), DatabaseError>;
272 fn is_connected(&self) -> bool;
273}
274
275pub struct MockConnection {
277 connected: bool,
278 mock_data: HashMap<String, Vec<Row>>,
279}
280
281impl MockConnection {
282 pub fn new() -> Self {
283 Self {
284 connected: true,
285 mock_data: HashMap::new(),
286 }
287 }
288
289 pub fn add_mock_data(&mut self, table: String, rows: Vec<Row>) {
290 self.mock_data.insert(table, rows);
291 }
292}
293
294impl Default for MockConnection {
295 fn default() -> Self {
296 Self::new()
297 }
298}
299
300impl Connection for MockConnection {
301 fn execute(&self, _query: &Query) -> Result<QueryResult, DatabaseError> {
302 if !self.connected {
303 return Err(DatabaseError::ConnectionFailed("Not connected".to_string()));
304 }
305
306 Ok(QueryResult {
307 rows_affected: 1,
308 execution_time: Duration::from_millis(10),
309 })
310 }
311
312 fn query(&self, _query: &Query) -> Result<ResultSet, DatabaseError> {
313 if !self.connected {
314 return Err(DatabaseError::ConnectionFailed("Not connected".to_string()));
315 }
316
317 let mut result = ResultSet::new(vec!["id".to_string(), "value".to_string()]);
319 result.set_execution_time(Duration::from_millis(5));
320 Ok(result)
321 }
322
323 fn begin_transaction(&self) -> Result<Transaction, DatabaseError> {
324 if !self.connected {
325 return Err(DatabaseError::ConnectionFailed("Not connected".to_string()));
326 }
327 Ok(Transaction::new())
328 }
329
330 fn close(&self) -> Result<(), DatabaseError> {
331 Ok(())
332 }
333
334 fn is_connected(&self) -> bool {
335 self.connected
336 }
337}
338
339pub struct DatabasePool {
341 #[allow(dead_code)]
342 config: DatabaseConfig,
343 connections: Arc<Mutex<Vec<Box<dyn Connection + Send + Sync>>>>,
344 max_size: usize,
345}
346
347impl DatabasePool {
348 pub fn new(config: DatabaseConfig) -> Self {
349 let max_size = config.pool_size;
350 Self {
351 config,
352 connections: Arc::new(Mutex::new(Vec::new())),
353 max_size,
354 }
355 }
356
357 pub fn get_connection(&self) -> Result<Box<dyn Connection + Send + Sync>, DatabaseError> {
358 Ok(Box::new(MockConnection::new()))
361 }
362
363 pub fn return_connection(&self, _connection: Box<dyn Connection + Send + Sync>) {
364 }
366
367 pub fn size(&self) -> usize {
368 self.connections.lock().unwrap().len()
369 }
370
371 pub fn max_size(&self) -> usize {
372 self.max_size
373 }
374}
375
376pub struct QueryBuilder {
378 query_type: QueryType,
379 table: Option<String>,
380 columns: Vec<String>,
381 conditions: Vec<String>,
382 joins: Vec<String>,
383 order_by: Vec<String>,
384 group_by: Vec<String>,
385 having: Vec<String>,
386 limit: Option<usize>,
387 offset: Option<usize>,
388 parameters: Vec<Value>,
389}
390
391#[derive(Debug, Clone)]
392#[allow(dead_code)]
393enum QueryType {
394 Select,
395 Insert,
396 Update,
397 Delete,
398}
399
400impl QueryBuilder {
401 pub fn select() -> Self {
402 Self {
403 query_type: QueryType::Select,
404 table: None,
405 columns: Vec::new(),
406 conditions: Vec::new(),
407 joins: Vec::new(),
408 order_by: Vec::new(),
409 group_by: Vec::new(),
410 having: Vec::new(),
411 limit: None,
412 offset: None,
413 parameters: Vec::new(),
414 }
415 }
416
417 pub fn from(mut self, table: &str) -> Self {
418 self.table = Some(table.to_string());
419 self
420 }
421
422 pub fn columns(mut self, columns: &[&str]) -> Self {
423 self.columns = columns.iter().map(|s| s.to_string()).collect();
424 self
425 }
426
427 pub fn where_clause(mut self, condition: &str) -> Self {
428 self.conditions.push(condition.to_string());
429 self
430 }
431
432 pub fn join(mut self, join_clause: &str) -> Self {
433 self.joins.push(join_clause.to_string());
434 self
435 }
436
437 pub fn order_by(mut self, column: &str, ascending: bool) -> Self {
438 let direction = if ascending { "ASC" } else { "DESC" };
439 self.order_by.push(format!("{column} {direction}"));
440 self
441 }
442
443 pub fn group_by(mut self, columns: &[&str]) -> Self {
444 self.group_by = columns.iter().map(|s| s.to_string()).collect();
445 self
446 }
447
448 pub fn limit(mut self, limit: usize) -> Self {
449 self.limit = Some(limit);
450 self
451 }
452
453 pub fn offset(mut self, offset: usize) -> Self {
454 self.offset = Some(offset);
455 self
456 }
457
458 pub fn parameter<T: Into<Value>>(mut self, value: T) -> Self {
459 self.parameters.push(value.into());
460 self
461 }
462
463 pub fn build(self) -> Query {
464 let sql = self.build_sql();
465 Query::new(sql, self.parameters)
466 }
467
468 fn build_sql(&self) -> String {
469 match self.query_type {
470 QueryType::Select => self.build_select(),
471 _ => "".to_string(), }
473 }
474
475 fn build_select(&self) -> String {
476 let mut query = String::new();
477
478 query.push_str("SELECT ");
480 if self.columns.is_empty() {
481 query.push('*');
482 } else {
483 query.push_str(&self.columns.join(", "));
484 }
485
486 if let Some(table) = &self.table {
488 query.push_str(&format!(" FROM {table}"));
489 }
490
491 for join in &self.joins {
493 query.push_str(&format!(" {join}"));
494 }
495
496 if !self.conditions.is_empty() {
498 query.push_str(&format!(" WHERE {}", self.conditions.join(" AND ")));
499 }
500
501 if !self.group_by.is_empty() {
503 query.push_str(&format!(" GROUP BY {}", self.group_by.join(", ")));
504 }
505
506 if !self.having.is_empty() {
508 query.push_str(&format!(" HAVING {}", self.having.join(" AND ")));
509 }
510
511 if !self.order_by.is_empty() {
513 query.push_str(&format!(" ORDER BY {}", self.order_by.join(", ")));
514 }
515
516 if let Some(limit) = self.limit {
518 query.push_str(&format!(" LIMIT {limit}"));
519 }
520
521 if let Some(offset) = self.offset {
523 query.push_str(&format!(" OFFSET {offset}"));
524 }
525
526 query
527 }
528}
529
530#[derive(Debug, Clone)]
532pub struct Query {
533 sql: String,
534 parameters: Vec<Value>,
535}
536
537impl Query {
538 pub fn new(sql: String, parameters: Vec<Value>) -> Self {
539 Self { sql, parameters }
540 }
541
542 pub fn sql(&self) -> &str {
543 &self.sql
544 }
545
546 pub fn parameters(&self) -> &[Value] {
547 &self.parameters
548 }
549}
550
551#[derive(Debug, Clone)]
553pub struct QueryResult {
554 pub rows_affected: usize,
555 pub execution_time: Duration,
556}
557
558#[derive(Debug, Clone)]
560pub struct ResultSet {
561 rows: Vec<Row>,
562 columns: Vec<String>,
563 execution_time: Duration,
564 #[allow(dead_code)]
565 rows_affected: Option<usize>,
566}
567
568impl ResultSet {
569 pub fn new(columns: Vec<String>) -> Self {
570 Self {
571 rows: Vec::new(),
572 columns,
573 execution_time: Duration::from_secs(0),
574 rows_affected: None,
575 }
576 }
577
578 pub fn add_row(&mut self, row: Row) {
579 self.rows.push(row);
580 }
581
582 pub fn rows(&self) -> &[Row] {
583 &self.rows
584 }
585
586 pub fn columns(&self) -> &[String] {
587 &self.columns
588 }
589
590 pub fn len(&self) -> usize {
591 self.rows.len()
592 }
593
594 pub fn is_empty(&self) -> bool {
595 self.rows.is_empty()
596 }
597
598 pub fn execution_time(&self) -> Duration {
599 self.execution_time
600 }
601
602 pub fn set_execution_time(&mut self, time: Duration) {
603 self.execution_time = time;
604 }
605
606 pub fn to_array2(&self) -> Result<Array2<f64>, DatabaseError> {
608 if self.rows.is_empty() {
609 return Err(DatabaseError::ConversionFailed(
610 "Cannot convert empty result set to array".to_string(),
611 ));
612 }
613
614 let n_rows = self.rows.len();
615 let n_cols = self.columns.len();
616 let mut data = Array2::zeros((n_rows, n_cols));
617
618 for (row_idx, row) in self.rows.iter().enumerate() {
619 for (col_idx, col_name) in self.columns.iter().enumerate() {
620 let value = row.get(col_name).ok_or_else(|| {
621 DatabaseError::ConversionFailed(format!("Column '{col_name}' not found in row"))
622 })?;
623
624 let numeric_value = value.as_f64().ok_or_else(|| {
625 DatabaseError::ConversionFailed(format!(
626 "Cannot convert value to f64: {value:?}"
627 ))
628 })?;
629
630 data[[row_idx, col_idx]] = numeric_value;
631 }
632 }
633
634 Ok(data)
635 }
636
637 pub fn column_to_array1(&self, column: &str) -> Result<Array1<f64>, DatabaseError> {
639 if !self.columns.contains(&column.to_string()) {
640 return Err(DatabaseError::ConversionFailed(format!(
641 "Column '{column}' not found"
642 )));
643 }
644
645 let mut data = Array1::zeros(self.rows.len());
646 for (idx, row) in self.rows.iter().enumerate() {
647 let value = row.get(column).ok_or_else(|| {
648 DatabaseError::ConversionFailed(format!("Column '{column}' not found in row"))
649 })?;
650
651 let numeric_value = value.as_f64().ok_or_else(|| {
652 DatabaseError::ConversionFailed(format!("Cannot convert value to f64: {value:?}"))
653 })?;
654
655 data[idx] = numeric_value;
656 }
657
658 Ok(data)
659 }
660
661 pub fn unique_values(&self, column: &str) -> Result<Vec<Value>, DatabaseError> {
663 if !self.columns.contains(&column.to_string()) {
664 return Err(DatabaseError::ConversionFailed(format!(
665 "Column '{column}' not found"
666 )));
667 }
668
669 let mut unique_values = Vec::new();
670 for row in &self.rows {
671 if let Some(value) = row.get(column) {
672 if !unique_values.contains(value) {
673 unique_values.push(value.clone());
674 }
675 }
676 }
677
678 Ok(unique_values)
679 }
680}
681
682pub struct Transaction {
684 committed: bool,
685 rolled_back: bool,
686}
687
688impl Transaction {
689 pub fn new() -> Self {
690 Self {
691 committed: false,
692 rolled_back: false,
693 }
694 }
695
696 pub fn commit(&mut self) -> Result<(), DatabaseError> {
697 if self.rolled_back {
698 return Err(DatabaseError::TransactionFailed(
699 "Transaction already rolled back".to_string(),
700 ));
701 }
702 self.committed = true;
703 Ok(())
704 }
705
706 pub fn rollback(&mut self) -> Result<(), DatabaseError> {
707 if self.committed {
708 return Err(DatabaseError::TransactionFailed(
709 "Transaction already committed".to_string(),
710 ));
711 }
712 self.rolled_back = true;
713 Ok(())
714 }
715
716 pub fn is_committed(&self) -> bool {
717 self.committed
718 }
719
720 pub fn is_rolled_back(&self) -> bool {
721 self.rolled_back
722 }
723}
724
725impl Default for Transaction {
726 fn default() -> Self {
727 Self::new()
728 }
729}
730
731impl fmt::Display for DatabaseConfig {
732 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
733 write!(
734 f,
735 "{}@{}:{}/{}",
736 self.username, self.host, self.port, self.database
737 )
738 }
739}
740
741#[allow(non_snake_case)]
742#[cfg(test)]
743mod tests {
744 use super::*;
745
746 #[test]
747 fn test_database_config() {
748 let config = DatabaseConfig::new(
749 "localhost".to_string(),
750 "testdb".to_string(),
751 "user".to_string(),
752 "pass".to_string(),
753 )
754 .with_port(3306)
755 .with_pool_size(5);
756
757 assert_eq!(config.host, "localhost");
758 assert_eq!(config.port, 3306);
759 assert_eq!(config.pool_size, 5);
760
761 let conn_str = config.connection_string();
762 assert!(conn_str.contains("host=localhost"));
763 assert!(conn_str.contains("port=3306"));
764 }
765
766 #[test]
767 fn test_value_conversions() {
768 let int_val = Value::from(42i64);
769 assert_eq!(int_val.as_i64(), Some(42));
770 assert_eq!(int_val.as_f64(), Some(42.0));
771
772 let float_val = Value::from(std::f64::consts::PI);
773 assert_eq!(float_val.as_f64(), Some(std::f64::consts::PI));
774
775 let string_val = Value::from("hello");
776 assert_eq!(string_val.as_string(), Some("hello".to_string()));
777 }
778
779 #[test]
780 fn test_row_operations() {
781 let mut row = Row::new();
782 row.insert("id".to_string(), 1i64);
783 row.insert("name".to_string(), "test");
784 row.insert("score".to_string(), 95.5f64);
785
786 assert_eq!(row.get_i64("id"), Some(1));
787 assert_eq!(row.get_string("name"), Some("test".to_string()));
788 assert_eq!(row.get_f64("score"), Some(95.5));
789 assert_eq!(row.columns().len(), 3);
790 }
791
792 #[test]
793 fn test_result_set_array_conversion() {
794 let mut result_set = ResultSet::new(vec!["a".to_string(), "b".to_string()]);
795
796 let mut row1 = Row::new();
797 row1.insert("a".to_string(), 1.0f64);
798 row1.insert("b".to_string(), 2.0f64);
799 result_set.add_row(row1);
800
801 let mut row2 = Row::new();
802 row2.insert("a".to_string(), 3.0f64);
803 row2.insert("b".to_string(), 4.0f64);
804 result_set.add_row(row2);
805
806 let array = result_set.to_array2().unwrap();
807 assert_eq!(array.shape(), &[2, 2]);
808 assert_eq!(array[[0, 0]], 1.0);
809 assert_eq!(array[[1, 1]], 4.0);
810 }
811
812 #[test]
813 fn test_query_builder() {
814 let query = QueryBuilder::select()
815 .columns(&["id", "name", "score"])
816 .from("users")
817 .where_clause("score > 80")
818 .order_by("score", false)
819 .limit(10)
820 .build();
821
822 let sql = query.sql();
823 assert!(sql.contains("SELECT id, name, score"));
824 assert!(sql.contains("FROM users"));
825 assert!(sql.contains("WHERE score > 80"));
826 assert!(sql.contains("ORDER BY score DESC"));
827 assert!(sql.contains("LIMIT 10"));
828 }
829
830 #[test]
831 fn test_mock_connection() {
832 let connection = MockConnection::new();
833 assert!(connection.is_connected());
834
835 let query = Query::new("SELECT 1".to_string(), vec![]);
836 let result = connection.execute(&query).unwrap();
837 assert_eq!(result.rows_affected, 1);
838
839 let result_set = connection.query(&query).unwrap();
840 assert_eq!(result_set.columns().len(), 2);
841 }
842
843 #[test]
844 fn test_transaction() {
845 let mut transaction = Transaction::new();
846 assert!(!transaction.is_committed());
847 assert!(!transaction.is_rolled_back());
848
849 transaction.commit().unwrap();
850 assert!(transaction.is_committed());
851
852 assert!(transaction.rollback().is_err());
854 }
855
856 #[test]
857 fn test_database_pool() {
858 let config = DatabaseConfig::default();
859 let pool = DatabasePool::new(config);
860
861 assert_eq!(pool.max_size(), 10);
862
863 let connection = pool.get_connection().unwrap();
864 assert!(connection.is_connected());
865 }
866
867 #[test]
868 fn test_result_set_unique_values() {
869 let mut result_set = ResultSet::new(vec!["category".to_string()]);
870
871 let mut row1 = Row::new();
872 row1.insert("category".to_string(), "A");
873 result_set.add_row(row1);
874
875 let mut row2 = Row::new();
876 row2.insert("category".to_string(), "B");
877 result_set.add_row(row2);
878
879 let mut row3 = Row::new();
880 row3.insert("category".to_string(), "A");
881 result_set.add_row(row3);
882
883 let unique_values = result_set.unique_values("category").unwrap();
884 assert_eq!(unique_values.len(), 2);
885 assert!(unique_values.contains(&Value::String("A".to_string())));
886 assert!(unique_values.contains(&Value::String("B".to_string())));
887 }
888}