prx 0.5.9

Praxis — agent-native Unix tools. Single binary replacing grep, cat, find, sed, diff for AI coding agents.
use assert_cmd::Command;
use serde::Deserialize;

#[derive(Deserialize)]
struct NdcgDataset {
    queries: Vec<NdcgQuery>,
    #[serde(default)]
    root: Option<String>,
}

#[derive(Deserialize)]
struct NdcgQuery {
    query: String,
    category: String,
    relevant: Vec<RelevantFile>,
}

#[derive(Deserialize)]
struct RelevantFile {
    file: String,
    relevance: u32,
}

#[derive(Deserialize)]
struct SearchEnvelope {
    data: SearchData,
}

#[derive(Deserialize)]
struct SearchData {
    matches: Vec<SearchMatch>,
}

#[derive(Deserialize)]
struct SearchMatch {
    file: String,
}

fn dcg(relevances: &[f64], k: usize) -> f64 {
    relevances
        .iter()
        .take(k)
        .enumerate()
        .map(|(i, &rel)| rel / (i as f64 + 2.0).log2())
        .sum()
}

fn ndcg_at_k(relevances: &[f64], k: usize) -> f64 {
    let actual_dcg = dcg(relevances, k);
    if actual_dcg == 0.0 {
        return 0.0;
    }

    let mut ideal = relevances.to_vec();
    ideal.sort_by(|a, b| b.partial_cmp(a).unwrap());
    let ideal_dcg = dcg(&ideal, k);

    if ideal_dcg == 0.0 {
        0.0
    } else {
        actual_dcg / ideal_dcg
    }
}

fn ag() -> Command {
    let mut cmd = Command::cargo_bin("prx").unwrap();
    cmd.env("PRX_STATS_FILE", "/dev/null");
    cmd.env("PRX_ERRORS_FILE", "/dev/null");
    cmd
}

fn run_ndcg(dataset_file: &str, label: &str) {
    let dataset_path = std::path::Path::new(env!("CARGO_MANIFEST_DIR"))
        .join("benchmarks")
        .join(dataset_file);
    let dataset_str = std::fs::read_to_string(&dataset_path)
        .unwrap_or_else(|e| panic!("failed to read {}: {e}", dataset_path.display()));
    let dataset: NdcgDataset = serde_json::from_str(&dataset_str).unwrap();

    let search_root = match &dataset.root {
        Some(root) => root.clone(),
        None => {
            let src_dir = std::path::Path::new(env!("CARGO_MANIFEST_DIR")).join("src");
            src_dir.to_string_lossy().to_string()
        }
    };
    let src_path = search_root;

    let mut all_ndcg5 = Vec::new();
    let mut all_ndcg10 = Vec::new();
    let mut by_category: std::collections::HashMap<String, Vec<f64>> =
        std::collections::HashMap::new();

    for q in &dataset.queries {
        let output = ag()
            .args(["search", &q.query, &src_path, "--top-k", "10"])
            .output()
            .unwrap();

        if !output.status.success() {
            eprintln!("SKIP query {:?}: command failed", q.query);
            continue;
        }

        let envelope: SearchEnvelope = match serde_json::from_slice(&output.stdout) {
            Ok(e) => e,
            Err(e) => {
                eprintln!("SKIP query {:?}: parse error: {e}", q.query);
                continue;
            }
        };

        let relevances: Vec<f64> = envelope
            .data
            .matches
            .iter()
            .map(|m| {
                q.relevant
                    .iter()
                    .find(|r| {
                        m.file == r.file || m.file.ends_with(&r.file) || r.file.ends_with(&m.file)
                    })
                    .map(|r| r.relevance as f64)
                    .unwrap_or(0.0)
            })
            .collect();

        let n5 = ndcg_at_k(&relevances, 5);
        let n10 = ndcg_at_k(&relevances, 10);
        all_ndcg5.push(n5);
        all_ndcg10.push(n10);
        by_category.entry(q.category.clone()).or_default().push(n10);
    }

    let mean_ndcg5 = if all_ndcg5.is_empty() {
        0.0
    } else {
        all_ndcg5.iter().sum::<f64>() / all_ndcg5.len() as f64
    };
    let mean_ndcg10 = if all_ndcg10.is_empty() {
        0.0
    } else {
        all_ndcg10.iter().sum::<f64>() / all_ndcg10.len() as f64
    };

    eprintln!("\n=== NDCG Results ({label}) ===");
    eprintln!("Queries evaluated: {}", all_ndcg10.len());
    eprintln!("Mean NDCG@5:  {mean_ndcg5:.4}");
    eprintln!("Mean NDCG@10: {mean_ndcg10:.4}");

    for (cat, scores) in &by_category {
        let mean = scores.iter().sum::<f64>() / scores.len() as f64;
        eprintln!("  {cat}: NDCG@10 = {mean:.4} (n={})", scores.len());
    }

    let results = serde_json::json!({
        "label": label,
        "mean_ndcg5": mean_ndcg5,
        "mean_ndcg10": mean_ndcg10,
        "queries_evaluated": all_ndcg10.len(),
        "by_category": by_category.iter().map(|(k, v)| {
            (k.clone(), serde_json::json!({
                "mean_ndcg10": v.iter().sum::<f64>() / v.len() as f64,
                "count": v.len(),
            }))
        }).collect::<serde_json::Map<String, serde_json::Value>>(),
    });

    let results_dir = std::path::Path::new(env!("CARGO_MANIFEST_DIR")).join("benchmarks");
    let results_file = format!(
        "ndcg_results_{}.json",
        label.replace(' ', "_").to_lowercase()
    );
    let results_path = results_dir.join(&results_file);
    let _ = std::fs::write(
        &results_path,
        serde_json::to_string_pretty(&results).unwrap(),
    );
    eprintln!("Results written to {}", results_path.display());

    assert!(
        !all_ndcg10.is_empty(),
        "no queries were evaluated successfully"
    );
}

#[test]
fn ndcg_at_10_measurement() {
    run_ndcg("ndcg_dataset.json", "prx");
}

#[test]
fn ndcg_at_10_fiddler() {
    let dataset_path = std::path::Path::new(env!("CARGO_MANIFEST_DIR"))
        .join("benchmarks")
        .join("ndcg_dataset_fiddler.json");
    if !dataset_path.exists() {
        eprintln!("SKIP: fiddler dataset not found");
        return;
    }
    let root_path = std::path::Path::new("/Users/jeryn/workspace/projects/fiddler");
    if !root_path.exists() {
        eprintln!("SKIP: fiddler repo not found at {}", root_path.display());
        return;
    }
    run_ndcg("ndcg_dataset_fiddler.json", "fiddler");
}