1use std::{fmt, path::PathBuf};
2
3use anyhow::Result;
4use schemars::JsonSchema;
5use serde::{Deserialize, Serialize};
6
7use crate::{
8 errors::{KclError, KclErrorDetails},
9 execution::{EnvironmentRef, PreImportedGeometry},
10 fs::{FileManager, FileSystem},
11 parsing::ast::types::{ImportPath, Node, Program},
12 source_range::SourceRange,
13};
14
15#[derive(Debug, Default, Clone, Copy, PartialEq, Eq, Hash, Deserialize, Serialize, ts_rs::TS, JsonSchema)]
17#[ts(export)]
18pub struct ModuleId(u32);
19
20impl ModuleId {
21 pub fn from_usize(id: usize) -> Self {
22 Self(u32::try_from(id).expect("module ID should fit in a u32"))
23 }
24
25 pub fn as_usize(&self) -> usize {
26 usize::try_from(self.0).expect("module ID should fit in a usize")
27 }
28
29 pub fn is_top_level(&self) -> bool {
32 *self == Self::default()
33 }
34}
35
36impl std::fmt::Display for ModuleId {
37 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
38 write!(f, "{}", self.0)
39 }
40}
41
42#[derive(Debug, Clone, Default)]
43pub(crate) struct ModuleLoader {
44 pub import_stack: Vec<PathBuf>,
47}
48
49impl ModuleLoader {
50 pub(crate) fn cycle_check(&self, path: &ModulePath, source_range: SourceRange) -> Result<(), KclError> {
51 if self.import_stack.contains(path.expect_path()) {
52 return Err(self.import_cycle_error(path, source_range));
53 }
54 Ok(())
55 }
56
57 pub(crate) fn import_cycle_error(&self, path: &ModulePath, source_range: SourceRange) -> KclError {
58 KclError::ImportCycle(KclErrorDetails {
59 message: format!(
60 "circular import of modules is not allowed: {} -> {}",
61 self.import_stack
62 .iter()
63 .map(|p| p.as_path().to_string_lossy())
64 .collect::<Vec<_>>()
65 .join(" -> "),
66 path,
67 ),
68 source_ranges: vec![source_range],
69 })
70 }
71
72 pub(crate) fn enter_module(&mut self, path: &ModulePath) {
73 if let ModulePath::Local { value: ref path } = path {
74 self.import_stack.push(path.clone());
75 }
76 }
77
78 pub(crate) fn leave_module(&mut self, path: &ModulePath) {
79 if let ModulePath::Local { value: ref path } = path {
80 let popped = self.import_stack.pop().unwrap();
81 assert_eq!(path, &popped);
82 }
83 }
84}
85
86pub(crate) fn read_std(mod_name: &str) -> Option<&'static str> {
87 match mod_name {
88 "prelude" => Some(include_str!("../std/prelude.kcl")),
89 "math" => Some(include_str!("../std/math.kcl")),
90 "sketch" => Some(include_str!("../std/sketch.kcl")),
91 _ => None,
92 }
93}
94
95#[derive(Debug, Clone, PartialEq, Deserialize, Serialize)]
97pub struct ModuleInfo {
98 pub(crate) id: ModuleId,
100 pub(crate) path: ModulePath,
102 pub(crate) repr: ModuleRepr,
103}
104
105impl ModuleInfo {
106 pub(crate) fn take_repr(&mut self) -> ModuleRepr {
107 let mut result = ModuleRepr::Dummy;
108 std::mem::swap(&mut self.repr, &mut result);
109 result
110 }
111
112 pub(crate) fn restore_repr(&mut self, repr: ModuleRepr) {
113 assert!(matches!(&self.repr, ModuleRepr::Dummy));
114 self.repr = repr;
115 }
116}
117
118#[allow(clippy::large_enum_variant)]
119#[derive(Debug, Clone, PartialEq, Deserialize, Serialize)]
120pub enum ModuleRepr {
121 Root,
122 Kcl(Node<Program>, Option<(EnvironmentRef, Vec<String>)>),
124 Foreign(PreImportedGeometry),
125 Dummy,
126}
127
128#[allow(clippy::large_enum_variant)]
129#[derive(Debug, Clone, Eq, PartialEq, Deserialize, Serialize, Hash, ts_rs::TS)]
130#[serde(tag = "type")]
131pub enum ModulePath {
132 Main,
134 Local { value: PathBuf },
135 Std { value: String },
136}
137
138impl ModulePath {
139 pub(crate) fn expect_path(&self) -> &PathBuf {
140 match self {
141 ModulePath::Local { value: p } => p,
142 _ => unreachable!(),
143 }
144 }
145
146 pub(crate) fn std_path(&self) -> Option<String> {
147 match self {
148 ModulePath::Std { value: p } => Some(p.clone()),
149 _ => None,
150 }
151 }
152
153 pub(crate) async fn source(&self, fs: &FileManager, source_range: SourceRange) -> Result<ModuleSource, KclError> {
154 match self {
155 ModulePath::Local { value: p } => Ok(ModuleSource {
156 source: fs.read_to_string(p, source_range).await?,
157 path: self.clone(),
158 }),
159 ModulePath::Std { value: name } => Ok(ModuleSource {
160 source: read_std(name)
161 .ok_or_else(|| {
162 KclError::Semantic(KclErrorDetails {
163 message: format!("Cannot find standard library module to import: std::{name}."),
164 source_ranges: vec![source_range],
165 })
166 })
167 .map(str::to_owned)?,
168 path: self.clone(),
169 }),
170 ModulePath::Main => unreachable!(),
171 }
172 }
173
174 pub(crate) fn from_import_path(path: &ImportPath, project_directory: &Option<PathBuf>) -> Self {
175 match path {
176 ImportPath::Kcl { filename: path } | ImportPath::Foreign { path } => {
177 let resolved_path = if let Some(project_dir) = project_directory {
178 project_dir.join(path)
179 } else {
180 std::path::PathBuf::from(path)
181 };
182 ModulePath::Local { value: resolved_path }
183 }
184 ImportPath::Std { path } => {
185 assert_eq!(path.len(), 2);
187 assert_eq!(&path[0], "std");
188
189 ModulePath::Std { value: path[1].clone() }
190 }
191 }
192 }
193}
194
195impl fmt::Display for ModulePath {
196 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
197 match self {
198 ModulePath::Main => write!(f, "main"),
199 ModulePath::Local { value: path } => path.display().fmt(f),
200 ModulePath::Std { value: s } => write!(f, "std::{s}"),
201 }
202 }
203}
204
205#[derive(Debug, Clone, Eq, PartialEq, Deserialize, Serialize, ts_rs::TS)]
206pub struct ModuleSource {
207 pub path: ModulePath,
208 pub source: String,
209}