use std::path::PathBuf;
use std::time::Instant;
use anyhow::Result;
use clap::Args;
use tldr_core::semantic::{
chunk_code, CacheConfig, ChunkGranularity, ChunkOptions, EmbedReport, EmbeddedChunk, Embedder,
EmbeddingCache, EmbeddingModel,
};
use crate::output::{OutputFormat, OutputWriter};
#[derive(Debug, Args)]
pub struct EmbedArgs {
pub path: PathBuf,
#[arg(short, long)]
pub output: Option<PathBuf>,
#[arg(short, long, default_value = "function")]
pub granularity: String,
#[arg(short, long, default_value = "arctic-m")]
pub model: String,
#[arg(long)]
pub lang: Option<Vec<String>>,
#[arg(long)]
pub include_vectors: bool,
#[arg(long)]
pub no_cache: bool,
}
impl EmbedArgs {
pub fn run(&self, format: OutputFormat, quiet: bool) -> Result<()> {
let writer = OutputWriter::new(format, quiet);
let start = Instant::now();
let model = parse_model(&self.model)?;
let granularity = match self.granularity.as_str() {
"file" => ChunkGranularity::File,
"function" => ChunkGranularity::Function,
_ => {
return Err(anyhow::anyhow!(
"Invalid granularity '{}'. Use 'file' or 'function'.",
self.granularity
))
}
};
writer.progress(&format!(
"Embedding code in {} ({:?} granularity, {} model)...",
self.path.display(),
granularity,
self.model
));
let languages = self.lang.as_ref().map(|langs| {
langs
.iter()
.filter_map(|s| tldr_core::Language::from_extension(s))
.collect()
});
let chunk_opts = ChunkOptions {
granularity,
languages,
..Default::default()
};
let chunk_result = chunk_code(&self.path, &chunk_opts)?;
writer.progress(&format!(
"Found {} chunks, generating embeddings...",
chunk_result.chunks.len()
));
let mut cache = if self.no_cache {
None
} else {
Some(EmbeddingCache::open(CacheConfig::default())?)
};
let mut cache_hits = 0usize;
let mut cache_misses = 0usize;
let mut embedded_chunks: Vec<EmbeddedChunk> = Vec::with_capacity(chunk_result.chunks.len());
let mut uncached_indices: Vec<usize> = Vec::new();
for (i, chunk) in chunk_result.chunks.iter().enumerate() {
if let Some(ref mut c) = cache {
if let Some(e) = c.get(chunk, model) {
cache_hits += 1;
embedded_chunks.push(EmbeddedChunk {
chunk: chunk.clone(),
embedding: e,
});
continue;
}
}
cache_misses += 1;
embedded_chunks.push(EmbeddedChunk {
chunk: chunk.clone(),
embedding: Vec::new(),
});
uncached_indices.push(i);
}
if !uncached_indices.is_empty() {
let mut embedder = Embedder::new(model)?;
let texts: Vec<&str> = uncached_indices
.iter()
.map(|&i| chunk_result.chunks[i].content.as_str())
.collect();
let embeddings = embedder.embed_batch(texts, true)?;
for (idx, embedding) in uncached_indices.iter().zip(embeddings) {
if let Some(ref mut c) = cache {
c.put(&chunk_result.chunks[*idx], embedding.clone(), model);
}
embedded_chunks[*idx].embedding = embedding;
}
}
if let Some(ref mut c) = cache {
c.flush()?;
}
let latency_ms = start.elapsed().as_millis() as u64;
let report = EmbedReport {
path: self.path.clone(),
model,
granularity,
chunks_embedded: cache_misses,
chunks_cached: cache_hits,
chunks: if self.include_vectors {
Some(embedded_chunks)
} else {
None
},
latency_ms,
};
writer.progress(&format!(
"Embedded {} chunks ({} cached, {} new) in {}ms",
cache_hits + cache_misses,
cache_hits,
cache_misses,
latency_ms
));
if let Some(ref output_path) = self.output {
let file = std::fs::File::create(output_path)?;
serde_json::to_writer_pretty(file, &report)?;
writer.progress(&format!("Output written to {}", output_path.display()));
} else {
writer.write(&report)?;
}
Ok(())
}
}
fn parse_model(model_str: &str) -> Result<EmbeddingModel> {
match model_str {
"arctic-xs" | "xs" => Ok(EmbeddingModel::ArcticXS),
"arctic-s" | "s" => Ok(EmbeddingModel::ArcticS),
"arctic-m" | "m" => Ok(EmbeddingModel::ArcticM),
"arctic-m-long" | "m-long" => Ok(EmbeddingModel::ArcticMLong),
"arctic-l" | "l" => Ok(EmbeddingModel::ArcticL),
_ => Err(anyhow::anyhow!(
"Invalid model '{}'. Options: arctic-xs, arctic-s, arctic-m, arctic-m-long, arctic-l",
model_str
)),
}
}