use std::cell::LazyCell;
use std::collections::HashMap;
use std::ffi::CString;
use std::fs::create_dir_all;
use std::mem::MaybeUninit;
use std::panic::UnwindSafe;
use std::path::{Path, PathBuf};
use std::process::Command;
use std::sync::atomic::AtomicBool;
use std::sync::RwLock;
use cc::Build;
use libloading::Library;
use tree_sitter::Language;
use tree_sitter_language::LanguageFn;
use walkdir::WalkDir;
use crate::Error;
use crate::queries::{has_extension, language_name};
const LOADED_LANGUAGES: LazyCell<RwLock<HashMap<PathBuf, LanguageFn>>> =
LazyCell::new(|| RwLock::new(HashMap::new()));
const REGISTERED_HANDLER: AtomicBool = AtomicBool::new(false);
const TESTING_LOADED_LANGUAGE: AtomicBool = AtomicBool::new(false);
pub(crate) fn dyload_language(path: impl AsRef<Path>) -> Result<Language, Error> {
register_handler_if_necessary();
let path = path.as_ref();
if let Some(language_fn) = LOADED_LANGUAGES.read().unwrap().get(path) {
return Ok(copy_language_fn(language_fn).into());
}
let language_fn = dyload_new_language(path)?;
LOADED_LANGUAGES.write().unwrap().insert(path.to_path_buf(), copy_language_fn(&language_fn));
Ok(language_fn.into())
}
fn dyload_new_language(path: &Path) -> Result<LanguageFn, Error> {
let dylib_path = dylib_path(path);
let symbol_name = CString::new(language_name(path)?)
.map_err(|_| Error::IllegalTSLanguageSymbolName)?;
build_dylib_if_needed(path, &dylib_path)?;
eprintln!("Dynamically loading {}...", symbol_name.to_str().unwrap());
unsafe {
let dylib = Library::new(&dylib_path.canonicalize()?).map_err(Error::LoadDylibFailed)?;
let language_fn_symbol = dylib
.get::<unsafe extern "C" fn() -> *const ()>(symbol_name.as_bytes())
.map_err(Error::LoadDylibSymbolFailed)?;
let language_fn = LanguageFn::from_raw(*language_fn_symbol);
std::mem::forget(dylib);
testing_loaded_language(|| {
let language = Language::from(copy_language_fn(&language_fn));
let version = language.version();
if version < tree_sitter::MIN_COMPATIBLE_LANGUAGE_VERSION || version > tree_sitter::LANGUAGE_VERSION {
return Err(Error::IncompatibleLanguageVersion { version });
}
Ok(())
})?;
Ok(language_fn)
}
}
fn build_dylib_if_needed(path: &Path, dylib_path: &Path) -> Result<(), Error> {
if !dylib_path.exists() {
build_dylib(path, dylib_path)?;
}
if !dylib_path.exists() {
return Err(Error::MissingDylib);
}
Ok(())
}
fn dylib_path(path: &Path) -> PathBuf {
let mut path = path.join("target/c-release-so/libtree-sitter");
if cfg!(target_os = "macos") {
path.set_extension("dylib");
} else if cfg!(target_os = "windows") {
path.set_extension("dll");
} else {
path.set_extension("so");
}
path
}
fn build_dylib(path: &Path, dylib_path: &Path) -> Result<(), Error> {
let dylib_dir = dylib_path.parent().unwrap();
create_dir_all(dylib_dir)?;
eprintln!("Building {}...", dylib_path.display());
let src_dir = path.join("src");
let sources = src_dir.read_dir()?
.filter_map(|e| e.ok())
.map(|e| e.path())
.filter(|p| has_extension(p, "c"));
Build::new()
.host(env!("HOST"))
.target(env!("TARGET"))
.opt_level_str(env!("OPT_LEVEL"))
.debug(env!("DEBUG") == "true")
.flag_if_supported("-Wno-unused-parameter")
.flag_if_supported("-Wno-unused-but-set-variable")
.flag_if_supported("-Wno-trigraphs")
.include(&src_dir)
.files(sources)
.shared_flag(true)
.cargo_metadata(false)
.out_dir(&dylib_dir)
.try_compile("tree-sitter")?;
eprintln!("Dynamic linking {}...", dylib_path.display());
let status = if cfg!(target_os = "macos") {
Command::new("/usr/bin/clang")
.args(["-dynamiclib", "-undefined", "error", "-o"])
.arg(&dylib_path)
.args(find_object_files_in(dylib_dir))
.status()
.map_err(Error::LinkDylibCmdFailed)?
} else if cfg!(target_family = "unix") {
Command::new("/usr/bin/ld")
.args(["-shared", "-o"])
.arg(&dylib_path)
.args(find_object_files_in(dylib_dir))
.status()
.map_err(Error::LinkDylibCmdFailed)?
} else {
return Err(Error::LinkDylibUnsupported);
};
if !status.success() {
return Err(Error::LinkDylibFailed { exit_status: status });
}
Ok(())
}
fn find_object_files_in(dir: &Path) -> impl Iterator<Item=PathBuf> {
WalkDir::new(dir)
.into_iter()
.filter_map(|entry| entry.ok())
.filter(|entry| entry.file_type().is_file())
.filter(|entry| has_extension(entry.path(), "o"))
.map(|entry| entry.into_path())
}
fn copy_language_fn(language_fn: &LanguageFn) -> LanguageFn {
unsafe {
std::mem::transmute_copy(language_fn)
}
}
fn register_handler_if_necessary() {
if !REGISTERED_HANDLER.swap(true, std::sync::atomic::Ordering::Relaxed) {
unsafe {
libc::signal(libc::SIGSEGV, loaded_language_sigsegv_handler as libc::size_t);
}
}
}
fn testing_loaded_language(f: impl FnOnce() -> Result<(), Error> + UnwindSafe) -> Result<(), Error> {
TESTING_LOADED_LANGUAGE.store(true, std::sync::atomic::Ordering::Relaxed);
let result = std::panic::catch_unwind(f).unwrap_or(Err(Error::CorruptDylib));
TESTING_LOADED_LANGUAGE.store(false, std::sync::atomic::Ordering::Relaxed);
result
}
unsafe extern "C" fn loaded_language_sigsegv_handler(signal: libc::c_int) {
unsafe {
if signal != libc::SIGSEGV || !TESTING_LOADED_LANGUAGE.load(std::sync::atomic::Ordering::Relaxed) {
libc::signal(signal, libc::SIG_DFL);
libc::raise(signal);
}
#[allow(dangling_pointers_from_temporaries)]
let sigs = MaybeUninit::<libc::sigset_t>::uninit().as_mut_ptr();
libc::sigemptyset(sigs);
libc::sigaddset(sigs, signal);
libc::sigprocmask(libc::SIG_UNBLOCK, sigs.cast_const(), std::ptr::null_mut());
}
panic!("SIGSEGV!");
}