1use std::{fmt, path::PathBuf};
2
3use anyhow::Result;
4use schemars::JsonSchema;
5use serde::{Deserialize, Serialize};
6
7use crate::{
8 errors::{KclError, KclErrorDetails},
9 execution::{EnvironmentRef, PreImportedGeometry},
10 fs::{FileManager, FileSystem},
11 parsing::ast::types::{ImportPath, Node, Program},
12 source_range::SourceRange,
13};
14
15#[derive(Debug, Default, Clone, Copy, PartialEq, Eq, Hash, Deserialize, Serialize, ts_rs::TS, JsonSchema)]
17#[ts(export)]
18pub struct ModuleId(u32);
19
20impl ModuleId {
21 pub fn from_usize(id: usize) -> Self {
22 Self(u32::try_from(id).expect("module ID should fit in a u32"))
23 }
24
25 pub fn as_usize(&self) -> usize {
26 usize::try_from(self.0).expect("module ID should fit in a usize")
27 }
28
29 pub fn is_top_level(&self) -> bool {
32 *self == Self::default()
33 }
34}
35
36impl std::fmt::Display for ModuleId {
37 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
38 write!(f, "{}", self.0)
39 }
40}
41
42#[derive(Debug, Clone, Default)]
43pub(crate) struct ModuleLoader {
44 pub import_stack: Vec<PathBuf>,
47}
48
49impl ModuleLoader {
50 pub(crate) fn cycle_check(&self, path: &ModulePath, source_range: SourceRange) -> Result<(), KclError> {
51 if self.import_stack.contains(path.expect_path()) {
52 return Err(self.import_cycle_error(path, source_range));
53 }
54 Ok(())
55 }
56
57 pub(crate) fn import_cycle_error(&self, path: &ModulePath, source_range: SourceRange) -> KclError {
58 KclError::ImportCycle(KclErrorDetails {
59 message: format!(
60 "circular import of modules is not allowed: {} -> {}",
61 self.import_stack
62 .iter()
63 .map(|p| p.as_path().to_string_lossy())
64 .collect::<Vec<_>>()
65 .join(" -> "),
66 path,
67 ),
68 source_ranges: vec![source_range],
69 })
70 }
71
72 pub(crate) fn enter_module(&mut self, path: &ModulePath) {
73 if let ModulePath::Local { value: ref path } = path {
74 self.import_stack.push(path.clone());
75 }
76 }
77
78 pub(crate) fn leave_module(&mut self, path: &ModulePath) {
79 if let ModulePath::Local { value: ref path } = path {
80 let popped = self.import_stack.pop().unwrap();
81 assert_eq!(path, &popped);
82 }
83 }
84}
85
86pub(crate) fn read_std(mod_name: &str) -> Option<&'static str> {
87 match mod_name {
88 "prelude" => Some(include_str!("../std/prelude.kcl")),
89 "math" => Some(include_str!("../std/math.kcl")),
90 "sketch" => Some(include_str!("../std/sketch.kcl")),
91 "turns" => Some(include_str!("../std/turns.kcl")),
92 _ => None,
93 }
94}
95
96#[derive(Debug, Clone, PartialEq, Deserialize, Serialize)]
98pub struct ModuleInfo {
99 pub(crate) id: ModuleId,
101 pub(crate) path: ModulePath,
103 pub(crate) repr: ModuleRepr,
104}
105
106impl ModuleInfo {
107 pub(crate) fn take_repr(&mut self) -> ModuleRepr {
108 let mut result = ModuleRepr::Dummy;
109 std::mem::swap(&mut self.repr, &mut result);
110 result
111 }
112
113 pub(crate) fn restore_repr(&mut self, repr: ModuleRepr) {
114 assert!(matches!(&self.repr, ModuleRepr::Dummy));
115 self.repr = repr;
116 }
117}
118
119#[allow(clippy::large_enum_variant)]
120#[derive(Debug, Clone, PartialEq, Deserialize, Serialize)]
121pub enum ModuleRepr {
122 Root,
123 Kcl(Node<Program>, Option<(EnvironmentRef, Vec<String>)>),
125 Foreign(PreImportedGeometry),
126 Dummy,
127}
128
129#[allow(clippy::large_enum_variant)]
130#[derive(Debug, Clone, Eq, PartialEq, Deserialize, Serialize, Hash, ts_rs::TS)]
131#[serde(tag = "type")]
132pub enum ModulePath {
133 Main,
135 Local { value: PathBuf },
136 Std { value: String },
137}
138
139impl ModulePath {
140 pub(crate) fn expect_path(&self) -> &PathBuf {
141 match self {
142 ModulePath::Local { value: p } => p,
143 _ => unreachable!(),
144 }
145 }
146
147 pub(crate) fn std_path(&self) -> Option<String> {
148 match self {
149 ModulePath::Std { value: p } => Some(p.clone()),
150 _ => None,
151 }
152 }
153
154 pub(crate) async fn source(&self, fs: &FileManager, source_range: SourceRange) -> Result<ModuleSource, KclError> {
155 match self {
156 ModulePath::Local { value: p } => Ok(ModuleSource {
157 source: fs.read_to_string(p, source_range).await?,
158 path: self.clone(),
159 }),
160 ModulePath::Std { value: name } => Ok(ModuleSource {
161 source: read_std(name)
162 .ok_or_else(|| {
163 KclError::Semantic(KclErrorDetails {
164 message: format!("Cannot find standard library module to import: std::{name}."),
165 source_ranges: vec![source_range],
166 })
167 })
168 .map(str::to_owned)?,
169 path: self.clone(),
170 }),
171 ModulePath::Main => unreachable!(),
172 }
173 }
174
175 pub(crate) fn from_import_path(path: &ImportPath, project_directory: &Option<PathBuf>) -> Self {
176 match path {
177 ImportPath::Kcl { filename: path } | ImportPath::Foreign { path } => {
178 let resolved_path = if let Some(project_dir) = project_directory {
179 project_dir.join(path)
180 } else {
181 std::path::PathBuf::from(path)
182 };
183 ModulePath::Local { value: resolved_path }
184 }
185 ImportPath::Std { path } => {
186 assert_eq!(path.len(), 2);
188 assert_eq!(&path[0], "std");
189
190 ModulePath::Std { value: path[1].clone() }
191 }
192 }
193 }
194}
195
196impl fmt::Display for ModulePath {
197 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
198 match self {
199 ModulePath::Main => write!(f, "main"),
200 ModulePath::Local { value: path } => path.display().fmt(f),
201 ModulePath::Std { value: s } => write!(f, "std::{s}"),
202 }
203 }
204}
205
206#[derive(Debug, Clone, Eq, PartialEq, Deserialize, Serialize, ts_rs::TS)]
207pub struct ModuleSource {
208 pub path: ModulePath,
209 pub source: String,
210}