use std::path::{Path, PathBuf};
use std::process::ExitCode;
use flodl::{Device, Graph};
use flodl_cli::{parse_or_schema, FdlArgs, FdlArgsTrait};
use flodl_hf::export::{build_for_export, export_hf_dir};
use flodl_hf::hub::HubExportHead;
use flodl_hf::models::auto::{AutoConfig, AutoModel};
use flodl_hf::safetensors_io::keys_have_pooler;
#[derive(FdlArgs, Debug)]
struct ExportArgs {
#[option]
hub: Option<String>,
#[option]
checkpoint: Option<String>,
#[option]
out: Option<String>,
#[option]
config: Option<String>,
#[option]
force: bool,
#[option]
preserve_source_config: bool,
#[option]
head: Option<String>,
}
fn resolve_path(arg: &str) -> PathBuf {
let p = Path::new(arg);
if p.is_absolute() {
return p.to_path_buf();
}
if let Some(root) = std::env::var_os("FDL_PROJECT_ROOT") {
return PathBuf::from(root).join(p);
}
p.to_path_buf()
}
fn main() -> ExitCode {
let cli: ExportArgs = parse_or_schema();
match dispatch(&cli) {
Ok(()) => ExitCode::SUCCESS,
Err(DispatchError::Usage(msg)) => {
eprintln!("error: {msg}");
eprintln!();
eprintln!("{}", ExportArgs::render_help());
ExitCode::from(2)
}
Err(DispatchError::Runtime(msg)) => {
eprintln!("error: {msg}");
ExitCode::FAILURE
}
}
}
enum DispatchError {
Usage(String),
Runtime(String),
}
impl From<flodl::TensorError> for DispatchError {
fn from(e: flodl::TensorError) -> Self {
DispatchError::Runtime(format!("{e}"))
}
}
fn dispatch(cli: &ExportArgs) -> Result<(), DispatchError> {
match (cli.hub.is_some(), cli.checkpoint.is_some()) {
(true, true) => {
return Err(DispatchError::Usage(
"--hub and --checkpoint are mutually exclusive; pass exactly one."
.into(),
));
}
(false, false) => {
return Err(DispatchError::Usage(
"missing required input: pass --hub <repo_id> or --checkpoint <file.fdl>."
.into(),
));
}
_ => {}
}
if cli.preserve_source_config && cli.hub.is_some() {
return Err(DispatchError::Usage(
"--preserve-source-config requires --checkpoint (Hub mode regenerates config from to_json_str)."
.into(),
));
}
if cli.config.is_some() && cli.hub.is_some() {
return Err(DispatchError::Usage(
"--config is only meaningful with --checkpoint.".into(),
));
}
if cli.head.is_some() && cli.checkpoint.is_some() {
return Err(DispatchError::Usage(
"--head is Hub-mode only (checkpoint mode reads the architecture from the sidecar config)."
.into(),
));
}
let head_override = match cli.head.as_deref() {
None | Some("auto") => None,
Some(other) => Some(
HubExportHead::parse(other).map_err(|e| DispatchError::Usage(e.to_string()))?,
),
};
let out_arg = cli
.out
.as_deref()
.ok_or_else(|| DispatchError::Usage("missing required --out <dir>.".into()))?;
let out_dir = resolve_path(out_arg);
if !cli.force {
let model_path = out_dir.join("model.safetensors");
let config_path = out_dir.join("config.json");
let source_path = out_dir.join("config.source.json");
let preserve_check = cli.preserve_source_config && source_path.exists();
if model_path.exists() || config_path.exists() || preserve_check {
return Err(DispatchError::Runtime(format!(
"{} already contains model.safetensors / config.json (or config.source.json under --preserve-source-config). Pass --force to overwrite.",
out_dir.display(),
)));
}
}
if let Some(repo_id) = cli.hub.as_deref() {
run_hub(repo_id, &out_dir, head_override)?;
} else if let Some(checkpoint_path) = cli.checkpoint.as_deref() {
run_checkpoint(
checkpoint_path,
cli.config.as_deref(),
&out_dir,
cli.preserve_source_config,
)?;
}
Ok(())
}
fn run_hub(
repo_id: &str,
out_dir: &Path,
head_override: Option<HubExportHead>,
) -> flodl::Result<()> {
eprintln!("fetching config.json for {repo_id} ...");
let config = AutoConfig::from_pretrained(repo_id)?;
eprintln!("detected family: {}", config.model_type());
eprintln!("loading weights for {repo_id} ...");
let graph = match head_override {
Some(head) => {
eprintln!("forcing head class: {head:?} (overrides architectures[0])");
AutoModel::from_pretrained_for_export_with_head(repo_id, head)?
}
None => AutoModel::from_pretrained_for_export(repo_id)?,
};
let canonical = graph
.source_config()
.unwrap_or_else(|| config.to_json_str());
let stamped = inject_source_repo(&canonical, repo_id)?;
eprintln!("exporting to {} ...", out_dir.display());
export_hf_dir(&graph, &stamped, out_dir)?;
println!(
"exported {repo_id} → {}\n model.safetensors + config.json ready for AutoModel.from_pretrained",
out_dir.display(),
);
Ok(())
}
fn inject_source_repo(canonical: &str, repo_id: &str) -> flodl::Result<String> {
let mut v: serde_json::Value = serde_json::from_str(canonical).map_err(|e| {
flodl::TensorError::new(&format!(
"inject_source_repo: parse canonical config: {e}"
))
})?;
let obj = v.as_object_mut().ok_or_else(|| {
flodl::TensorError::new("inject_source_repo: canonical config is not a JSON object")
})?;
obj.insert(
"flodl_source_repo".into(),
serde_json::Value::String(repo_id.to_string()),
);
serde_json::to_string_pretty(&v).map_err(|e| {
flodl::TensorError::new(&format!(
"inject_source_repo: re-serialize canonical config: {e}"
))
})
}
fn run_checkpoint(
checkpoint_path: &str,
config_override: Option<&str>,
out_dir: &Path,
preserve_source_config: bool,
) -> flodl::Result<()> {
let checkpoint_path = resolve_path(checkpoint_path);
let checkpoint_str = checkpoint_path.to_string_lossy();
let config_str = if let Some(cfg) = config_override {
let cfg_path = resolve_path(cfg);
eprintln!("reading config from {} ...", cfg_path.display());
std::fs::read_to_string(&cfg_path).map_err(|e| {
flodl::TensorError::new(&format!(
"cannot read --config {}: {e}",
cfg_path.display()
))
})?
} else {
let sidecar = sidecar_for(&checkpoint_path);
if !sidecar.exists() {
return Err(flodl::TensorError::new(&format!(
"no sidecar config at {}; pass --config <file> to override (or save the checkpoint via flodl-hf so the sidecar is emitted automatically)",
sidecar.display(),
)));
}
eprintln!("reading sidecar from {} ...", sidecar.display());
std::fs::read_to_string(&sidecar).map_err(|e| {
flodl::TensorError::new(&format!(
"cannot read sidecar {}: {e}",
sidecar.display()
))
})?
};
let config = AutoConfig::from_json_str(&config_str)?;
eprintln!("detected family: {}", config.model_type());
let keys = flodl::checkpoint_keys(&checkpoint_str)?;
let has_pooler = keys_have_pooler(&keys);
eprintln!(
"checkpoint declares {} keys, with_pooler={has_pooler}",
keys.len()
);
let graph: Graph = build_for_export(&config, has_pooler, Device::CPU)?;
let report = graph.load_checkpoint(&checkpoint_str)?;
eprintln!(
"loaded {} param(s)/buffer(s); {} skipped, {} missing",
report.loaded.len(),
report.skipped.len(),
report.missing.len(),
);
let canonical_config = config.to_json_str();
let normalized = source_only_top_level_keys(&config_str, &canonical_config);
if !normalized.is_empty() {
eprintln!(
"note: {} field(s) present in source config not emitted in canonical: {}",
normalized.len(),
normalized.join(", "),
);
}
eprintln!("exporting to {} ...", out_dir.display());
export_hf_dir(&graph, &canonical_config, out_dir)?;
if preserve_source_config {
let source_path = out_dir.join("config.source.json");
std::fs::write(&source_path, &config_str).map_err(|e| {
flodl::TensorError::new(&format!(
"write {}: {e}",
source_path.display(),
))
})?;
eprintln!(
"wrote source config to {} (canonical config.json kept for AutoConfig)",
source_path.display(),
);
}
let copied = copy_tokenizer_files(&checkpoint_path, out_dir)?;
if copied == 0 {
eprintln!(
"warning: no tokenizer files matched the auto-whitelist next to {}. \
Copy them into {} manually if HF Python needs them (tokenizer.json, \
vocab.txt, sentencepiece.bpe.model, ...).",
checkpoint_path.display(),
out_dir.display(),
);
} else {
eprintln!("copied {copied} tokenizer file(s) into {}", out_dir.display());
}
println!(
"exported {} → {}\n model.safetensors + config.json ready for AutoModel.from_pretrained",
checkpoint_path.display(),
out_dir.display(),
);
Ok(())
}
const TOKENIZER_WHITELIST: &[&str] = &[
"tokenizer.json",
"tokenizer_config.json",
"special_tokens_map.json",
"added_tokens.json",
"vocab.txt",
"vocab.json",
"merges.txt",
"sentencepiece.bpe.model",
"spm.model",
];
fn copy_tokenizer_files(checkpoint_path: &Path, out_dir: &Path) -> flodl::Result<usize> {
let parent = match checkpoint_path.parent() {
Some(p) if !p.as_os_str().is_empty() => p,
_ => return Ok(0),
};
let mut copied = 0_usize;
for name in TOKENIZER_WHITELIST {
let src = parent.join(name);
if !src.is_file() {
continue;
}
let dst = out_dir.join(name);
std::fs::copy(&src, &dst).map_err(|e| {
flodl::TensorError::new(&format!(
"copy tokenizer file {} -> {}: {e}",
src.display(),
dst.display(),
))
})?;
copied += 1;
}
Ok(copied)
}
fn source_only_top_level_keys(source_json: &str, canonical_json: &str) -> Vec<String> {
let src: serde_json::Value = match serde_json::from_str(source_json) {
Ok(v) => v,
Err(_) => return Vec::new(),
};
let canon: serde_json::Value = match serde_json::from_str(canonical_json) {
Ok(v) => v,
Err(_) => return Vec::new(),
};
let (Some(src_obj), Some(canon_obj)) = (src.as_object(), canon.as_object()) else {
return Vec::new();
};
let mut out: Vec<String> = src_obj
.keys()
.filter(|k| !canon_obj.contains_key(k.as_str()))
.cloned()
.collect();
out.sort();
out
}
fn sidecar_for(checkpoint: &Path) -> PathBuf {
let mut p = checkpoint.to_path_buf();
if p.extension().and_then(|e| e.to_str()) == Some("gz") {
p.set_extension("");
}
p.set_extension("config.json");
p
}