lust/modules/
embedded.rs

1use alloc::{
2    format,
3    string::{String, ToString},
4    vec::Vec,
5};
6use hashbrown::{HashMap, HashSet};
7#[cfg(feature = "std")]
8use std::path::PathBuf;
9
10use crate::{
11    ast::{ItemKind, UseTree},
12    lexer::Lexer,
13    modules::{LoadedModule, ModuleExports, ModuleImports, Program},
14    parser::Parser,
15    LustError, Result,
16};
17
18#[derive(Debug, Clone)]
19pub struct EmbeddedModule<'a> {
20    pub module: &'a str,
21    pub parent: Option<&'a str>,
22    pub source: Option<&'a str>,
23}
24
25pub fn build_directory_map(entries: &[EmbeddedModule<'_>]) -> HashMap<String, Vec<String>> {
26    let mut map: HashMap<String, Vec<String>> = HashMap::new();
27    for entry in entries {
28        let parent = entry.parent.unwrap_or("");
29        map.entry(parent.to_string())
30            .or_default()
31            .push(entry.module.to_string());
32    }
33
34    for children in map.values_mut() {
35        children.sort();
36    }
37    map
38}
39
40pub fn load_program_from_embedded(
41    entries: &[EmbeddedModule<'_>],
42    entry_module: &str,
43) -> Result<Program> {
44    let mut module_names: HashSet<String> = entries.iter().map(|e| e.module.to_string()).collect();
45
46    let mut registry: HashMap<String, LoadedModule> = HashMap::new();
47    for entry in entries {
48        if let Some(source) = entry.source {
49            let module = parse_module(entry.module, source)?;
50            registry.insert(entry.module.to_string(), module);
51        } else {
52            module_names.insert(entry.module.to_string());
53        }
54    }
55
56    let dependency_map = build_dependency_map(&registry, &module_names);
57    let mut ordered = Vec::new();
58    let mut visited = HashSet::new();
59    let mut stack = HashSet::new();
60
61    for module in registry.keys().cloned().collect::<Vec<_>>() {
62        visit_dependencies(
63            &module,
64            &dependency_map,
65            &mut visited,
66            &mut stack,
67            &mut ordered,
68        )?;
69    }
70
71    for module in ordered {
72        finalize_module(&module_names, &mut registry, &module)?;
73    }
74
75    let mut modules: Vec<LoadedModule> = registry.into_values().collect();
76    modules.sort_by(|a, b| a.path.cmp(&b.path));
77
78    Ok(Program {
79        modules,
80        entry_module: entry_module.to_string(),
81    })
82}
83
84fn parse_module(module: &str, source: &str) -> Result<LoadedModule> {
85    let mut lexer = Lexer::new(source);
86    let tokens = lexer.tokenize()?;
87    let mut parser = Parser::new(tokens);
88    let items = parser.parse()?;
89
90    Ok(LoadedModule {
91        path: module.to_string(),
92        items,
93        imports: ModuleImports::default(),
94        exports: ModuleExports::default(),
95        init_function: None,
96        #[cfg(feature = "std")]
97        source_path: PathBuf::new(),
98    })
99}
100
101fn build_dependency_map(
102    modules: &HashMap<String, LoadedModule>,
103    module_names: &HashSet<String>,
104) -> HashMap<String, Vec<String>> {
105    let mut deps = HashMap::new();
106    for (name, module) in modules {
107        let collected = collect_dependencies(module, module_names);
108        deps.insert(name.clone(), collected);
109    }
110
111    deps
112}
113
114fn collect_dependencies(module: &LoadedModule, module_names: &HashSet<String>) -> Vec<String> {
115    let mut deps = HashSet::new();
116    for item in &module.items {
117        match &item.kind {
118            ItemKind::Use { public: _, tree } => {
119                collect_deps_from_use(tree, module_names, &mut deps);
120            }
121            ItemKind::Script(stmts) => {
122                for stmt in stmts {
123                    collect_deps_from_lua_require_stmt(stmt, &mut deps);
124                }
125            }
126            ItemKind::Function(func) => {
127                for stmt in &func.body {
128                    collect_deps_from_lua_require_stmt(stmt, &mut deps);
129                }
130            }
131            ItemKind::Const { value, .. } | ItemKind::Static { value, .. } => {
132                collect_deps_from_lua_require_expr(value, &mut deps);
133            }
134            ItemKind::Impl(impl_block) => {
135                for method in &impl_block.methods {
136                    for stmt in &method.body {
137                        collect_deps_from_lua_require_stmt(stmt, &mut deps);
138                    }
139                }
140            }
141            ItemKind::Trait(trait_def) => {
142                for method in &trait_def.methods {
143                    if let Some(default_impl) = &method.default_impl {
144                        for stmt in default_impl {
145                            collect_deps_from_lua_require_stmt(stmt, &mut deps);
146                        }
147                    }
148                }
149            }
150            _ => {}
151        }
152    }
153
154    deps.into_iter().collect()
155}
156
157fn collect_deps_from_lua_require_stmt(stmt: &crate::ast::Stmt, deps: &mut HashSet<String>) {
158    use crate::ast::StmtKind;
159    match &stmt.kind {
160        StmtKind::Local { initializer, .. } => {
161            if let Some(values) = initializer {
162                for expr in values {
163                    collect_deps_from_lua_require_expr(expr, deps);
164                }
165            }
166        }
167        StmtKind::Assign { targets, values } => {
168            for expr in targets {
169                collect_deps_from_lua_require_expr(expr, deps);
170            }
171            for expr in values {
172                collect_deps_from_lua_require_expr(expr, deps);
173            }
174        }
175        StmtKind::CompoundAssign { target, value, .. } => {
176            collect_deps_from_lua_require_expr(target, deps);
177            collect_deps_from_lua_require_expr(value, deps);
178        }
179        StmtKind::Expr(expr) => collect_deps_from_lua_require_expr(expr, deps),
180        StmtKind::If {
181            condition,
182            then_block,
183            elseif_branches,
184            else_block,
185        } => {
186            collect_deps_from_lua_require_expr(condition, deps);
187            for stmt in then_block {
188                collect_deps_from_lua_require_stmt(stmt, deps);
189            }
190            for (cond, block) in elseif_branches {
191                collect_deps_from_lua_require_expr(cond, deps);
192                for stmt in block {
193                    collect_deps_from_lua_require_stmt(stmt, deps);
194                }
195            }
196            if let Some(block) = else_block {
197                for stmt in block {
198                    collect_deps_from_lua_require_stmt(stmt, deps);
199                }
200            }
201        }
202        StmtKind::While { condition, body } => {
203            collect_deps_from_lua_require_expr(condition, deps);
204            for stmt in body {
205                collect_deps_from_lua_require_stmt(stmt, deps);
206            }
207        }
208        StmtKind::ForNumeric {
209            start, end, step, body, ..
210        } => {
211            collect_deps_from_lua_require_expr(start, deps);
212            collect_deps_from_lua_require_expr(end, deps);
213            if let Some(step) = step {
214                collect_deps_from_lua_require_expr(step, deps);
215            }
216            for stmt in body {
217                collect_deps_from_lua_require_stmt(stmt, deps);
218            }
219        }
220        StmtKind::ForIn { iterator, body, .. } => {
221            collect_deps_from_lua_require_expr(iterator, deps);
222            for stmt in body {
223                collect_deps_from_lua_require_stmt(stmt, deps);
224            }
225        }
226        StmtKind::Return(values) => {
227            for expr in values {
228                collect_deps_from_lua_require_expr(expr, deps);
229            }
230        }
231        StmtKind::Block(stmts) => {
232            for stmt in stmts {
233                collect_deps_from_lua_require_stmt(stmt, deps);
234            }
235        }
236        StmtKind::Break | StmtKind::Continue => {}
237    }
238}
239
240fn collect_deps_from_lua_require_expr(expr: &crate::ast::Expr, deps: &mut HashSet<String>) {
241    use crate::ast::{ExprKind, Literal};
242    match &expr.kind {
243        ExprKind::Call { callee, args } => {
244            if is_lua_require_callee(callee) {
245                if let Some(name) = args.get(0).and_then(extract_lua_require_name) {
246                    if !is_lua_builtin_module_name(&name) {
247                        deps.insert(name);
248                    }
249                }
250            }
251            collect_deps_from_lua_require_expr(callee, deps);
252            for arg in args {
253                collect_deps_from_lua_require_expr(arg, deps);
254            }
255        }
256        ExprKind::MethodCall { receiver, args, .. } => {
257            collect_deps_from_lua_require_expr(receiver, deps);
258            for arg in args {
259                collect_deps_from_lua_require_expr(arg, deps);
260            }
261        }
262        ExprKind::Binary { left, right, .. } => {
263            collect_deps_from_lua_require_expr(left, deps);
264            collect_deps_from_lua_require_expr(right, deps);
265        }
266        ExprKind::Unary { operand, .. } => collect_deps_from_lua_require_expr(operand, deps),
267        ExprKind::FieldAccess { object, .. } => collect_deps_from_lua_require_expr(object, deps),
268        ExprKind::Index { object, index } => {
269            collect_deps_from_lua_require_expr(object, deps);
270            collect_deps_from_lua_require_expr(index, deps);
271        }
272        ExprKind::Array(elements) | ExprKind::Tuple(elements) => {
273            for element in elements {
274                collect_deps_from_lua_require_expr(element, deps);
275            }
276        }
277        ExprKind::Map(entries) => {
278            for (k, v) in entries {
279                collect_deps_from_lua_require_expr(k, deps);
280                collect_deps_from_lua_require_expr(v, deps);
281            }
282        }
283        ExprKind::StructLiteral { fields, .. } => {
284            for field in fields {
285                collect_deps_from_lua_require_expr(&field.value, deps);
286            }
287        }
288        ExprKind::EnumConstructor { args, .. } => {
289            for arg in args {
290                collect_deps_from_lua_require_expr(arg, deps);
291            }
292        }
293        ExprKind::Lambda { body, .. } => collect_deps_from_lua_require_expr(body, deps),
294        ExprKind::Paren(inner) => collect_deps_from_lua_require_expr(inner, deps),
295        ExprKind::Cast { expr, .. } => collect_deps_from_lua_require_expr(expr, deps),
296        ExprKind::TypeCheck { expr, .. } => collect_deps_from_lua_require_expr(expr, deps),
297        ExprKind::IsPattern { expr, .. } => collect_deps_from_lua_require_expr(expr, deps),
298        ExprKind::If {
299            condition,
300            then_branch,
301            else_branch,
302        } => {
303            collect_deps_from_lua_require_expr(condition, deps);
304            collect_deps_from_lua_require_expr(then_branch, deps);
305            if let Some(other) = else_branch {
306                collect_deps_from_lua_require_expr(other, deps);
307            }
308        }
309        ExprKind::Block(stmts) => {
310            for stmt in stmts {
311                collect_deps_from_lua_require_stmt(stmt, deps);
312            }
313        }
314        ExprKind::Return(values) => {
315            for value in values {
316                collect_deps_from_lua_require_expr(value, deps);
317            }
318        }
319        ExprKind::Range { start, end, .. } => {
320            collect_deps_from_lua_require_expr(start, deps);
321            collect_deps_from_lua_require_expr(end, deps);
322        }
323        ExprKind::Literal(Literal::String(_))
324        | ExprKind::Literal(_)
325        | ExprKind::Identifier(_) => {}
326    }
327}
328
329fn is_lua_builtin_module_name(name: &str) -> bool {
330    matches!(
331        name,
332        "math"
333            | "table"
334            | "string"
335            | "io"
336            | "os"
337            | "package"
338            | "coroutine"
339            | "debug"
340            | "utf8"
341    )
342}
343
344fn is_lua_require_callee(callee: &crate::ast::Expr) -> bool {
345    use crate::ast::ExprKind;
346    match &callee.kind {
347        ExprKind::Identifier(name) => name == "require",
348        ExprKind::FieldAccess { object, field } => {
349            field == "require" && matches!(&object.kind, ExprKind::Identifier(name) if name == "lua")
350        }
351        _ => false,
352    }
353}
354
355fn extract_lua_require_name(expr: &crate::ast::Expr) -> Option<String> {
356    use crate::ast::{ExprKind, Literal};
357    match &expr.kind {
358        ExprKind::Literal(Literal::String(s)) => Some(s.clone()),
359        ExprKind::Call { callee, args } if is_lua_to_value_callee(callee) => args
360            .get(0)
361            .and_then(|arg| match &arg.kind {
362                ExprKind::Literal(Literal::String(s)) => Some(s.clone()),
363                _ => None,
364            }),
365        _ => None,
366    }
367}
368
369fn is_lua_to_value_callee(callee: &crate::ast::Expr) -> bool {
370    use crate::ast::ExprKind;
371    matches!(
372        &callee.kind,
373        ExprKind::FieldAccess { object, field }
374            if field == "to_value" && matches!(&object.kind, ExprKind::Identifier(name) if name == "lua")
375    )
376}
377
378fn collect_deps_from_use(
379    tree: &UseTree,
380    module_names: &HashSet<String>,
381    deps: &mut HashSet<String>,
382) {
383    match tree {
384        UseTree::Path { path, .. } => {
385            let full = path.join(".");
386            if module_names.contains(&full) {
387                deps.insert(full);
388            } else if path.len() > 1 {
389                deps.insert(path[..path.len() - 1].join("."));
390            }
391        }
392        UseTree::Group { prefix, items } => {
393            let module = prefix.join(".");
394            if !module.is_empty() {
395                deps.insert(module);
396            }
397
398            for item in items {
399                if item.path.len() > 1 {
400                    let mut combined = prefix.clone();
401                    combined.extend(item.path[..item.path.len() - 1].iter().cloned());
402                    let module_path = combined.join(".");
403                    if !module_path.is_empty() {
404                        deps.insert(module_path);
405                    }
406                }
407            }
408        }
409        UseTree::Glob { prefix } => {
410            deps.insert(prefix.join("."));
411        }
412    }
413}
414
415fn visit_dependencies(
416    module: &str,
417    deps: &HashMap<String, Vec<String>>,
418    visited: &mut HashSet<String>,
419    stack: &mut HashSet<String>,
420    ordered: &mut Vec<String>,
421) -> Result<()> {
422    if visited.contains(module) {
423        return Ok(());
424    }
425
426    if !stack.insert(module.to_string()) {
427        return Err(LustError::Unknown(format!(
428            "Cyclic dependency detected while loading module '{}'",
429            module
430        )));
431    }
432
433    if let Some(list) = deps.get(module) {
434        for dep in list {
435            visit_dependencies(dep, deps, visited, stack, ordered)?;
436        }
437    }
438
439    stack.remove(module);
440    visited.insert(module.to_string());
441    ordered.push(module.to_string());
442    Ok(())
443}
444
445fn finalize_module(
446    module_names: &HashSet<String>,
447    registry: &mut HashMap<String, LoadedModule>,
448    module_name: &str,
449) -> Result<()> {
450    let mut module = registry
451        .remove(module_name)
452        .ok_or_else(|| LustError::Unknown(format!("Unknown module '{}'", module_name)))?;
453
454    let registry_ref = ModuleRegistryView { modules: registry };
455    for item in &module.items {
456        if let ItemKind::Use { tree, .. } = &item.kind {
457            process_use_tree(&registry_ref, module_names, tree, &mut module.imports)?;
458        }
459    }
460
461    for item in &module.items {
462        if let ItemKind::Use { public: true, tree } = &item.kind {
463            apply_reexport(&registry_ref, module_names, tree, &mut module.exports)?;
464        }
465    }
466
467    let tail = simple_tail(module_name);
468    module
469        .imports
470        .module_aliases
471        .entry(tail.to_string())
472        .or_insert_with(|| module_name.to_string());
473
474    registry.insert(module_name.to_string(), module);
475    Ok(())
476}
477
478struct ModuleRegistryView<'a> {
479    modules: &'a HashMap<String, LoadedModule>,
480}
481
482impl<'a> ModuleRegistryView<'a> {
483    fn get(&self, name: &str) -> Option<&'a LoadedModule> {
484        self.modules.get(name)
485    }
486}
487
488fn process_use_tree(
489    registry: &ModuleRegistryView<'_>,
490    module_names: &HashSet<String>,
491    tree: &UseTree,
492    imports: &mut ModuleImports,
493) -> Result<()> {
494    match tree {
495        UseTree::Path { path, alias, .. } => {
496            let full = path.join(".");
497            if module_names.contains(&full) {
498                let alias_name = alias
499                    .clone()
500                    .unwrap_or_else(|| path.last().unwrap().clone());
501                imports.module_aliases.insert(alias_name, full);
502            } else if path.len() > 1 {
503                let module = path[..path.len() - 1].join(".");
504                let item = path.last().unwrap().clone();
505                let alias_name = alias.clone().unwrap_or_else(|| item.clone());
506                let classification = classify_import_target(registry, &module, &item);
507                let fq = format!("{}.{}", module, item);
508                if classification.import_value {
509                    imports
510                        .function_aliases
511                        .insert(alias_name.clone(), fq.clone());
512                }
513
514                if classification.import_type {
515                    imports.type_aliases.insert(alias_name, fq);
516                }
517            }
518        }
519        UseTree::Group { prefix, items } => {
520            for item in items {
521                if item.path.is_empty() {
522                    continue;
523                }
524
525                let alias_name = item
526                    .alias
527                    .clone()
528                    .unwrap_or_else(|| item.path.last().unwrap().clone());
529                let mut full_segments = prefix.clone();
530                full_segments.extend(item.path.clone());
531                let full = full_segments.join(".");
532                if module_names.contains(&full) {
533                    imports.module_aliases.insert(alias_name, full);
534                    continue;
535                }
536
537                let mut module_segments = full_segments.clone();
538                let item_name = module_segments.pop().unwrap();
539                let module_path = module_segments.join(".");
540                let fq_name = if module_path.is_empty() {
541                    item_name.clone()
542                } else {
543                    format!("{}.{}", module_path, item_name)
544                };
545                let classification = classify_import_target(registry, &module_path, &item_name);
546                if classification.import_value {
547                    imports
548                        .function_aliases
549                        .insert(alias_name.clone(), fq_name.clone());
550                }
551
552                if classification.import_type {
553                    imports.type_aliases.insert(alias_name, fq_name);
554                }
555            }
556        }
557        UseTree::Glob { prefix } => {
558            let module = prefix.join(".");
559            if let Some(loaded) = registry.get(&module) {
560                for (name, fq) in &loaded.exports.functions {
561                    imports.function_aliases.insert(name.clone(), fq.clone());
562                }
563
564                for (name, fq) in &loaded.exports.types {
565                    imports.type_aliases.insert(name.clone(), fq.clone());
566                }
567            }
568
569            if !module.is_empty() {
570                let alias_name = prefix.last().cloned().unwrap_or_else(|| module.clone());
571                imports.module_aliases.insert(alias_name, module);
572            }
573        }
574    }
575
576    Ok(())
577}
578
579fn apply_reexport(
580    registry: &ModuleRegistryView<'_>,
581    module_names: &HashSet<String>,
582    tree: &UseTree,
583    exports: &mut ModuleExports,
584) -> Result<()> {
585    match tree {
586        UseTree::Path { path, alias, .. } => {
587            if path.len() == 1 {
588                return Ok(());
589            }
590
591            let module = path[..path.len() - 1].join(".");
592            let item = path.last().unwrap().clone();
593            let alias_name = alias.clone().unwrap_or_else(|| item.clone());
594            let fq = format!("{}.{}", module, item);
595            let classification = classify_import_target(registry, &module, &item);
596            if classification.import_type {
597                exports.types.insert(alias_name.clone(), fq.clone());
598            }
599
600            if classification.import_value {
601                exports.functions.insert(alias_name, fq);
602            }
603
604            Ok(())
605        }
606        UseTree::Group { prefix, items } => {
607            for item in items {
608                if item.path.is_empty() {
609                    continue;
610                }
611
612                let mut full_segments = prefix.clone();
613                full_segments.extend(item.path.clone());
614                let full = full_segments.join(".");
615                if module_names.contains(&full) {
616                    continue;
617                }
618
619                let mut module_segments = full_segments.clone();
620                let item_name = module_segments.pop().unwrap();
621                let module_path = module_segments.join(".");
622                let fq_name = if module_path.is_empty() {
623                    item_name.clone()
624                } else {
625                    format!("{}.{}", module_path, item_name)
626                };
627                let alias_name = item
628                    .alias
629                    .clone()
630                    .unwrap_or_else(|| item.path.last().unwrap().clone());
631                let classification = classify_import_target(registry, &module_path, &item_name);
632                if classification.import_type {
633                    exports.types.insert(alias_name.clone(), fq_name.clone());
634                }
635
636                if classification.import_value {
637                    exports.functions.insert(alias_name, fq_name);
638                }
639            }
640
641            Ok(())
642        }
643        UseTree::Glob { prefix } => {
644            let module = prefix.join(".");
645            if let Some(loaded) = registry.get(&module) {
646                for (n, fq) in &loaded.exports.types {
647                    exports.types.insert(n.clone(), fq.clone());
648                }
649
650                for (n, fq) in &loaded.exports.functions {
651                    exports.functions.insert(n.clone(), fq.clone());
652                }
653            }
654
655            Ok(())
656        }
657    }
658}
659
660#[derive(Clone, Copy)]
661struct ImportResolution {
662    import_value: bool,
663    import_type: bool,
664}
665
666impl ImportResolution {
667    fn both() -> Self {
668        Self {
669            import_value: true,
670            import_type: true,
671        }
672    }
673}
674
675fn classify_import_target(
676    registry: &ModuleRegistryView<'_>,
677    module_path: &str,
678    item_name: &str,
679) -> ImportResolution {
680    if module_path.is_empty() {
681        return ImportResolution::both();
682    }
683
684    if let Some(module) = registry.get(module_path) {
685        let has_value = module.exports.functions.contains_key(item_name);
686        let has_type = module.exports.types.contains_key(item_name);
687        if has_value || has_type {
688            return ImportResolution {
689                import_value: has_value,
690                import_type: has_type,
691            };
692        }
693    }
694
695    ImportResolution::both()
696}
697
698fn simple_tail(module_path: &str) -> &str {
699    module_path
700        .rsplit_once('.')
701        .map(|(_, n)| n)
702        .unwrap_or(module_path)
703}