1use std::{fs, path::PathBuf};
2
3use async_trait::async_trait;
4
5use crate::{
6 backend::Connection,
7 error::{PoolError, Result},
8};
9
10#[derive(Debug, Clone)]
11pub enum SqlSource {
12 Directory(PathBuf),
13 File(PathBuf),
14 Embedded(&'static [&'static str]),
15}
16
17impl SqlSource {
18 fn read_sql_files(&self) -> Result<Vec<String>> {
19 match self {
20 SqlSource::Directory(path) => {
21 let mut scripts = Vec::new();
22 for entry in fs::read_dir(path).map_err(|e| {
23 PoolError::MigrationError(format!("Failed to read directory: {}", e))
24 })? {
25 let entry = entry.map_err(|e| {
26 PoolError::MigrationError(format!("Failed to read directory entry: {}", e))
27 })?;
28 let path = entry.path();
29 if path.is_file() {
30 let sql = fs::read_to_string(&path).map_err(|e| {
31 PoolError::MigrationError(format!("Failed to read SQL file: {}", e))
32 })?;
33 scripts.push(sql);
34 }
35 }
36 Ok(scripts)
37 }
38 SqlSource::File(path) => {
39 let sql = fs::read_to_string(path).map_err(|e| {
40 PoolError::MigrationError(format!("Failed to read SQL file: {}", e))
41 })?;
42 Ok(vec![sql])
43 }
44 SqlSource::Embedded(scripts) => Ok(scripts.iter().map(|s| s.to_string()).collect()),
45 }
46 }
47}
48
49#[async_trait]
50pub trait RunSql {
51 async fn run_sql_scripts(&mut self, source: &SqlSource) -> Result<()>;
52}
53
54#[async_trait]
55impl<T> RunSql for T
56where
57 T: Connection + Send,
58{
59 async fn run_sql_scripts(&mut self, source: &SqlSource) -> Result<()> {
60 let scripts = source.read_sql_files()?;
61 for script in scripts {
62 tracing::info!("Running SQL script");
63 self.execute(&script).await?;
64 }
65 Ok(())
66 }
67}
68
69#[cfg(test)]
70#[cfg(feature = "postgres")]
71mod tests {
72 use super::*;
73 use crate::{
74 backend::DatabasePool, backends::PostgresBackend, env::get_postgres_url, pool::PoolConfig,
75 template::DatabaseTemplate,
76 };
77
78 #[tokio::test]
79 async fn test_sql_scripts() {
80 let backend = PostgresBackend::new(&get_postgres_url().unwrap())
81 .await
82 .unwrap();
83 let template = DatabaseTemplate::new(backend, PoolConfig::default(), 5)
84 .await
85 .unwrap();
86
87 let temp_dir = tempfile::tempdir().unwrap();
89 let setup_path = temp_dir.path().join("setup.sql");
90 fs::write(
91 &setup_path,
92 r#"
93 CREATE TABLE users (
94 id SERIAL PRIMARY KEY,
95 name TEXT NOT NULL,
96 email TEXT
97 );
98 "#,
99 )
100 .unwrap();
101
102 template
104 .initialize_template(|mut conn| async move {
105 conn.run_sql_scripts(&SqlSource::File(setup_path)).await?;
106 Ok(())
107 })
108 .await
109 .unwrap();
110
111 let db = template.get_immutable_database().await.unwrap();
113 let mut conn = db.get_pool().acquire().await.unwrap();
114
115 conn.execute(
117 r#"
118 INSERT INTO users (name, email)
119 VALUES ('test', 'test@example.com');
120 "#,
121 )
122 .await
123 .unwrap();
124 }
125}