Skip to main content

spawn_db/
template.rs

1use crate::config;
2use crate::engine::EngineType;
3use crate::escape::EscapedIdentifier;
4use crate::store::pinner::latest::Latest;
5use crate::store::pinner::spawn::Spawn;
6use crate::store::pinner::Pinner;
7use crate::store::Store;
8use crate::variables::Variables;
9use minijinja::{Environment, Value};
10
11use crate::sql_formatter::SqlDialect;
12use uuid::Uuid;
13
14use anyhow::{Context, Result};
15use minijinja::context;
16
17/// Maps an EngineType to the appropriate SQL dialect for formatting.
18///
19/// Multiple engine types may share the same dialect. For example,
20/// both a psql CLI engine and a native PostgreSQL driver would use
21/// the Postgres dialect.
22fn engine_to_dialect(engine: &EngineType) -> SqlDialect {
23    match engine {
24        EngineType::PostgresPSQL => SqlDialect::Postgres,
25        // Future engines:
26        // EngineType::PostgresNative => SqlDialect::Postgres,
27        // EngineType::MySQL => SqlDialect::MySQL,
28        // EngineType::SqlServer => SqlDialect::SqlServer,
29    }
30}
31
32pub fn template_env(store: Store, engine: &EngineType) -> Result<Environment<'static>> {
33    let mut env = Environment::new();
34
35    let mj_store = MiniJinjaLoader { store };
36    env.set_loader(move |name: &str| mj_store.load(name));
37    env.add_function("gen_uuid_v4", gen_uuid_v4);
38    env.add_function("gen_uuid_v5", gen_uuid_v5);
39    env.add_filter("escape_identifier", escape_identifier_filter);
40
41    // Get the appropriate dialect for this engine
42    let dialect = engine_to_dialect(engine);
43
44    // Enable SQL auto-escaping for .sql files using the dialect-specific callback
45    env.set_auto_escape_callback(crate::sql_formatter::get_auto_escape_callback(dialect));
46
47    // Set custom formatter that handles SQL escaping based on the dialect
48    env.set_formatter(crate::sql_formatter::get_formatter(dialect));
49
50    Ok(env)
51}
52
53struct MiniJinjaLoader {
54    pub store: Store,
55}
56
57impl MiniJinjaLoader {
58    pub fn load(&self, name: &str) -> std::result::Result<Option<String>, minijinja::Error> {
59        let result = tokio::task::block_in_place(|| {
60            tokio::runtime::Handle::current()
61                .block_on(async { self.store.load_component(name).await })
62        });
63
64        result.map_err(|e| {
65            minijinja::Error::new(
66                minijinja::ErrorKind::InvalidOperation,
67                format!("Failed to load from object store: {}", e),
68            )
69        })
70    }
71}
72
73fn gen_uuid_v4() -> Result<String, minijinja::Error> {
74    Ok(Uuid::new_v4().to_string())
75}
76
77fn gen_uuid_v5(seed: &str) -> Result<String, minijinja::Error> {
78    Ok(Uuid::new_v5(&Uuid::NAMESPACE_DNS, seed.as_bytes()).to_string())
79}
80
81/// Filter to escape a value as a PostgreSQL identifier (e.g., database name, table name).
82///
83/// This wraps the value in double quotes and escapes any embedded double quotes,
84/// making it safe to use in SQL statements where an identifier is expected.
85///
86/// Usage in templates: `{{ dbname|escape_identifier }}`
87fn escape_identifier_filter(value: &Value) -> Result<Value, minijinja::Error> {
88    let s = value.to_string();
89    let escaped = EscapedIdentifier::new(&s);
90    // Return as a safe string so it won't be further escaped by the SQL formatter
91    Ok(Value::from_safe_string(escaped.to_string()))
92}
93
94pub struct Generation {
95    pub content: String,
96}
97
98/// Holds all the data needed to render a template to a writer.
99/// This struct is Send and can be moved into a WriterFn closure.
100pub struct StreamingGeneration {
101    store: Store,
102    template_contents: String,
103    environment: String,
104    variables: Variables,
105    engine: EngineType,
106}
107
108impl StreamingGeneration {
109    /// Render the template to the provided writer.
110    /// This creates the minijinja environment and renders in one step.
111    pub fn render_to_writer<W: std::io::Write + ?Sized>(self, writer: &mut W) -> Result<()> {
112        let mut env = template_env(self.store, &self.engine)?;
113        env.add_template("migration.sql", &self.template_contents)?;
114        let tmpl = env.get_template("migration.sql")?;
115        tmpl.render_to_write(
116            context!(env => self.environment, variables => self.variables),
117            writer,
118        )?;
119        Ok(())
120    }
121
122    /// Convert this streaming generation into a WriterFn that can be passed to migration_apply.
123    pub fn into_writer_fn(self) -> crate::engine::WriterFn {
124        Box::new(move |writer: &mut dyn std::io::Write| {
125            self.render_to_writer(writer)
126                .map_err(|e| std::io::Error::new(std::io::ErrorKind::Other, e))
127        })
128    }
129}
130
131/// Generate a streaming migration that can be rendered directly to a writer.
132/// This avoids materializing the entire SQL in memory.
133pub async fn generate_streaming(
134    cfg: &config::Config,
135    lock_file: Option<String>,
136    name: &str,
137    variables: Option<Variables>,
138) -> Result<StreamingGeneration> {
139    let pinner: Box<dyn Pinner> = if let Some(lock_file) = lock_file {
140        let lock = cfg
141            .load_lock_file(&lock_file)
142            .await
143            .context("could not load pinned files lock file")?;
144        let pinner = Spawn::new_with_root_hash(
145            cfg.pather().pinned_folder(),
146            cfg.pather().components_folder(),
147            &lock.pin,
148            &cfg.operator(),
149        )
150        .await
151        .context("could not get new root with hash")?;
152        Box::new(pinner)
153    } else {
154        let pinner = Latest::new(cfg.pather().spawn_folder_path())?;
155        Box::new(pinner)
156    };
157
158    let store = Store::new(pinner, cfg.operator().clone(), cfg.pather())
159        .context("could not create new store for generate")?;
160    let db_config = cfg
161        .db_config()
162        .context("could not get db config for generate")?;
163
164    generate_streaming_with_store(
165        name,
166        variables,
167        &db_config.environment,
168        &db_config.engine,
169        store,
170    )
171    .await
172}
173
174/// Generate a streaming migration with an existing store.
175pub async fn generate_streaming_with_store(
176    name: &str,
177    variables: Option<Variables>,
178    environment: &str,
179    engine: &EngineType,
180    store: Store,
181) -> Result<StreamingGeneration> {
182    // Read contents from our object store first:
183    let contents = store
184        .load_migration(name)
185        .await
186        .context("generate_streaming_with_store could not read migration")?;
187
188    Ok(StreamingGeneration {
189        store,
190        template_contents: contents,
191        environment: environment.to_string(),
192        variables: variables.unwrap_or_default(),
193        engine: engine.clone(),
194    })
195}
196
197#[cfg(test)]
198mod tests {
199    use super::*;
200    use crate::sql_formatter::{get_auto_escape_callback, get_formatter};
201    use minijinja::{context, Environment, Value};
202
203    /// Helper to test SQL formatting of a value by rendering it in a .sql template
204    fn render_sql_value(value: Value) -> String {
205        let mut env = Environment::new();
206        env.set_auto_escape_callback(get_auto_escape_callback(SqlDialect::Postgres));
207        env.set_formatter(get_formatter(SqlDialect::Postgres));
208        env.add_template("test.sql", "{{ value }}").unwrap();
209        let tmpl = env.get_template("test.sql").unwrap();
210        tmpl.render(context!(value => value)).unwrap()
211    }
212
213    #[test]
214    fn test_engine_to_dialect_postgres_psql() {
215        let dialect = engine_to_dialect(&EngineType::PostgresPSQL);
216        assert_eq!(dialect, SqlDialect::Postgres);
217    }
218
219    // Basic escaping tests - verify the integration with spawn-sql-format works
220    // More comprehensive tests are in the spawn-sql-format crate itself
221
222    #[test]
223    fn test_sql_escape_string() {
224        let result = render_sql_value(Value::from("hello"));
225        assert_eq!(result, "'hello'");
226    }
227
228    #[test]
229    fn test_sql_escape_string_injection_attempt() {
230        let result = render_sql_value(Value::from("'; DROP TABLE users; --"));
231        assert_eq!(result, "'''; DROP TABLE users; --'");
232    }
233
234    #[test]
235    fn test_sql_escape_integer() {
236        let result = render_sql_value(Value::from(42));
237        assert_eq!(result, "42");
238    }
239
240    #[test]
241    fn test_sql_escape_bool() {
242        let result = render_sql_value(Value::from(true));
243        assert_eq!(result, "TRUE");
244    }
245
246    #[test]
247    fn test_sql_escape_none() {
248        let result = render_sql_value(Value::from(()));
249        assert_eq!(result, "NULL");
250    }
251
252    #[test]
253    fn test_sql_escape_seq() {
254        let result = render_sql_value(Value::from(vec![1, 2, 3]));
255        assert_eq!(result, "ARRAY[1, 2, 3]");
256    }
257
258    #[test]
259    fn test_sql_escape_bytes() {
260        let bytes = Value::from_bytes(vec![0xDE, 0xAD, 0xBE, 0xEF]);
261        let result = render_sql_value(bytes);
262        assert_eq!(result, "'\\xdeadbeef'::bytea");
263    }
264
265    #[test]
266    fn test_sql_escape_for_non_sql_templates() {
267        let mut env = Environment::new();
268        env.set_auto_escape_callback(get_auto_escape_callback(SqlDialect::Postgres));
269        env.set_formatter(get_formatter(SqlDialect::Postgres));
270        // Use .txt extension - should still trigger SQL escaping
271        env.add_template("test.txt", "{{ value }}").unwrap();
272        let tmpl = env.get_template("test.txt").unwrap();
273        let result = tmpl.render(context!(value => "hello")).unwrap();
274        // SQL escaping applies to all files
275        assert_eq!(result, "'hello'");
276    }
277
278    #[test]
279    fn test_sql_safe_filter_bypasses_escaping() {
280        let mut env = Environment::new();
281        env.set_auto_escape_callback(get_auto_escape_callback(SqlDialect::Postgres));
282        env.set_formatter(get_formatter(SqlDialect::Postgres));
283        // Using |safe filter should bypass escaping
284        env.add_template("test.sql", "{{ value|safe }}").unwrap();
285        let tmpl = env.get_template("test.sql").unwrap();
286        let result = tmpl.render(context!(value => "raw SQL here")).unwrap();
287        // Should be output as-is without quotes
288        assert_eq!(result, "raw SQL here");
289    }
290
291    #[test]
292    fn test_sql_escape_only_on_output_not_in_loops() {
293        let mut env = Environment::new();
294        env.set_auto_escape_callback(get_auto_escape_callback(SqlDialect::Postgres));
295        env.set_formatter(get_formatter(SqlDialect::Postgres));
296
297        let template =
298            r#"{% for item in items %}{{ item }}{% if not loop.last %}, {% endif %}{% endfor %}"#;
299        env.add_template("test.sql", template).unwrap();
300        let tmpl = env.get_template("test.sql").unwrap();
301
302        let items = vec!["alice", "bob", "charlie"];
303        let result = tmpl.render(context!(items => items)).unwrap();
304        assert_eq!(result, "'alice', 'bob', 'charlie'");
305    }
306}