jepa 0.1.0

CLI and TUI for the jepa-rs workspace
Documentation
use std::path::{Path, PathBuf};

use anyhow::{Context, Result};
use image::{Rgb, RgbImage};

use crate::cli::{ArchPreset, EnergyChoice, MaskingChoice, RegularizerChoice, TrainArgs};

pub const DEMO_IMAGE_COUNT: usize = 6;
const DEMO_IMAGE_SIZE: u32 = 96;

type PixelFn = fn(u32, u32) -> [u8; 3];

const DEMO_IMAGES: [(&str, PixelFn); DEMO_IMAGE_COUNT] = [
    ("class_a/gradient_h.png", gradient_horizontal),
    ("class_a/gradient_v.png", gradient_vertical),
    ("class_a/checker.png", checkerboard),
    ("class_b/diagonal.png", diagonal_mix),
    ("class_b/rings.png", concentric_rings),
    ("class_b/quadrants.png", quadrants),
];

#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum DemoId {
    ImageFolderTraining,
    SyntheticTraining,
    PrepareImageFolder,
}

impl DemoId {
    pub const ALL: [DemoId; 3] = [
        DemoId::ImageFolderTraining,
        DemoId::SyntheticTraining,
        DemoId::PrepareImageFolder,
    ];

    pub fn title(self) -> &'static str {
        match self {
            Self::ImageFolderTraining => "Image-Folder Training",
            Self::SyntheticTraining => "Synthetic Training",
            Self::PrepareImageFolder => "Prepare Demo Dataset",
        }
    }

    pub fn example_name(self) -> &'static str {
        match self {
            Self::ImageFolderTraining => "train_image_folder_demo",
            Self::SyntheticTraining => "train_synthetic_demo",
            Self::PrepareImageFolder => "prepare_demo_image_folder",
        }
    }

    pub fn subtitle(self) -> &'static str {
        match self {
            Self::ImageFolderTraining => {
                "Real strict I-JEPA optimization over generated image files"
            }
            Self::SyntheticTraining => {
                "Same optimizer and EMA path, but with synthetic random tensors"
            }
            Self::PrepareImageFolder => {
                "Generate a tiny recursive image dataset under target/example-data"
            }
        }
    }

    pub fn estimated_duration(self) -> &'static str {
        match self {
            Self::ImageFolderTraining => "~10-20s on CPU",
            Self::SyntheticTraining => "~10-20s on CPU",
            Self::PrepareImageFolder => "<1s",
        }
    }

    pub fn command(self) -> String {
        format!("cargo run -p jepa --example {}", self.example_name())
    }

    pub fn process_notes(self) -> &'static [&'static str] {
        match self {
            Self::ImageFolderTraining => &[
                "Generates a tiny nested image dataset at runtime.",
                "Exercises recursive file discovery and deterministic preprocessing.",
                "Runs the real strict masked-image training loop with AdamW and EMA.",
            ],
            Self::SyntheticTraining => &[
                "Skips dataset I/O and preprocessing.",
                "Still uses the real masking, optimizer, EMA, and predictor path.",
                "Good for a fast sanity check of the training stack.",
            ],
            Self::PrepareImageFolder => &[
                "Creates 6 small PNG fixtures under target/example-data/jepa/.",
                "Uses nested directories so --dataset-dir recursion is exercised.",
                "Keeps binary demo assets out of git while remaining reproducible.",
            ],
        }
    }

    pub fn monitoring_notes(self) -> &'static [&'static str] {
        match self {
            Self::ImageFolderTraining => &[
                "Watch loss and energy move step by step.",
                "Confirm the dataset path, file count, resize, and crop settings.",
                "Use the live log to follow preprocessing, training, and completion.",
            ],
            Self::SyntheticTraining => &[
                "Watch loss, learning rate, and EMA updates without disk I/O.",
                "Useful to confirm the strict reference path still executes cleanly.",
                "Compare behavior against the image-folder demo.",
            ],
            Self::PrepareImageFolder => &[
                "Watch the generated file list and output directory.",
                "Use the resulting path directly with jepa train --dataset-dir.",
                "This is a setup demo, not a learning run.",
            ],
        }
    }
}

#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum InferenceDemoId {
    PatternVitSmall,
    PatternVitBase,
}

impl InferenceDemoId {
    pub const ALL: [InferenceDemoId; 2] = [
        InferenceDemoId::PatternVitSmall,
        InferenceDemoId::PatternVitBase,
    ];

    pub fn title(self) -> &'static str {
        match self {
            Self::PatternVitSmall => "Pattern Walkthrough",
            Self::PatternVitBase => "Pattern Walkthrough XL",
        }
    }

    pub fn subtitle(self) -> &'static str {
        match self {
            Self::PatternVitSmall => "ViT-S/16 inference over deterministic demo image patterns",
            Self::PatternVitBase => {
                "ViT-B/16 inference over the same patterns to compare scale and cost"
            }
        }
    }

    pub fn estimated_duration(self) -> &'static str {
        match self {
            Self::PatternVitSmall => "~1-2s on CPU",
            Self::PatternVitBase => "~2-4s on CPU",
        }
    }

    pub fn preset(self) -> ArchPreset {
        match self {
            Self::PatternVitSmall => ArchPreset::VitSmall16,
            Self::PatternVitBase => ArchPreset::VitBase16,
        }
    }

    pub fn input_size(self) -> (usize, usize) {
        (224, 224)
    }

    pub fn sample_count(self) -> usize {
        3
    }

    pub fn process_notes(self) -> &'static [&'static str] {
        match self {
            Self::PatternVitSmall => &[
                "Builds a random-initialized ViT-S/16 encoder in demo mode.",
                "Synthesizes three deterministic image patterns used throughout the demos.",
                "Runs real tokenization and encoder attention, then inspects the embedding output.",
            ],
            Self::PatternVitBase => &[
                "Uses the same deterministic inputs with a larger ViT-B/16 encoder.",
                "Highlights how patch count stays fixed while latency and representation scale change.",
                "Good for comparing runtime cost against the smaller demo.",
            ],
        }
    }

    pub fn monitoring_notes(self) -> &'static [&'static str] {
        match self {
            Self::PatternVitSmall => &[
                "Watch the phase panel move from encoder init to per-sample inference.",
                "Latency, activation mean/std, and token norms update after each sample.",
                "The result panel explains what the embedding stats mean and what they do not.",
            ],
            Self::PatternVitBase => &[
                "Compare runtime and token-norm drift against the smaller pattern walkthrough.",
                "Use the sample previews to verify the same inputs flowed through both presets.",
                "This is a structure-and-monitoring demo, not a pretrained semantic benchmark.",
            ],
        }
    }

    pub fn engine_note(self) -> &'static str {
        match self {
            Self::PatternVitSmall | Self::PatternVitBase => {
                "Demo mode with random-initialized weights"
            }
        }
    }
}

#[derive(Debug, Clone)]
pub struct PreparedDemoDataset {
    pub root: PathBuf,
    pub files: Vec<String>,
}

pub fn workspace_root() -> PathBuf {
    Path::new(env!("CARGO_MANIFEST_DIR"))
        .parent()
        .and_then(Path::parent)
        .expect("workspace root should exist above crates/jepa")
        .to_path_buf()
}

pub fn demo_image_folder() -> PathBuf {
    workspace_root()
        .join("target")
        .join("example-data")
        .join("jepa")
        .join("demo-image-folder")
}

pub fn demo_checkpoint_dir(name: &str) -> PathBuf {
    workspace_root()
        .join("target")
        .join("example-data")
        .join("jepa")
        .join("checkpoints")
        .join(name)
}

pub fn prepare_demo_image_folder() -> Result<PreparedDemoDataset> {
    let root = demo_image_folder();
    if root.exists() {
        std::fs::remove_dir_all(&root)
            .with_context(|| format!("failed to clear {}", root.display()))?;
    }
    std::fs::create_dir_all(&root)
        .with_context(|| format!("failed to create {}", root.display()))?;

    let mut files = Vec::with_capacity(DEMO_IMAGE_COUNT);
    for (relative_path, pixel_fn) in DEMO_IMAGES {
        let path = root.join(relative_path);
        if let Some(parent) = path.parent() {
            std::fs::create_dir_all(parent)
                .with_context(|| format!("failed to create {}", parent.display()))?;
        }

        render_demo_image(pixel_fn)
            .save(&path)
            .with_context(|| format!("failed to save {}", path.display()))?;
        files.push(relative_path.to_string());
    }

    Ok(PreparedDemoDataset { root, files })
}

pub(crate) fn demo_pattern_images() -> Vec<(String, RgbImage)> {
    DEMO_IMAGES
        .iter()
        .map(|(relative_path, pixel_fn)| (relative_path.to_string(), render_demo_image(*pixel_fn)))
        .collect()
}

pub fn synthetic_demo_args() -> TrainArgs {
    TrainArgs {
        preset: ArchPreset::VitSmall16,
        steps: 2,
        warmup: 1,
        lr: 1e-3,
        batch_size: 2,
        dataset: None,
        dataset_key: "images".to_string(),
        dataset_dir: None,
        resize: None,
        crop_size: None,
        mean: None,
        std: None,
        dataset_limit: None,
        shuffle: false,
        masking: MaskingChoice::Block,
        energy: EnergyChoice::L2,
        regularizer: RegularizerChoice::Vicreg,
        reg_weight: 0.01,
        ema_momentum: 0.996,
        log_interval: 1,
        checkpoint_interval: 10,
        output_dir: demo_checkpoint_dir("synthetic-demo"),
    }
}

pub fn image_folder_demo_args(dataset_dir: PathBuf) -> TrainArgs {
    TrainArgs {
        preset: ArchPreset::VitSmall16,
        steps: 2,
        warmup: 1,
        lr: 1e-3,
        batch_size: 2,
        dataset: None,
        dataset_key: "images".to_string(),
        dataset_dir: Some(dataset_dir),
        resize: Some(256),
        crop_size: Some(224),
        mean: None,
        std: None,
        dataset_limit: Some(DEMO_IMAGE_COUNT),
        shuffle: true,
        masking: MaskingChoice::Block,
        energy: EnergyChoice::L2,
        regularizer: RegularizerChoice::Vicreg,
        reg_weight: 0.01,
        ema_momentum: 0.996,
        log_interval: 1,
        checkpoint_interval: 10,
        output_dir: demo_checkpoint_dir("image-folder-demo"),
    }
}

fn render_demo_image(pixel_fn: PixelFn) -> RgbImage {
    let mut image = RgbImage::new(DEMO_IMAGE_SIZE, DEMO_IMAGE_SIZE);
    for y in 0..DEMO_IMAGE_SIZE {
        for x in 0..DEMO_IMAGE_SIZE {
            image.put_pixel(x, y, Rgb(pixel_fn(x, y)));
        }
    }
    image
}

fn gradient_horizontal(x: u32, y: u32) -> [u8; 3] {
    let width = DEMO_IMAGE_SIZE - 1;
    let red = scale_u32_to_u8(x, width);
    let green = scale_u32_to_u8(y, width);
    let blue = 255u8.saturating_sub(red / 2);
    [red, green, blue]
}

fn gradient_vertical(x: u32, y: u32) -> [u8; 3] {
    let width = DEMO_IMAGE_SIZE - 1;
    let red = scale_u32_to_u8(y, width);
    let green = 255u8.saturating_sub(scale_u32_to_u8(x, width) / 2);
    let blue = scale_u32_to_u8(x + y, width * 2);
    [red, green, blue]
}

fn checkerboard(x: u32, y: u32) -> [u8; 3] {
    let tile = 12;
    let on = ((x / tile) + (y / tile)) % 2 == 0;
    if on {
        [240, 210, 40]
    } else {
        [20, 40, 180]
    }
}

fn diagonal_mix(x: u32, y: u32) -> [u8; 3] {
    let width = DEMO_IMAGE_SIZE - 1;
    let red = scale_u32_to_u8(x + y, width * 2);
    let green = scale_u32_to_u8(width.saturating_sub(x), width);
    let blue = scale_u32_to_u8(width.saturating_sub(y), width);
    [red, green, blue]
}

fn concentric_rings(x: u32, y: u32) -> [u8; 3] {
    let center = (DEMO_IMAGE_SIZE / 2) as i32;
    let dx = x as i32 - center;
    let dy = y as i32 - center;
    let distance_bucket = ((dx * dx + dy * dy) as f32).sqrt() as u32 / 6;
    match distance_bucket % 3 {
        0 => [230, 70, 70],
        1 => [40, 210, 120],
        _ => [60, 90, 230],
    }
}

fn quadrants(x: u32, y: u32) -> [u8; 3] {
    let half = DEMO_IMAGE_SIZE / 2;
    match (x < half, y < half) {
        (true, true) => [255, 110, 70],
        (false, true) => [80, 220, 120],
        (true, false) => [70, 140, 255],
        (false, false) => [240, 230, 120],
    }
}

fn scale_u32_to_u8(value: u32, max: u32) -> u8 {
    if max == 0 {
        return 0;
    }
    ((value.min(max) * 255) / max) as u8
}