1use crate::config;
2use crate::engine::EngineType;
3use crate::escape::EscapedIdentifier;
4use crate::store::pinner::latest::Latest;
5use crate::store::pinner::spawn::Spawn;
6use crate::store::pinner::Pinner;
7use crate::store::Store;
8use crate::variables::Variables;
9use minijinja::{Environment, Value};
10
11use crate::sql_formatter::SqlDialect;
12use uuid::Uuid;
13
14use anyhow::{Context, Result};
15use minijinja::context;
16
17fn engine_to_dialect(engine: &EngineType) -> SqlDialect {
23 match engine {
24 EngineType::PostgresPSQL => SqlDialect::Postgres,
25 }
30}
31
32pub fn template_env(store: Store, engine: &EngineType) -> Result<Environment<'static>> {
33 let mut env = Environment::new();
34
35 let mj_store = MiniJinjaLoader { store };
36 env.set_loader(move |name: &str| mj_store.load(name));
37 env.add_function("gen_uuid_v4", gen_uuid_v4);
38 env.add_function("gen_uuid_v5", gen_uuid_v5);
39 env.add_filter("escape_identifier", escape_identifier_filter);
40
41 let dialect = engine_to_dialect(engine);
43
44 env.set_auto_escape_callback(crate::sql_formatter::get_auto_escape_callback(dialect));
46
47 env.set_formatter(crate::sql_formatter::get_formatter(dialect));
49
50 Ok(env)
51}
52
53struct MiniJinjaLoader {
54 pub store: Store,
55}
56
57impl MiniJinjaLoader {
58 pub fn load(&self, name: &str) -> std::result::Result<Option<String>, minijinja::Error> {
59 let result = tokio::task::block_in_place(|| {
60 tokio::runtime::Handle::current()
61 .block_on(async { self.store.load_component(name).await })
62 });
63
64 result.map_err(|e| {
65 minijinja::Error::new(
66 minijinja::ErrorKind::InvalidOperation,
67 format!("Failed to load from object store: {}", e),
68 )
69 })
70 }
71}
72
73fn gen_uuid_v4() -> Result<String, minijinja::Error> {
74 Ok(Uuid::new_v4().to_string())
75}
76
77fn gen_uuid_v5(seed: &str) -> Result<String, minijinja::Error> {
78 Ok(Uuid::new_v5(&Uuid::NAMESPACE_DNS, seed.as_bytes()).to_string())
79}
80
81fn escape_identifier_filter(value: &Value) -> Result<Value, minijinja::Error> {
88 let s = value.to_string();
89 let escaped = EscapedIdentifier::new(&s);
90 Ok(Value::from_safe_string(escaped.to_string()))
92}
93
94pub struct Generation {
95 pub content: String,
96}
97
98pub struct StreamingGeneration {
101 store: Store,
102 template_contents: String,
103 environment: String,
104 variables: Variables,
105 engine: EngineType,
106}
107
108impl StreamingGeneration {
109 pub fn render_to_writer<W: std::io::Write + ?Sized>(self, writer: &mut W) -> Result<()> {
112 let mut env = template_env(self.store, &self.engine)?;
113 env.add_template("migration.sql", &self.template_contents)?;
114 let tmpl = env.get_template("migration.sql")?;
115 tmpl.render_to_write(
116 context!(env => self.environment, variables => self.variables),
117 writer,
118 )?;
119 Ok(())
120 }
121
122 pub fn into_writer_fn(self) -> crate::engine::WriterFn {
124 Box::new(move |writer: &mut dyn std::io::Write| {
125 self.render_to_writer(writer)
126 .map_err(|e| std::io::Error::new(std::io::ErrorKind::Other, e))
127 })
128 }
129}
130
131pub async fn generate_streaming(
134 cfg: &config::Config,
135 lock_file: Option<String>,
136 name: &str,
137 variables: Option<Variables>,
138) -> Result<StreamingGeneration> {
139 let pinner: Box<dyn Pinner> = if let Some(lock_file) = lock_file {
140 let lock = cfg
141 .load_lock_file(&lock_file)
142 .await
143 .context("could not load pinned files lock file")?;
144 let pinner = Spawn::new_with_root_hash(
145 cfg.pather().pinned_folder(),
146 cfg.pather().components_folder(),
147 &lock.pin,
148 &cfg.operator(),
149 )
150 .await
151 .context("could not get new root with hash")?;
152 Box::new(pinner)
153 } else {
154 let pinner = Latest::new(cfg.pather().spawn_folder_path())?;
155 Box::new(pinner)
156 };
157
158 let store = Store::new(pinner, cfg.operator().clone(), cfg.pather())
159 .context("could not create new store for generate")?;
160 let db_config = cfg
161 .db_config()
162 .context("could not get db config for generate")?;
163
164 generate_streaming_with_store(
165 name,
166 variables,
167 &db_config.environment,
168 &db_config.engine,
169 store,
170 )
171 .await
172}
173
174pub async fn generate_streaming_with_store(
176 name: &str,
177 variables: Option<Variables>,
178 environment: &str,
179 engine: &EngineType,
180 store: Store,
181) -> Result<StreamingGeneration> {
182 let contents = store
184 .load_migration(name)
185 .await
186 .context("generate_streaming_with_store could not read migration")?;
187
188 Ok(StreamingGeneration {
189 store,
190 template_contents: contents,
191 environment: environment.to_string(),
192 variables: variables.unwrap_or_default(),
193 engine: engine.clone(),
194 })
195}
196
197#[cfg(test)]
198mod tests {
199 use super::*;
200 use crate::sql_formatter::{get_auto_escape_callback, get_formatter};
201 use minijinja::{context, Environment, Value};
202
203 fn render_sql_value(value: Value) -> String {
205 let mut env = Environment::new();
206 env.set_auto_escape_callback(get_auto_escape_callback(SqlDialect::Postgres));
207 env.set_formatter(get_formatter(SqlDialect::Postgres));
208 env.add_template("test.sql", "{{ value }}").unwrap();
209 let tmpl = env.get_template("test.sql").unwrap();
210 tmpl.render(context!(value => value)).unwrap()
211 }
212
213 #[test]
214 fn test_engine_to_dialect_postgres_psql() {
215 let dialect = engine_to_dialect(&EngineType::PostgresPSQL);
216 assert_eq!(dialect, SqlDialect::Postgres);
217 }
218
219 #[test]
223 fn test_sql_escape_string() {
224 let result = render_sql_value(Value::from("hello"));
225 assert_eq!(result, "'hello'");
226 }
227
228 #[test]
229 fn test_sql_escape_string_injection_attempt() {
230 let result = render_sql_value(Value::from("'; DROP TABLE users; --"));
231 assert_eq!(result, "'''; DROP TABLE users; --'");
232 }
233
234 #[test]
235 fn test_sql_escape_integer() {
236 let result = render_sql_value(Value::from(42));
237 assert_eq!(result, "42");
238 }
239
240 #[test]
241 fn test_sql_escape_bool() {
242 let result = render_sql_value(Value::from(true));
243 assert_eq!(result, "TRUE");
244 }
245
246 #[test]
247 fn test_sql_escape_none() {
248 let result = render_sql_value(Value::from(()));
249 assert_eq!(result, "NULL");
250 }
251
252 #[test]
253 fn test_sql_escape_seq() {
254 let result = render_sql_value(Value::from(vec![1, 2, 3]));
255 assert_eq!(result, "ARRAY[1, 2, 3]");
256 }
257
258 #[test]
259 fn test_sql_escape_bytes() {
260 let bytes = Value::from_bytes(vec![0xDE, 0xAD, 0xBE, 0xEF]);
261 let result = render_sql_value(bytes);
262 assert_eq!(result, "'\\xdeadbeef'::bytea");
263 }
264
265 #[test]
266 fn test_sql_escape_for_non_sql_templates() {
267 let mut env = Environment::new();
268 env.set_auto_escape_callback(get_auto_escape_callback(SqlDialect::Postgres));
269 env.set_formatter(get_formatter(SqlDialect::Postgres));
270 env.add_template("test.txt", "{{ value }}").unwrap();
272 let tmpl = env.get_template("test.txt").unwrap();
273 let result = tmpl.render(context!(value => "hello")).unwrap();
274 assert_eq!(result, "'hello'");
276 }
277
278 #[test]
279 fn test_sql_safe_filter_bypasses_escaping() {
280 let mut env = Environment::new();
281 env.set_auto_escape_callback(get_auto_escape_callback(SqlDialect::Postgres));
282 env.set_formatter(get_formatter(SqlDialect::Postgres));
283 env.add_template("test.sql", "{{ value|safe }}").unwrap();
285 let tmpl = env.get_template("test.sql").unwrap();
286 let result = tmpl.render(context!(value => "raw SQL here")).unwrap();
287 assert_eq!(result, "raw SQL here");
289 }
290
291 #[test]
292 fn test_sql_escape_only_on_output_not_in_loops() {
293 let mut env = Environment::new();
294 env.set_auto_escape_callback(get_auto_escape_callback(SqlDialect::Postgres));
295 env.set_formatter(get_formatter(SqlDialect::Postgres));
296
297 let template =
298 r#"{% for item in items %}{{ item }}{% if not loop.last %}, {% endif %}{% endfor %}"#;
299 env.add_template("test.sql", template).unwrap();
300 let tmpl = env.get_template("test.sql").unwrap();
301
302 let items = vec!["alice", "bob", "charlie"];
303 let result = tmpl.render(context!(items => items)).unwrap();
304 assert_eq!(result, "'alice', 'bob', 'charlie'");
305 }
306}