1use crate::config;
2use crate::engine::EngineType;
3use crate::escape::{EscapedIdentifier, EscapedLiteral};
4use crate::store::pinner::latest::Latest;
5use crate::store::pinner::spawn::Spawn;
6use crate::store::pinner::Pinner;
7use crate::store::Store;
8use crate::variables::Variables;
9use minijinja::{Environment, Value};
10
11use crate::sql_formatter::SqlDialect;
12use base64::{engine::general_purpose::STANDARD, Engine as _};
13use uuid::Uuid;
14
15use anyhow::{Context, Result};
16use minijinja::context;
17use std::sync::Arc;
18
19fn engine_to_dialect(engine: &EngineType) -> SqlDialect {
25 match engine {
26 EngineType::PostgresPSQL => SqlDialect::Postgres,
27 }
32}
33
34pub fn template_env(store: Store, engine: &EngineType) -> Result<Environment<'static>> {
35 let mut env = Environment::new();
36
37 let store = Arc::new(store);
38
39 let mj_store = MiniJinjaLoader {
40 store: Arc::clone(&store),
41 };
42 env.set_loader(move |name: &str| mj_store.load(name));
43 env.add_function("gen_uuid_v4", gen_uuid_v4);
44 env.add_function("gen_uuid_v5", gen_uuid_v5);
45 env.add_function("gen_uuid_v7", gen_uuid_v7);
46 env.add_filter("escape_identifier", escape_identifier_filter);
47 env.add_filter("escape_literal", escape_literal_filter);
48
49 let read_file_store = Arc::clone(&store);
50 env.add_filter(
51 "read_file",
52 move |path: &str| -> Result<Value, minijinja::Error> {
53 read_file_filter(path, &read_file_store)
54 },
55 );
56 env.add_filter("base64_encode", base64_encode_filter);
57 env.add_filter("to_string_lossy", to_string_lossy_filter);
58 env.add_filter("parse_json", parse_json_filter);
59 env.add_filter("parse_toml", parse_toml_filter);
60 env.add_filter("parse_yaml", parse_yaml_filter);
61
62 let read_json_store = Arc::clone(&store);
63 env.add_filter(
64 "read_json",
65 move |path: &str| -> Result<Value, minijinja::Error> {
66 let bytes = read_file_bytes(path, &read_json_store)?;
67 let s = string_from_bytes(&bytes)?;
68 parse_json_filter(&s)
69 },
70 );
71 let read_toml_store = Arc::clone(&store);
72 env.add_filter(
73 "read_toml",
74 move |path: &str| -> Result<Value, minijinja::Error> {
75 let bytes = read_file_bytes(path, &read_toml_store)?;
76 let s = string_from_bytes(&bytes)?;
77 parse_toml_filter(&s)
78 },
79 );
80 let read_yaml_store = Arc::clone(&store);
81 env.add_filter(
82 "read_yaml",
83 move |path: &str| -> Result<Value, minijinja::Error> {
84 let bytes = read_file_bytes(path, &read_yaml_store)?;
85 let s = string_from_bytes(&bytes)?;
86 parse_yaml_filter(&s)
87 },
88 );
89
90 let dialect = engine_to_dialect(engine);
92
93 env.set_auto_escape_callback(crate::sql_formatter::get_auto_escape_callback(dialect));
95
96 env.set_formatter(crate::sql_formatter::get_formatter(dialect));
98
99 Ok(env)
100}
101
102struct MiniJinjaLoader {
103 pub store: Arc<Store>,
104}
105
106impl MiniJinjaLoader {
107 pub fn load(&self, name: &str) -> std::result::Result<Option<String>, minijinja::Error> {
108 let result = tokio::task::block_in_place(|| {
109 tokio::runtime::Handle::current()
110 .block_on(async { self.store.load_component(name).await })
111 });
112
113 result.map_err(|e| {
114 minijinja::Error::new(
115 minijinja::ErrorKind::InvalidOperation,
116 format!("Failed to load from object store: {}", e),
117 )
118 })
119 }
120}
121
122fn gen_uuid_v4() -> Result<String, minijinja::Error> {
123 Ok(Uuid::new_v4().to_string())
124}
125
126fn gen_uuid_v5(seed: &str) -> Result<String, minijinja::Error> {
127 Ok(Uuid::new_v5(&Uuid::NAMESPACE_DNS, seed.as_bytes()).to_string())
128}
129
130fn gen_uuid_v7() -> Result<String, minijinja::Error> {
131 Ok(Uuid::now_v7().to_string())
132}
133
134fn escape_identifier_filter(value: &Value) -> Result<Value, minijinja::Error> {
141 let s = value.to_string();
142 let escaped = EscapedIdentifier::new(&s);
143 Ok(Value::from_safe_string(escaped.to_string()))
145}
146
147fn escape_literal_filter(value: &Value) -> Result<Value, minijinja::Error> {
155 let s = value.to_string();
156 let escaped = EscapedLiteral::new(&s);
157 Ok(Value::from_safe_string(escaped.to_string()))
159}
160
161fn read_file_bytes(path: &str, store: &Arc<Store>) -> Result<Vec<u8>, minijinja::Error> {
163 let bytes = tokio::task::block_in_place(|| {
164 tokio::runtime::Handle::current().block_on(async { store.read_file_bytes(path).await })
165 });
166
167 bytes.map_err(|e| {
168 minijinja::Error::new(
169 minijinja::ErrorKind::InvalidOperation,
170 format!("Failed to read file '{}': {}", path, e),
171 )
172 })
173}
174
175fn string_from_bytes(bytes: &[u8]) -> Result<String, minijinja::Error> {
177 String::from_utf8(bytes.to_vec()).map_err(|e| {
178 minijinja::Error::new(
179 minijinja::ErrorKind::InvalidOperation,
180 format!("File is not valid UTF-8: {}", e),
181 )
182 })
183}
184
185fn read_file_filter(path: &str, store: &Arc<Store>) -> Result<Value, minijinja::Error> {
191 Ok(Value::from_bytes(read_file_bytes(path, store)?))
192}
193
194fn base64_encode_filter(value: &Value) -> Result<Value, minijinja::Error> {
200 use minijinja::value::ValueKind;
201 match value.kind() {
202 ValueKind::Bytes => {
203 let bytes = value.as_bytes().unwrap();
204 Ok(Value::from(STANDARD.encode(bytes)))
205 }
206 ValueKind::String => Ok(Value::from(STANDARD.encode(value.as_str().unwrap()))),
207 _ => Err(minijinja::Error::new(
208 minijinja::ErrorKind::InvalidOperation,
209 "base64_encode filter expects bytes or string input",
210 )),
211 }
212}
213
214fn to_string_lossy_filter(value: &Value) -> Result<Value, minijinja::Error> {
219 use minijinja::value::ValueKind;
220 match value.kind() {
221 ValueKind::Bytes => {
222 let bytes = value.as_bytes().unwrap();
223 Ok(Value::from(String::from_utf8_lossy(bytes).into_owned()))
224 }
225 ValueKind::String => Ok(value.clone()),
226 _ => Err(minijinja::Error::new(
227 minijinja::ErrorKind::InvalidOperation,
228 "to_string_lossy filter expects bytes or string input",
229 )),
230 }
231}
232
233fn parse_json_filter(value: &str) -> Result<Value, minijinja::Error> {
237 let vars = Variables::from_str("json", value).map_err(|e| {
238 minijinja::Error::new(
239 minijinja::ErrorKind::InvalidOperation,
240 format!("parse_json: {}", e),
241 )
242 })?;
243 Ok(Value::from_serialize(&vars))
244}
245
246fn parse_toml_filter(value: &str) -> Result<Value, minijinja::Error> {
250 let vars = Variables::from_str("toml", value).map_err(|e| {
251 minijinja::Error::new(
252 minijinja::ErrorKind::InvalidOperation,
253 format!("parse_toml: {}", e),
254 )
255 })?;
256 Ok(Value::from_serialize(&vars))
257}
258
259fn parse_yaml_filter(value: &str) -> Result<Value, minijinja::Error> {
263 let vars = Variables::from_str("yaml", value).map_err(|e| {
264 minijinja::Error::new(
265 minijinja::ErrorKind::InvalidOperation,
266 format!("parse_yaml: {}", e),
267 )
268 })?;
269 Ok(Value::from_serialize(&vars))
270}
271
272pub struct Generation {
273 pub content: String,
274}
275
276pub struct StreamingGeneration {
279 store: Store,
280 template_contents: String,
281 environment: String,
282 variables: Variables,
283 engine: EngineType,
284}
285
286impl StreamingGeneration {
287 pub fn render_to_writer<W: std::io::Write + ?Sized>(self, writer: &mut W) -> Result<()> {
290 let mut env = template_env(self.store, &self.engine)?;
291 env.add_template("migration.sql", &self.template_contents)?;
292 let tmpl = env.get_template("migration.sql")?;
293 tmpl.render_to_write(
294 context!(env => self.environment, variables => self.variables),
295 writer,
296 )?;
297 Ok(())
298 }
299
300 pub fn into_writer_fn(self) -> crate::engine::WriterFn {
302 Box::new(move |writer: &mut dyn std::io::Write| {
303 self.render_to_writer(writer)
304 .map_err(|e| std::io::Error::new(std::io::ErrorKind::Other, e))
305 })
306 }
307}
308
309pub async fn generate_streaming(
312 cfg: &config::Config,
313 lock_file: Option<String>,
314 name: &str,
315 variables: Option<Variables>,
316) -> Result<StreamingGeneration> {
317 let pinner: Box<dyn Pinner> = if let Some(lock_file) = lock_file {
318 let lock = cfg
319 .load_lock_file(&lock_file)
320 .await
321 .context("could not load pinned files lock file")?;
322 let pinner = Spawn::new_with_root_hash(
323 cfg.pather().pinned_folder(),
324 cfg.pather().components_folder(),
325 &lock.pin,
326 &cfg.operator(),
327 )
328 .await
329 .context("could not get new root with hash")?;
330 Box::new(pinner)
331 } else {
332 let pinner = Latest::new(cfg.pather().spawn_folder_path())?;
333 Box::new(pinner)
334 };
335
336 let store = Store::new(pinner, cfg.operator().clone(), cfg.pather())
337 .context("could not create new store for generate")?;
338 let target_config = cfg
339 .target_config()
340 .context("could not get target config for generate")?;
341
342 generate_streaming_with_store(
343 name,
344 variables,
345 &target_config.environment,
346 &target_config.engine,
347 store,
348 )
349 .await
350}
351
352pub async fn generate_streaming_with_store(
354 name: &str,
355 variables: Option<Variables>,
356 environment: &str,
357 engine: &EngineType,
358 store: Store,
359) -> Result<StreamingGeneration> {
360 let contents = store
362 .load_migration(name)
363 .await
364 .context("generate_streaming_with_store could not read migration")?;
365
366 Ok(StreamingGeneration {
367 store,
368 template_contents: contents,
369 environment: environment.to_string(),
370 variables: variables.unwrap_or_default(),
371 engine: engine.clone(),
372 })
373}
374
375#[cfg(test)]
376mod tests {
377 use super::*;
378 use crate::sql_formatter::{get_auto_escape_callback, get_formatter};
379 use minijinja::{context, Environment, Value};
380
381 fn render_sql_value(value: Value) -> String {
383 let mut env = Environment::new();
384 env.set_auto_escape_callback(get_auto_escape_callback(SqlDialect::Postgres));
385 env.set_formatter(get_formatter(SqlDialect::Postgres));
386 env.add_template("test.sql", "{{ value }}").unwrap();
387 let tmpl = env.get_template("test.sql").unwrap();
388 tmpl.render(context!(value => value)).unwrap()
389 }
390
391 #[test]
392 fn test_engine_to_dialect_postgres_psql() {
393 let dialect = engine_to_dialect(&EngineType::PostgresPSQL);
394 assert_eq!(dialect, SqlDialect::Postgres);
395 }
396
397 #[test]
401 fn test_sql_escape_string() {
402 let result = render_sql_value(Value::from("hello"));
403 assert_eq!(result, "'hello'");
404 }
405
406 #[test]
407 fn test_sql_escape_string_injection_attempt() {
408 let result = render_sql_value(Value::from("'; DROP TABLE users; --"));
409 assert_eq!(result, "'''; DROP TABLE users; --'");
410 }
411
412 #[test]
413 fn test_sql_escape_integer() {
414 let result = render_sql_value(Value::from(42));
415 assert_eq!(result, "42");
416 }
417
418 #[test]
419 fn test_sql_escape_bool() {
420 let result = render_sql_value(Value::from(true));
421 assert_eq!(result, "TRUE");
422 }
423
424 #[test]
425 fn test_sql_escape_none() {
426 let result = render_sql_value(Value::from(()));
427 assert_eq!(result, "NULL");
428 }
429
430 #[test]
431 fn test_sql_escape_seq() {
432 let result = render_sql_value(Value::from(vec![1, 2, 3]));
433 assert_eq!(result, "ARRAY[1, 2, 3]");
434 }
435
436 #[test]
437 fn test_sql_escape_bytes() {
438 let bytes = Value::from_bytes(vec![0xDE, 0xAD, 0xBE, 0xEF]);
439 let result = render_sql_value(bytes);
440 assert_eq!(result, "'\\xdeadbeef'::bytea");
441 }
442
443 #[test]
444 fn test_sql_escape_for_non_sql_templates() {
445 let mut env = Environment::new();
446 env.set_auto_escape_callback(get_auto_escape_callback(SqlDialect::Postgres));
447 env.set_formatter(get_formatter(SqlDialect::Postgres));
448 env.add_template("test.txt", "{{ value }}").unwrap();
450 let tmpl = env.get_template("test.txt").unwrap();
451 let result = tmpl.render(context!(value => "hello")).unwrap();
452 assert_eq!(result, "'hello'");
454 }
455
456 #[test]
457 fn test_sql_safe_filter_bypasses_escaping() {
458 let mut env = Environment::new();
459 env.set_auto_escape_callback(get_auto_escape_callback(SqlDialect::Postgres));
460 env.set_formatter(get_formatter(SqlDialect::Postgres));
461 env.add_template("test.sql", "{{ value|safe }}").unwrap();
463 let tmpl = env.get_template("test.sql").unwrap();
464 let result = tmpl.render(context!(value => "raw SQL here")).unwrap();
465 assert_eq!(result, "raw SQL here");
467 }
468
469 #[test]
470 fn test_sql_escape_only_on_output_not_in_loops() {
471 let mut env = Environment::new();
472 env.set_auto_escape_callback(get_auto_escape_callback(SqlDialect::Postgres));
473 env.set_formatter(get_formatter(SqlDialect::Postgres));
474
475 let template =
476 r#"{% for item in items %}{{ item }}{% if not loop.last %}, {% endif %}{% endfor %}"#;
477 env.add_template("test.sql", template).unwrap();
478 let tmpl = env.get_template("test.sql").unwrap();
479
480 let items = vec!["alice", "bob", "charlie"];
481 let result = tmpl.render(context!(items => items)).unwrap();
482 assert_eq!(result, "'alice', 'bob', 'charlie'");
483 }
484
485 #[test]
486 fn test_base64_encode_filter() {
487 let bytes = Value::from_bytes(vec![0xDE, 0xAD, 0xBE, 0xEF]);
488 let result = base64_encode_filter(&bytes).unwrap();
489 assert_eq!(result.to_string(), "3q2+7w==");
490 }
491
492 #[test]
493 fn test_base64_encode_filter_text() {
494 let bytes = Value::from_bytes(b"hello world".to_vec());
495 let result = base64_encode_filter(&bytes).unwrap();
496 assert_eq!(result.to_string(), "aGVsbG8gd29ybGQ=");
497 }
498
499 #[test]
500 fn test_base64_encode_filter_string() {
501 let value = Value::from("hello world");
502 let result = base64_encode_filter(&value).unwrap();
503 assert_eq!(result.to_string(), "aGVsbG8gd29ybGQ=");
504 }
505
506 #[test]
507 fn test_base64_encode_filter_rejects_other_types() {
508 let value = Value::from(42);
509 let result = base64_encode_filter(&value);
510 assert!(result.is_err());
511 }
512
513 #[test]
514 fn test_to_string_lossy_filter_valid_utf8() {
515 let bytes = Value::from_bytes(b"hello world".to_vec());
516 let result = to_string_lossy_filter(&bytes).unwrap();
517 assert_eq!(result.to_string(), "hello world");
518 }
519
520 #[test]
521 fn test_to_string_lossy_filter_invalid_utf8() {
522 let bytes = Value::from_bytes(vec![0x68, 0x65, 0x6C, 0xFF, 0x6F]);
523 let result = to_string_lossy_filter(&bytes).unwrap();
524 let s = result.to_string();
525 assert!(s.contains("hel"));
526 assert!(s.contains('\u{FFFD}'));
527 assert!(s.contains('o'));
528 }
529
530 #[test]
531 fn test_to_string_lossy_filter_passes_through_string() {
532 let value = Value::from("already a string");
533 let result = to_string_lossy_filter(&value).unwrap();
534 assert_eq!(result.to_string(), "already a string");
535 }
536
537 #[test]
538 fn test_to_string_lossy_filter_rejects_other_types() {
539 let value = Value::from(42);
540 let result = to_string_lossy_filter(&value);
541 assert!(result.is_err());
542 }
543
544 #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
545 async fn test_read_file_filter_with_store() {
546 use crate::config::FolderPather;
547 use crate::store::pinner::latest::Latest;
548 use opendal::services::Memory;
549 use opendal::Operator;
550
551 let mem_service = Memory::default();
553 let op = Operator::new(mem_service).unwrap().finish();
554 op.write("components/test.txt", "file contents here")
555 .await
556 .unwrap();
557
558 let pinner = Latest::new("").unwrap();
559 let pather = FolderPather {
560 spawn_folder: "".to_string(),
561 };
562 let store = Store::new(Box::new(pinner), op, pather).unwrap();
563
564 let mut env = template_env(store, &EngineType::PostgresPSQL).unwrap();
565 env.add_template(
566 "test.sql",
567 r#"{{ "test.txt"|read_file|to_string_lossy|safe }}"#,
568 )
569 .unwrap();
570 let tmpl = env.get_template("test.sql").unwrap();
571 let result = tmpl.render(context!()).unwrap();
572 assert_eq!(result, "file contents here");
573 }
574
575 #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
576 async fn test_read_file_with_base64_encode() {
577 use crate::config::FolderPather;
578 use crate::store::pinner::latest::Latest;
579 use opendal::services::Memory;
580 use opendal::Operator;
581
582 let mem_service = Memory::default();
583 let op = Operator::new(mem_service).unwrap().finish();
584 op.write("components/binary.dat", vec![0xDE, 0xAD, 0xBE, 0xEF])
585 .await
586 .unwrap();
587
588 let pinner = Latest::new("").unwrap();
589 let pather = FolderPather {
590 spawn_folder: "".to_string(),
591 };
592 let store = Store::new(Box::new(pinner), op, pather).unwrap();
593
594 let mut env = template_env(store, &EngineType::PostgresPSQL).unwrap();
595 env.add_template(
596 "test.sql",
597 r#"{{ "binary.dat"|read_file|base64_encode|safe }}"#,
598 )
599 .unwrap();
600 let tmpl = env.get_template("test.sql").unwrap();
601 let result = tmpl.render(context!()).unwrap();
602 assert_eq!(result, "3q2+7w==");
603 }
604
605 #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
606 async fn test_read_file_missing_file_returns_error() {
607 use crate::config::FolderPather;
608 use crate::store::pinner::latest::Latest;
609 use opendal::services::Memory;
610 use opendal::Operator;
611
612 let mem_service = Memory::default();
613 let op = Operator::new(mem_service).unwrap().finish();
614
615 let pinner = Latest::new("").unwrap();
616 let pather = FolderPather {
617 spawn_folder: "".to_string(),
618 };
619 let store = Store::new(Box::new(pinner), op, pather).unwrap();
620
621 let mut env = template_env(store, &EngineType::PostgresPSQL).unwrap();
622 env.add_template(
623 "test.sql",
624 r#"{{ "nonexistent.txt"|read_file|to_string_lossy }}"#,
625 )
626 .unwrap();
627 let tmpl = env.get_template("test.sql").unwrap();
628 let result = tmpl.render(context!());
629 assert!(result.is_err());
630 }
631
632 #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
633 async fn test_read_file_filter_uses_pinned_store() {
634 use crate::config::FolderPather;
635 use crate::store::pinner::snapshot;
636 use crate::store::pinner::spawn::Spawn;
637 use opendal::services::Memory;
638 use opendal::Operator;
639
640 let mem_service = Memory::default();
641 let op = Operator::new(mem_service).unwrap().finish();
642
643 op.write("components/test.txt", "pinned content")
645 .await
646 .unwrap();
647 let root_hash = snapshot(&op, "pinned/", "components/").await.unwrap();
648
649 op.delete("components/test.txt").await.unwrap();
651
652 let pinner = Spawn::new_with_root_hash(
654 "pinned/".to_string(),
655 "components/".to_string(),
656 &root_hash,
657 &op,
658 )
659 .await
660 .unwrap();
661
662 let pather = FolderPather {
663 spawn_folder: "".to_string(),
664 };
665 let store = Store::new(Box::new(pinner), op, pather).unwrap();
666
667 let mut env = template_env(store, &EngineType::PostgresPSQL).unwrap();
668 env.add_template(
669 "test.sql",
670 r#"{{ "test.txt"|read_file|to_string_lossy|safe }}"#,
671 )
672 .unwrap();
673 let tmpl = env.get_template("test.sql").unwrap();
674 let result = tmpl.render(context!()).unwrap();
675 assert_eq!(result, "pinned content");
676 }
677}