#![recursion_limit = "256"]
#![allow(dead_code, unused_imports, unused_variables)]
mod config;
mod data;
mod model;
mod training;
use std::path::PathBuf;
use anyhow::Result;
use burn::backend::{wgpu::WgpuDevice, Autodiff, Wgpu};
use clap::{Parser, Subcommand};
use tracing_subscriber::EnvFilter;
type WgpuBackend = Wgpu<f32, i32>;
type AutoWgpu = Autodiff<WgpuBackend>;
#[derive(Parser)]
#[command(
name = "opentslm",
about = "OpenTSLM — Time-Series Language Model (Rust / Burn / llama.cpp)",
long_about = "\
Zero-configuration quick start:
opentslm train
Model — downloaded automatically from HuggingFace on first run and cached in
~/.cache/huggingface/hub/. Set $HF_TOKEN for private repos.
Data — downloaded automatically if data/ is missing or incomplete.
All sources are public; no login required."
)]
struct Cli {
#[command(subcommand)]
command: Commands,
#[arg(long, global = true, default_value = "info")]
log_level: String,
}
#[derive(Subcommand)]
enum Commands {
Train {
#[arg(long)]
model: Option<PathBuf>,
#[arg(long, default_value = "data/")]
data_dir: PathBuf,
#[arg(long, num_args = 1.., value_delimiter = ' ')]
stages: Option<Vec<String>>,
#[arg(long)]
batch_size: Option<usize>,
},
Eval {
#[arg(long)]
model: Option<PathBuf>,
#[arg(long, default_value = "data/")]
data_dir: PathBuf,
#[arg(long)]
stage: String,
},
Infer {
#[arg(long)]
model: Option<PathBuf>,
#[arg(long)]
series: String,
#[arg(long)]
prompt: String,
#[arg(long, default_value_t = 200)]
max_tokens: usize,
},
Plot {
#[arg(long, num_args = 1.., value_delimiter = ' ')]
stages: Option<Vec<String>>,
#[arg(long, default_value = "figures/")]
figures_dir: PathBuf,
},
#[cfg(feature = "download")]
DownloadData {
#[arg(long, default_value = "data/")]
out_dir: PathBuf,
#[arg(long)]
limit: Option<usize>,
#[arg(long, value_parser = ["har", "sleep", "ecg"])]
only: Option<String>,
},
}
fn main() -> Result<()> {
let cli = Cli::parse();
let filter = EnvFilter::try_from_default_env().unwrap_or_else(|_| {
#[cfg(not(feature = "verbose"))]
{
EnvFilter::new(format!(
"{level},\
cubecl_runtime::tune=warn,\
cubecl_wgpu=warn",
level = cli.log_level,
))
}
#[cfg(feature = "verbose")]
{
EnvFilter::new(&cli.log_level)
}
});
tracing_subscriber::fmt()
.with_env_filter(filter)
.init();
match cli.command {
Commands::Train { model, data_dir, stages, batch_size } =>
cmd_train(model, data_dir, stages, batch_size),
Commands::Eval { model, data_dir, stage } =>
cmd_eval(model, data_dir, stage),
Commands::Infer { model, series, prompt, max_tokens } =>
cmd_infer(model, series, prompt, max_tokens),
Commands::Plot { stages, figures_dir } =>
cmd_plot(stages, figures_dir),
#[cfg(feature = "download")]
Commands::DownloadData { out_dir, limit, only } =>
cmd_download_data(out_dir, limit, only),
}
}
fn cmd_train(
model: Option<PathBuf>,
data_dir: PathBuf,
stages: Option<Vec<String>>,
batch_size: Option<usize>,
) -> Result<()> {
let model = resolve_model(model)?;
ensure_data(&data_dir)?;
tracing::info!("Model : {}", model.display());
tracing::info!("Data dir : {}", data_dir.display());
let mut trainer =
training::curriculum::CurriculumTrainer::new(&model, &data_dir, "wgpu");
if let Some(bs) = batch_size { trainer.batch_size = bs; }
let run_stages: Vec<String> = stages.unwrap_or_else(|| {
config::CURRICULUM_STAGES.iter().map(|s| s.to_string()).collect()
});
for stage in &run_stages {
trainer.run_stage::<AutoWgpu>(stage)?;
}
{
use training::metrics::{StageMetrics, plot_curriculum_overview};
let figs = PathBuf::from("figures");
let all: Vec<StageMetrics> = config::CURRICULUM_STAGES.iter()
.filter_map(|s| StageMetrics::from_csv(s, &figs).ok())
.collect();
if !all.is_empty() {
let refs: Vec<&StageMetrics> = all.iter().collect();
if let Err(e) = plot_curriculum_overview(&refs, &figs) {
tracing::warn!("Could not write curriculum overview: {e}");
}
}
}
Ok(())
}
fn cmd_eval(model: Option<PathBuf>, data_dir: PathBuf, stage: String) -> Result<()> {
let model = resolve_model(model)?;
ensure_data(&data_dir)?;
training::curriculum::CurriculumTrainer::new(&model, &data_dir, "wgpu")
.run_stage::<AutoWgpu>(&stage)
}
fn cmd_infer(
model: Option<PathBuf>,
series_str: String,
prompt: String,
max_tokens: usize,
) -> Result<()> {
use model::llm::{llama_cpp::LlamaCppBackend, opentslm_sp::OpenTslmSp};
let model = resolve_model(model)?;
let device = WgpuDevice::default();
let llm = LlamaCppBackend::load(&model, config::N_GPU_LAYERS, config::CTX_SIZE)?;
let sp_model: OpenTslmSp<AutoWgpu> = OpenTslmSp::new(&llm, &device);
let series: Vec<f32> = series_str
.split(',')
.map(|s| s.trim().parse::<f32>().unwrap_or(0.0))
.collect();
let sample = data::batch::Sample {
pre_prompt: prompt,
time_series_text: vec!["Time series data:".to_string()],
time_series: vec![series],
post_prompt: String::new(),
answer: String::new(),
label: None,
};
let outputs = sp_model.generate(&[sample], &llm, max_tokens, &device);
println!("\n─── Model Output ────────────────────────────────────────");
println!("{}", outputs.into_iter().next().unwrap_or_default());
println!("─────────────────────────────────────────────────────────");
Ok(())
}
fn cmd_plot(stages: Option<Vec<String>>, figures_dir: PathBuf) -> Result<()> {
use training::metrics::{StageMetrics, plot_curriculum_overview, write_html_index};
let run_stages: Vec<String> = stages.unwrap_or_else(|| {
config::CURRICULUM_STAGES.iter().map(|s| s.to_string()).collect()
});
let mut loaded: Vec<StageMetrics> = Vec::new();
let mut any_ok = false;
for stage in &run_stages {
match StageMetrics::from_csv(stage, &figures_dir) {
Ok(m) => {
m.save(&figures_dir)?;
write_html_index(&m, &figures_dir)?;
tracing::info!("{stage}: plots written → {}/", figures_dir.join(stage).display());
loaded.push(m);
any_ok = true;
}
Err(e) => tracing::warn!("{stage}: skipped — {e}"),
}
}
if !any_ok {
anyhow::bail!(
"No metrics.csv files found in {}.\n\
Run training first: cargo run --release -- train",
figures_dir.display()
);
}
let refs: Vec<&StageMetrics> = loaded.iter().collect();
if let Err(e) = plot_curriculum_overview(&refs, &figures_dir) {
tracing::warn!("Could not write curriculum overview: {e}");
}
Ok(())
}
#[cfg(feature = "download")]
fn cmd_download_data(
out_dir: PathBuf,
limit: Option<usize>,
only: Option<String>,
) -> Result<()> {
tracing::info!("Downloading wearable datasets → {}", out_dir.display());
data::downloader::run(&data::downloader::DownloadConfig { out_dir, limit, only })
}
const REQUIRED_DATA_FILES: &[&str] = &[
"tsqa/train.jsonl",
"m4/train_samples.jsonl",
"har_cot/train.jsonl",
"har_cot/val.jsonl",
"har_cot/test.jsonl",
"sleep_cot/train.jsonl",
"sleep_cot/val.jsonl",
"sleep_cot/test.jsonl",
"ecg_qa_cot/train.jsonl",
"ecg_qa_cot/val.jsonl",
"ecg_qa_cot/test.jsonl",
];
fn ensure_data(data_dir: &PathBuf) -> Result<()> {
let missing: Vec<&str> = REQUIRED_DATA_FILES
.iter()
.copied()
.filter(|f| !data_dir.join(f).exists())
.collect();
if missing.is_empty() {
return Ok(());
}
tracing::info!(
"Data directory {:?} is missing {} file(s) — downloading now …",
data_dir,
missing.len(),
);
for f in &missing {
tracing::info!(" missing: {f}");
}
#[cfg(feature = "download")]
{
data::downloader::run(&data::downloader::DownloadConfig {
out_dir: data_dir.clone(),
limit: None,
only: None,
})
}
#[cfg(not(feature = "download"))]
anyhow::bail!(
"Dataset files are missing and the binary was built without the \
`download` feature.\n\
Re-build with: cargo build --features download\n\
Or place the required JSONL files in {:?} manually.",
data_dir
)
}
fn resolve_model(explicit: Option<PathBuf>) -> Result<PathBuf> {
if let Some(p) = explicit {
return Ok(p);
}
hf_get(config::DEFAULT_MODEL_REPO, config::DEFAULT_MODEL_FILE)
}
fn hf_get(repo: &str, filename: &str) -> Result<PathBuf> {
use hf_hub::api::sync::Api;
let api = Api::new()
.map_err(|e| anyhow::anyhow!("HF Hub init failed: {e}"))?;
tracing::info!("Resolving {repo}/{filename} (downloading to HF cache if needed) …");
let path = api
.model(repo.to_string())
.get(filename)
.map_err(|e| anyhow::anyhow!(
"Could not fetch {repo}/{filename}: {e}\n\
\n\
To download manually:\n huggingface-cli download {repo} {filename}\n\
To use a local file:\n --model /path/to/{filename}"
))?;
tracing::info!("Model : {}", path.display());
Ok(path)
}