1use crate::config::CliConfig;
4use candle_core::Device as CandleDevice;
5use clap::Args;
6use colored::Colorize;
7use ferrum_models::source::{ModelFormat, ResolvedModelSource};
8use ferrum_models::HfDownloader;
9use ferrum_models::{BertModelExecutor, ConfigManager};
10use ferrum_types::Result;
11use std::io::{self, BufRead};
12use std::path::PathBuf;
13
14#[derive(Args, Debug)]
16pub struct EmbedCommand {
17 #[arg(required = true)]
19 pub model: String,
20
21 #[arg(short, long)]
23 pub text: Option<String>,
24
25 #[arg(short, long)]
27 pub image: Option<String>,
28
29 #[arg(short, long, default_value = "json")]
31 pub format: String,
32
33 #[arg(short, long, default_value = "true")]
35 pub normalize: bool,
36}
37
38pub async fn execute(cmd: EmbedCommand, config: CliConfig) -> Result<()> {
39 eprintln!("{}", format!("Loading {}...", cmd.model).dimmed());
40
41 let model_id = cmd.model.clone();
43 let cache_dir = get_hf_cache_dir(&config);
44
45 let source = match find_cached_model(&cache_dir, &model_id) {
46 Some(source) => source,
47 None => {
48 eprintln!(
49 "{} Model '{}' not found locally, downloading...",
50 "📥".cyan(),
51 model_id
52 );
53
54 let token = std::env::var("HF_TOKEN")
55 .or_else(|_| std::env::var("HUGGING_FACE_HUB_TOKEN"))
56 .ok();
57
58 let downloader = HfDownloader::new(cache_dir, token)?;
59 let snapshot_path = downloader.download(&model_id, None).await?;
60
61 let format = detect_format(&snapshot_path);
62 if format == ModelFormat::Unknown {
63 return Err(ferrum_types::FerrumError::model(
64 "Downloaded model has unknown format",
65 ));
66 }
67
68 ResolvedModelSource {
69 original: model_id.clone(),
70 local_path: snapshot_path,
71 format,
72 from_cache: false,
73 }
74 }
75 };
76
77 let model_path = source.local_path.to_string_lossy().to_string();
78 eprintln!("{}", "Using CPU backend".dimmed());
79
80 let mut config_manager = ConfigManager::new();
82 let model_def = config_manager.load_from_path(&source.local_path).await?;
83
84 let device = CandleDevice::Cpu;
85 let is_clip = model_def.architecture == ferrum_models::Architecture::Clip;
86
87 let mut all_embeddings: Vec<(String, Vec<f32>)> = Vec::new();
88
89 if is_clip {
90 let executor = ferrum_models::ClipModelExecutor::from_path(
92 &model_path,
93 device.clone(),
94 candle_core::DType::F32,
95 )?;
96 eprintln!("{}", "CLIP model loaded.".green());
97
98 if let Some(ref image_path) = cmd.image {
99 let embedding_tensor = executor.embed_image_path(image_path)?;
100 let embedding = tensor_to_vec(&embedding_tensor, cmd.normalize)?;
101 all_embeddings.push((format!("[image] {image_path}"), embedding));
102 }
103
104 let texts = collect_texts(&cmd)?;
105 if !texts.is_empty() {
106 let tokenizer = load_tokenizer(&source.local_path)?;
107
108 for text in &texts {
109 let encoding = tokenizer
110 .encode(text.as_str(), true)
111 .map_err(|e| ferrum_types::FerrumError::model(format!("Tokenize: {e}")))?;
112 let embedding_tensor = executor.embed_text(encoding.get_ids())?;
113 let embedding = tensor_to_vec(&embedding_tensor, cmd.normalize)?;
114 all_embeddings.push((text.clone(), embedding));
115 }
116 }
117
118 if all_embeddings.is_empty() {
119 eprintln!("{}", "No input provided. Use --text or --image.".yellow());
120 return Ok(());
121 }
122 } else {
123 let executor = BertModelExecutor::from_path(&model_path, &model_def, device).await?;
125 eprintln!("{}", "BERT model loaded.".green());
126
127 let tokenizer = load_tokenizer(&source.local_path)?;
128
129 let texts = collect_texts(&cmd)?;
130 if texts.is_empty() {
131 eprintln!("{}", "No text provided.".yellow());
132 return Ok(());
133 }
134
135 for text in &texts {
136 let encoding = tokenizer
137 .encode(text.as_str(), true)
138 .map_err(|e| ferrum_types::FerrumError::model(format!("Tokenize: {e}")))?;
139 let embedding_tensor = executor.get_embeddings(encoding.get_ids())?;
140 let embedding = tensor_to_vec(&embedding_tensor, cmd.normalize)?;
141 all_embeddings.push((text.clone(), embedding));
142 }
143 }
144
145 match cmd.format.as_str() {
147 "json" => {
148 let output: Vec<serde_json::Value> = all_embeddings
149 .iter()
150 .map(|(text, emb)| {
151 serde_json::json!({
152 "text": text,
153 "embedding": emb,
154 "dimensions": emb.len()
155 })
156 })
157 .collect();
158 println!("{}", serde_json::to_string_pretty(&output).unwrap());
159 }
160 "csv" => {
161 if let Some((_, first_emb)) = all_embeddings.first() {
162 let header: Vec<String> =
163 (0..first_emb.len()).map(|i| format!("dim_{}", i)).collect();
164 println!("text,{}", header.join(","));
165 }
166 for (text, emb) in &all_embeddings {
167 let emb_str: Vec<String> = emb.iter().map(|v| format!("{:.6}", v)).collect();
168 println!("\"{}\",{}", text.replace("\"", "\\\""), emb_str.join(","));
169 }
170 }
171 "raw" => {
172 for (text, emb) in &all_embeddings {
173 eprintln!("{}: {} dimensions", text.dimmed(), emb.len());
174 let preview: Vec<String> =
175 emb.iter().take(5).map(|v| format!("{:.4}", v)).collect();
176 println!("[{}, ...]", preview.join(", "));
177 }
178 }
179 _ => {
180 eprintln!("{}", format!("Unknown format: {}", cmd.format).red());
181 }
182 }
183
184 Ok(())
185}
186
187pub fn load_tokenizer(model_dir: &std::path::Path) -> Result<tokenizers::Tokenizer> {
189 let tokenizer_json = model_dir.join("tokenizer.json");
190 if tokenizer_json.exists() {
191 return tokenizers::Tokenizer::from_file(&tokenizer_json)
192 .map_err(|e| ferrum_types::FerrumError::model(format!("Load tokenizer.json: {e}")));
193 }
194
195 let vocab_txt = model_dir.join("vocab.txt");
197 if vocab_txt.exists() {
198 use tokenizers::models::wordpiece::WordPiece;
199 use tokenizers::processors::template::TemplateProcessing;
200 use tokenizers::Model;
201 let wp = WordPiece::from_file(vocab_txt.to_str().unwrap())
202 .unk_token("[UNK]".to_string())
203 .build()
204 .map_err(|e| ferrum_types::FerrumError::model(format!("Load vocab.txt: {e}")))?;
205
206 let vocab = wp.get_vocab();
208 let cls_id = vocab.get("[CLS]").copied().unwrap_or(101);
209 let sep_id = vocab.get("[SEP]").copied().unwrap_or(102);
210
211 let mut tokenizer = tokenizers::Tokenizer::new(wp);
212 use tokenizers::pre_tokenizers::sequence::Sequence;
216 use tokenizers::pre_tokenizers::unicode_scripts::UnicodeScripts;
217 tokenizer.with_pre_tokenizer(Some(Sequence::new(vec![
218 tokenizers::pre_tokenizers::PreTokenizerWrapper::UnicodeScripts(UnicodeScripts),
219 tokenizers::pre_tokenizers::PreTokenizerWrapper::BertPreTokenizer(
220 tokenizers::pre_tokenizers::bert::BertPreTokenizer,
221 ),
222 ])));
223 let template = TemplateProcessing::builder()
224 .try_single("[CLS] $A [SEP]")
225 .unwrap()
226 .special_tokens(vec![("[CLS]", cls_id), ("[SEP]", sep_id)])
227 .build()
228 .map_err(|e| ferrum_types::FerrumError::model(format!("Template: {e}")))?;
229 tokenizer.with_post_processor(Some(template));
230 return Ok(tokenizer);
231 }
232
233 Err(ferrum_types::FerrumError::model(
234 "No tokenizer.json or vocab.txt found in model directory",
235 ))
236}
237
238fn collect_texts(cmd: &EmbedCommand) -> Result<Vec<String>> {
239 if let Some(ref text) = cmd.text {
240 Ok(vec![text.clone()])
241 } else if cmd.image.is_some() {
242 Ok(vec![])
244 } else {
245 eprintln!(
246 "{}",
247 "Reading text from stdin (one per line, Ctrl+D to finish):".dimmed()
248 );
249 let stdin = io::stdin();
250 Ok(stdin.lock().lines().filter_map(|l| l.ok()).collect())
251 }
252}
253
254fn tensor_to_vec(tensor: &candle_core::Tensor, normalize: bool) -> Result<Vec<f32>> {
255 let mut embedding = tensor
256 .flatten_all()
257 .map_err(|e| ferrum_types::FerrumError::model(format!("Flatten: {e}")))?
258 .to_vec1::<f32>()
259 .map_err(|e| ferrum_types::FerrumError::model(format!("to_vec1: {e}")))?;
260
261 if normalize {
262 let norm: f32 = embedding.iter().map(|x| x * x).sum::<f32>().sqrt();
263 if norm > 0.0 {
264 for v in &mut embedding {
265 *v /= norm;
266 }
267 }
268 }
269 Ok(embedding)
270}
271
272fn find_cached_model(cache_dir: &PathBuf, model_id: &str) -> Option<ResolvedModelSource> {
273 let hub_dir = cache_dir.join("hub");
275 let model_dir_name = format!("models--{}", model_id.replace("/", "--"));
276 let model_dir = hub_dir.join(&model_dir_name);
277
278 if model_dir.exists() {
279 let snapshots_dir = model_dir.join("snapshots");
281 if snapshots_dir.exists() {
282 if let Ok(entries) = std::fs::read_dir(&snapshots_dir) {
283 for entry in entries.filter_map(|e| e.ok()) {
284 let snapshot_path = entry.path();
285 if snapshot_path.is_dir() && snapshot_path.join("config.json").exists() {
286 let format = detect_format(&snapshot_path);
287 if format != ModelFormat::Unknown {
288 return Some(ResolvedModelSource {
289 original: model_id.to_string(),
290 local_path: snapshot_path,
291 format,
292 from_cache: true,
293 });
294 }
295 }
296 }
297 }
298 }
299 }
300
301 let direct = cache_dir.join(model_id);
303 if direct.exists() && direct.join("config.json").exists() {
304 let format = detect_format(&direct);
305 if format != ModelFormat::Unknown {
306 return Some(ResolvedModelSource {
307 original: model_id.to_string(),
308 local_path: direct,
309 format,
310 from_cache: true,
311 });
312 }
313 }
314
315 None
316}
317
318fn get_hf_cache_dir(config: &CliConfig) -> PathBuf {
319 if let Ok(hf_home) = std::env::var("HF_HOME") {
320 return PathBuf::from(hf_home);
321 }
322 let configured = shellexpand::tilde(&config.models.download.hf_cache_dir).to_string();
323 PathBuf::from(configured)
324}
325
326fn detect_format(path: &PathBuf) -> ModelFormat {
327 if path.join("model.safetensors").exists() {
328 ModelFormat::SafeTensors
329 } else if std::fs::read_dir(path)
330 .map(|d| {
331 d.filter_map(|e| e.ok()).any(|e| {
332 e.path()
333 .extension()
334 .map_or(false, |ext| ext == "safetensors")
335 })
336 })
337 .unwrap_or(false)
338 {
339 ModelFormat::SafeTensors
340 } else if path.join("pytorch_model.bin").exists() {
341 ModelFormat::PyTorchBin
342 } else {
343 ModelFormat::Unknown
344 }
345}