1use std::{
4 borrow::Cow,
5 fmt,
6 sync::{
7 RwLock,
8 RwLockReadGuard,
9 RwLockWriteGuard,
10 atomic::{AtomicBool, Ordering},
11 },
12};
13
14use minijinja::{
15 Environment,
16 Error as JinjaError,
17 ErrorKind as JinjaErrorKind,
18 State,
19 args,
20 value::{Rest as JinjaRest, Value as JinjaValue},
21};
22use mlua::prelude::{
23 Lua,
24 LuaError,
25 LuaFunction,
26 LuaMultiValue,
27 LuaSerdeExt,
28 LuaTable,
29 LuaUserData,
30 LuaUserDataMethods,
31 LuaValue,
32};
33
34use crate::{
35 convert::{
36 LuaFunctionObject,
37 LuaObject,
38 lua_to_auto_escape,
39 lua_to_minijinja,
40 lua_to_syntax_config,
41 lua_to_undefined_behavior,
42 minijinja_to_lua,
43 undefined_behavior_to_lua,
44 },
45 state::bind_lua,
46};
47
48#[derive(Debug)]
51pub struct LuaEnvironment {
52 env: RwLock<Environment<'static>>,
53 reload_before_render: AtomicBool,
54}
55
56impl LuaEnvironment {
57 pub(crate) fn new() -> Self {
59 let mut env = Environment::new();
60 minijinja_contrib::add_to_environment(&mut env);
61
62 Self {
63 env: RwLock::new(env),
64 reload_before_render: AtomicBool::new(false),
65 }
66 }
67
68 pub(crate) fn empty() -> Self {
70 let env = Environment::empty();
71
72 Self {
73 env: RwLock::new(env),
74 reload_before_render: AtomicBool::new(false),
75 }
76 }
77
78 pub(crate) fn reload_before_render(&self) -> bool {
79 self.reload_before_render.load(Ordering::Relaxed)
80 }
81
82 pub(crate) fn set_reload_before_render(&self, enable: bool) {
83 self.reload_before_render.store(enable, Ordering::Relaxed);
84 }
85
86 pub(crate) fn read_env(&self) -> Result<RwLockReadGuard<'_, Environment<'static>>, LuaError> {
88 self.env
89 .read()
90 .map_err(|_| LuaError::runtime("environment lock poisoned"))
91 }
92
93 pub(crate) fn write_env(&self) -> Result<RwLockWriteGuard<'_, Environment<'static>>, LuaError> {
95 self.env
96 .write()
97 .map_err(|_| LuaError::runtime("environment lock poisoned"))
98 }
99}
100
101impl Default for LuaEnvironment {
102 fn default() -> Self {
103 Self::new()
104 }
105}
106
107impl fmt::Display for LuaEnvironment {
108 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
109 write!(f, "Environment")
110 }
111}
112
113impl LuaUserData for LuaEnvironment {
114 fn add_fields<F: mlua::UserDataFields<Self>>(fields: &mut F) {
115 fields.add_field_method_get(
116 "reload_before_render",
117 |_, this: &LuaEnvironment| -> Result<bool, LuaError> {
118 Ok(this.reload_before_render())
119 },
120 );
121
122 fields.add_field_method_set(
123 "reload_before_render",
124 |_, this: &mut LuaEnvironment, val: bool| -> Result<(), LuaError> {
125 this.set_reload_before_render(val);
126
127 Ok(())
128 },
129 );
130
131 fields.add_field_method_get(
132 "keep_trailing_newline",
133 |_, this: &LuaEnvironment| -> Result<bool, LuaError> {
134 Ok(this.read_env()?.keep_trailing_newline())
135 },
136 );
137
138 fields.add_field_method_set(
139 "keep_trailing_newline",
140 |_, this: &mut LuaEnvironment, val: bool| -> Result<(), LuaError> {
141 this.write_env()?.set_keep_trailing_newline(val);
142
143 Ok(())
144 },
145 );
146
147 fields.add_field_method_get(
148 "trim_blocks",
149 |_, this: &LuaEnvironment| -> Result<bool, LuaError> {
150 Ok(this.read_env()?.trim_blocks())
151 },
152 );
153
154 fields.add_field_method_set(
155 "trim_blocks",
156 |_, this: &mut LuaEnvironment, val: bool| -> Result<(), LuaError> {
157 this.write_env()?.set_trim_blocks(val);
158
159 Ok(())
160 },
161 );
162
163 fields.add_field_method_get(
164 "lstrip_blocks",
165 |_, this: &LuaEnvironment| -> Result<bool, LuaError> {
166 Ok(this.read_env()?.lstrip_blocks())
167 },
168 );
169
170 fields.add_field_method_set(
171 "lstrip_blocks",
172 |_, this: &mut LuaEnvironment, val: bool| -> Result<(), LuaError> {
173 this.write_env()?.set_lstrip_blocks(val);
174
175 Ok(())
176 },
177 );
178
179 fields.add_field_method_get(
180 "debug",
181 |_, this: &LuaEnvironment| -> Result<bool, LuaError> { Ok(this.read_env()?.debug()) },
182 );
183
184 fields.add_field_method_set(
185 "debug",
186 |_, this: &mut LuaEnvironment, val: bool| -> Result<(), LuaError> {
187 this.write_env()?.set_debug(val);
188
189 Ok(())
190 },
191 );
192
193 fields.add_field_method_get(
194 "fuel",
195 |_, this: &LuaEnvironment| -> Result<Option<u64>, LuaError> {
196 Ok(this.read_env()?.fuel())
197 },
198 );
199
200 fields.add_field_method_set(
201 "fuel",
202 |_, this: &mut LuaEnvironment, val: Option<u64>| -> Result<(), LuaError> {
203 this.write_env()?.set_fuel(val);
204
205 Ok(())
206 },
207 );
208
209 fields.add_field_method_get(
210 "recursion_limit",
211 |_, this: &LuaEnvironment| -> Result<usize, LuaError> {
212 Ok(this.read_env()?.recursion_limit())
213 },
214 );
215
216 fields.add_field_method_set(
217 "recursion_limit",
218 |_, this: &mut LuaEnvironment, val: usize| -> Result<(), LuaError> {
219 this.write_env()?.set_recursion_limit(val);
220
221 Ok(())
222 },
223 );
224
225 fields.add_field_method_get(
226 "undefined_behavior",
227 |_, this: &LuaEnvironment| -> Result<Option<String>, LuaError> {
228 let ub = this.read_env()?.undefined_behavior();
229
230 Ok(undefined_behavior_to_lua(ub))
231 },
232 );
233
234 fields.add_field_method_set(
235 "undefined_behavior",
236 |_, this: &mut LuaEnvironment, val: String| -> Result<(), LuaError> {
237 let behavior = lua_to_undefined_behavior(&val)?;
238
239 this.write_env()?.set_undefined_behavior(behavior);
240
241 Ok(())
242 },
243 );
244 }
245
246 fn add_methods<M: LuaUserDataMethods<Self>>(methods: &mut M) {
247 methods.add_function("new", |_, _: LuaMultiValue| -> Result<LuaEnvironment, _> {
248 Ok(LuaEnvironment::new())
249 });
250
251 methods.add_function(
252 "empty",
253 |_, _: LuaMultiValue| -> Result<LuaEnvironment, _> { Ok(LuaEnvironment::empty()) },
254 );
255
256 methods.add_method(
257 "add_template",
258 |lua: &Lua,
259 this: &LuaEnvironment,
260 (name, source): (String, String)|
261 -> Result<(), LuaError> {
262 bind_lua(lua, || {
263 this.write_env()?
264 .add_template_owned(name, source)
265 .map_err(LuaError::external)
266 })
267 },
268 );
269
270 methods.add_method(
271 "remove_template",
272 |lua: &Lua, this: &LuaEnvironment, name: String| -> Result<(), LuaError> {
273 bind_lua(lua, || {
274 this.write_env()?.remove_template(&name);
275 Ok(())
276 })
277 },
278 );
279
280 methods.add_method(
281 "clear_templates",
282 |lua: &Lua, this: &LuaEnvironment, _: LuaValue| -> Result<(), LuaError> {
283 bind_lua(lua, || {
284 this.write_env()?.clear_templates();
285
286 Ok(())
287 })
288 },
289 );
290
291 methods.add_method(
292 "undeclared_variables",
293 |lua: &Lua,
294 this: &LuaEnvironment,
295 (name, nested): (String, Option<bool>)|
296 -> Result<LuaValue, LuaError> {
297 bind_lua(lua, || {
298 if this.reload_before_render() {
299 this.write_env()?.clear_templates();
300 }
301
302 let nested = nested.unwrap_or(false);
303
304 let vars = this
305 .read_env()?
306 .get_template(&name)
307 .map_err(LuaError::external)?
308 .undeclared_variables(nested);
309
310 lua.to_value(&vars)
311 })
312 },
313 );
314
315 methods.add_method(
316 "set_loader",
317 |lua: &Lua, this: &LuaEnvironment, callback: LuaFunction| -> Result<(), LuaError> {
318 let key = lua.create_registry_value(callback)?;
319 let func = LuaFunctionObject::new(key);
320
321 this.write_env()?.set_loader(move |name| {
322 let source = func.with_func(args!(name), None)?;
323 Ok(source.and_then(|v| {
324 if v.is_undefined() {
328 None
329 } else {
330 Some(v.to_string())
331 }
332 }))
333 });
334
335 Ok(())
336 },
337 );
338
339 methods.add_method(
340 "set_path_join_callback",
341 |lua: &Lua, this: &LuaEnvironment, callback: LuaFunction| -> Result<(), LuaError> {
342 let key = lua.create_registry_value(callback)?;
343 let func = LuaFunctionObject::new(key);
344
345 this.write_env()?
346 .set_path_join_callback(move |name, parent| {
347 func.with_func(args!(name, parent), None)
348 .ok()
349 .flatten()
350 .and_then(|v| v.as_str().map(|s| Cow::Owned(s.to_string())))
351 .unwrap_or(Cow::Borrowed(name))
352 });
353 Ok(())
354 },
355 );
356
357 methods.add_method(
358 "set_unknown_method_callback",
359 |lua: &Lua, this: &LuaEnvironment, callback: LuaFunction| -> Result<(), LuaError> {
360 let key = lua.create_registry_value(callback)?;
361 let mut func = LuaFunctionObject::new(key);
362 func.set_pass_state(true);
363
364 this.write_env()?
365 .set_unknown_method_callback(move |state, value, method, args| {
366 func.with_func(args!(value, method, ..args), Some(state))
367 .map(|v| v.unwrap_or(JinjaValue::UNDEFINED))
368 });
369
370 Ok(())
371 },
372 );
373
374 methods.add_method(
375 "set_auto_escape_callback",
376 |lua: &Lua, this: &LuaEnvironment, callback: LuaFunction| -> Result<(), LuaError> {
377 let key = lua.create_registry_value(callback)?;
378 let func = LuaFunctionObject::new(key);
379
380 this.write_env()?.set_auto_escape_callback(move |name| {
381 func.with_func(args!(name), None)
382 .ok()
383 .flatten()
384 .and_then(|v| {
385 let s = v.as_str()?.to_string();
386 lua_to_auto_escape(&s).ok()
387 })
388 .unwrap_or(minijinja::AutoEscape::None)
389 });
390 Ok(())
391 },
392 );
393
394 methods.add_method(
395 "set_formatter",
396 |lua: &Lua, this: &LuaEnvironment, callback: LuaFunction| -> Result<(), LuaError> {
397 let key = lua.create_registry_value(callback)?;
398 let mut func = LuaFunctionObject::new(key);
399 func.set_pass_state(true);
400
401 this.write_env()?.set_formatter(move |out, state, value| {
402 let Some(val) = func.with_func(args!(value), Some(state)).ok().flatten() else {
403 return Ok(());
404 };
405
406 let Some(s) = val.as_str() else {
407 return Err(JinjaError::new(
408 JinjaErrorKind::WriteFailure,
409 "formatter must return a string",
410 ));
411 };
412
413 out.write_str(s)
414 .map_err(|_| JinjaError::new(JinjaErrorKind::WriteFailure, "write failed"))
415 });
416
417 Ok(())
418 },
419 );
420
421 methods.add_method(
422 "set_syntax",
423 |_, this: &LuaEnvironment, syntax: LuaTable| -> Result<(), LuaError> {
424 let config = lua_to_syntax_config(syntax).map_err(LuaError::external)?;
425 this.write_env()?.set_syntax(config);
426
427 Ok(())
428 },
429 );
430
431 methods.add_method(
432 "render_template",
433 |lua: &Lua,
434 this: &LuaEnvironment,
435 (name, ctx): (String, Option<LuaTable>)|
436 -> Result<String, LuaError> {
437 if this.reload_before_render() {
438 this.write_env()?.clear_templates();
439 }
440
441 let ctx = ctx.unwrap_or(lua.create_table()?);
442
443 let context = lua_to_minijinja(lua, &LuaValue::Table(ctx));
444
445 bind_lua(lua, || {
446 this.read_env()?
447 .get_template(&name)
448 .map_err(LuaError::external)?
449 .render(context)
450 .map_err(LuaError::external)
451 })
452 },
453 );
454
455 methods.add_method(
456 "render_str",
457 |lua: &Lua,
458 this: &LuaEnvironment,
459 (source, ctx, name): (String, Option<LuaTable>, Option<String>)|
460 -> Result<String, LuaError> {
461 let ctx = ctx.unwrap_or(lua.create_table()?);
462
463 let name = name.unwrap_or("<string>".to_string());
464 let context = lua_to_minijinja(lua, &LuaValue::Table(ctx));
465
466 bind_lua(lua, || {
467 this.read_env()?
468 .render_named_str(&name, &source, context)
469 .map_err(LuaError::external)
470 })
471 },
472 );
473
474 methods.add_method(
475 "eval",
476 |lua: &Lua,
477 this: &LuaEnvironment,
478 (source, ctx): (String, Option<LuaTable>)|
479 -> Result<LuaValue, LuaError> {
480 let ctx = ctx.unwrap_or(lua.create_table()?);
481
482 let context = lua_to_minijinja(lua, &LuaValue::Table(ctx));
483
484 bind_lua(lua, || {
485 let expr = this
486 .read_env()?
487 .compile_expression(&source)
488 .map_err(LuaError::external)?
489 .eval(&context)
490 .map_err(LuaError::external)?;
491
492 minijinja_to_lua(lua, &expr).ok_or_else(|| LuaError::ToLuaConversionError {
493 from: "".to_string(),
494 to: "",
495 message: None,
496 })
497 })
498 },
499 );
500
501 methods.add_method(
502 "add_filter",
503 |lua: &Lua,
504 this: &LuaEnvironment,
505 (name, filter, pass_state): (String, LuaFunction, Option<bool>)|
506 -> Result<(), LuaError> {
507 let key = lua.create_registry_value(filter)?;
508 let mut func = LuaFunctionObject::new(key);
509 func.set_pass_state(pass_state.unwrap_or(true));
510
511 this.write_env()?.add_filter(
512 name,
513 move |state: &State, args: JinjaRest<JinjaValue>| {
514 func.with_func(&args, Some(state))
515 },
516 );
517
518 Ok(())
519 },
520 );
521
522 methods.add_method(
523 "remove_filter",
524 |_, this: &LuaEnvironment, name: String| -> Result<(), LuaError> {
525 this.write_env()?.remove_filter(&name);
526
527 Ok(())
528 },
529 );
530
531 methods.add_method(
532 "add_test",
533 |lua: &Lua,
534 this: &LuaEnvironment,
535 (name, test, pass_state): (String, LuaFunction, Option<bool>)|
536 -> Result<(), LuaError> {
537 let key = lua.create_registry_value(test)?;
538 let mut func = LuaFunctionObject::new(key);
539 func.set_pass_state(pass_state.unwrap_or(true));
540
541 this.write_env()?.add_test(
542 name,
543 move |state: &State, args: JinjaRest<JinjaValue>| {
544 func.with_func(&args, Some(state))
545 },
546 );
547
548 Ok(())
549 },
550 );
551
552 methods.add_method(
553 "remove_test",
554 |_, this: &LuaEnvironment, name: String| -> Result<(), LuaError> {
555 this.write_env()?.remove_test(&name);
556
557 Ok(())
558 },
559 );
560
561 methods.add_method(
562 "add_global",
563 |lua: &Lua,
564 this: &LuaEnvironment,
565 (name, val, pass_state): (String, LuaValue, Option<bool>)|
566 -> Result<(), LuaError> {
567 match val {
568 LuaValue::Function(f) => {
569 let key = lua.create_registry_value(f)?;
570 let mut func = LuaFunctionObject::new(key);
571 func.set_pass_state(pass_state.unwrap_or(true));
572
573 this.write_env()?.add_function(
574 name,
575 move |state: &State, args: JinjaRest<JinjaValue>| {
576 func.with_func(&args, Some(state))
577 },
578 )
579 },
580 _ => this
581 .write_env()?
582 .add_global(name, lua_to_minijinja(lua, &val)),
583 };
584
585 Ok(())
586 },
587 );
588
589 methods.add_method(
590 "remove_global",
591 |_, this: &LuaEnvironment, name: String| -> Result<(), LuaError> {
592 this.write_env()?.remove_global(&name);
593
594 Ok(())
595 },
596 );
597
598 methods.add_method(
599 "globals",
600 |lua: &Lua, this: &LuaEnvironment, _val: LuaValue| -> Result<LuaTable, LuaError> {
601 let table = lua.create_table()?;
602
603 for (name, value) in this.read_env()?.globals() {
604 let val = minijinja_to_lua(lua, &value);
605 table.set(name, val)?;
606 }
607
608 Ok(table)
609 },
610 );
611 }
612}