use anyhow::{Context, Result};
use clap::Parser;
use inkwell::context::Context as InkwellContext;
use std::collections::HashSet;
use std::fs;
use std::path::{Path, PathBuf};
use std::process::Command;
use tl_lang::compiler::codegen::CodeGenerator;
use tl_lang::compiler::error::{format_error_with_source, TlError};
use tl_lang::compiler::inference::{forward_chain, query, GroundAtom, Value};
use tl_lang::compiler::semantics::SemanticAnalyzer;
#[derive(Parser)]
#[command(name = "tlc")]
#[command(version)]
#[command(about = "Tensor Logic Compiler", long_about = None)]
struct Cli {
#[arg(required = true)]
files: Vec<String>,
#[arg(short, long)]
compile: bool,
#[arg(short, long)]
output: Option<PathBuf>,
#[arg(short = 'S', long)]
save_asm: bool,
#[arg(long = "emit-llvm")]
emit_llvm: bool,
#[arg(short, long, default_value = "auto")]
device: String,
#[arg(last = true)]
args: Vec<String>,
#[arg(long)]
mem_log: bool,
#[arg(short, long, action = clap::ArgAction::Count)]
verbose: u8,
}
fn main() -> Result<()> {
let cli = Cli::parse();
let mut builder = env_logger::Builder::new();
builder.filter_level(log::LevelFilter::Warn);
builder.filter_module("tokenizers", log::LevelFilter::Error);
match cli.verbose {
0 => {
if std::env::var("RUST_LOG").is_ok() {
builder.parse_default_env();
}
}
1 => { builder.filter_level(log::LevelFilter::Info); }
2 => { builder.filter_level(log::LevelFilter::Debug); }
_ => { builder.filter_level(log::LevelFilter::Trace); }
};
builder.init();
if std::env::var("TL_DEVICE").is_err() {
unsafe { std::env::set_var("TL_DEVICE", &cli.device); }
}
if cli.mem_log {
unsafe { std::env::set_var("TL_MEM_LOG", "1"); }
}
let mut source_files = Vec::new();
let mut object_files = Vec::new();
for f in &cli.files {
let p = PathBuf::from(f);
if let Some(ext) = p.extension() {
if ext == "tl" {
source_files.push(p);
} else if ext == "o" || ext == "s" {
object_files.push(p);
} else {
source_files.push(p);
}
} else {
source_files.push(p);
}
}
let builtins = load_builtins().context("Failed to load builtins")?;
log::info!("Loaded builtins: {} structs, {} impls", builtins.structs.len(), builtins.impls.len());
let is_compile_mode = cli.compile || cli.output.is_some() || cli.save_asm || cli.emit_llvm;
if is_compile_mode {
let mut generated_objects = Vec::new();
for file in &source_files {
log::info!("Compiling file: {:?}", file);
let (mut ast, source) = match load_module_with_source(file.clone()) {
Ok((ast, source)) => (ast, source),
Err(e) => {
let source = fs::read_to_string(file).unwrap_or_default();
print_tl_error_with_source(
&e,
&source,
Some(file.to_str().unwrap_or("unknown")),
);
std::process::exit(1);
}
};
ast.structs.extend(builtins.structs.clone());
ast.enums.extend(builtins.enums.clone());
ast.impls.extend(builtins.impls.clone());
ast.functions.extend(builtins.functions.clone());
let mut analyzer = SemanticAnalyzer::new(String::new());
if let Err(e) = analyzer.check_module(&mut ast) {
let tl_err = e.with_file(file.to_str().unwrap_or("unknown"));
print_tl_error_with_source(
&tl_err,
&source,
Some(file.to_str().unwrap_or("unknown")),
);
std::process::exit(1);
}
let mut monomorphizer = tl_lang::compiler::monomorphize::Monomorphizer::new();
if let Err(e) = monomorphizer.run(&mut ast) {
let tl_err = e.with_file(file.to_str().unwrap_or("unknown"));
print_tl_error_with_source(
&tl_err,
&source,
Some(file.to_str().unwrap_or("unknown")),
);
std::process::exit(1);
}
let context = InkwellContext::create();
let module_name = file.file_stem().unwrap().to_str().unwrap();
let mut codegen = CodeGenerator::new(&context, module_name);
if let Err(e) = codegen.compile_module(&ast, "main") {
let tl_err = TlError::Codegen {
kind: tl_lang::compiler::error::CodegenErrorKind::Generic(e),
span: None,
}
.with_file(file.to_str().unwrap_or("unknown"));
print_tl_error_with_source(
&tl_err,
&source,
Some(file.to_str().unwrap_or("unknown")),
);
std::process::exit(1);
}
if std::env::var("TL_DUMP_IR").is_ok() {
codegen.dump_ir();
}
if cli.emit_llvm {
let ll_path = file.with_extension("ll");
if let Err(e) = codegen.emit_llvm_file(&ll_path) {
log::error!("Failed to emit LLVM IR for {:?}: {}", file, e);
std::process::exit(1);
}
log::info!("Generated LLVM IR: {:?}", ll_path);
}
if cli.save_asm {
let asm_path = file.with_extension("s");
if let Err(e) = codegen.emit_assembly_file(&asm_path) {
log::error!("Failed to emit assembly for {:?}: {}", file, e);
std::process::exit(1);
}
log::info!("Generated assembly: {:?}", asm_path);
} else if !cli.emit_llvm {
let obj_path = file.with_extension("o");
if let Err(e) = codegen.emit_object_file(&obj_path) {
log::error!("Failed to emit object file for {:?}: {}", file, e);
std::process::exit(1);
}
generated_objects.push(obj_path);
}
}
let output_is_object = cli
.output
.as_ref()
.map(|p| p.extension().map_or(false, |e| e == "o"))
.unwrap_or(false);
if (cli.compile || cli.output.is_some()) && !cli.save_asm && !cli.emit_llvm && !output_is_object {
let mut link_args = Vec::new();
link_args.extend(
generated_objects
.iter()
.map(|p| p.to_str().unwrap().to_string()),
);
link_args.extend(object_files.iter().map(|p| p.to_str().unwrap().to_string()));
let output_exe = if let Some(out) = cli.output {
out
} else {
if !source_files.is_empty() {
let mut p = source_files[0].clone();
p.set_extension("");
p
} else {
PathBuf::from("a.out")
}
};
log::info!("Linking to {:?}", output_exe);
let runtime_path = PathBuf::from("target/debug");
link_args.push(format!("-L{}", runtime_path.display()));
link_args.push("-ltl_runtime".to_string());
link_args.push("-lpthread".to_string());
link_args.push("-ldl".to_string());
link_args.push("-lm".to_string());
link_args.push("-lc++".to_string());
#[cfg(target_os = "macos")]
{
link_args.push("-framework".to_string());
link_args.push("Accelerate".to_string());
link_args.push("-framework".to_string());
link_args.push("Metal".to_string());
link_args.push("-framework".to_string());
link_args.push("Foundation".to_string());
link_args.push("-framework".to_string());
link_args.push("MetalPerformanceShaders".to_string());
link_args.push("-framework".to_string());
link_args.push("Security".to_string());
link_args.push("-framework".to_string());
link_args.push("CoreFoundation".to_string());
link_args.push("-framework".to_string());
link_args.push("SystemConfiguration".to_string());
}
let status = Command::new("cc")
.args(&link_args)
.arg("-o")
.arg(&output_exe)
.status()
.context("Failed to run linker (cc)")?;
if !status.success() {
log::error!("Linking failed");
std::process::exit(1);
}
log::info!("Build successful: {:?}", output_exe);
}
} else {
tl_runtime::args::init_args(cli.args.clone());
tl_runtime::force_link();
let mut combined_module = tl_lang::compiler::ast::Module {
structs: vec![],
enums: vec![],
impls: vec![],
functions: vec![],
tensor_decls: vec![],
relations: vec![],
rules: vec![],
queries: vec![],
imports: vec![],
submodules: std::collections::HashMap::new(),
};
let mut combined_source = String::new();
for file in &source_files {
match load_module_with_source(file.clone()) {
Ok((mod_, source)) => {
combined_module.structs.extend(mod_.structs);
combined_module.enums.extend(mod_.enums);
combined_module.impls.extend(mod_.impls);
combined_module.functions.extend(mod_.functions);
combined_module.tensor_decls.extend(mod_.tensor_decls);
combined_module.relations.extend(mod_.relations);
combined_module.rules.extend(mod_.rules);
combined_module.queries.extend(mod_.queries);
combined_module.imports.extend(mod_.imports);
combined_module.submodules.extend(mod_.submodules);
if combined_source.is_empty() {
combined_source = source;
}
}
Err(e) => {
let source = fs::read_to_string(file).unwrap_or_default();
print_tl_error_with_source(
&e,
&source,
Some(file.to_str().unwrap_or("unknown")),
);
std::process::exit(1);
}
}
}
combined_module.structs.extend(builtins.structs.clone());
combined_module.enums.extend(builtins.enums.clone());
combined_module.impls.extend(builtins.impls.clone());
combined_module.functions.extend(builtins.functions.clone());
let mut analyzer = SemanticAnalyzer::new(String::new());
if let Err(e) = analyzer.check_module(&mut combined_module) {
let file_hint = if !source_files.is_empty() {
source_files[0].to_str()
} else {
None
};
let tl_err = if let Some(f) = file_hint {
e.with_file(f)
} else {
e
};
print_tl_error_with_source(&tl_err, &combined_source, file_hint);
std::process::exit(1);
}
let mut monomorphizer = tl_lang::compiler::monomorphize::Monomorphizer::new();
if let Err(e) = monomorphizer.run(&mut combined_module) {
print_tl_error_with_source(&e, &combined_source, None);
std::process::exit(1);
}
use tl_runtime::registry;
registry::reset_global_context();
let context = InkwellContext::create();
let mut codegen = CodeGenerator::new(&context, "main");
eprintln!("[DEBUG] Starting compile_module");
if let Err(e) = codegen.compile_module(&combined_module, "main") {
let tl_err = TlError::Codegen {
kind: tl_lang::compiler::error::CodegenErrorKind::Generic(e),
span: None,
};
print_tl_error_with_source(&tl_err, &combined_source, None);
std::process::exit(1);
}
eprintln!("[DEBUG] compile_module completed");
if std::env::var("TL_DUMP_IR").is_ok() {
codegen.dump_ir();
}
match codegen.jit_execute("main") {
Ok(ret) => {
let _ = ret; }
Err(e) => {
println!("Execution failed: {}", e);
std::process::exit(1);
}
}
tl_runtime::system::tl_metal_sync();
let is_logic_program = !combined_module.relations.is_empty()
|| !combined_module.rules.is_empty()
|| !combined_module.queries.is_empty();
if is_logic_program {
let tensor_context = registry::get_global_context();
run_logic_program(&combined_module, &tensor_context);
}
}
Ok(())
}
fn run_logic_program(
module: &tl_lang::compiler::ast::Module,
ctx: &tl_lang::compiler::inference::TensorContext,
) {
use tl_lang::compiler::ast::{Atom, ExprKind};
log::info!("Executing logic program...");
let mut initial_facts: HashSet<GroundAtom> = HashSet::new();
let mut rules = Vec::new();
for rule in &module.rules {
let head_ground = try_atom_to_ground(&rule.head);
if let Some(ground) = head_ground {
if rule.body.is_empty() || is_trivially_true(&rule.body) {
initial_facts.insert(ground);
continue;
}
}
rules.push(rule.clone());
}
log::info!("Initial facts: {}", initial_facts.len());
log::info!("Rules: {}", rules.len());
let derived_facts = match forward_chain(initial_facts, &rules, ctx) {
Ok(f) => f,
Err(e) => {
log::error!("Inference error: {}", e);
return;
}
};
log::info!("Derived facts: {}", derived_facts.len());
for query_expr in &module.queries {
if let ExprKind::FnCall(pred, args) = &query_expr.inner {
let query_atom = Atom {
predicate: pred.clone(),
args: args.clone(),
};
println!("\nQuery: {}({:?})", pred, args);
let results = query(&query_atom, &derived_facts, ctx);
if results.is_empty() {
println!(" Result: false (no matches)");
} else {
println!(" Result: true ({} matches)", results.len());
for (i, subst) in results.iter().enumerate() {
if !subst.is_empty() {
println!(" Match {}: {:?}", i + 1, subst);
}
}
}
} else {
println!("Unsupported query expression: {:?}", query_expr);
}
}
}
fn try_atom_to_ground(atom: &tl_lang::compiler::ast::Atom) -> Option<GroundAtom> {
use tl_lang::compiler::ast::ExprKind;
let mut args = Vec::new();
for expr in &atom.args {
match &expr.inner {
ExprKind::Int(n) => args.push(Value::Int(*n)),
ExprKind::Float(f) => args.push(Value::Float(*f)),
ExprKind::Symbol(s) => args.push(Value::Str(s.clone())),
ExprKind::Bool(b) => args.push(Value::Bool(*b)),
ExprKind::StringLiteral(s) => args.push(Value::Str(s.clone())),
_ => return None, }
}
Some(GroundAtom {
predicate: atom.predicate.clone(),
args,
})
}
fn is_trivially_true(body: &[tl_lang::compiler::ast::LogicLiteral]) -> bool {
use tl_lang::compiler::ast::LogicLiteral;
if body.is_empty() {
return true;
}
if body.len() == 1 {
if let LogicLiteral::Pos(atom) = &body[0] {
return atom.predicate == "true" && atom.args.is_empty();
}
}
false
}
fn load_module_with_source(
path: PathBuf,
) -> Result<(tl_lang::compiler::ast::Module, String), TlError> {
let mut visited = HashSet::new();
load_module_recursive(path, &mut visited)
}
fn load_module_recursive(
path: PathBuf,
visited: &mut HashSet<PathBuf>,
) -> Result<(tl_lang::compiler::ast::Module, String), TlError> {
let canonical_path = match fs::canonicalize(&path) {
Ok(p) => p,
Err(_) => path.clone(), };
if visited.contains(&canonical_path) {
return Err(TlError::Parse {
kind: tl_lang::compiler::error::ParseErrorKind::Generic(format!(
"Cyclic dependency detected: {:?}",
path
)),
span: None,
});
}
visited.insert(canonical_path.clone());
let path_str = path.to_str().unwrap_or("unknown").to_string();
let content = match fs::read_to_string(&path) {
Ok(c) => c,
Err(e) => {
return Err(TlError::Io(e));
}
};
let source = content.clone();
let mut module = match tl_lang::compiler::parser::parse_from_source(&content) {
Ok(m) => m,
Err(e) => {
return Err(e.with_file(&path_str));
}
};
let parent_dir = path.parent().unwrap_or(Path::new("."));
let imports = module.imports.clone();
for import_name in &imports {
let is_wildcard = import_name.ends_with("::*");
let real_name = if is_wildcard {
import_name.trim_end_matches("::*")
} else {
import_name
};
let import_path = parent_dir.join(format!("{}.tl", real_name));
if !import_path.exists() {
return Err(TlError::Parse {
kind: tl_lang::compiler::error::ParseErrorKind::Generic(format!(
"Module {} not found at {:?}",
real_name, import_path
)),
span: None,
});
}
match load_module_recursive(import_path, visited) {
Ok((submodule, _)) => {
if is_wildcard {
module.structs.extend(submodule.structs);
module.enums.extend(submodule.enums);
module.impls.extend(submodule.impls);
module.functions.extend(submodule.functions);
module.tensor_decls.extend(submodule.tensor_decls);
module.relations.extend(submodule.relations);
module.rules.extend(submodule.rules);
module.queries.extend(submodule.queries);
module.submodules.extend(submodule.submodules);
} else {
module.submodules.insert(import_name.clone(), submodule);
}
}
Err(e) => return Err(e),
}
}
visited.remove(&canonical_path);
Ok((module, source))
}
fn print_tl_error_with_source(error: &TlError, source: &str, file_hint: Option<&str>) {
let output = format_error_with_source(error, source, file_hint);
eprint!("{}", output);
}
fn load_builtins() -> Result<tl_lang::compiler::ast::Module> {
use tl_lang::compiler::codegen::builtin_types;
let sources = [
builtin_types::vec::SOURCE,
builtin_types::hashmap::SOURCE,
builtin_types::option::SOURCE,
builtin_types::result::SOURCE,
builtin_types::llm::SOURCE,
];
let mut combined = tl_lang::compiler::ast::Module {
structs: vec![],
enums: vec![],
impls: vec![],
functions: vec![],
tensor_decls: vec![],
relations: vec![],
rules: vec![],
queries: vec![],
imports: vec![],
submodules: std::collections::HashMap::new(),
};
for (i, src) in sources.iter().enumerate() {
let m = tl_lang::compiler::parser::parse_from_source(src)
.map_err(|e| anyhow::anyhow!("Failed to parse builtin {}: {:?}", i, e))?;
combined.structs.extend(m.structs);
combined.enums.extend(m.enums);
combined.impls.extend(m.impls);
combined.functions.extend(m.functions);
combined.tensor_decls.extend(m.tensor_decls);
combined.relations.extend(m.relations);
combined.rules.extend(m.rules);
combined.queries.extend(m.queries);
}
Ok(combined)
}