1use std::{fmt, path::PathBuf};
2
3use anyhow::Result;
4use schemars::JsonSchema;
5use serde::{Deserialize, Serialize};
6
7use crate::{
8 errors::{KclError, KclErrorDetails},
9 execution::{EnvironmentRef, PreImportedGeometry},
10 fs::{FileManager, FileSystem},
11 parsing::ast::types::{ImportPath, Node, Program},
12 source_range::SourceRange,
13};
14
15#[derive(Debug, Default, Clone, Copy, PartialEq, Eq, Hash, Deserialize, Serialize, ts_rs::TS, JsonSchema)]
17#[ts(export)]
18pub struct ModuleId(u32);
19
20impl ModuleId {
21 pub fn from_usize(id: usize) -> Self {
22 Self(u32::try_from(id).expect("module ID should fit in a u32"))
23 }
24
25 pub fn as_usize(&self) -> usize {
26 usize::try_from(self.0).expect("module ID should fit in a usize")
27 }
28
29 pub fn is_top_level(&self) -> bool {
32 *self == Self::default()
33 }
34}
35
36#[derive(Debug, Clone, Default)]
37pub(crate) struct ModuleLoader {
38 pub import_stack: Vec<PathBuf>,
41}
42
43impl ModuleLoader {
44 pub(crate) fn cycle_check(&self, path: &ModulePath, source_range: SourceRange) -> Result<(), KclError> {
45 if self.import_stack.contains(path.expect_path()) {
46 return Err(self.import_cycle_error(path, source_range));
47 }
48 Ok(())
49 }
50
51 pub(crate) fn import_cycle_error(&self, path: &ModulePath, source_range: SourceRange) -> KclError {
52 KclError::ImportCycle(KclErrorDetails {
53 message: format!(
54 "circular import of modules is not allowed: {} -> {}",
55 self.import_stack
56 .iter()
57 .map(|p| p.as_path().to_string_lossy())
58 .collect::<Vec<_>>()
59 .join(" -> "),
60 path,
61 ),
62 source_ranges: vec![source_range],
63 })
64 }
65
66 pub(crate) fn enter_module(&mut self, path: &ModulePath) {
67 if let ModulePath::Local { value: ref path } = path {
68 self.import_stack.push(path.clone());
69 }
70 }
71
72 pub(crate) fn leave_module(&mut self, path: &ModulePath) {
73 if let ModulePath::Local { value: ref path } = path {
74 let popped = self.import_stack.pop().unwrap();
75 assert_eq!(path, &popped);
76 }
77 }
78}
79
80pub(crate) fn read_std(mod_name: &str) -> Option<&'static str> {
81 match mod_name {
82 "prelude" => Some(include_str!("../std/prelude.kcl")),
83 "math" => Some(include_str!("../std/math.kcl")),
84 _ => None,
85 }
86}
87
88#[derive(Debug, Clone, PartialEq, Deserialize, Serialize)]
90pub struct ModuleInfo {
91 pub(crate) id: ModuleId,
93 pub(crate) path: ModulePath,
95 pub(crate) repr: ModuleRepr,
96}
97
98impl ModuleInfo {
99 pub(crate) fn take_repr(&mut self) -> ModuleRepr {
100 let mut result = ModuleRepr::Dummy;
101 std::mem::swap(&mut self.repr, &mut result);
102 result
103 }
104
105 pub(crate) fn restore_repr(&mut self, repr: ModuleRepr) {
106 assert!(matches!(&self.repr, ModuleRepr::Dummy));
107 self.repr = repr;
108 }
109}
110
111#[allow(clippy::large_enum_variant)]
112#[derive(Debug, Clone, PartialEq, Deserialize, Serialize)]
113pub enum ModuleRepr {
114 Root,
115 Kcl(Node<Program>, Option<(EnvironmentRef, Vec<String>)>),
117 Foreign(PreImportedGeometry),
118 Dummy,
119}
120
121#[allow(clippy::large_enum_variant)]
122#[derive(Debug, Clone, Eq, PartialEq, Deserialize, Serialize, Hash, ts_rs::TS)]
123#[serde(tag = "type")]
124pub enum ModulePath {
125 Main,
127 Local { value: PathBuf },
128 Std { value: String },
129}
130
131impl ModulePath {
132 pub(crate) fn expect_path(&self) -> &PathBuf {
133 match self {
134 ModulePath::Local { value: p } => p,
135 _ => unreachable!(),
136 }
137 }
138
139 pub(crate) fn std_path(&self) -> Option<String> {
140 match self {
141 ModulePath::Std { value: p } => Some(p.clone()),
142 _ => None,
143 }
144 }
145
146 pub(crate) async fn source(&self, fs: &FileManager, source_range: SourceRange) -> Result<ModuleSource, KclError> {
147 match self {
148 ModulePath::Local { value: p } => Ok(ModuleSource {
149 source: fs.read_to_string(p, source_range).await?,
150 path: self.clone(),
151 }),
152 ModulePath::Std { value: name } => Ok(ModuleSource {
153 source: read_std(name)
154 .ok_or_else(|| {
155 KclError::Semantic(KclErrorDetails {
156 message: format!("Cannot find standard library module to import: std::{name}."),
157 source_ranges: vec![source_range],
158 })
159 })
160 .map(str::to_owned)?,
161 path: self.clone(),
162 }),
163 ModulePath::Main => unreachable!(),
164 }
165 }
166
167 pub(crate) fn from_import_path(path: &ImportPath, project_directory: &Option<PathBuf>) -> Self {
168 match path {
169 ImportPath::Kcl { filename: path } | ImportPath::Foreign { path } => {
170 let resolved_path = if let Some(project_dir) = project_directory {
171 project_dir.join(path)
172 } else {
173 std::path::PathBuf::from(path)
174 };
175 ModulePath::Local { value: resolved_path }
176 }
177 ImportPath::Std { path } => {
178 assert_eq!(path.len(), 2);
180 assert_eq!(&path[0], "std");
181
182 ModulePath::Std { value: path[1].clone() }
183 }
184 }
185 }
186}
187
188impl fmt::Display for ModulePath {
189 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
190 match self {
191 ModulePath::Main => write!(f, "main"),
192 ModulePath::Local { value: path } => path.display().fmt(f),
193 ModulePath::Std { value: s } => write!(f, "std::{s}"),
194 }
195 }
196}
197
198#[derive(Debug, Clone, Eq, PartialEq, Deserialize, Serialize, ts_rs::TS)]
199pub struct ModuleSource {
200 pub path: ModulePath,
201 pub source: String,
202}