1use anyhow::{Context, Result};
7use mlua::{Function, Lua, LuaSerdeExt, Table, Value};
8use serde::Deserialize;
9use std::collections::HashMap;
10use std::path::{Path, PathBuf};
11
12#[derive(Debug, Clone)]
14pub struct ConfigData {
15 pub site: SiteConfig,
16 pub seo: SeoConfig,
17 pub build: BuildConfig,
18 pub images: ImagesConfig,
19 pub highlight: HighlightConfig,
20 pub paths: PathsConfig,
21 pub templates: TemplatesConfig,
22 pub permalinks: PermalinksConfig,
23 pub encryption: EncryptionConfig,
24 pub graph: GraphConfig,
25 pub rss: RssConfig,
26 pub text: TextConfig,
27 pub sections: SectionsConfig,
28}
29
30pub struct Config {
32 pub data: ConfigData,
34
35 lua: Lua,
37 computed: HashMap<String, mlua::RegistryKey>,
38 filters: HashMap<String, mlua::RegistryKey>,
39 functions: HashMap<String, mlua::RegistryKey>,
40 computed_pages: Option<mlua::RegistryKey>,
41 sort_fns: HashMap<String, mlua::RegistryKey>,
42 before_build: Option<mlua::RegistryKey>,
43 after_build: Option<mlua::RegistryKey>,
44}
45
46impl std::ops::Deref for Config {
48 type Target = ConfigData;
49 fn deref(&self) -> &Self::Target {
50 &self.data
51 }
52}
53
54impl std::ops::DerefMut for Config {
55 fn deref_mut(&mut self) -> &mut Self::Target {
56 &mut self.data
57 }
58}
59
60#[derive(Debug, Deserialize, Clone)]
61pub struct SiteConfig {
62 pub title: String,
63 pub description: String,
64 pub base_url: String,
65 pub author: String,
66}
67
68#[derive(Debug, Deserialize, Clone, Default)]
69pub struct SeoConfig {
70 pub twitter_handle: Option<String>,
71 pub default_og_image: Option<String>,
72}
73
74#[derive(Debug, Deserialize, Clone)]
75pub struct BuildConfig {
76 pub output_dir: String,
77 #[serde(default = "default_true")]
78 pub minify_css: bool,
79 #[serde(default = "default_css_output")]
80 pub css_output: String,
81}
82
83fn default_css_output() -> String {
84 "rs.css".to_string()
85}
86
87#[derive(Debug, Deserialize, Clone)]
88pub struct ImagesConfig {
89 #[serde(default = "default_quality")]
90 pub quality: f32,
91 #[serde(default = "default_scale_factor")]
92 pub scale_factor: f64,
93}
94
95fn default_quality() -> f32 {
96 85.0
97}
98
99fn default_scale_factor() -> f64 {
100 1.0
101}
102
103#[derive(Debug, Deserialize, Clone, Default)]
104pub struct HighlightConfig {
105 #[serde(default)]
106 pub names: Vec<String>,
107 #[serde(default = "default_highlight_class")]
108 pub class: String,
109}
110
111fn default_highlight_class() -> String {
112 "me".to_string()
113}
114
115#[derive(Debug, Deserialize, Clone)]
116pub struct PathsConfig {
117 #[serde(default = "default_content_dir")]
118 pub content: String,
119 #[serde(default = "default_styles_dir")]
120 pub styles: String,
121 #[serde(default = "default_static_dir")]
122 pub static_files: String,
123 #[serde(default = "default_templates_dir")]
124 pub templates: String,
125 #[serde(default = "default_home_page")]
126 pub home: String,
127 #[serde(default)]
128 pub exclude: Vec<String>,
129 #[serde(default = "default_true")]
130 pub exclude_defaults: bool,
131 #[serde(default = "default_true")]
132 pub respect_gitignore: bool,
133}
134
135impl Default for PathsConfig {
136 fn default() -> Self {
137 Self {
138 content: default_content_dir(),
139 styles: default_styles_dir(),
140 static_files: default_static_dir(),
141 templates: default_templates_dir(),
142 home: default_home_page(),
143 exclude: Vec::new(),
144 exclude_defaults: true,
145 respect_gitignore: true,
146 }
147 }
148}
149
150fn default_content_dir() -> String {
151 "content".to_string()
152}
153fn default_styles_dir() -> String {
154 "styles".to_string()
155}
156fn default_static_dir() -> String {
157 "static".to_string()
158}
159fn default_templates_dir() -> String {
160 "templates".to_string()
161}
162fn default_home_page() -> String {
163 "index.md".to_string()
164}
165
166#[derive(Debug, Deserialize, Clone, Default)]
168pub struct TemplatesConfig {
169 #[serde(flatten)]
170 pub sections: HashMap<String, String>,
171}
172
173#[derive(Debug, Deserialize, Clone, Default)]
175pub struct PermalinksConfig {
176 #[serde(flatten)]
177 pub sections: HashMap<String, String>,
178}
179
180#[derive(Debug, Deserialize, Clone, Default)]
182pub struct EncryptionConfig {
183 pub password_command: Option<String>,
184 pub password: Option<String>,
185}
186
187#[derive(Debug, Deserialize, Clone)]
189pub struct GraphConfig {
190 #[serde(default = "default_true")]
191 pub enabled: bool,
192 #[serde(default = "default_graph_template")]
193 pub template: String,
194 #[serde(default = "default_graph_path")]
195 pub path: String,
196}
197
198impl Default for GraphConfig {
199 fn default() -> Self {
200 Self {
201 enabled: true,
202 template: default_graph_template(),
203 path: default_graph_path(),
204 }
205 }
206}
207
208fn default_graph_template() -> String {
209 "graph.html".to_string()
210}
211
212fn default_graph_path() -> String {
213 "graph".to_string()
214}
215
216#[derive(Debug, Deserialize, Clone)]
218pub struct RssConfig {
219 #[serde(default = "default_true")]
220 pub enabled: bool,
221 #[serde(default = "default_rss_filename")]
222 pub filename: String,
223 #[serde(default)]
224 pub sections: Vec<String>,
225 #[serde(default = "default_rss_limit")]
226 pub limit: usize,
227 #[serde(default)]
228 pub exclude_encrypted_blocks: bool,
229}
230
231impl Default for RssConfig {
232 fn default() -> Self {
233 Self {
234 enabled: true,
235 filename: default_rss_filename(),
236 sections: Vec::new(),
237 limit: default_rss_limit(),
238 exclude_encrypted_blocks: false,
239 }
240 }
241}
242
243fn default_rss_filename() -> String {
244 "rss.xml".to_string()
245}
246
247fn default_rss_limit() -> usize {
248 20
249}
250
251#[derive(Debug, Deserialize, Clone)]
253pub struct TextConfig {
254 #[serde(default)]
255 pub enabled: bool,
256 #[serde(default)]
257 pub sections: Vec<String>,
258 #[serde(default)]
259 pub exclude_encrypted: bool,
260 #[serde(default = "default_true")]
261 pub include_home: bool,
262}
263
264impl Default for TextConfig {
265 fn default() -> Self {
266 Self {
267 enabled: false,
268 sections: Vec::new(),
269 exclude_encrypted: false,
270 include_home: true,
271 }
272 }
273}
274
275#[derive(Debug, Deserialize, Clone, Default)]
277pub struct SectionsConfig {
278 #[serde(flatten)]
279 pub sections: HashMap<String, SectionConfig>,
280}
281
282#[derive(Debug, Deserialize, Clone)]
284pub struct SectionConfig {
285 #[serde(default = "default_iterate")]
287 pub iterate: String,
288}
289
290impl Default for SectionConfig {
291 fn default() -> Self {
292 Self {
293 iterate: default_iterate(),
294 }
295 }
296}
297
298fn default_iterate() -> String {
299 "files".to_string()
300}
301
302fn default_true() -> bool {
303 true
304}
305
306#[derive(Debug, Clone, serde::Deserialize)]
308pub struct ComputedPage {
309 pub path: String,
311 pub template: String,
313 pub title: String,
315 pub data: serde_json::Value,
317}
318
319impl Config {
320 #[cfg(test)]
322 pub fn from_data(data: ConfigData) -> Self {
323 let lua = Lua::new();
324 Self {
325 data,
326 lua,
327 computed: HashMap::new(),
328 filters: HashMap::new(),
329 functions: HashMap::new(),
330 computed_pages: None,
331 sort_fns: HashMap::new(),
332 before_build: None,
333 after_build: None,
334 }
335 }
336
337 pub fn load<P: AsRef<Path>>(path: P) -> Result<Self> {
339 let path = path.as_ref();
340
341 let config_path = if path.is_dir() {
343 let lua_path = path.join("config.lua");
344 if lua_path.exists() {
345 lua_path
346 } else {
347 anyhow::bail!("No config.lua found in {:?}", path);
348 }
349 } else {
350 path.to_path_buf()
351 };
352
353 let lua = Lua::new();
354
355 let project_root = config_path
357 .parent()
358 .map(|p| {
359 if p.as_os_str().is_empty() {
360 PathBuf::from(".")
361 } else {
362 p.to_path_buf()
363 }
364 })
365 .unwrap_or_else(|| PathBuf::from("."));
366 let project_root = project_root
367 .canonicalize()
368 .unwrap_or_else(|_| project_root.clone());
369
370 register_lua_functions(&lua, &project_root, false)
372 .map_err(|e| anyhow::anyhow!("Failed to register Lua functions: {}", e))?;
373
374 let content = std::fs::read_to_string(&config_path)
376 .with_context(|| format!("Failed to read config file: {:?}", config_path))?;
377
378 let config_table: Table = lua
379 .load(&content)
380 .set_name(config_path.to_string_lossy())
381 .eval()
382 .map_err(|e| {
383 anyhow::anyhow!("Failed to execute config file {:?}: {}", config_path, e)
384 })?;
385
386 let sandbox = config_table
388 .get::<Table>("lua")
389 .ok()
390 .and_then(|t| t.get::<bool>("sandbox").ok())
391 .unwrap_or(true);
392
393 if sandbox {
395 register_lua_functions(&lua, &project_root, true)
396 .map_err(|e| anyhow::anyhow!("Failed to register Lua functions: {}", e))?;
397 }
398
399 let mut sort_fns = HashMap::new();
401 let data = parse_config(&lua, &config_table, &mut sort_fns)
402 .map_err(|e| anyhow::anyhow!("Failed to parse config: {}", e))?;
403
404 let computed = extract_functions(&lua, &config_table, "computed")
406 .map_err(|e| anyhow::anyhow!("Failed to extract computed functions: {}", e))?;
407
408 let filters = extract_functions(&lua, &config_table, "filters")
410 .map_err(|e| anyhow::anyhow!("Failed to extract filter functions: {}", e))?;
411
412 let functions = extract_functions(&lua, &config_table, "functions")
414 .map_err(|e| anyhow::anyhow!("Failed to extract custom functions: {}", e))?;
415
416 let computed_pages = if let Ok(func) = config_table.get::<Function>("computed_pages") {
418 Some(
419 lua.create_registry_value(func)
420 .map_err(|e| anyhow::anyhow!("Failed to store computed_pages: {}", e))?,
421 )
422 } else {
423 None
424 };
425
426 let hooks: Option<Table> = config_table.get("hooks").ok();
428 let before_build = if let Some(ref h) = hooks {
429 h.get::<Function>("before_build")
430 .ok()
431 .map(|f| lua.create_registry_value(f))
432 .transpose()
433 .map_err(|e| anyhow::anyhow!("Failed to store before_build hook: {}", e))?
434 } else {
435 None
436 };
437 let after_build = if let Some(ref h) = hooks {
438 h.get::<Function>("after_build")
439 .ok()
440 .map(|f| lua.create_registry_value(f))
441 .transpose()
442 .map_err(|e| anyhow::anyhow!("Failed to store after_build hook: {}", e))?
443 } else {
444 None
445 };
446
447 Ok(Config {
448 data,
449 lua,
450 computed,
451 filters,
452 functions,
453 computed_pages,
454 sort_fns,
455 before_build,
456 after_build,
457 })
458 }
459
460 pub fn call_computed(&self, name: &str, sections_json: &str) -> Result<serde_json::Value> {
462 let key = self
463 .computed
464 .get(name)
465 .with_context(|| format!("Computed function '{}' not found", name))?;
466
467 let func: Function = self
468 .lua
469 .registry_value(key)
470 .map_err(|e| anyhow::anyhow!("Failed to get computed function: {}", e))?;
471
472 let json_value: serde_json::Value = serde_json::from_str(sections_json)
473 .map_err(|e| anyhow::anyhow!("Invalid JSON: {}", e))?;
474 let sections: Value = self
475 .lua
476 .to_value(&json_value)
477 .map_err(|e| anyhow::anyhow!("Failed to convert to Lua: {}", e))?;
478
479 let result: Value = func
480 .call(sections)
481 .map_err(|e| anyhow::anyhow!("Failed to call computed '{}': {}", name, e))?;
482 let json_value: serde_json::Value = self
483 .lua
484 .from_value(result)
485 .map_err(|e| anyhow::anyhow!("Failed to convert result: {}", e))?;
486
487 Ok(json_value)
488 }
489
490 pub fn has_sort_fn(&self, section_name: &str) -> bool {
492 self.sort_fns.contains_key(section_name)
493 }
494
495 pub fn call_sort_fn(
497 &self,
498 section_name: &str,
499 a_json: &serde_json::Value,
500 b_json: &serde_json::Value,
501 ) -> Result<std::cmp::Ordering> {
502 let key = self
503 .sort_fns
504 .get(section_name)
505 .with_context(|| format!("Sort function for '{}' not found", section_name))?;
506
507 let func: Function = self
508 .lua
509 .registry_value(key)
510 .map_err(|e| anyhow::anyhow!("Failed to get sort function: {}", e))?;
511
512 let a: Value = self
513 .lua
514 .to_value(a_json)
515 .map_err(|e| anyhow::anyhow!("Failed to convert a to Lua: {}", e))?;
516 let b: Value = self
517 .lua
518 .to_value(b_json)
519 .map_err(|e| anyhow::anyhow!("Failed to convert b to Lua: {}", e))?;
520
521 let result: i32 = func
522 .call((a, b))
523 .map_err(|e| anyhow::anyhow!("Sort function failed: {}", e))?;
524
525 Ok(match result {
526 n if n < 0 => std::cmp::Ordering::Less,
527 n if n > 0 => std::cmp::Ordering::Greater,
528 _ => std::cmp::Ordering::Equal,
529 })
530 }
531
532 pub fn computed_names(&self) -> Vec<&str> {
534 self.computed.keys().map(|s| s.as_str()).collect()
535 }
536
537 pub fn filter_names(&self) -> Vec<&str> {
539 self.filters.keys().map(|s| s.as_str()).collect()
540 }
541
542 pub fn has_computed_pages(&self) -> bool {
544 self.computed_pages.is_some()
545 }
546
547 pub fn call_before_build(&self) -> Result<()> {
549 if let Some(ref key) = self.before_build {
550 let func: Function = self
551 .lua
552 .registry_value(key)
553 .map_err(|e| anyhow::anyhow!("Failed to get before_build: {}", e))?;
554 func.call::<()>(())
555 .map_err(|e| anyhow::anyhow!("before_build hook failed: {}", e))?;
556 }
557 Ok(())
558 }
559
560 pub fn call_after_build(&self) -> Result<()> {
562 if let Some(ref key) = self.after_build {
563 let func: Function = self
564 .lua
565 .registry_value(key)
566 .map_err(|e| anyhow::anyhow!("Failed to get after_build: {}", e))?;
567 func.call::<()>(())
568 .map_err(|e| anyhow::anyhow!("after_build hook failed: {}", e))?;
569 }
570 Ok(())
571 }
572
573 pub fn call_filter(&self, name: &str, value: &str) -> Result<String> {
575 let key = self
576 .filters
577 .get(name)
578 .with_context(|| format!("Filter '{}' not found", name))?;
579
580 let func: Function = self
581 .lua
582 .registry_value(key)
583 .map_err(|e| anyhow::anyhow!("Failed to get filter: {}", e))?;
584
585 let result: String = func
586 .call(value.to_string())
587 .map_err(|e| anyhow::anyhow!("Filter '{}' failed: {}", name, e))?;
588
589 Ok(result)
590 }
591
592 pub fn call_function(
594 &self,
595 name: &str,
596 args: Vec<serde_json::Value>,
597 ) -> Result<serde_json::Value> {
598 let key = self
599 .functions
600 .get(name)
601 .with_context(|| format!("Function '{}' not found", name))?;
602
603 let func: Function = self
604 .lua
605 .registry_value(key)
606 .map_err(|e| anyhow::anyhow!("Failed to get function: {}", e))?;
607
608 let lua_args: Vec<Value> = args
609 .into_iter()
610 .map(|v| self.lua.to_value(&v))
611 .collect::<mlua::Result<Vec<_>>>()
612 .map_err(|e| anyhow::anyhow!("Failed to convert args: {}", e))?;
613
614 let result: Value = func
615 .call(mlua::MultiValue::from_iter(lua_args))
616 .map_err(|e| anyhow::anyhow!("Function '{}' failed: {}", name, e))?;
617
618 let json_result: serde_json::Value = self
619 .lua
620 .from_value(result)
621 .map_err(|e| anyhow::anyhow!("Failed to convert result: {}", e))?;
622
623 Ok(json_result)
624 }
625
626 pub fn function_names(&self) -> Vec<&str> {
628 self.functions.keys().map(|s| s.as_str()).collect()
629 }
630
631 pub fn call_computed_pages(&self, sections_json: &str) -> Result<Vec<ComputedPage>> {
633 let key = match &self.computed_pages {
634 Some(k) => k,
635 None => return Ok(Vec::new()),
636 };
637
638 let func: Function = self
639 .lua
640 .registry_value(key)
641 .map_err(|e| anyhow::anyhow!("Failed to get computed_pages: {}", e))?;
642
643 let json_value: serde_json::Value = serde_json::from_str(sections_json)
644 .map_err(|e| anyhow::anyhow!("Invalid JSON: {}", e))?;
645 let sections: Value = self
646 .lua
647 .to_value(&json_value)
648 .map_err(|e| anyhow::anyhow!("Failed to convert to Lua: {}", e))?;
649
650 let result: Value = func
651 .call(sections)
652 .map_err(|e| anyhow::anyhow!("Failed to call computed_pages: {}", e))?;
653 let pages: Vec<ComputedPage> = self
654 .lua
655 .from_value(result)
656 .map_err(|e| anyhow::anyhow!("Failed to convert result: {}", e))?;
657
658 Ok(pages)
659 }
660}
661
662fn is_path_within_root(path: &Path, root: &Path) -> bool {
664 let resolved = if path.exists() {
666 path.canonicalize().ok()
667 } else {
668 path.parent()
670 .map(|p| {
671 if p.as_os_str().is_empty() {
672 PathBuf::from(".")
673 } else {
674 p.to_path_buf()
675 }
676 })
677 .and_then(|p| p.canonicalize().ok())
678 .map(|p| p.join(path.file_name().unwrap_or_default()))
679 };
680
681 match resolved {
682 Some(abs_path) => abs_path.starts_with(root),
683 None => false,
684 }
685}
686
687fn resolve_path(path: &str, root: &Path) -> PathBuf {
689 let p = Path::new(path);
690 if p.is_absolute() {
691 p.to_path_buf()
692 } else {
693 root.join(p)
694 }
695}
696
697fn register_lua_functions(lua: &Lua, project_root: &Path, sandbox: bool) -> mlua::Result<()> {
699 let globals = lua.globals();
700
701 globals.set("__sandbox_enabled", sandbox)?;
703 globals.set("__project_root", project_root.to_string_lossy().to_string())?;
704
705 let root = project_root.to_path_buf();
706
707 let root_clone = root.clone();
709 let load_json = lua.create_function(move |lua, path: String| {
710 let resolved = resolve_path(&path, &root_clone);
711 if sandbox && !is_path_within_root(&resolved, &root_clone) {
712 return Err(mlua::Error::RuntimeError(format!(
713 "Sandbox: cannot access '{}' outside project directory. Set lua.sandbox = false to disable.",
714 path
715 )));
716 }
717
718 let content = match std::fs::read_to_string(&resolved) {
719 Ok(c) => c,
720 Err(_) => return Ok(Value::Nil),
721 };
722
723 match serde_json::from_str::<serde_json::Value>(&content) {
724 Ok(v) => lua.to_value(&v),
725 Err(_) => Ok(Value::Nil),
726 }
727 })?;
728 globals.set("load_json", load_json)?;
729
730 let root_clone = root.clone();
732 let read_file = lua.create_function(move |lua, path: String| {
733 let resolved = resolve_path(&path, &root_clone);
734 if sandbox && !is_path_within_root(&resolved, &root_clone) {
735 return Err(mlua::Error::RuntimeError(format!(
736 "Sandbox: cannot access '{}' outside project directory. Set lua.sandbox = false to disable.",
737 path
738 )));
739 }
740
741 match std::fs::read_to_string(&resolved) {
742 Ok(content) => Ok(Value::String(lua.create_string(&content)?)),
743 Err(_) => Ok(Value::Nil),
744 }
745 })?;
746 globals.set("read_file", read_file)?;
747
748 let root_clone = root.clone();
750 let file_exists = lua.create_function(move |_, path: String| {
751 let resolved = resolve_path(&path, &root_clone);
752 if sandbox && !is_path_within_root(&resolved, &root_clone) {
753 return Err(mlua::Error::RuntimeError(format!(
754 "Sandbox: cannot access '{}' outside project directory. Set lua.sandbox = false to disable.",
755 path
756 )));
757 }
758 Ok(resolved.exists())
759 })?;
760 globals.set("file_exists", file_exists)?;
761
762 let root_clone = root.clone();
764 let list_files = lua.create_function(move |lua, (path, pattern): (String, Option<String>)| {
765 let resolved = resolve_path(&path, &root_clone);
766 if sandbox && !is_path_within_root(&resolved, &root_clone) {
767 return Err(mlua::Error::RuntimeError(format!(
768 "Sandbox: cannot access '{}' outside project directory. Set lua.sandbox = false to disable.",
769 path
770 )));
771 }
772
773 let pattern = pattern.unwrap_or_else(|| "*".to_string());
774 let glob_pattern = format!("{}/{}", resolved.display(), pattern);
775
776 let mut files = Vec::new();
777 if let Ok(entries) = glob::glob(&glob_pattern) {
778 for entry in entries.flatten() {
779 if sandbox && !is_path_within_root(&entry, &root_clone) {
781 continue;
782 }
783 if entry.is_file() {
784 let table = lua.create_table()?;
785 table.set("path", entry.to_string_lossy().to_string())?;
786 table.set(
787 "name",
788 entry
789 .file_name()
790 .map(|n| n.to_string_lossy().to_string())
791 .unwrap_or_default(),
792 )?;
793 table.set(
794 "stem",
795 entry
796 .file_stem()
797 .map(|n| n.to_string_lossy().to_string())
798 .unwrap_or_default(),
799 )?;
800 table.set(
801 "ext",
802 entry
803 .extension()
804 .map(|n| n.to_string_lossy().to_string())
805 .unwrap_or_default(),
806 )?;
807 files.push(table);
808 }
809 }
810 }
811
812 let result = lua.create_table()?;
813 for (i, file) in files.into_iter().enumerate() {
814 result.set(i + 1, file)?;
815 }
816 Ok(result)
817 })?;
818 globals.set("list_files", list_files)?;
819
820 let root_clone = root.clone();
822 let list_dirs = lua.create_function(move |lua, path: String| {
823 let resolved = resolve_path(&path, &root_clone);
824 if sandbox && !is_path_within_root(&resolved, &root_clone) {
825 return Err(mlua::Error::RuntimeError(format!(
826 "Sandbox: cannot access '{}' outside project directory. Set lua.sandbox = false to disable.",
827 path
828 )));
829 }
830
831 let mut dirs = Vec::new();
832 if let Ok(entries) = std::fs::read_dir(&resolved) {
833 for entry in entries.flatten() {
834 let entry_path = entry.path();
835 if sandbox && !is_path_within_root(&entry_path, &root_clone) {
837 continue;
838 }
839 if entry_path.is_dir()
840 && let Some(name) = entry_path.file_name().and_then(|n| n.to_str())
841 && !name.starts_with('.')
842 {
843 dirs.push(name.to_string());
844 }
845 }
846 }
847 dirs.sort();
848
849 let result = lua.create_table()?;
850 for (i, dir) in dirs.into_iter().enumerate() {
851 result.set(i + 1, dir)?;
852 }
853 Ok(result)
854 })?;
855 globals.set("list_dirs", list_dirs)?;
856
857 let root_clone = root.clone();
859 let write_file = lua.create_function(move |_, (path, content): (String, String)| {
860 let resolved = resolve_path(&path, &root_clone);
861 if sandbox && !is_path_within_root(&resolved, &root_clone) {
862 return Err(mlua::Error::RuntimeError(format!(
863 "Sandbox: cannot write '{}' outside project directory. Set lua.sandbox = false to disable.",
864 path
865 )));
866 }
867
868 if let Some(parent) = resolved.parent() {
870 let _ = std::fs::create_dir_all(parent);
871 }
872 match std::fs::write(&resolved, &content) {
873 Ok(_) => Ok(true),
874 Err(_) => Ok(false),
875 }
876 })?;
877 globals.set("write_file", write_file)?;
878
879 let env_fn = lua.create_function(|lua, name: String| match std::env::var(&name) {
881 Ok(val) => Ok(Value::String(lua.create_string(&val)?)),
882 Err(_) => Ok(Value::Nil),
883 })?;
884 globals.set("env", env_fn)?;
885
886 let print_fn = lua.create_function(|_, args: mlua::Variadic<String>| {
888 let msg = args
889 .iter()
890 .map(|s| s.as_str())
891 .collect::<Vec<_>>()
892 .join("\t");
893 log::info!("[Lua] {}", msg);
894 Ok(())
895 })?;
896 globals.set("print", print_fn)?;
897
898 register_async_helpers(lua)?;
900
901 register_parallel_functions(lua, project_root, sandbox)?;
903
904 Ok(())
905}
906
907fn register_async_helpers(lua: &Lua) -> mlua::Result<()> {
909 let async_code = r#"
911 local async = {}
912
913 -- Create a task from a function (wraps in coroutine)
914 function async.task(fn)
915 return {
916 _co = coroutine.create(fn),
917 _completed = false,
918 _result = nil,
919 }
920 end
921
922 -- Run a task to completion
923 function async.await(task)
924 if task._completed then
925 return task._result
926 end
927 while coroutine.status(task._co) ~= "dead" do
928 local ok, result = coroutine.resume(task._co)
929 if not ok then
930 error(result)
931 end
932 task._result = result
933 end
934 task._completed = true
935 return task._result
936 end
937
938 -- Yield from current task (for cooperative multitasking)
939 function async.yield(value)
940 return coroutine.yield(value)
941 end
942
943 -- Run multiple tasks concurrently (interleaved execution)
944 function async.all(tasks)
945 local results = {}
946 local pending = {}
947
948 for i, task in ipairs(tasks) do
949 pending[i] = task
950 results[i] = nil
951 end
952
953 -- Round-robin execution until all complete
954 local any_pending = true
955 while any_pending do
956 any_pending = false
957 for i, task in ipairs(pending) do
958 if task and coroutine.status(task._co) ~= "dead" then
959 any_pending = true
960 local ok, result = coroutine.resume(task._co)
961 if not ok then
962 error(result)
963 end
964 task._result = result
965 elseif task then
966 results[i] = task._result
967 task._completed = true
968 pending[i] = nil
969 end
970 end
971 end
972
973 return results
974 end
975
976 -- Run tasks and return first completed result
977 function async.race(tasks)
978 while true do
979 for i, task in ipairs(tasks) do
980 if coroutine.status(task._co) ~= "dead" then
981 local ok, result = coroutine.resume(task._co)
982 if not ok then
983 error(result)
984 end
985 if coroutine.status(task._co) == "dead" then
986 task._result = result
987 task._completed = true
988 return result, i
989 end
990 end
991 end
992 end
993 end
994
995 -- Sleep/delay (yields N times for cooperative scheduling)
996 function async.sleep(n)
997 for _ = 1, (n or 1) do
998 coroutine.yield()
999 end
1000 end
1001
1002 return async
1003 "#;
1004
1005 let async_module: Table = lua.load(async_code).eval()?;
1006 lua.globals().set("async", async_module)?;
1007
1008 Ok(())
1009}
1010
1011fn register_parallel_functions(lua: &Lua, project_root: &Path, sandbox: bool) -> mlua::Result<()> {
1013 let parallel = lua.create_table()?;
1014 let root = project_root.to_path_buf();
1015
1016 let root_clone = root.clone();
1018 let load_json_parallel = lua.create_function(move |lua, paths: Table| {
1019 use rayon::prelude::*;
1020
1021 let path_list: Vec<String> = paths
1023 .sequence_values::<String>()
1024 .filter_map(|r| r.ok())
1025 .collect();
1026
1027 let results: Vec<Option<serde_json::Value>> = path_list
1029 .par_iter()
1030 .map(|path| {
1031 let resolved = resolve_path(path, &root_clone);
1032 if sandbox && !is_path_within_root(&resolved, &root_clone) {
1033 return None;
1034 }
1035 std::fs::read_to_string(&resolved)
1036 .ok()
1037 .and_then(|content| serde_json::from_str(&content).ok())
1038 })
1039 .collect();
1040
1041 let result_table = lua.create_table()?;
1043 for (i, result) in results.into_iter().enumerate() {
1044 match result {
1045 Some(v) => result_table.set(i + 1, lua.to_value(&v)?)?,
1046 None => result_table.set(i + 1, Value::Nil)?,
1047 }
1048 }
1049 Ok(result_table)
1050 })?;
1051 parallel.set("load_json", load_json_parallel)?;
1052
1053 let root_clone = root.clone();
1055 let read_files_parallel = lua.create_function(move |lua, paths: Table| {
1056 use rayon::prelude::*;
1057
1058 let path_list: Vec<String> = paths
1059 .sequence_values::<String>()
1060 .filter_map(|r| r.ok())
1061 .collect();
1062
1063 let results: Vec<Option<String>> = path_list
1064 .par_iter()
1065 .map(|path| {
1066 let resolved = resolve_path(path, &root_clone);
1067 if sandbox && !is_path_within_root(&resolved, &root_clone) {
1068 return None;
1069 }
1070 std::fs::read_to_string(&resolved).ok()
1071 })
1072 .collect();
1073
1074 let result_table = lua.create_table()?;
1075 for (i, result) in results.into_iter().enumerate() {
1076 match result {
1077 Some(content) => result_table.set(i + 1, lua.create_string(&content)?)?,
1078 None => result_table.set(i + 1, Value::Nil)?,
1079 }
1080 }
1081 Ok(result_table)
1082 })?;
1083 parallel.set("read_files", read_files_parallel)?;
1084
1085 let root_clone = root.clone();
1087 let file_exists_parallel = lua.create_function(move |lua, paths: Table| {
1088 use rayon::prelude::*;
1089
1090 let path_list: Vec<String> = paths
1091 .sequence_values::<String>()
1092 .filter_map(|r| r.ok())
1093 .collect();
1094
1095 let results: Vec<bool> = path_list
1096 .par_iter()
1097 .map(|path| {
1098 let resolved = resolve_path(path, &root_clone);
1099 if sandbox && !is_path_within_root(&resolved, &root_clone) {
1100 return false;
1101 }
1102 resolved.exists()
1103 })
1104 .collect();
1105
1106 let result_table = lua.create_table()?;
1107 for (i, exists) in results.into_iter().enumerate() {
1108 result_table.set(i + 1, exists)?;
1109 }
1110 Ok(result_table)
1111 })?;
1112 parallel.set("file_exists", file_exists_parallel)?;
1113
1114 let map_fn = lua.create_function(|lua, (items, func): (Table, Function)| {
1116 let result_table = lua.create_table()?;
1117 let mut i = 1;
1118 for v in items.sequence_values::<Value>().flatten() {
1119 let res: Value = func.call(v)?;
1120 result_table.set(i, res)?;
1121 i += 1;
1122 }
1123 Ok(result_table)
1124 })?;
1125 parallel.set("map", map_fn)?;
1126
1127 let filter_fn = lua.create_function(|lua, (items, func): (Table, Function)| {
1129 let result_table = lua.create_table()?;
1130 let mut i = 1;
1131 for v in items.sequence_values::<Value>().flatten() {
1132 let keep: bool = func.call(v.clone())?;
1133 if keep {
1134 result_table.set(i, v)?;
1135 i += 1;
1136 }
1137 }
1138 Ok(result_table)
1139 })?;
1140 parallel.set("filter", filter_fn)?;
1141
1142 let reduce_fn =
1144 lua.create_function(|_, (items, initial, func): (Table, Value, Function)| {
1145 let mut acc = initial;
1146 for v in items.sequence_values::<Value>().flatten() {
1147 acc = func.call((acc, v))?;
1148 }
1149 Ok(acc)
1150 })?;
1151 parallel.set("reduce", reduce_fn)?;
1152
1153 lua.globals().set("parallel", parallel)?;
1154 Ok(())
1155}
1156
1157fn extract_functions(
1159 lua: &Lua,
1160 config_table: &Table,
1161 key: &str,
1162) -> mlua::Result<HashMap<String, mlua::RegistryKey>> {
1163 let mut functions = HashMap::new();
1164
1165 if let Ok(table) = config_table.get::<Table>(key) {
1166 for pair in table.pairs::<String, Function>() {
1167 let (name, func) = pair?;
1168 let registry_key = lua.create_registry_value(func)?;
1169 functions.insert(name, registry_key);
1170 }
1171 }
1172
1173 Ok(functions)
1174}
1175
1176fn parse_config(
1178 lua: &Lua,
1179 table: &Table,
1180 sort_fns: &mut HashMap<String, mlua::RegistryKey>,
1181) -> mlua::Result<ConfigData> {
1182 let site = parse_site_config(table)?;
1183 let seo = parse_seo_config(table)?;
1184 let build = parse_build_config(table)?;
1185 let images = parse_images_config(table)?;
1186 let highlight = parse_highlight_config(table)?;
1187 let paths = parse_paths_config(table)?;
1188 let templates = parse_templates_config(table)?;
1189 let permalinks = parse_permalinks_config(table)?;
1190 let encryption = parse_encryption_config(table)?;
1191 let graph = parse_graph_config(table)?;
1192 let rss = parse_rss_config(table)?;
1193 let text = parse_text_config(table)?;
1194 let sections = parse_sections_config(lua, table, sort_fns)?;
1195
1196 Ok(ConfigData {
1197 site,
1198 seo,
1199 build,
1200 images,
1201 highlight,
1202 paths,
1203 templates,
1204 permalinks,
1205 encryption,
1206 graph,
1207 rss,
1208 text,
1209 sections,
1210 })
1211}
1212
1213fn parse_site_config(table: &Table) -> mlua::Result<SiteConfig> {
1214 let site: Table = table.get("site")?;
1215
1216 Ok(SiteConfig {
1217 title: site.get("title").unwrap_or_default(),
1218 description: site.get("description").unwrap_or_default(),
1219 base_url: site.get("base_url").unwrap_or_default(),
1220 author: site.get("author").unwrap_or_default(),
1221 })
1222}
1223
1224fn parse_seo_config(table: &Table) -> mlua::Result<SeoConfig> {
1225 let seo: Table = table.get("seo").unwrap_or_else(|_| table.clone());
1226
1227 Ok(SeoConfig {
1228 twitter_handle: seo.get("twitter_handle").ok(),
1229 default_og_image: seo.get("default_og_image").ok(),
1230 })
1231}
1232
1233fn parse_build_config(table: &Table) -> mlua::Result<BuildConfig> {
1234 let build: Table = table.get("build").unwrap_or_else(|_| table.clone());
1235
1236 Ok(BuildConfig {
1237 output_dir: build
1238 .get("output_dir")
1239 .unwrap_or_else(|_| "dist".to_string()),
1240 minify_css: build.get("minify_css").unwrap_or(true),
1241 css_output: build
1242 .get("css_output")
1243 .unwrap_or_else(|_| "rs.css".to_string()),
1244 })
1245}
1246
1247fn parse_images_config(table: &Table) -> mlua::Result<ImagesConfig> {
1248 let images: Table = table.get("images").unwrap_or_else(|_| table.clone());
1249
1250 Ok(ImagesConfig {
1251 quality: images.get("quality").unwrap_or(85.0),
1252 scale_factor: images.get("scale_factor").unwrap_or(1.0),
1253 })
1254}
1255
1256fn parse_highlight_config(table: &Table) -> mlua::Result<HighlightConfig> {
1257 let highlight: Table = table.get("highlight").unwrap_or_else(|_| table.clone());
1258
1259 let names: Vec<String> = highlight
1260 .get::<Table>("names")
1261 .map(|t| {
1262 t.sequence_values::<String>()
1263 .filter_map(|r| r.ok())
1264 .collect()
1265 })
1266 .unwrap_or_default();
1267
1268 Ok(HighlightConfig {
1269 names,
1270 class: highlight.get("class").unwrap_or_else(|_| "me".to_string()),
1271 })
1272}
1273
1274fn parse_paths_config(table: &Table) -> mlua::Result<PathsConfig> {
1275 let paths: Table = table.get("paths").unwrap_or_else(|_| table.clone());
1276
1277 let exclude: Vec<String> = paths
1278 .get::<Table>("exclude")
1279 .map(|t| {
1280 t.sequence_values::<String>()
1281 .filter_map(|r| r.ok())
1282 .collect()
1283 })
1284 .unwrap_or_default();
1285
1286 Ok(PathsConfig {
1287 content: paths
1288 .get("content")
1289 .unwrap_or_else(|_| "content".to_string()),
1290 styles: paths.get("styles").unwrap_or_else(|_| "styles".to_string()),
1291 static_files: paths
1292 .get("static_files")
1293 .unwrap_or_else(|_| "static".to_string()),
1294 templates: paths
1295 .get("templates")
1296 .unwrap_or_else(|_| "templates".to_string()),
1297 home: paths.get("home").unwrap_or_else(|_| "index.md".to_string()),
1298 exclude,
1299 exclude_defaults: paths.get("exclude_defaults").unwrap_or(true),
1300 respect_gitignore: paths.get("respect_gitignore").unwrap_or(true),
1301 })
1302}
1303
1304fn parse_templates_config(table: &Table) -> mlua::Result<TemplatesConfig> {
1305 let mut sections = HashMap::new();
1306
1307 if let Ok(templates) = table.get::<Table>("templates") {
1308 for (k, v) in templates.pairs::<String, String>().flatten() {
1309 sections.insert(k, v);
1310 }
1311 }
1312
1313 Ok(TemplatesConfig { sections })
1314}
1315
1316fn parse_permalinks_config(table: &Table) -> mlua::Result<PermalinksConfig> {
1317 let mut sections = HashMap::new();
1318
1319 if let Ok(permalinks) = table.get::<Table>("permalinks") {
1320 for (k, v) in permalinks.pairs::<String, String>().flatten() {
1321 sections.insert(k, v);
1322 }
1323 }
1324
1325 Ok(PermalinksConfig { sections })
1326}
1327
1328fn parse_encryption_config(table: &Table) -> mlua::Result<EncryptionConfig> {
1329 let encryption: Table = table.get("encryption").unwrap_or_else(|_| table.clone());
1330
1331 Ok(EncryptionConfig {
1332 password_command: encryption.get("password_command").ok(),
1333 password: encryption.get("password").ok(),
1334 })
1335}
1336
1337fn parse_graph_config(table: &Table) -> mlua::Result<GraphConfig> {
1338 let graph: Table = table.get("graph").unwrap_or_else(|_| table.clone());
1339
1340 Ok(GraphConfig {
1341 enabled: graph.get("enabled").unwrap_or(true),
1342 template: graph
1343 .get("template")
1344 .unwrap_or_else(|_| "graph.html".to_string()),
1345 path: graph.get("path").unwrap_or_else(|_| "graph".to_string()),
1346 })
1347}
1348
1349fn parse_rss_config(table: &Table) -> mlua::Result<RssConfig> {
1350 let rss: Table = table.get("rss").unwrap_or_else(|_| table.clone());
1351
1352 let sections: Vec<String> = rss
1353 .get::<Table>("sections")
1354 .map(|t| {
1355 t.sequence_values::<String>()
1356 .filter_map(|r| r.ok())
1357 .collect()
1358 })
1359 .unwrap_or_default();
1360
1361 Ok(RssConfig {
1362 enabled: rss.get("enabled").unwrap_or(true),
1363 filename: rss
1364 .get("filename")
1365 .unwrap_or_else(|_| "rss.xml".to_string()),
1366 sections,
1367 limit: rss.get("limit").unwrap_or(20),
1368 exclude_encrypted_blocks: rss.get("exclude_encrypted_blocks").unwrap_or(false),
1369 })
1370}
1371
1372fn parse_text_config(table: &Table) -> mlua::Result<TextConfig> {
1373 let text: Table = table.get("text").unwrap_or_else(|_| table.clone());
1374
1375 let sections: Vec<String> = text
1376 .get::<Table>("sections")
1377 .map(|t| {
1378 t.sequence_values::<String>()
1379 .filter_map(|r| r.ok())
1380 .collect()
1381 })
1382 .unwrap_or_default();
1383
1384 Ok(TextConfig {
1385 enabled: text.get("enabled").unwrap_or(false),
1386 sections,
1387 exclude_encrypted: text.get("exclude_encrypted").unwrap_or(false),
1388 include_home: text.get("include_home").unwrap_or(true),
1389 })
1390}
1391
1392fn parse_sections_config(
1393 lua: &Lua,
1394 table: &Table,
1395 sort_fns: &mut HashMap<String, mlua::RegistryKey>,
1396) -> mlua::Result<SectionsConfig> {
1397 let mut sections = HashMap::new();
1398
1399 if let Ok(sections_table) = table.get::<Table>("sections") {
1400 for (name, section_table) in sections_table.pairs::<String, Table>().flatten() {
1401 let iterate = section_table
1402 .get("iterate")
1403 .unwrap_or_else(|_| "files".to_string());
1404
1405 if let Ok(func) = section_table.get::<mlua::Function>("sort_by") {
1407 let key = lua.create_registry_value(func)?;
1408 sort_fns.insert(name.clone(), key);
1409 }
1410
1411 sections.insert(name, SectionConfig { iterate });
1412 }
1413 }
1414
1415 Ok(SectionsConfig { sections })
1416}
1417
1418#[cfg(test)]
1419mod tests {
1420 use super::*;
1421
1422 fn test_project_root() -> PathBuf {
1423 std::env::current_dir().unwrap()
1424 }
1425
1426 #[test]
1427 fn test_minimal_lua_config() {
1428 let lua = Lua::new();
1429 let root = test_project_root();
1430 register_lua_functions(&lua, &root, false).unwrap();
1431
1432 let config_str = r#"
1433 return {
1434 site = {
1435 title = "Test Site",
1436 description = "A test site",
1437 base_url = "https://example.com",
1438 author = "Test Author",
1439 },
1440 build = {
1441 output_dir = "dist",
1442 },
1443 }
1444 "#;
1445
1446 let table: Table = lua.load(config_str).eval().unwrap();
1447 let mut sort_fns = HashMap::new();
1448 let config = parse_config(&lua, &table, &mut sort_fns).unwrap();
1449
1450 assert_eq!(config.site.title, "Test Site");
1451 assert_eq!(config.site.base_url, "https://example.com");
1452 assert_eq!(config.build.output_dir, "dist");
1453 }
1454
1455 #[test]
1456 fn test_lua_config_with_sections() {
1457 let lua = Lua::new();
1458 let root = test_project_root();
1459 register_lua_functions(&lua, &root, false).unwrap();
1460
1461 let config_str = r#"
1462 return {
1463 site = {
1464 title = "Test",
1465 description = "Test",
1466 base_url = "https://example.com",
1467 author = "Test",
1468 },
1469 build = { output_dir = "dist" },
1470 sections = {
1471 problems = { iterate = "directories" },
1472 blog = { iterate = "files" },
1473 },
1474 }
1475 "#;
1476
1477 let table: Table = lua.load(config_str).eval().unwrap();
1478 let mut sort_fns = HashMap::new();
1479 let config = parse_config(&lua, &table, &mut sort_fns).unwrap();
1480
1481 let problems = config.sections.sections.get("problems");
1482 assert!(problems.is_some());
1483 assert_eq!(problems.unwrap().iterate, "directories");
1484
1485 let blog = config.sections.sections.get("blog");
1486 assert!(blog.is_some());
1487 assert_eq!(blog.unwrap().iterate, "files");
1488 }
1489
1490 #[test]
1491 fn test_lua_helper_functions() {
1492 let lua = Lua::new();
1493 let root = test_project_root();
1494 register_lua_functions(&lua, &root, false).unwrap();
1495
1496 let result: bool = lua.load("return file_exists('Cargo.toml')").eval().unwrap();
1498 assert!(result);
1499
1500 let result: bool = lua
1501 .load("return file_exists('nonexistent.file')")
1502 .eval()
1503 .unwrap();
1504 assert!(!result);
1505 }
1506
1507 #[test]
1508 fn test_sandbox_blocks_outside_access() {
1509 let lua = Lua::new();
1510 let root = test_project_root();
1511 register_lua_functions(&lua, &root, true).unwrap();
1512
1513 let result = lua.load("return read_file('/etc/passwd')").eval::<Value>();
1515 assert!(result.is_err());
1516
1517 let result = lua.load("return read_file('../some_file')").eval::<Value>();
1519 assert!(result.is_err());
1520 }
1521
1522 #[test]
1523 fn test_sandbox_allows_project_access() {
1524 let lua = Lua::new();
1525 let root = test_project_root();
1526 register_lua_functions(&lua, &root, true).unwrap();
1527
1528 let result: bool = lua.load("return file_exists('Cargo.toml')").eval().unwrap();
1530 assert!(result);
1531
1532 let result = lua.load("return read_file('Cargo.toml')").eval::<Value>();
1534 assert!(result.is_ok());
1535 }
1536}