use std::process::ExitCode;
use std::str::FromStr;
use std::time::Instant;
use lattice_embed::{EmbeddingModel, EmbeddingService, NativeEmbeddingService};
fn usage(msg: &str) -> ExitCode {
eprintln!("ERROR: {msg}\n");
eprintln!("{USAGE}");
ExitCode::FAILURE
}
const USAGE: &str = "\
usage: embed [--model <NAME>] --text <TEXT> [--text <TEXT> ...] [--json]
Generate embeddings for one or more text strings.
options:
--model <NAME> Embedding model to use. Default: bge-small-en-v1.5
Accepted: bge-small-en-v1.5, bge-base-en-v1.5, bge-large-en-v1.5,
multilingual-e5-small, multilingual-e5-base, all-minilm-l6-v2,
paraphrase-multilingual-minilm-l12-v2
Also accepts HuggingFace IDs like BAAI/bge-small-en-v1.5.
--text <TEXT> Text to embed. Repeat for multiple texts.
--json Emit a structured @@lattice {\"ev\":\"embed_done\",...} line to stdout.
--download-only Ensure the model is downloaded and loadable, then exit (no --text needed).
Emits @@lattice {\"ev\":\"download_done\",\"ok\":bool} with --json.
-h, --help Print this help and exit.
";
#[tokio::main]
async fn main() -> ExitCode {
let args: Vec<String> = std::env::args().collect();
let mut model_name: Option<String> = None;
let mut texts: Vec<String> = Vec::new();
let mut emit_json: bool = false;
let mut download_only: bool = false;
let mut i = 1;
while i < args.len() {
match args[i].as_str() {
"--model" => {
i += 1;
let Some(v) = args.get(i) else {
return usage("--model requires an argument");
};
model_name = Some(v.clone());
}
"--text" => {
i += 1;
let Some(v) = args.get(i) else {
return usage("--text requires an argument");
};
texts.push(v.clone());
}
"--json" => {
emit_json = true;
}
"--download-only" => {
download_only = true;
}
"--help" | "-h" => {
eprintln!("{USAGE}");
return ExitCode::SUCCESS;
}
other => return usage(&format!("unknown argument: {other}")),
}
i += 1;
}
if !download_only && texts.is_empty() {
return usage("at least one --text argument is required");
}
let model = match model_name {
Some(ref name) => match EmbeddingModel::from_str(name) {
Ok(m) => m,
Err(_) => {
return usage(&format!(
"--model '{name}' is not a recognised embedding model"
));
}
},
None => EmbeddingModel::default(),
};
eprintln!("Model: {model}");
eprintln!("Dimensions: {}", model.dimensions());
eprintln!("Texts: {}", texts.len());
eprintln!();
eprintln!("Generating embeddings (model loads on first call — may download ~130 MB)...");
let service = NativeEmbeddingService::with_model(model);
if download_only {
match service.ensure_loaded().await {
Ok(()) => {
eprintln!("Model {model} is downloaded and ready.");
if emit_json {
let obj = serde_json::json!({
"ev": "download_done",
"model": model.to_string(),
"ok": true,
});
println!("@@lattice {obj}");
}
return ExitCode::SUCCESS;
}
Err(err) => {
eprintln!("ERROR: model download/load failed: {err}");
if emit_json {
let obj = serde_json::json!({
"ev": "download_done",
"model": model.to_string(),
"ok": false,
"error": err.to_string(),
});
println!("@@lattice {obj}");
}
return ExitCode::FAILURE;
}
}
}
let t0 = Instant::now();
let embeddings = match service.embed(&texts, model).await {
Ok(e) => e,
Err(err) => {
eprintln!("ERROR: embedding failed: {err}");
return ExitCode::FAILURE;
}
};
let elapsed_ms = t0.elapsed().as_millis();
if embeddings.is_empty() {
eprintln!("ERROR: service returned zero embeddings");
return ExitCode::FAILURE;
}
let dims = embeddings[0].len();
let count = embeddings.len();
let mut cosine: Vec<Vec<f32>> = Vec::with_capacity(count);
for i in 0..count {
let mut row = Vec::with_capacity(count);
for j in 0..count {
let sim = lattice_embed::utils::cosine_similarity(&embeddings[i], &embeddings[j]);
row.push(sim);
}
cosine.push(row);
}
let preview_len = dims.min(8);
let preview: Vec<Vec<f32>> = embeddings
.iter()
.map(|e| e[..preview_len].to_vec())
.collect();
eprintln!("=== Embedding Results ===");
eprintln!("Dims: {dims}");
eprintln!("Count: {count}");
eprintln!("Elapsed: {elapsed_ms}ms");
eprintln!();
eprintln!("Pairwise cosine similarity:");
for (i, row) in cosine.iter().enumerate() {
let vals: Vec<String> = row.iter().map(|v| format!("{v:.4}")).collect();
eprintln!(" [{i}] {}", vals.join(" "));
}
if emit_json {
let obj = serde_json::json!({
"ev": "embed_done",
"model": model.to_string(),
"dims": dims,
"count": count,
"cosine": cosine,
"preview": preview,
"ms": elapsed_ms,
});
println!("@@lattice {obj}");
}
ExitCode::SUCCESS
}