Skip to main content

ferrum_cli/commands/
embed.rs

1//! Embed command - Generate embeddings using BERT models
2
3use crate::config::CliConfig;
4use candle_core::Device as CandleDevice;
5use clap::Args;
6use colored::Colorize;
7use ferrum_models::source::{ModelFormat, ResolvedModelSource};
8use ferrum_models::HfDownloader;
9use ferrum_models::{BertModelExecutor, ConfigManager};
10use ferrum_types::Result;
11use std::io::{self, BufRead};
12use std::path::PathBuf;
13
14/// Generate embeddings using a BERT model
15#[derive(Args, Debug)]
16pub struct EmbedCommand {
17    /// Model name (e.g., google-bert/bert-base-chinese)
18    #[arg(required = true)]
19    pub model: String,
20
21    /// Text to embed (if not provided, reads from stdin)
22    #[arg(short, long)]
23    pub text: Option<String>,
24
25    /// Output format: json, csv, or raw
26    #[arg(short, long, default_value = "json")]
27    pub format: String,
28
29    /// Normalize embeddings to unit length
30    #[arg(short, long, default_value = "true")]
31    pub normalize: bool,
32}
33
34pub async fn execute(cmd: EmbedCommand, config: CliConfig) -> Result<()> {
35    eprintln!("{}", format!("Loading {}...", cmd.model).dimmed());
36
37    // Resolve model path using same logic as list/run commands
38    let model_id = cmd.model.clone();
39    let cache_dir = get_hf_cache_dir(&config);
40
41    let source = match find_cached_model(&cache_dir, &model_id) {
42        Some(source) => source,
43        None => {
44            eprintln!(
45                "{} Model '{}' not found locally, downloading...",
46                "📥".cyan(),
47                model_id
48            );
49
50            let token = std::env::var("HF_TOKEN")
51                .or_else(|_| std::env::var("HUGGING_FACE_HUB_TOKEN"))
52                .ok();
53
54            let downloader = HfDownloader::new(cache_dir, token)?;
55            let snapshot_path = downloader.download(&model_id, None).await?;
56
57            let format = detect_format(&snapshot_path);
58            if format == ModelFormat::Unknown {
59                return Err(ferrum_types::FerrumError::model(
60                    "Downloaded model has unknown format",
61                ));
62            }
63
64            ResolvedModelSource {
65                original: model_id.clone(),
66                local_path: snapshot_path,
67                format,
68                from_cache: false,
69            }
70        }
71    };
72
73    let model_path = source.local_path.to_string_lossy().to_string();
74    eprintln!("{}", "Using CPU backend".dimmed());
75
76    // Load model definition
77    let mut config_manager = ConfigManager::new();
78    let model_def = config_manager.load_from_path(&source.local_path).await?;
79
80    // Load BERT executor
81    let device = CandleDevice::Cpu;
82    let executor = BertModelExecutor::from_path(&model_path, &model_def, device).await?;
83
84    eprintln!("{}", "Model loaded. Ready for embedding.".green());
85
86    // Load tokenizer
87    let tokenizer = tokenizers::Tokenizer::from_file(source.local_path.join("tokenizer.json"))
88        .map_err(|e| {
89            ferrum_types::FerrumError::model(format!("Failed to load tokenizer: {}", e))
90        })?;
91
92    // Process input text
93    let texts: Vec<String> = if let Some(text) = cmd.text {
94        vec![text]
95    } else {
96        eprintln!(
97            "{}",
98            "Reading text from stdin (one text per line, Ctrl+D to finish):".dimmed()
99        );
100        let stdin = io::stdin();
101        stdin.lock().lines().filter_map(|l| l.ok()).collect()
102    };
103
104    if texts.is_empty() {
105        eprintln!("{}", "No text provided.".yellow());
106        return Ok(());
107    }
108
109    // Generate embeddings for each text
110    let mut all_embeddings = Vec::new();
111
112    for text in &texts {
113        // Tokenize
114        let encoding = tokenizer
115            .encode(text.as_str(), true)
116            .map_err(|e| ferrum_types::FerrumError::model(format!("Tokenization failed: {}", e)))?;
117
118        let token_ids: Vec<u32> = encoding.get_ids().to_vec();
119
120        // Get embeddings
121        let embedding_tensor = executor.get_embeddings(&token_ids)?;
122
123        // Convert to vec
124        let mut embedding = embedding_tensor
125            .flatten_all()
126            .map_err(|e| ferrum_types::FerrumError::model(format!("Flatten failed: {}", e)))?
127            .to_vec1::<f32>()
128            .map_err(|e| ferrum_types::FerrumError::model(format!("to_vec1 failed: {}", e)))?;
129
130        // Normalize if requested
131        if cmd.normalize {
132            let norm: f32 = embedding.iter().map(|x| x * x).sum::<f32>().sqrt();
133            if norm > 0.0 {
134                for v in &mut embedding {
135                    *v /= norm;
136                }
137            }
138        }
139
140        all_embeddings.push((text.clone(), embedding));
141    }
142
143    // Output embeddings
144    match cmd.format.as_str() {
145        "json" => {
146            let output: Vec<serde_json::Value> = all_embeddings
147                .iter()
148                .map(|(text, emb)| {
149                    serde_json::json!({
150                        "text": text,
151                        "embedding": emb,
152                        "dimensions": emb.len()
153                    })
154                })
155                .collect();
156            println!("{}", serde_json::to_string_pretty(&output).unwrap());
157        }
158        "csv" => {
159            if let Some((_, first_emb)) = all_embeddings.first() {
160                let header: Vec<String> =
161                    (0..first_emb.len()).map(|i| format!("dim_{}", i)).collect();
162                println!("text,{}", header.join(","));
163            }
164            for (text, emb) in &all_embeddings {
165                let emb_str: Vec<String> = emb.iter().map(|v| format!("{:.6}", v)).collect();
166                println!("\"{}\",{}", text.replace("\"", "\\\""), emb_str.join(","));
167            }
168        }
169        "raw" => {
170            for (text, emb) in &all_embeddings {
171                eprintln!("{}: {} dimensions", text.dimmed(), emb.len());
172                let preview: Vec<String> =
173                    emb.iter().take(5).map(|v| format!("{:.4}", v)).collect();
174                println!("[{}, ...]", preview.join(", "));
175            }
176        }
177        _ => {
178            eprintln!("{}", format!("Unknown format: {}", cmd.format).red());
179        }
180    }
181
182    Ok(())
183}
184
185fn find_cached_model(cache_dir: &PathBuf, model_id: &str) -> Option<ResolvedModelSource> {
186    // HuggingFace cache structure: hub/models--Org--ModelName/snapshots/<hash>/
187    let hub_dir = cache_dir.join("hub");
188    let model_dir_name = format!("models--{}", model_id.replace("/", "--"));
189    let model_dir = hub_dir.join(&model_dir_name);
190
191    if model_dir.exists() {
192        // Find the latest snapshot
193        let snapshots_dir = model_dir.join("snapshots");
194        if snapshots_dir.exists() {
195            if let Ok(entries) = std::fs::read_dir(&snapshots_dir) {
196                for entry in entries.filter_map(|e| e.ok()) {
197                    let snapshot_path = entry.path();
198                    if snapshot_path.is_dir() && snapshot_path.join("config.json").exists() {
199                        let format = detect_format(&snapshot_path);
200                        if format != ModelFormat::Unknown {
201                            return Some(ResolvedModelSource {
202                                original: model_id.to_string(),
203                                local_path: snapshot_path,
204                                format,
205                                from_cache: true,
206                            });
207                        }
208                    }
209                }
210            }
211        }
212    }
213
214    // Also check direct path (for models downloaded to custom locations)
215    let direct = cache_dir.join(model_id);
216    if direct.exists() && direct.join("config.json").exists() {
217        let format = detect_format(&direct);
218        if format != ModelFormat::Unknown {
219            return Some(ResolvedModelSource {
220                original: model_id.to_string(),
221                local_path: direct,
222                format,
223                from_cache: true,
224            });
225        }
226    }
227
228    None
229}
230
231fn get_hf_cache_dir(config: &CliConfig) -> PathBuf {
232    if let Ok(hf_home) = std::env::var("HF_HOME") {
233        return PathBuf::from(hf_home);
234    }
235    let configured = shellexpand::tilde(&config.models.download.hf_cache_dir).to_string();
236    PathBuf::from(configured)
237}
238
239fn detect_format(path: &PathBuf) -> ModelFormat {
240    if path.join("model.safetensors").exists() {
241        ModelFormat::SafeTensors
242    } else if std::fs::read_dir(path)
243        .map(|d| {
244            d.filter_map(|e| e.ok()).any(|e| {
245                e.path()
246                    .extension()
247                    .map_or(false, |ext| ext == "safetensors")
248            })
249        })
250        .unwrap_or(false)
251    {
252        ModelFormat::SafeTensors
253    } else if path.join("pytorch_model.bin").exists() {
254        ModelFormat::PyTorchBin
255    } else {
256        ModelFormat::Unknown
257    }
258}