#[cfg(feature = "eval")]
use std::path::PathBuf;
use rag_rat_core::Config;
#[cfg(feature = "eval")]
use rag_rat_core::OutputFormat;
#[cfg(feature = "eval")]
use crate::cli::{BenchmarkEmbeddingArgs, EvalArgs};
use crate::cli::{ModelsArgs, ModelsCommand};
#[cfg(feature = "eval")]
use crate::commands::output_format;
use crate::open_index;
#[cfg(feature = "eval")]
use crate::render::print_eval_summary;
use crate::render::print_output;
#[cfg(feature = "eval")]
pub(crate) fn eval(config: &Config, args: &EvalArgs) -> anyhow::Result<()> {
if args.replay_parent_state {
let report = rag_rat_core::eval::run_replay_parent_state(
config,
&rag_rat_core::eval::ReplayOptions {
max_cases: args.replay_max_cases,
max_files: args.replay_max_files,
},
)?;
print_output(&report)?;
return Ok(());
}
let options = rag_rat_core::eval::EvalOptions {
queries_path: args
.queries
.clone()
.unwrap_or_else(|| default_eval_path(config, "queries.toml")),
expected_path: args
.expected
.clone()
.unwrap_or_else(|| default_eval_path(config, "expected_hits.toml")),
update_baseline: args.update_baseline,
scip_path: args.scip.clone().or_else(|| {
let default = default_eval_path(config, "oracle.scip");
default.exists().then_some(default)
}),
replay: args.replay.then_some(rag_rat_core::eval::ReplayOptions {
max_cases: args.replay_max_cases,
max_files: args.replay_max_files,
}),
rerank: args.rerank,
search_limit: args.search_limit,
};
let report = rag_rat_core::eval::run(config, &options)?;
if output_format() == OutputFormat::Json || options.update_baseline {
print_output(&report)?;
} else {
print_eval_summary(&report);
}
if !report.pass {
anyhow::bail!(
"eval failed: stale_current_source_violations={}, failed_queries={}",
report.metrics.stale_current_source_violations,
report.results.iter().filter(|result| !result.passed).count()
);
}
Ok(())
}
#[cfg(feature = "eval")]
pub(crate) fn default_eval_path(config: &Config, file_name: &str) -> PathBuf {
config.root.join("evals").join(file_name)
}
#[cfg(feature = "eval")]
pub(crate) fn benchmark_embedding(
config: &Config,
args: &BenchmarkEmbeddingArgs,
) -> anyhow::Result<()> {
use rag_rat_core::config::RemoteEmbeddingConfig;
let base = config.llm.embedding.remote.clone().unwrap_or_default();
let cap = base.bounded_concurrency();
let max_embedding_chars = config.llm.embedding.runtime.max_embedding_chars;
let candidates: Vec<u32> = if args.candidates.is_empty() {
rag_rat_core::index::ai::default_benchmark_candidates(cap)
} else {
let mut c: Vec<u32> = args
.candidates
.iter()
.map(|&c| RemoteEmbeddingConfig::bounded_concurrency_value(c))
.collect();
c.sort_unstable();
c.dedup();
c
};
let provision_concurrency = RemoteEmbeddingConfig::bounded_concurrency_value(
candidates.iter().copied().max().unwrap_or(cap).max(cap),
);
let remote = RemoteEmbeddingConfig {
model: args.model.clone(),
backend: args.backend,
endpoint: None,
cookbook: Some(args.cookbook.clone()),
query_endpoint: None,
auth_env: None,
gpu: args.gpu.clone(),
concurrency: provision_concurrency,
..base
};
let budget_ms =
args.budget_ms.unwrap_or_else(rag_rat_core::index::ai::default_benchmark_budget_ms);
let min_budget = rag_rat_core::index::ai::min_benchmark_budget_ms(candidates.len());
anyhow::ensure!(
budget_ms >= min_budget,
"--budget-ms {budget_ms} is too small to benchmark {} candidate(s): need at least \
{min_budget} ms (~1s per candidate). Raise --budget-ms or pass fewer --candidates.",
candidates.len(),
);
let spec = rag_rat_core::embedding_models::spec(&args.model);
let provisioned = rag_rat_core::index::ai::provision_box_for_benchmark(
&remote,
spec_or_measure_placeholder(spec),
)?;
let (selected_model_id, dim) = match spec {
Some(spec) => (spec.model_id.to_string(), spec.dim),
None => {
let dim = rag_rat_core::index::ai::measure_remote_dim(
&provisioned.endpoint,
provisioned.auth_token.as_deref(),
&remote,
)?;
(args.model.clone(), dim)
},
};
let measured = rag_rat_core::index::ai::benchmark_remote_concurrency(
&provisioned.endpoint,
provisioned.auth_token.as_deref(),
&remote,
&selected_model_id,
dim,
max_embedding_chars,
&candidates,
budget_ms,
);
let measured_set: std::collections::BTreeSet<u32> =
measured.iter().map(|m| m.concurrency).collect();
let skipped: Vec<u32> =
candidates.iter().copied().filter(|c| !measured_set.contains(c)).collect();
if !skipped.is_empty() {
eprintln!(
"benchmark-embedding: WARNING — {} requested candidate(s) not measured (probe window \
/ budget limit): {skipped:?}. Lower --candidates or [runtime] max_embedding_chars, \
or raise --budget-ms.",
skipped.len(),
);
}
let peak = measured
.iter()
.filter(|m| m.requests > 0 && !m.aborted)
.max_by(|a, b| a.texts_per_second.total_cmp(&b.texts_per_second))
.map(|m| serde_json::json!({ "concurrency": m.concurrency, "texts_per_second": m.texts_per_second }));
let report = serde_json::json!({
"backend": args.backend.as_db_str(),
"model": args.model,
"cookbook": args.cookbook,
"gpu": args.gpu,
"dim": dim,
"budget_ms": budget_ms,
"candidates": measured,
"skipped_candidates": skipped,
"peak": peak,
});
let json = serde_json::to_string_pretty(&report)?;
match &args.output {
Some(path) => {
crate::write_atomic(path, json.as_bytes())?;
eprintln!(
"benchmark-embedding: wrote {} candidate rows to {}",
measured.len(),
path.display()
);
},
None => println!("{json}"),
}
Ok(())
}
#[cfg(feature = "eval")]
fn spec_or_measure_placeholder(
spec: Option<&'static rag_rat_core::embedding_models::EmbeddingModelSpec>,
) -> &'static rag_rat_core::embedding_models::EmbeddingModelSpec {
spec.unwrap_or_else(|| {
rag_rat_core::embedding_models::spec(rag_rat_core::embedding_models::FASTEMBED_MODEL_ID)
.expect("the fallback all-MiniLM spec is always registered")
})
}
fn remote_for_install<'a>(
config: &'a Config,
model_id: &str,
) -> anyhow::Result<Option<&'a rag_rat_core::config::RemoteEmbeddingConfig>> {
let Some(remote) = config.llm.embedding.remote.as_ref() else {
return Ok(None);
};
let requested = rag_rat_core::embedding_models::spec(model_id).map(|s| s.model_id);
let configured = config.llm.embedding.backend.model_id();
if requested.is_some() && requested == configured {
Ok(Some(remote))
} else {
anyhow::bail!(
"remote embedding is configured for `{}`; install that model remotely, or remove the \
[llm.embedding.remote] block to install `{model_id}` locally",
configured.unwrap_or("none"),
)
}
}
pub(crate) fn models(config: &Config, args: &ModelsArgs) -> anyhow::Result<()> {
let db = open_index(config)?;
match &args.command {
None | Some(ModelsCommand::List) => print_output(&db.list_models()?),
Some(ModelsCommand::Install { model_id }) => {
warn_if_short_context(model_id);
let remote = remote_for_install(config, model_id)?;
print_output(&db.install_model(model_id, remote)?)
},
}
}
fn warn_if_short_context(model_id: &str) {
let Some(spec) = rag_rat_core::embedding_models::spec(model_id) else { return };
let (Some(max_tokens), Some(model_chars)) = (spec.max_tokens, spec.max_input_chars()) else {
return;
};
if model_chars < rag_rat_core::index::ai::DEFAULT_MAX_EMBEDDING_CHARS {
eprintln!(
"note: {model_id} has a {max_tokens}-token context, so code chunks longer than that \
are truncated — their tail is not embedded, costing precision/recall on large \
functions. For code, a long-context model like jinaai/jina-embeddings-v2-base-code \
(8192 tokens) embeds whole chunks."
);
}
}
#[cfg(test)]
mod tests {
use std::path::PathBuf;
use std::sync::atomic::{AtomicU64, Ordering};
use rag_rat_core::Config;
static N: AtomicU64 = AtomicU64::new(0);
fn config_with_remote(model: &str, with_remote: bool) -> (PathBuf, Config) {
let root = std::env::temp_dir().join(format!(
"rag-rat-cli-remote-{}-{}",
std::process::id(),
N.fetch_add(1, Ordering::Relaxed)
));
let _ = std::fs::remove_dir_all(&root);
std::fs::create_dir_all(root.join("src")).unwrap();
std::fs::write(root.join("src/a.rs"), "pub fn a() {}\n").unwrap();
let remote = if with_remote {
"\n[llm.embedding.remote]\nendpoint = \"http://127.0.0.1:1\"\nmodel = \"all-minilm\"\n"
} else {
""
};
std::fs::write(
root.join("rag-rat.toml"),
format!(
"[index]\nroot = \".\"\n\n[target_bindings]\nrust = \
[\"src\"]\n\n[llm.embedding]\nmodel = \"{model}\"\n{remote}"
),
)
.unwrap();
let config = Config::load(root.join("rag-rat.toml")).unwrap();
(root, config)
}
#[test]
fn remote_for_install_only_applies_the_remote_block_to_the_configured_model() {
let (root, config) = config_with_remote("sentence-transformers/all-MiniLM-L6-v2", true);
assert!(
super::remote_for_install(&config, "sentence-transformers/all-MiniLM-L6-v2")
.unwrap()
.is_some(),
"the configured model installs over the remote",
);
let err = super::remote_for_install(&config, "BAAI/bge-small-en-v1.5")
.expect_err("a different model than the configured one must be rejected");
let msg = err.to_string();
assert!(msg.contains("remote embedding is configured for"), "{msg}");
assert!(msg.contains("sentence-transformers/all-MiniLM-L6-v2"), "names configured: {msg}");
assert!(msg.contains("BAAI/bge-small-en-v1.5"), "names requested: {msg}");
let _ = std::fs::remove_dir_all(&root);
}
#[test]
fn remote_for_install_returns_none_without_a_remote_block() {
let (root, config) = config_with_remote("sentence-transformers/all-MiniLM-L6-v2", false);
assert!(super::remote_for_install(&config, "BAAI/bge-small-en-v1.5").unwrap().is_none());
assert!(super::remote_for_install(&config, "embedding-hash").unwrap().is_none());
let _ = std::fs::remove_dir_all(&root);
}
}