db_testkit/
template.rs

1use std::{fmt::Display, sync::Arc};
2
3use parking_lot::Mutex;
4use tokio::sync::Semaphore;
5use uuid::Uuid;
6
7use crate::{
8    backend::{DatabaseBackend, DatabasePool},
9    error::{PoolError, Result},
10    pool::PoolConfig,
11};
12
13/// A unique name for a database
14#[derive(Debug, Clone)]
15pub struct DatabaseName(String);
16
17impl DatabaseName {
18    /// Create a new database name with a prefix
19    pub fn new(prefix: &str) -> Self {
20        Self(format!("{}_{}", prefix, Uuid::new_v4()))
21    }
22
23    /// Get the database name as a string
24    pub fn as_str(&self) -> &str {
25        &self.0
26    }
27}
28
29impl Display for DatabaseName {
30    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
31        write!(f, "{}", self.0)
32    }
33}
34
35/// A template database that can be used to create immutable copies
36pub struct DatabaseTemplate<B: DatabaseBackend + Clone + Send + 'static>
37where
38    B::Pool: DatabasePool<Connection = B::Connection>,
39{
40    backend: B,
41    config: PoolConfig,
42    name: DatabaseName,
43    replicas: Arc<Mutex<Vec<DatabaseName>>>,
44    semaphore: Arc<Semaphore>,
45}
46
47impl<B: DatabaseBackend + Clone + Send + 'static> DatabaseTemplate<B> {
48    /// Create a new template database
49    pub async fn new(backend: B, config: PoolConfig, max_replicas: usize) -> Result<Self> {
50        let name = DatabaseName::new("template");
51        backend.create_database(&name).await?;
52
53        Ok(Self {
54            backend,
55            config,
56            name,
57            replicas: Arc::new(Mutex::new(Vec::new())),
58            semaphore: Arc::new(Semaphore::new(max_replicas)),
59        })
60    }
61
62    /// Initialize the template database with a setup function
63    pub async fn initialize_template<F, Fut>(&self, setup: F) -> Result<()>
64    where
65        F: FnOnce(B::Connection) -> Fut + Send + 'static,
66        Fut: std::future::Future<Output = Result<()>> + Send + 'static,
67    {
68        let pool = self.backend.create_pool(&self.name, &self.config).await?;
69        let conn = pool.acquire().await?;
70        setup(conn).await?;
71        Ok(())
72    }
73
74    /// Get an immutable copy of the template database
75    pub async fn get_immutable_database(&self) -> Result<ImmutableDatabase<'_, B>> {
76        let _permit = self
77            .semaphore
78            .acquire()
79            .await
80            .map_err(|e| PoolError::PoolCreationFailed(e.to_string()))?;
81
82        let name = DatabaseName::new("test");
83        self.backend
84            .create_database_from_template(&name, &self.name)
85            .await?;
86
87        let pool = self.backend.create_pool(&name, &self.config).await?;
88        self.replicas.lock().push(name.clone());
89
90        Ok(ImmutableDatabase {
91            name,
92            pool,
93            backend: self.backend.clone(),
94            _permit,
95        })
96    }
97}
98
99impl<B: DatabaseBackend + Clone + Send + 'static> Drop for DatabaseTemplate<B> {
100    fn drop(&mut self) {
101        let replicas = self.replicas.lock().clone();
102        let backend = self.backend.clone();
103        let name = self.name.clone();
104
105        tokio::spawn(async move {
106            for replica in replicas {
107                if let Err(e) = backend.drop_database(&replica).await {
108                    tracing::error!("Failed to drop replica database: {}", e);
109                }
110            }
111            if let Err(e) = backend.drop_database(&name).await {
112                tracing::error!("Failed to drop template database: {}", e);
113            }
114        });
115    }
116}
117
118impl<B: DatabaseBackend + Clone + Send + 'static> Clone for DatabaseTemplate<B>
119where
120    B::Pool: DatabasePool<Connection = B::Connection>,
121{
122    fn clone(&self) -> Self {
123        Self {
124            backend: self.backend.clone(),
125            config: self.config.clone(),
126            name: self.name.clone(),
127            replicas: self.replicas.clone(),
128            semaphore: self.semaphore.clone(),
129        }
130    }
131}
132
133/// An immutable copy of a template database
134pub struct ImmutableDatabase<'a, B: DatabaseBackend + Clone + Send + 'static> {
135    name: DatabaseName,
136    pool: B::Pool,
137    backend: B,
138    _permit: tokio::sync::SemaphorePermit<'a>,
139}
140
141impl<'a, B: DatabaseBackend + Clone + Send + 'static> ImmutableDatabase<'a, B> {
142    /// Get the pool for this database
143    pub fn get_pool(&self) -> &B::Pool {
144        &self.pool
145    }
146
147    /// Get the name of this database
148    pub fn get_name(&self) -> &DatabaseName {
149        &self.name
150    }
151}
152
153impl<'a, B: DatabaseBackend + Clone + Send + 'static> Drop for ImmutableDatabase<'a, B> {
154    fn drop(&mut self) {
155        let backend = self.backend.clone();
156        let name = self.name.clone();
157
158        tokio::spawn(async move {
159            if let Err(e) = backend.drop_database(&name).await {
160                tracing::error!("Failed to drop database: {}", e);
161            }
162        });
163    }
164}