use proc_macro2::TokenStream;
use quote::quote;
use syn::parse::{Parse, ParseStream};
use syn::{LitStr, Token};
use std::path::{Path, PathBuf};
use std::process::Command;
pub fn expand(input: TokenStream) -> TokenStream {
let path_lit = match syn::parse2::<LitStr>(input) {
Ok(lit) => lit,
Err(err) => return err.to_compile_error(),
};
let raw_path = path_lit.value();
if raw_path.ends_with(".cbor") {
expand_cbor(&path_lit)
} else if raw_path.ends_with(".hs") || raw_path.contains(".hs::") {
expand_hs(&path_lit, &raw_path)
} else {
syn::Error::new(
path_lit.span(),
"haskell_eval! path must end in .cbor or .hs",
)
.to_compile_error()
}
}
fn expand_cbor(path_lit: &LitStr) -> TokenStream {
quote! {
{
static __CBOR: &[u8] = include_bytes!(#path_lit);
let __expr = tidepool_repr::serial::read::read_cbor(__CBOR)
.expect("failed to deserialize CBOR — re-run extraction (cargo xtask extract)");
let mut __heap = tidepool_eval::heap::VecHeap::new();
let __env = tidepool_eval::env::Env::new();
tidepool_eval::eval::eval(&__expr, &__env, &mut __heap)
}
}
}
fn expand_hs(path_lit: &LitStr, raw_path: &str) -> TokenStream {
let (hs_path_str, binding_name) = match raw_path.split_once(".hs::") {
Some((prefix, binding)) => (format!("{}.hs", prefix), Some(binding.to_string())),
None => (raw_path.to_string(), None),
};
let manifest_dir = match std::env::var("CARGO_MANIFEST_DIR") {
Ok(d) => d,
Err(_) => {
return syn::Error::new(path_lit.span(), "CARGO_MANIFEST_DIR not set")
.to_compile_error();
}
};
let abs_hs_path = Path::new(&manifest_dir).join(&hs_path_str);
if !abs_hs_path.exists() {
return syn::Error::new(
path_lit.span(),
format!("Haskell source not found: {}", abs_hs_path.display()),
)
.to_compile_error();
}
let basename = abs_hs_path.file_stem().unwrap().to_str().unwrap();
let output_dir = Path::new(&manifest_dir)
.join("target")
.join("tidepool-cbor")
.join(basename);
if let Err(msg) = run_tidepool_extract(
&abs_hs_path,
&output_dir,
binding_name.as_deref(),
Path::new(&manifest_dir),
) {
return syn::Error::new(path_lit.span(), msg).to_compile_error();
}
let cbor_path = match binding_name {
Some(ref name) => {
let p = output_dir.join(format!("{}.cbor", name));
if !p.exists() {
let available = list_bindings(&output_dir);
return syn::Error::new(
path_lit.span(),
format!("Binding '{}' not found. Available: {:?}", name, available),
)
.to_compile_error();
}
p
}
None => match find_single_binding(&output_dir) {
Ok(p) => p,
Err(msg) => {
return syn::Error::new(path_lit.span(), msg).to_compile_error();
}
},
};
let cbor_path_str = cbor_path.to_str().unwrap();
let hs_abs_str = abs_hs_path.to_str().unwrap();
quote! {
{
const _: &[u8] = include_bytes!(#hs_abs_str);
static __CBOR: &[u8] = include_bytes!(#cbor_path_str);
let __expr = tidepool_repr::serial::read::read_cbor(__CBOR)
.expect("failed to deserialize CBOR — re-run extraction");
let mut __heap = tidepool_eval::heap::VecHeap::new();
let __env = tidepool_eval::env::Env::new();
tidepool_eval::eval::eval(&__expr, &__env, &mut __heap)
}
}
}
pub fn expand_expr(input: TokenStream) -> TokenStream {
let path_lit = match syn::parse2::<LitStr>(input) {
Ok(lit) => lit,
Err(err) => return err.to_compile_error(),
};
let raw_path = path_lit.value();
if raw_path.ends_with(".cbor") {
expand_expr_cbor(&path_lit)
} else if raw_path.ends_with(".hs") || raw_path.contains(".hs::") {
expand_expr_hs(&path_lit, &raw_path)
} else {
syn::Error::new(
path_lit.span(),
"haskell_expr! path must end in .cbor or .hs",
)
.to_compile_error()
}
}
fn expand_expr_cbor(path_lit: &LitStr) -> TokenStream {
let cbor_path = path_lit.value();
let cbor_dir = Path::new(&cbor_path)
.parent()
.expect("cbor path has no parent");
let meta_path = cbor_dir.join("meta.cbor");
let meta_path_str = meta_path.to_str().unwrap();
quote! {
{
static __CBOR: &[u8] = include_bytes!(#path_lit);
static __META: &[u8] = include_bytes!(#meta_path_str);
let __expr = tidepool_repr::serial::read::read_cbor(__CBOR)
.expect("failed to deserialize CBOR");
let (__table, _warnings) = tidepool_repr::serial::read::read_metadata(__META)
.expect("failed to deserialize metadata");
(__expr, __table)
}
}
}
fn expand_expr_hs(path_lit: &LitStr, raw_path: &str) -> TokenStream {
let (hs_path_str, binding_name) = match raw_path.split_once(".hs::") {
Some((prefix, binding)) => (format!("{}.hs", prefix), Some(binding.to_string())),
None => (raw_path.to_string(), None),
};
let manifest_dir = match std::env::var("CARGO_MANIFEST_DIR") {
Ok(d) => d,
Err(_) => {
return syn::Error::new(path_lit.span(), "CARGO_MANIFEST_DIR not set")
.to_compile_error();
}
};
let abs_hs_path = Path::new(&manifest_dir).join(&hs_path_str);
if !abs_hs_path.exists() {
return syn::Error::new(
path_lit.span(),
format!("Haskell source not found: {}", abs_hs_path.display()),
)
.to_compile_error();
}
let basename = abs_hs_path.file_stem().unwrap().to_str().unwrap();
let output_dir = Path::new(&manifest_dir)
.join("target")
.join("tidepool-cbor")
.join(basename);
if let Err(msg) = run_tidepool_extract(
&abs_hs_path,
&output_dir,
binding_name.as_deref(),
Path::new(&manifest_dir),
) {
return syn::Error::new(path_lit.span(), msg).to_compile_error();
}
let cbor_path = match binding_name {
Some(ref name) => {
let p = output_dir.join(format!("{}.cbor", name));
if !p.exists() {
let available = list_bindings(&output_dir);
return syn::Error::new(
path_lit.span(),
format!("Binding '{}' not found. Available: {:?}", name, available),
)
.to_compile_error();
}
p
}
None => match find_single_binding(&output_dir) {
Ok(p) => p,
Err(msg) => {
return syn::Error::new(path_lit.span(), msg).to_compile_error();
}
},
};
let cbor_path_str = cbor_path.to_str().unwrap();
let hs_abs_str = abs_hs_path.to_str().unwrap();
let meta_path = output_dir.join("meta.cbor");
let meta_path_str = meta_path.to_str().unwrap();
quote! {
{
const _: &[u8] = include_bytes!(#hs_abs_str);
static __CBOR: &[u8] = include_bytes!(#cbor_path_str);
static __META: &[u8] = include_bytes!(#meta_path_str);
let __expr = tidepool_repr::serial::read::read_cbor(__CBOR)
.expect("failed to deserialize CBOR — re-run extraction");
let (__table, _warnings) = tidepool_repr::serial::read::read_metadata(__META)
.expect("failed to deserialize metadata");
(__expr, __table)
}
}
}
struct InlineInput {
target: String,
includes: Vec<String>,
source: LitStr,
}
impl Parse for InlineInput {
fn parse(input: ParseStream) -> syn::Result<Self> {
let target_ident: syn::Ident = input.parse()?;
if target_ident != "target" {
return Err(syn::Error::new(target_ident.span(), "expected `target`"));
}
input.parse::<Token![=]>()?;
let target_lit: LitStr = input.parse()?;
let target = target_lit.value();
input.parse::<Token![,]>()?;
let mut includes = Vec::new();
if input.peek(syn::Ident) {
let maybe_include = input.fork();
let ident: syn::Ident = maybe_include.parse()?;
if ident == "include" {
let _: syn::Ident = input.parse()?;
input.parse::<Token![=]>()?;
if input.peek(syn::token::Bracket) {
let content;
syn::bracketed!(content in input);
while !content.is_empty() {
let lit: LitStr = content.parse()?;
includes.push(lit.value());
if !content.is_empty() {
content.parse::<Token![,]>()?;
}
}
} else {
let lit: LitStr = input.parse()?;
includes.push(lit.value());
}
let _ = input.parse::<Token![,]>();
}
}
let source = if input.is_empty() {
LitStr::new("", proc_macro2::Span::call_site())
} else {
let _ = input.parse::<Token![,]>();
if input.is_empty() {
LitStr::new("", proc_macro2::Span::call_site())
} else {
input.parse()?
}
};
Ok(InlineInput {
target,
includes,
source,
})
}
}
pub fn expand_inline(input: TokenStream) -> TokenStream {
let parsed = match syn::parse2::<InlineInput>(input) {
Ok(p) => p,
Err(err) => return err.to_compile_error(),
};
let manifest_dir = match std::env::var("CARGO_MANIFEST_DIR") {
Ok(d) => d,
Err(_) => {
return syn::Error::new(parsed.source.span(), "CARGO_MANIFEST_DIR not set")
.to_compile_error();
}
};
let module_name = capitalize(&parsed.target);
let abs_includes: Vec<PathBuf> = parsed
.includes
.iter()
.map(|d| Path::new(&manifest_dir).join(d))
.collect();
let mut included_module_names: Vec<String> = Vec::new();
for dir in &abs_includes {
if let Ok(entries) = std::fs::read_dir(dir) {
for entry in entries.flatten() {
let p = entry.path();
if p.extension().is_some_and(|ext| ext == "hs") {
if let Some(stem) = p.file_stem().and_then(|s| s.to_str()) {
included_module_names.push(stem.to_string());
}
}
}
}
}
let mut all_extensions = vec![
"GADTs".to_string(),
"DataKinds".to_string(),
"TypeOperators".to_string(),
"FlexibleContexts".to_string(),
];
let mut all_imports = vec!["import Control.Monad.Freer".to_string()];
let mut include_bodies = String::new();
for dir in &abs_includes {
if let Ok(entries) = std::fs::read_dir(dir) {
for entry in entries.flatten() {
let p = entry.path();
if p.extension().is_some_and(|ext| ext == "hs") {
if let Ok(content) = std::fs::read_to_string(&p) {
let header = strip_module_header(&content);
for ext in header.extensions {
if !all_extensions.contains(&ext) {
all_extensions.push(ext);
}
}
for imp in header.imports {
let is_internal = included_module_names.iter().any(|m| {
imp.trim().starts_with(&format!("import {}", m))
|| imp.trim().starts_with(&format!("import qualified {}", m))
});
if !is_internal && !all_imports.contains(&imp) {
all_imports.push(imp);
}
}
include_bodies.push_str(&header.body);
include_bodies.push('\n');
}
}
}
}
}
let source_text = parsed.source.value();
let extensions_str = all_extensions.join(", ");
let imports_str = all_imports.join("\n");
let full_source = format!(
"{{-# LANGUAGE {} #-}}\nmodule {} where\n{}\n{}\n{}",
extensions_str, module_name, imports_str, include_bodies, source_text
);
let inline_dir = Path::new(&manifest_dir)
.join("target")
.join("tidepool-inline");
if let Err(e) = std::fs::create_dir_all(&inline_dir) {
return syn::Error::new(
parsed.source.span(),
format!("Failed to create {}: {}", inline_dir.display(), e),
)
.to_compile_error();
}
let hs_file = inline_dir.join(format!("{}.hs", module_name));
if let Err(e) = std::fs::write(&hs_file, &full_source) {
return syn::Error::new(
parsed.source.span(),
format!("Failed to write {}: {}", hs_file.display(), e),
)
.to_compile_error();
}
let output_dir = Path::new(&manifest_dir)
.join("target")
.join("tidepool-cbor")
.join(&module_name);
if let Err(msg) = run_tidepool_extract(
&hs_file,
&output_dir,
Some(&parsed.target),
Path::new(&manifest_dir),
) {
return syn::Error::new(parsed.source.span(), msg).to_compile_error();
}
let cbor_path = output_dir.join(format!("{}.cbor", parsed.target));
if !cbor_path.exists() {
let available = list_bindings(&output_dir);
return syn::Error::new(
parsed.source.span(),
format!(
"Binding '{}' not found after compilation. Available: {:?}",
parsed.target, available
),
)
.to_compile_error();
}
let cbor_path_str = cbor_path.to_str().unwrap();
let meta_path = output_dir.join("meta.cbor");
let meta_path_str = meta_path.to_str().unwrap();
let hs_path_str = hs_file.to_str().unwrap();
let include_tracks: Vec<TokenStream> = abs_includes
.iter()
.filter_map(|dir| {
std::fs::read_dir(dir).ok().map(|entries| {
entries
.filter_map(|e| e.ok())
.map(|e| e.path())
.filter(|p| p.extension().is_some_and(|ext| ext == "hs"))
.map(|p| {
let s = p.to_str().unwrap().to_string();
quote! { const _: &[u8] = include_bytes!(#s); }
})
.collect::<Vec<_>>()
})
})
.flatten()
.collect();
quote! {
{
const _: &[u8] = include_bytes!(#hs_path_str);
#(#include_tracks)*
static __CBOR: &[u8] = include_bytes!(#cbor_path_str);
static __META: &[u8] = include_bytes!(#meta_path_str);
let __expr = tidepool_repr::serial::read::read_cbor(__CBOR)
.expect("failed to deserialize CBOR — re-run extraction");
let (__table, _warnings) = tidepool_repr::serial::read::read_metadata(__META)
.expect("failed to deserialize metadata");
(__expr, __table)
}
}
}
struct HaskellHeader {
extensions: Vec<String>,
imports: Vec<String>,
body: String,
}
fn strip_module_header(source: &str) -> HaskellHeader {
let mut extensions = Vec::new();
let mut imports = Vec::new();
let mut body_lines: Vec<&str> = Vec::new();
let mut past_header = false;
for line in source.lines() {
let trimmed = line.trim();
if !past_header {
if trimmed.starts_with("{-#") && trimmed.contains("LANGUAGE") {
if let Some(start) = trimmed.find("LANGUAGE") {
let after = &trimmed[start + "LANGUAGE".len()..];
if let Some(end) = after.find("#-}") {
let exts = &after[..end];
for ext in exts.split(',') {
let ext = ext.trim();
if !ext.is_empty() {
extensions.push(ext.to_string());
}
}
}
}
continue;
}
if trimmed.starts_with("{-#") || trimmed.starts_with("module ") || trimmed.is_empty() {
continue;
}
if trimmed.starts_with("import ") {
imports.push(line.to_string());
continue;
}
past_header = true;
}
body_lines.push(line);
}
HaskellHeader {
extensions,
imports,
body: body_lines.join("\n"),
}
}
fn capitalize(s: &str) -> String {
let mut chars = s.chars();
match chars.next() {
None => String::new(),
Some(c) => c.to_uppercase().collect::<String>() + chars.as_str(),
}
}
fn run_tidepool_extract(
hs_path: &Path,
output_dir: &Path,
target: Option<&str>,
manifest_dir: &Path,
) -> Result<(), String> {
let mut cmd = Command::new("tidepool-extract");
cmd.arg(hs_path);
cmd.arg("--output-dir");
cmd.arg(output_dir);
if let Some(name) = target {
cmd.arg("--target");
cmd.arg(name);
}
match cmd.output() {
Ok(output) if output.status.success() => return Ok(()),
Ok(_) | Err(_) => {
}
}
let flake_root = find_flake_root(manifest_dir).ok_or_else(|| {
"tidepool-extract not found on PATH and no flake.nix in any parent directory".to_string()
})?;
let mut cmd = Command::new("nix");
cmd.args([
"run",
&format!("{}#tidepool-extract", flake_root.display()),
"--",
]);
cmd.arg(hs_path);
cmd.arg("--output-dir");
cmd.arg(output_dir);
if let Some(name) = target {
cmd.arg("--target");
cmd.arg(name);
}
match cmd.output() {
Ok(output) if output.status.success() => Ok(()),
Ok(output) => {
let stderr = String::from_utf8_lossy(&output.stderr);
Err(format!(
"nix run tidepool-extract failed (exit {}):\n{}",
output.status, stderr
))
}
Err(e) => Err(format!("Failed to run nix: {}. Is nix installed?", e)),
}
}
fn find_flake_root(start: &Path) -> Option<PathBuf> {
let mut dir = start.to_path_buf();
loop {
if dir.join("flake.nix").exists() {
return Some(dir);
}
if !dir.pop() {
return None;
}
}
}
fn list_bindings(output_dir: &Path) -> Vec<String> {
std::fs::read_dir(output_dir)
.into_iter()
.flatten()
.filter_map(|e| e.ok())
.map(|e| e.path())
.filter(|p| p.extension().is_some_and(|ext| ext == "cbor"))
.filter(|p| p.file_stem().is_some_and(|s| s != "meta"))
.filter_map(|p| p.file_stem().map(|s| s.to_string_lossy().into_owned()))
.collect()
}
fn find_single_binding(output_dir: &Path) -> Result<PathBuf, String> {
let entries: Vec<PathBuf> = std::fs::read_dir(output_dir)
.map(|rd| {
rd.filter_map(|e| e.ok())
.map(|e| e.path())
.filter(|p| p.extension().is_some_and(|ext| ext == "cbor"))
.filter(|p| p.file_stem().is_some_and(|s| s != "meta"))
.collect()
})
.unwrap_or_default();
match entries.len() {
0 => Err("No .cbor bindings produced by tidepool-extract".to_string()),
1 => Ok(entries.into_iter().next().unwrap()),
_ => {
let names: Vec<String> = entries
.iter()
.filter_map(|p| p.file_stem().map(|s| s.to_string_lossy().into_owned()))
.collect();
Err(format!(
"Multiple bindings found: {:?}. Use haskell_eval!(\"path.hs::binding_name\")",
names
))
}
}
}