db_testkit/
migrations.rs

1use std::{fs, path::PathBuf};
2
3use async_trait::async_trait;
4
5use crate::{
6    backend::Connection,
7    error::{PoolError, Result},
8};
9
10#[derive(Debug, Clone)]
11pub enum SqlSource {
12    Directory(PathBuf),
13    File(PathBuf),
14    Embedded(&'static [&'static str]),
15}
16
17impl SqlSource {
18    fn read_sql_files(&self) -> Result<Vec<String>> {
19        match self {
20            SqlSource::Directory(path) => {
21                let mut scripts = Vec::new();
22                for entry in fs::read_dir(path).map_err(|e| {
23                    PoolError::MigrationError(format!("Failed to read directory: {}", e))
24                })? {
25                    let entry = entry.map_err(|e| {
26                        PoolError::MigrationError(format!("Failed to read directory entry: {}", e))
27                    })?;
28                    let path = entry.path();
29                    if path.is_file() {
30                        let sql = fs::read_to_string(&path).map_err(|e| {
31                            PoolError::MigrationError(format!("Failed to read SQL file: {}", e))
32                        })?;
33                        scripts.push(sql);
34                    }
35                }
36                Ok(scripts)
37            }
38            SqlSource::File(path) => {
39                let sql = fs::read_to_string(path).map_err(|e| {
40                    PoolError::MigrationError(format!("Failed to read SQL file: {}", e))
41                })?;
42                Ok(vec![sql])
43            }
44            SqlSource::Embedded(scripts) => Ok(scripts.iter().map(|s| s.to_string()).collect()),
45        }
46    }
47}
48
49#[async_trait]
50pub trait RunSql {
51    async fn run_sql_scripts(&mut self, source: &SqlSource) -> Result<()>;
52}
53
54#[async_trait]
55impl<T> RunSql for T
56where
57    T: Connection + Send,
58{
59    async fn run_sql_scripts(&mut self, source: &SqlSource) -> Result<()> {
60        let scripts = source.read_sql_files()?;
61        for script in scripts {
62            tracing::info!("Running SQL script");
63            self.execute(&script).await?;
64        }
65        Ok(())
66    }
67}
68
69#[cfg(test)]
70#[cfg(feature = "postgres")]
71mod tests {
72    use super::*;
73    use crate::{
74        backend::DatabasePool, backends::PostgresBackend, env::get_postgres_url, pool::PoolConfig,
75        template::DatabaseTemplate,
76    };
77
78    #[tokio::test]
79    async fn test_sql_scripts() {
80        let backend = PostgresBackend::new(&get_postgres_url().unwrap())
81            .await
82            .unwrap();
83        let template = DatabaseTemplate::new(backend, PoolConfig::default(), 5)
84            .await
85            .unwrap();
86
87        // Create a temporary directory with SQL scripts
88        let temp_dir = tempfile::tempdir().unwrap();
89        let setup_path = temp_dir.path().join("setup.sql");
90        fs::write(
91            &setup_path,
92            r#"
93            CREATE TABLE users (
94                id SERIAL PRIMARY KEY,
95                name TEXT NOT NULL,
96                email TEXT
97            );
98            "#,
99        )
100        .unwrap();
101
102        // Initialize template with SQL scripts
103        template
104            .initialize_template(|mut conn| async move {
105                conn.run_sql_scripts(&SqlSource::File(setup_path)).await?;
106                Ok(())
107            })
108            .await
109            .unwrap();
110
111        // Get a database and verify table was created
112        let db = template.get_immutable_database().await.unwrap();
113        let mut conn = db.get_pool().acquire().await.unwrap();
114
115        // Verify table exists and has expected columns
116        conn.execute(
117            r#"
118            INSERT INTO users (name, email)
119            VALUES ('test', 'test@example.com');
120            "#,
121        )
122        .await
123        .unwrap();
124    }
125}