gluon/
import.rs

1//! Implementation of the `import!` macro.
2
3use std::{
4    any::{Any, TypeId},
5    borrow::Cow,
6    fs::File,
7    io::Read,
8    mem,
9    ops::{Deref, DerefMut},
10    path::PathBuf,
11    sync::{Arc, Mutex, MutexGuard, RwLock},
12};
13
14use {
15    async_trait::async_trait,
16    futures::{
17        future::{self},
18        prelude::*,
19    },
20    itertools::Itertools,
21    salsa::debug::DebugQueryTable,
22};
23
24use crate::base::{
25    ast::{self, expr_to_path, Expr, Literal, SpannedExpr},
26    filename_to_module, pos,
27    source::FileId,
28    symbol::{Symbol, Symbols},
29    types::ArcType,
30};
31
32use crate::vm::{
33    self,
34    gc::Trace,
35    macros::{Error as MacroError, LazyMacroResult, Macro, MacroExpander, MacroFuture},
36    thread::{RootedThread, Thread},
37    vm::VmEnv,
38    ExternLoader, ExternModule,
39};
40
41use crate::{
42    compiler_pipeline::{Salvage, SalvageResult},
43    query::{AsyncCompilation, Compilation, CompilerDatabase},
44    IoError, ModuleCompiler, ThreadExt,
45};
46
47quick_error! {
48    /// Error type for the import macro
49    #[derive(Debug, Clone, Eq, PartialEq, Hash)]
50    pub enum Error {
51        /// The importer found a cyclic dependency when loading files
52        CyclicDependency(module: String, cycle: Vec<String>) {
53            display(
54                "Module '{}' occurs in a cyclic dependency: `{}`",
55                module,
56                cycle.iter().chain(Some(module)).format(" -> ")
57            )
58        }
59        /// Generic message error
60        String(message: String) {
61            display("{}", message)
62        }
63        /// The importer could not load the imported file
64        IO(err: IoError) {
65            display("{}", err)
66            from()
67        }
68    }
69}
70
71impl base::error::AsDiagnostic for Error {
72    fn as_diagnostic(
73        &self,
74        _map: &base::source::CodeMap,
75    ) -> codespan_reporting::diagnostic::Diagnostic<FileId> {
76        codespan_reporting::diagnostic::Diagnostic::error().with_message(self.to_string())
77    }
78}
79
80include!(concat!(env!("OUT_DIR"), "/std_modules.rs"));
81
82#[async_trait]
83pub trait Importer: Any + Clone + Sync + Send {
84    async fn import(
85        &self,
86        compiler: &mut ModuleCompiler<'_, '_>,
87        vm: &Thread,
88        modulename: &str,
89    ) -> SalvageResult<ArcType, crate::Error>;
90}
91
92#[derive(Clone)]
93pub struct DefaultImporter;
94#[async_trait]
95impl Importer for DefaultImporter {
96    async fn import(
97        &self,
98        compiler: &mut ModuleCompiler<'_, '_>,
99        _vm: &Thread,
100        modulename: &str,
101    ) -> SalvageResult<ArcType> {
102        let result = compiler.database.global(modulename.to_string()).await;
103        // Forcibly load module_type so we can salvage a type for the error if necessary
104        let _ = compiler
105            .database
106            .module_type(modulename.to_string(), None)
107            .await;
108
109        let value = result.map_err(|error| {
110            let value = compiler.database.peek_module_type(modulename);
111            Salvage { value, error }
112        })?;
113        Ok(value.typ)
114    }
115}
116
117pub struct DatabaseMut {
118    // Only needed to ensure that the the `Compiler` the guard points to lives long enough
119    _import: Arc<dyn Macro>,
120    // This isn't actually static but relies on `import` to live longer than the guard
121    compiler: Option<MutexGuard<'static, CompilerDatabase>>,
122}
123
124impl Drop for DatabaseMut {
125    fn drop(&mut self) {
126        // The compiler moves back to only be owned by the import so we need to remove the thread
127        // to break the cycle
128        self.thread = None;
129        self.compiler.take();
130    }
131}
132
133impl Deref for DatabaseMut {
134    type Target = CompilerDatabase;
135    fn deref(&self) -> &Self::Target {
136        self.compiler.as_ref().unwrap()
137    }
138}
139
140impl DerefMut for DatabaseMut {
141    fn deref_mut(&mut self) -> &mut Self::Target {
142        self.compiler.as_mut().unwrap()
143    }
144}
145
146#[async_trait]
147pub(crate) trait ImportApi: Send + Sync {
148    fn get_module_source(
149        &self,
150        use_standard_lib: bool,
151        module: &str,
152        filename: &str,
153    ) -> Result<Cow<'static, str>, Error>;
154    async fn load_module(
155        &self,
156        compiler: &mut ModuleCompiler<'_, '_>,
157        vm: &Thread,
158        module_id: &Symbol,
159    ) -> SalvageResult<ArcType>;
160    fn snapshot(&self, thread: RootedThread) -> salsa::Snapshot<CompilerDatabase>;
161    fn fork(
162        &mut self,
163        forker: salsa::ForkState,
164        thread: RootedThread,
165    ) -> salsa::Snapshot<CompilerDatabase>;
166}
167
168#[async_trait]
169impl<I> ImportApi for Import<I>
170where
171    I: Importer,
172{
173    fn get_module_source(
174        &self,
175        use_standard_lib: bool,
176        module: &str,
177        filename: &str,
178    ) -> Result<Cow<'static, str>, Error> {
179        Self::get_module_source(self, use_standard_lib, module, filename)
180    }
181    async fn load_module(
182        &self,
183        compiler: &mut ModuleCompiler<'_, '_>,
184        vm: &Thread,
185        module_id: &Symbol,
186    ) -> SalvageResult<ArcType> {
187        assert!(module_id.is_global());
188        let modulename = module_id.name().definition_name();
189
190        self.importer.import(compiler, vm, &modulename).await
191    }
192    fn snapshot(&self, thread: RootedThread) -> salsa::Snapshot<CompilerDatabase> {
193        Self::snapshot(self, thread)
194    }
195    fn fork(
196        &mut self,
197        forker: salsa::ForkState,
198        thread: RootedThread,
199    ) -> salsa::Snapshot<CompilerDatabase> {
200        Self::fork(self, forker, thread)
201    }
202}
203
204/// Macro which rewrites occurances of `import! "filename"` to a load of that file if it is not
205/// already loaded and then a global access to the loaded module
206pub struct Import<I = DefaultImporter> {
207    pub paths: RwLock<Vec<PathBuf>>,
208    pub importer: I,
209
210    pub compiler: Mutex<CompilerDatabase>,
211}
212
213#[derive(Debug)]
214pub struct PtrEq<T>(pub Arc<T>);
215
216impl<T> std::ops::Deref for PtrEq<T> {
217    type Target = T;
218    fn deref(&self) -> &T {
219        &self.0
220    }
221}
222
223impl<T> Clone for PtrEq<T> {
224    fn clone(&self) -> Self {
225        Self(self.0.clone())
226    }
227}
228
229impl<T> Eq for PtrEq<T> {}
230
231impl<T> PartialEq for PtrEq<T> {
232    fn eq(&self, other: &Self) -> bool {
233        Arc::ptr_eq(&self.0, &other.0)
234    }
235}
236
237impl<T> std::hash::Hash for PtrEq<T> {
238    fn hash<H: std::hash::Hasher>(&self, hasher: &mut H) {
239        (&*self.0 as *const T).hash(hasher)
240    }
241}
242
243impl<I> Import<I> {
244    /// Creates a new import macro
245    pub fn new(importer: I) -> Import<I> {
246        Import {
247            paths: RwLock::new(vec![PathBuf::from(".")]),
248            compiler: CompilerDatabase::new_base(None).into(),
249            importer: importer,
250        }
251    }
252
253    /// Adds a path to the list of paths which the importer uses to find files
254    pub fn add_path<P: Into<PathBuf>>(&self, path: P) {
255        self.paths.write().unwrap().push(path.into());
256    }
257
258    pub fn set_paths(&self, paths: Vec<PathBuf>) {
259        *self.paths.write().unwrap() = paths;
260    }
261
262    pub fn modules(&self, compiler: &mut ModuleCompiler<'_, '_>) -> Vec<Cow<'static, str>> {
263        STD_LIBS
264            .iter()
265            .map(|t| Cow::Borrowed(t.0))
266            .chain(
267                crate::query::ExternLoaderQuery
268                    .in_db(&*compiler.database)
269                    .entries::<Vec<_>>()
270                    .into_iter()
271                    .map(|entry| Cow::Owned(entry.key)),
272            )
273            .collect()
274    }
275
276    pub fn database_mut(self: Arc<Self>, thread: RootedThread) -> DatabaseMut
277    where
278        I: Importer,
279    {
280        // Since `self` lives longer than the lifetime in the mutex guard this is safe
281        let mut compiler = unsafe {
282            DatabaseMut {
283                compiler: Some(mem::transmute::<
284                    MutexGuard<CompilerDatabase>,
285                    MutexGuard<CompilerDatabase>,
286                >(self.compiler.lock().unwrap())),
287                _import: self,
288            }
289        };
290
291        compiler.thread = Some(thread);
292
293        compiler
294    }
295
296    pub fn snapshot(&self, thread: RootedThread) -> salsa::Snapshot<CompilerDatabase> {
297        self.compiler.lock().unwrap().snapshot(thread)
298    }
299
300    pub fn fork(
301        &mut self,
302        forker: salsa::ForkState,
303        thread: RootedThread,
304    ) -> salsa::Snapshot<CompilerDatabase> {
305        self.compiler.lock().unwrap().fork(forker, thread)
306    }
307
308    pub(crate) fn get_module_source(
309        &self,
310        use_standard_lib: bool,
311        module: &str,
312        filename: &str,
313    ) -> Result<Cow<'static, str>, Error> {
314        let mut buffer = String::new();
315
316        // Retrieve the source, first looking in the standard library included in the
317        // binary
318
319        let std_file = if use_standard_lib {
320            STD_LIBS.iter().find(|tup| tup.0 == module)
321        } else {
322            None
323        };
324        Ok(match std_file {
325            Some(tup) => Cow::Borrowed(tup.1),
326            None => {
327                let paths = self.paths.read().unwrap();
328                let file = paths
329                    .iter()
330                    .filter_map(|p| {
331                        let base = p.join(filename);
332                        match File::open(&base) {
333                            Ok(file) => Some(file),
334                            Err(_) => None,
335                        }
336                    })
337                    .next();
338                let mut file = file.ok_or_else(|| {
339                    Error::String(format!(
340                        "Could not find module '{}'. Searched {}.",
341                        module,
342                        paths
343                            .iter()
344                            .map(|p| format!("`{}`", p.display()))
345                            .format(", ")
346                    ))
347                })?;
348                file.read_to_string(&mut buffer)
349                    .map_err(|err| Error::IO(err.into()))?;
350                Cow::Owned(buffer)
351            }
352        })
353    }
354}
355
356/// Adds an extern module to `thread`, letting it be loaded with `import! name` from gluon code.
357///
358/// ```
359/// use gluon::vm::{self, ExternModule};
360/// use gluon::{primitive, record, Thread, ThreadExt};
361/// use gluon::import::add_extern_module;
362///
363/// fn yell(s: &str) -> String {
364///     s.to_uppercase()
365/// }
366///
367/// fn my_module(thread: &Thread) -> vm::Result<ExternModule> {
368///     ExternModule::new(
369///         thread,
370///         record!{
371///             message => "Hello World!",
372///             yell => primitive!(1, yell)
373///         }
374///     )
375/// }
376///
377/// #[tokio::main]
378/// async fn main() -> gluon::Result<()> {
379///     let thread = gluon::new_vm_async().await;
380///     add_extern_module(&thread, "my_module", my_module);
381///     let script = r#"
382///         let module = import! "my_module"
383///         module.yell module.message
384///     "#;
385///     let (result, _) = thread.run_expr_async::<String>("example", script).await?;
386///     assert_eq!(result, "HELLO WORLD!");
387///     Ok(())
388/// }
389/// ```
390pub fn add_extern_module<F>(thread: &Thread, name: &str, loader: F)
391where
392    F: Fn(&Thread) -> vm::Result<ExternModule> + Send + Sync + 'static,
393{
394    add_extern_module_(
395        thread,
396        name,
397        ExternLoader {
398            load_fn: Box::new(loader),
399            dependencies: Vec::new(),
400        },
401    )
402}
403
404pub fn add_extern_module_with_deps<F>(
405    thread: &Thread,
406    name: &str,
407    loader: F,
408    dependencies: Vec<String>,
409) where
410    F: Fn(&Thread) -> vm::Result<ExternModule> + Send + Sync + 'static,
411{
412    add_extern_module_(
413        thread,
414        name,
415        ExternLoader {
416            load_fn: Box::new(loader),
417            dependencies,
418        },
419    )
420}
421
422fn add_extern_module_(thread: &Thread, name: &str, loader: ExternLoader) {
423    thread
424        .get_database_mut()
425        .set_extern_loader(name.into(), PtrEq(Arc::new(loader)));
426}
427
428macro_rules! add_extern_module_if {
429    (
430        #[cfg($($features: tt)*)],
431        available_if = $msg: expr,
432        $(dependencies = $dependencies: expr,)?
433        args($vm: expr, $mod_name: expr, $loader: path)
434    ) => {{
435        #[cfg($($features)*)]
436        $crate::import::add_extern_module_with_deps(
437            $vm,
438            $mod_name,
439            $loader,
440            None.into_iter() $( .chain($dependencies.iter().cloned()) )? .map(|s: &str| s.into()).collect::<Vec<String>>()
441        );
442
443        #[cfg(not($($features)*))]
444        $crate::import::add_extern_module($vm, $mod_name, |_: &::vm::thread::Thread| -> ::vm::Result<::vm::ExternModule> {
445            Err(::vm::Error::Message(
446                format!(
447                    "{} is only available if {}",
448                    $mod_name,
449                    $msg
450                )
451            ))
452        });
453    }};
454}
455
456impl<I> Macro for Import<I>
457where
458    I: Importer,
459{
460    fn get_capability_impl(
461        &self,
462        thread: &Thread,
463        arc_self: &Arc<dyn Macro>,
464        id: TypeId,
465    ) -> Option<Box<dyn Any>> {
466        if id == TypeId::of::<Box<dyn VmEnv>>() {
467            Some(Box::new(Box::new(crate::query::snapshot_env(
468                self.snapshot(thread.root_thread()),
469            )) as Box<dyn VmEnv>))
470        } else if id == TypeId::of::<Arc<dyn ImportApi>>() {
471            Some(Box::new(
472                arc_self.clone().downcast_arc::<Self>().ok().unwrap() as Arc<dyn ImportApi>,
473            ))
474        } else if id == TypeId::of::<salsa::Snapshot<CompilerDatabase>>() {
475            Some(Box::new(self.snapshot(thread.root_thread())))
476        } else if id == TypeId::of::<DatabaseMut>() {
477            Some(Box::new(
478                arc_self
479                    .clone()
480                    .downcast_arc::<Self>()
481                    .ok()
482                    .unwrap()
483                    .database_mut(thread.root_thread()),
484            ))
485        } else {
486            None
487        }
488    }
489
490    fn expand<'r, 'a: 'r, 'b: 'r, 'c: 'r, 'ast: 'r>(
491        &self,
492        macros: &'b mut MacroExpander<'a>,
493        _symbols: &'c mut Symbols,
494        _arena: &'b mut ast::OwnedArena<'ast, Symbol>,
495        args: &'b mut [SpannedExpr<'ast, Symbol>],
496    ) -> MacroFuture<'r, 'ast> {
497        fn get_module_name(args: &[SpannedExpr<Symbol>]) -> Result<String, Error> {
498            if args.len() != 1 {
499                return Err(Error::String("Expected import to get 1 argument".into()).into());
500            }
501
502            let modulename = match args[0].value {
503                Expr::Ident(_) | Expr::Projection(..) => {
504                    let mut modulename = String::new();
505                    expr_to_path(&args[0], &mut modulename)
506                        .map_err(|err| Error::String(err.to_string()))?;
507                    modulename
508                }
509                Expr::Literal(Literal::String(ref filename)) => filename_to_module(filename),
510                _ => {
511                    return Err(Error::String(
512                        "Expected a string literal or path to import".into(),
513                    )
514                    .into());
515                }
516            };
517            Ok(modulename)
518        }
519
520        let modulename = match get_module_name(&args).map_err(MacroError::new) {
521            Ok(modulename) => modulename,
522            Err(err) => return Box::pin(future::err(err)),
523        };
524
525        info!("import! {}", modulename);
526
527        let mut db = try_future!(macros
528            .userdata
529            .fork(macros.vm.root_thread())
530            .downcast::<salsa::Snapshot<CompilerDatabase>>()
531            .map_err(|_| MacroError::new(Error::String(
532                "`import` requires a `CompilerDatabase` as user data during macro expansion".into(),
533            ))));
534
535        let span = args[0].span;
536
537        #[cfg(feature = "tokio")]
538        if let Some(spawn) = macros.spawn {
539            use futures::task::SpawnExt;
540
541            let (tx, rx) = tokio::sync::oneshot::channel();
542            spawn
543                .spawn(Box::pin(async move {
544                    let result = std::panic::AssertUnwindSafe(db.import(modulename))
545                        .catch_unwind()
546                        .await
547                        .map(|r| {
548                            r.map_err(|salvage| {
549                                salvage.map_err(|err| MacroError::message(err.to_string()))
550                            })
551                        })
552                        .unwrap_or_else(|err| {
553                            Err(Salvage::from(MacroError::message(
554                                err.downcast::<String>()
555                                    .map(|s| *s)
556                                    .or_else(|e| e.downcast::<&str>().map(|s| String::from(&s[..])))
557                                    .unwrap_or_else(|_| "Unknown panic".to_string()),
558                            )))
559                        });
560                    // Drop the database before sending the result, otherwise the forker may drop before the forked database
561                    drop(db);
562                    let _ = tx.send(result);
563                }))
564                .unwrap();
565            return Box::pin(async move {
566                Ok(LazyMacroResult::from(move || {
567                    async move {
568                        rx.await
569                            .unwrap_or_else(|err| {
570                                Err(Salvage::from(MacroError::new(Error::String(
571                                    err.to_string(),
572                                ))))
573                            })
574                            .map(|id| pos::spanned(span, Expr::Ident(id)))
575                            .map_err(|salvage| {
576                                salvage.map(|id| pos::spanned(span, Expr::Ident(id)))
577                            })
578                    }
579                    .boxed()
580                }))
581            });
582        }
583
584        Box::pin(async move {
585            Ok(LazyMacroResult::from(move || {
586                async move {
587                    let result = db
588                        .import(modulename)
589                        .await
590                        .map_err(|salvage| {
591                            salvage
592                                .map(|id| pos::spanned(span, Expr::Ident(id)))
593                                .map_err(|err| MacroError::message(err.to_string()))
594                        })
595                        .map(move |id| pos::spanned(span, Expr::Ident(id)));
596                    drop(db);
597                    result
598                }
599                .boxed()
600            }))
601        })
602    }
603}
604
605unsafe impl<I> Trace for Import<I> {
606    impl_trace! { self, _gc, () }
607}