1use std::{fmt, path::PathBuf};
2
3use anyhow::Result;
4use schemars::JsonSchema;
5use serde::{Deserialize, Serialize};
6
7use crate::{
8 errors::{KclError, KclErrorDetails},
9 execution::{EnvironmentRef, PreImportedGeometry},
10 fs::{FileManager, FileSystem},
11 parsing::ast::types::{ImportPath, Node, Program},
12 source_range::SourceRange,
13};
14
15#[derive(Debug, Default, Clone, Copy, PartialEq, Eq, Hash, Deserialize, Serialize, ts_rs::TS, JsonSchema)]
17#[ts(export)]
18pub struct ModuleId(u32);
19
20impl ModuleId {
21 pub fn from_usize(id: usize) -> Self {
22 Self(u32::try_from(id).expect("module ID should fit in a u32"))
23 }
24
25 pub fn as_usize(&self) -> usize {
26 usize::try_from(self.0).expect("module ID should fit in a usize")
27 }
28
29 pub fn is_top_level(&self) -> bool {
32 *self == Self::default()
33 }
34}
35
36impl std::fmt::Display for ModuleId {
37 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
38 write!(f, "{}", self.0)
39 }
40}
41
42#[derive(Debug, Clone, Default)]
43pub(crate) struct ModuleLoader {
44 pub import_stack: Vec<PathBuf>,
47}
48
49impl ModuleLoader {
50 pub(crate) fn cycle_check(&self, path: &ModulePath, source_range: SourceRange) -> Result<(), KclError> {
51 if self.import_stack.contains(path.expect_path()) {
52 return Err(self.import_cycle_error(path, source_range));
53 }
54 Ok(())
55 }
56
57 pub(crate) fn import_cycle_error(&self, path: &ModulePath, source_range: SourceRange) -> KclError {
58 KclError::ImportCycle(KclErrorDetails {
59 message: format!(
60 "circular import of modules is not allowed: {} -> {}",
61 self.import_stack
62 .iter()
63 .map(|p| p.as_path().to_string_lossy())
64 .collect::<Vec<_>>()
65 .join(" -> "),
66 path,
67 ),
68 source_ranges: vec![source_range],
69 })
70 }
71
72 pub(crate) fn enter_module(&mut self, path: &ModulePath) {
73 if let ModulePath::Local { value: ref path } = path {
74 self.import_stack.push(path.clone());
75 }
76 }
77
78 pub(crate) fn leave_module(&mut self, path: &ModulePath) {
79 if let ModulePath::Local { value: ref path } = path {
80 let popped = self.import_stack.pop().unwrap();
81 assert_eq!(path, &popped);
82 }
83 }
84}
85
86pub(crate) fn read_std(mod_name: &str) -> Option<&'static str> {
87 match mod_name {
88 "prelude" => Some(include_str!("../std/prelude.kcl")),
89 "math" => Some(include_str!("../std/math.kcl")),
90 _ => None,
91 }
92}
93
94#[derive(Debug, Clone, PartialEq, Deserialize, Serialize)]
96pub struct ModuleInfo {
97 pub(crate) id: ModuleId,
99 pub(crate) path: ModulePath,
101 pub(crate) repr: ModuleRepr,
102}
103
104impl ModuleInfo {
105 pub(crate) fn take_repr(&mut self) -> ModuleRepr {
106 let mut result = ModuleRepr::Dummy;
107 std::mem::swap(&mut self.repr, &mut result);
108 result
109 }
110
111 pub(crate) fn restore_repr(&mut self, repr: ModuleRepr) {
112 assert!(matches!(&self.repr, ModuleRepr::Dummy));
113 self.repr = repr;
114 }
115}
116
117#[allow(clippy::large_enum_variant)]
118#[derive(Debug, Clone, PartialEq, Deserialize, Serialize)]
119pub enum ModuleRepr {
120 Root,
121 Kcl(Node<Program>, Option<(EnvironmentRef, Vec<String>)>),
123 Foreign(PreImportedGeometry),
124 Dummy,
125}
126
127#[allow(clippy::large_enum_variant)]
128#[derive(Debug, Clone, Eq, PartialEq, Deserialize, Serialize, Hash, ts_rs::TS)]
129#[serde(tag = "type")]
130pub enum ModulePath {
131 Main,
133 Local { value: PathBuf },
134 Std { value: String },
135}
136
137impl ModulePath {
138 pub(crate) fn expect_path(&self) -> &PathBuf {
139 match self {
140 ModulePath::Local { value: p } => p,
141 _ => unreachable!(),
142 }
143 }
144
145 pub(crate) fn std_path(&self) -> Option<String> {
146 match self {
147 ModulePath::Std { value: p } => Some(p.clone()),
148 _ => None,
149 }
150 }
151
152 pub(crate) async fn source(&self, fs: &FileManager, source_range: SourceRange) -> Result<ModuleSource, KclError> {
153 match self {
154 ModulePath::Local { value: p } => Ok(ModuleSource {
155 source: fs.read_to_string(p, source_range).await?,
156 path: self.clone(),
157 }),
158 ModulePath::Std { value: name } => Ok(ModuleSource {
159 source: read_std(name)
160 .ok_or_else(|| {
161 KclError::Semantic(KclErrorDetails {
162 message: format!("Cannot find standard library module to import: std::{name}."),
163 source_ranges: vec![source_range],
164 })
165 })
166 .map(str::to_owned)?,
167 path: self.clone(),
168 }),
169 ModulePath::Main => unreachable!(),
170 }
171 }
172
173 pub(crate) fn from_import_path(path: &ImportPath, project_directory: &Option<PathBuf>) -> Self {
174 match path {
175 ImportPath::Kcl { filename: path } | ImportPath::Foreign { path } => {
176 let resolved_path = if let Some(project_dir) = project_directory {
177 project_dir.join(path)
178 } else {
179 std::path::PathBuf::from(path)
180 };
181 ModulePath::Local { value: resolved_path }
182 }
183 ImportPath::Std { path } => {
184 assert_eq!(path.len(), 2);
186 assert_eq!(&path[0], "std");
187
188 ModulePath::Std { value: path[1].clone() }
189 }
190 }
191 }
192}
193
194impl fmt::Display for ModulePath {
195 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
196 match self {
197 ModulePath::Main => write!(f, "main"),
198 ModulePath::Local { value: path } => path.display().fmt(f),
199 ModulePath::Std { value: s } => write!(f, "std::{s}"),
200 }
201 }
202}
203
204#[derive(Debug, Clone, Eq, PartialEq, Deserialize, Serialize, ts_rs::TS)]
205pub struct ModuleSource {
206 pub path: ModulePath,
207 pub source: String,
208}