Skip to main content

ferrum_cli/commands/
embed.rs

1//! Embed command - Generate embeddings using BERT models
2
3use crate::config::CliConfig;
4use candle_core::Device as CandleDevice;
5use clap::Args;
6use colored::Colorize;
7use ferrum_models::source::{ModelFormat, ResolvedModelSource};
8use ferrum_models::HfDownloader;
9use ferrum_models::{BertModelExecutor, ConfigManager};
10use ferrum_types::Result;
11use std::io::{self, BufRead};
12use std::path::PathBuf;
13
14/// Generate embeddings using BERT or CLIP models
15#[derive(Args, Debug)]
16pub struct EmbedCommand {
17    /// Model name (e.g., google-bert/bert-base-chinese, OFA-Sys/chinese-clip-vit-base-patch16)
18    #[arg(required = true)]
19    pub model: String,
20
21    /// Text to embed (if not provided, reads from stdin)
22    #[arg(short, long)]
23    pub text: Option<String>,
24
25    /// Image path to embed (CLIP models only)
26    #[arg(short, long)]
27    pub image: Option<String>,
28
29    /// Output format: json, csv, or raw
30    #[arg(short, long, default_value = "json")]
31    pub format: String,
32
33    /// Normalize embeddings to unit length
34    #[arg(short, long, default_value = "true")]
35    pub normalize: bool,
36}
37
38pub async fn execute(cmd: EmbedCommand, config: CliConfig) -> Result<()> {
39    eprintln!("{}", format!("Loading {}...", cmd.model).dimmed());
40
41    // Resolve model path using same logic as list/run commands
42    let model_id = cmd.model.clone();
43    let cache_dir = get_hf_cache_dir(&config);
44
45    let source = match find_cached_model(&cache_dir, &model_id) {
46        Some(source) => source,
47        None => {
48            eprintln!(
49                "{} Model '{}' not found locally, downloading...",
50                "📥".cyan(),
51                model_id
52            );
53
54            let token = std::env::var("HF_TOKEN")
55                .or_else(|_| std::env::var("HUGGING_FACE_HUB_TOKEN"))
56                .ok();
57
58            let downloader = HfDownloader::new(cache_dir, token)?;
59            let snapshot_path = downloader.download(&model_id, None).await?;
60
61            let format = detect_format(&snapshot_path);
62            if format == ModelFormat::Unknown {
63                return Err(ferrum_types::FerrumError::model(
64                    "Downloaded model has unknown format",
65                ));
66            }
67
68            ResolvedModelSource {
69                original: model_id.clone(),
70                local_path: snapshot_path,
71                format,
72                from_cache: false,
73            }
74        }
75    };
76
77    let model_path = source.local_path.to_string_lossy().to_string();
78    eprintln!("{}", "Using CPU backend".dimmed());
79
80    // Load model definition to detect architecture
81    let mut config_manager = ConfigManager::new();
82    let model_def = config_manager.load_from_path(&source.local_path).await?;
83
84    let device = CandleDevice::Cpu;
85    let is_clip = model_def.architecture == ferrum_models::Architecture::Clip;
86
87    let mut all_embeddings: Vec<(String, Vec<f32>)> = Vec::new();
88
89    if is_clip {
90        // CLIP path: supports both text and image
91        let executor = ferrum_models::ClipModelExecutor::from_path(
92            &model_path,
93            device.clone(),
94            candle_core::DType::F32,
95        )?;
96        eprintln!("{}", "CLIP model loaded.".green());
97
98        if let Some(ref image_path) = cmd.image {
99            let embedding_tensor = executor.embed_image_path(image_path)?;
100            let embedding = tensor_to_vec(&embedding_tensor, cmd.normalize)?;
101            all_embeddings.push((format!("[image] {image_path}"), embedding));
102        }
103
104        let texts = collect_texts(&cmd)?;
105        if !texts.is_empty() {
106            let tokenizer = load_tokenizer(&source.local_path)?;
107
108            for text in &texts {
109                let encoding = tokenizer
110                    .encode(text.as_str(), true)
111                    .map_err(|e| ferrum_types::FerrumError::model(format!("Tokenize: {e}")))?;
112                let embedding_tensor = executor.embed_text(encoding.get_ids())?;
113                let embedding = tensor_to_vec(&embedding_tensor, cmd.normalize)?;
114                all_embeddings.push((text.clone(), embedding));
115            }
116        }
117
118        if all_embeddings.is_empty() {
119            eprintln!("{}", "No input provided. Use --text or --image.".yellow());
120            return Ok(());
121        }
122    } else {
123        // BERT path (existing)
124        let executor = BertModelExecutor::from_path(&model_path, &model_def, device).await?;
125        eprintln!("{}", "BERT model loaded.".green());
126
127        let tokenizer = load_tokenizer(&source.local_path)?;
128
129        let texts = collect_texts(&cmd)?;
130        if texts.is_empty() {
131            eprintln!("{}", "No text provided.".yellow());
132            return Ok(());
133        }
134
135        for text in &texts {
136            let encoding = tokenizer
137                .encode(text.as_str(), true)
138                .map_err(|e| ferrum_types::FerrumError::model(format!("Tokenize: {e}")))?;
139            let embedding_tensor = executor.get_embeddings(encoding.get_ids())?;
140            let embedding = tensor_to_vec(&embedding_tensor, cmd.normalize)?;
141            all_embeddings.push((text.clone(), embedding));
142        }
143    }
144
145    // Output embeddings
146    match cmd.format.as_str() {
147        "json" => {
148            let output: Vec<serde_json::Value> = all_embeddings
149                .iter()
150                .map(|(text, emb)| {
151                    serde_json::json!({
152                        "text": text,
153                        "embedding": emb,
154                        "dimensions": emb.len()
155                    })
156                })
157                .collect();
158            println!("{}", serde_json::to_string_pretty(&output).unwrap());
159        }
160        "csv" => {
161            if let Some((_, first_emb)) = all_embeddings.first() {
162                let header: Vec<String> =
163                    (0..first_emb.len()).map(|i| format!("dim_{}", i)).collect();
164                println!("text,{}", header.join(","));
165            }
166            for (text, emb) in &all_embeddings {
167                let emb_str: Vec<String> = emb.iter().map(|v| format!("{:.6}", v)).collect();
168                println!("\"{}\",{}", text.replace("\"", "\\\""), emb_str.join(","));
169            }
170        }
171        "raw" => {
172            for (text, emb) in &all_embeddings {
173                eprintln!("{}: {} dimensions", text.dimmed(), emb.len());
174                let preview: Vec<String> =
175                    emb.iter().take(5).map(|v| format!("{:.4}", v)).collect();
176                println!("[{}, ...]", preview.join(", "));
177            }
178        }
179        _ => {
180            eprintln!("{}", format!("Unknown format: {}", cmd.format).red());
181        }
182    }
183
184    Ok(())
185}
186
187/// Load tokenizer: try tokenizer.json first, fall back to vocab.txt (BERT-style).
188pub fn load_tokenizer(model_dir: &std::path::Path) -> Result<tokenizers::Tokenizer> {
189    let tokenizer_json = model_dir.join("tokenizer.json");
190    if tokenizer_json.exists() {
191        return tokenizers::Tokenizer::from_file(&tokenizer_json)
192            .map_err(|e| ferrum_types::FerrumError::model(format!("Load tokenizer.json: {e}")));
193    }
194
195    // Fall back to vocab.txt (Chinese-CLIP, older BERT models)
196    let vocab_txt = model_dir.join("vocab.txt");
197    if vocab_txt.exists() {
198        use tokenizers::models::wordpiece::WordPiece;
199        use tokenizers::processors::template::TemplateProcessing;
200        use tokenizers::Model;
201        let wp = WordPiece::from_file(vocab_txt.to_str().unwrap())
202            .unk_token("[UNK]".to_string())
203            .build()
204            .map_err(|e| ferrum_types::FerrumError::model(format!("Load vocab.txt: {e}")))?;
205
206        // Look up [CLS] and [SEP] IDs from vocab dynamically
207        let vocab = wp.get_vocab();
208        let cls_id = vocab.get("[CLS]").copied().unwrap_or(101);
209        let sep_id = vocab.get("[SEP]").copied().unwrap_or(102);
210
211        let mut tokenizer = tokenizers::Tokenizer::new(wp);
212        // BertPreTokenizer + Chinese char splitting (add spaces around CJK chars
213        // so WordPiece treats each character as a separate token, matching Python's
214        // BertTokenizer._tokenize_chinese_chars behavior)
215        use tokenizers::pre_tokenizers::sequence::Sequence;
216        use tokenizers::pre_tokenizers::unicode_scripts::UnicodeScripts;
217        tokenizer.with_pre_tokenizer(Some(Sequence::new(vec![
218            tokenizers::pre_tokenizers::PreTokenizerWrapper::UnicodeScripts(UnicodeScripts),
219            tokenizers::pre_tokenizers::PreTokenizerWrapper::BertPreTokenizer(
220                tokenizers::pre_tokenizers::bert::BertPreTokenizer,
221            ),
222        ])));
223        let template = TemplateProcessing::builder()
224            .try_single("[CLS] $A [SEP]")
225            .unwrap()
226            .special_tokens(vec![("[CLS]", cls_id), ("[SEP]", sep_id)])
227            .build()
228            .map_err(|e| ferrum_types::FerrumError::model(format!("Template: {e}")))?;
229        tokenizer.with_post_processor(Some(template));
230        return Ok(tokenizer);
231    }
232
233    Err(ferrum_types::FerrumError::model(
234        "No tokenizer.json or vocab.txt found in model directory",
235    ))
236}
237
238fn collect_texts(cmd: &EmbedCommand) -> Result<Vec<String>> {
239    if let Some(ref text) = cmd.text {
240        Ok(vec![text.clone()])
241    } else if cmd.image.is_some() {
242        // Image-only mode, no text needed
243        Ok(vec![])
244    } else {
245        eprintln!(
246            "{}",
247            "Reading text from stdin (one per line, Ctrl+D to finish):".dimmed()
248        );
249        let stdin = io::stdin();
250        Ok(stdin.lock().lines().filter_map(|l| l.ok()).collect())
251    }
252}
253
254fn tensor_to_vec(tensor: &candle_core::Tensor, normalize: bool) -> Result<Vec<f32>> {
255    let mut embedding = tensor
256        .flatten_all()
257        .map_err(|e| ferrum_types::FerrumError::model(format!("Flatten: {e}")))?
258        .to_vec1::<f32>()
259        .map_err(|e| ferrum_types::FerrumError::model(format!("to_vec1: {e}")))?;
260
261    if normalize {
262        let norm: f32 = embedding.iter().map(|x| x * x).sum::<f32>().sqrt();
263        if norm > 0.0 {
264            for v in &mut embedding {
265                *v /= norm;
266            }
267        }
268    }
269    Ok(embedding)
270}
271
272fn find_cached_model(cache_dir: &PathBuf, model_id: &str) -> Option<ResolvedModelSource> {
273    // HuggingFace cache structure: hub/models--Org--ModelName/snapshots/<hash>/
274    let hub_dir = cache_dir.join("hub");
275    let model_dir_name = format!("models--{}", model_id.replace("/", "--"));
276    let model_dir = hub_dir.join(&model_dir_name);
277
278    if model_dir.exists() {
279        // Find the latest snapshot
280        let snapshots_dir = model_dir.join("snapshots");
281        if snapshots_dir.exists() {
282            if let Ok(entries) = std::fs::read_dir(&snapshots_dir) {
283                for entry in entries.filter_map(|e| e.ok()) {
284                    let snapshot_path = entry.path();
285                    if snapshot_path.is_dir() && snapshot_path.join("config.json").exists() {
286                        let format = detect_format(&snapshot_path);
287                        if format != ModelFormat::Unknown {
288                            return Some(ResolvedModelSource {
289                                original: model_id.to_string(),
290                                local_path: snapshot_path,
291                                format,
292                                from_cache: true,
293                            });
294                        }
295                    }
296                }
297            }
298        }
299    }
300
301    // Also check direct path (for models downloaded to custom locations)
302    let direct = cache_dir.join(model_id);
303    if direct.exists() && direct.join("config.json").exists() {
304        let format = detect_format(&direct);
305        if format != ModelFormat::Unknown {
306            return Some(ResolvedModelSource {
307                original: model_id.to_string(),
308                local_path: direct,
309                format,
310                from_cache: true,
311            });
312        }
313    }
314
315    None
316}
317
318fn get_hf_cache_dir(config: &CliConfig) -> PathBuf {
319    if let Ok(hf_home) = std::env::var("HF_HOME") {
320        return PathBuf::from(hf_home);
321    }
322    let configured = shellexpand::tilde(&config.models.download.hf_cache_dir).to_string();
323    PathBuf::from(configured)
324}
325
326fn detect_format(path: &PathBuf) -> ModelFormat {
327    if path.join("model.safetensors").exists() {
328        ModelFormat::SafeTensors
329    } else if std::fs::read_dir(path)
330        .map(|d| {
331            d.filter_map(|e| e.ok()).any(|e| {
332                e.path()
333                    .extension()
334                    .map_or(false, |ext| ext == "safetensors")
335            })
336        })
337        .unwrap_or(false)
338    {
339        ModelFormat::SafeTensors
340    } else if path.join("pytorch_model.bin").exists() {
341        ModelFormat::PyTorchBin
342    } else {
343        ModelFormat::Unknown
344    }
345}