ferrum_cli/commands/
embed.rs1use crate::config::CliConfig;
4use candle_core::Device as CandleDevice;
5use clap::Args;
6use colored::Colorize;
7use ferrum_models::source::{ModelFormat, ResolvedModelSource};
8use ferrum_models::HfDownloader;
9use ferrum_models::{BertModelExecutor, ConfigManager};
10use ferrum_types::Result;
11use std::io::{self, BufRead};
12use std::path::PathBuf;
13
14#[derive(Args, Debug)]
16pub struct EmbedCommand {
17 #[arg(required = true)]
19 pub model: String,
20
21 #[arg(short, long)]
23 pub text: Option<String>,
24
25 #[arg(short, long, default_value = "json")]
27 pub format: String,
28
29 #[arg(short, long, default_value = "true")]
31 pub normalize: bool,
32}
33
34pub async fn execute(cmd: EmbedCommand, config: CliConfig) -> Result<()> {
35 eprintln!("{}", format!("Loading {}...", cmd.model).dimmed());
36
37 let model_id = cmd.model.clone();
39 let cache_dir = get_hf_cache_dir(&config);
40
41 let source = match find_cached_model(&cache_dir, &model_id) {
42 Some(source) => source,
43 None => {
44 eprintln!(
45 "{} Model '{}' not found locally, downloading...",
46 "📥".cyan(),
47 model_id
48 );
49
50 let token = std::env::var("HF_TOKEN")
51 .or_else(|_| std::env::var("HUGGING_FACE_HUB_TOKEN"))
52 .ok();
53
54 let downloader = HfDownloader::new(cache_dir, token)?;
55 let snapshot_path = downloader.download(&model_id, None).await?;
56
57 let format = detect_format(&snapshot_path);
58 if format == ModelFormat::Unknown {
59 return Err(ferrum_types::FerrumError::model(
60 "Downloaded model has unknown format",
61 ));
62 }
63
64 ResolvedModelSource {
65 original: model_id.clone(),
66 local_path: snapshot_path,
67 format,
68 from_cache: false,
69 }
70 }
71 };
72
73 let model_path = source.local_path.to_string_lossy().to_string();
74 eprintln!("{}", "Using CPU backend".dimmed());
75
76 let mut config_manager = ConfigManager::new();
78 let model_def = config_manager.load_from_path(&source.local_path).await?;
79
80 let device = CandleDevice::Cpu;
82 let executor = BertModelExecutor::from_path(&model_path, &model_def, device).await?;
83
84 eprintln!("{}", "Model loaded. Ready for embedding.".green());
85
86 let tokenizer = tokenizers::Tokenizer::from_file(source.local_path.join("tokenizer.json"))
88 .map_err(|e| {
89 ferrum_types::FerrumError::model(format!("Failed to load tokenizer: {}", e))
90 })?;
91
92 let texts: Vec<String> = if let Some(text) = cmd.text {
94 vec![text]
95 } else {
96 eprintln!(
97 "{}",
98 "Reading text from stdin (one text per line, Ctrl+D to finish):".dimmed()
99 );
100 let stdin = io::stdin();
101 stdin.lock().lines().filter_map(|l| l.ok()).collect()
102 };
103
104 if texts.is_empty() {
105 eprintln!("{}", "No text provided.".yellow());
106 return Ok(());
107 }
108
109 let mut all_embeddings = Vec::new();
111
112 for text in &texts {
113 let encoding = tokenizer
115 .encode(text.as_str(), true)
116 .map_err(|e| ferrum_types::FerrumError::model(format!("Tokenization failed: {}", e)))?;
117
118 let token_ids: Vec<u32> = encoding.get_ids().to_vec();
119
120 let embedding_tensor = executor.get_embeddings(&token_ids)?;
122
123 let mut embedding = embedding_tensor
125 .flatten_all()
126 .map_err(|e| ferrum_types::FerrumError::model(format!("Flatten failed: {}", e)))?
127 .to_vec1::<f32>()
128 .map_err(|e| ferrum_types::FerrumError::model(format!("to_vec1 failed: {}", e)))?;
129
130 if cmd.normalize {
132 let norm: f32 = embedding.iter().map(|x| x * x).sum::<f32>().sqrt();
133 if norm > 0.0 {
134 for v in &mut embedding {
135 *v /= norm;
136 }
137 }
138 }
139
140 all_embeddings.push((text.clone(), embedding));
141 }
142
143 match cmd.format.as_str() {
145 "json" => {
146 let output: Vec<serde_json::Value> = all_embeddings
147 .iter()
148 .map(|(text, emb)| {
149 serde_json::json!({
150 "text": text,
151 "embedding": emb,
152 "dimensions": emb.len()
153 })
154 })
155 .collect();
156 println!("{}", serde_json::to_string_pretty(&output).unwrap());
157 }
158 "csv" => {
159 if let Some((_, first_emb)) = all_embeddings.first() {
160 let header: Vec<String> =
161 (0..first_emb.len()).map(|i| format!("dim_{}", i)).collect();
162 println!("text,{}", header.join(","));
163 }
164 for (text, emb) in &all_embeddings {
165 let emb_str: Vec<String> = emb.iter().map(|v| format!("{:.6}", v)).collect();
166 println!("\"{}\",{}", text.replace("\"", "\\\""), emb_str.join(","));
167 }
168 }
169 "raw" => {
170 for (text, emb) in &all_embeddings {
171 eprintln!("{}: {} dimensions", text.dimmed(), emb.len());
172 let preview: Vec<String> =
173 emb.iter().take(5).map(|v| format!("{:.4}", v)).collect();
174 println!("[{}, ...]", preview.join(", "));
175 }
176 }
177 _ => {
178 eprintln!("{}", format!("Unknown format: {}", cmd.format).red());
179 }
180 }
181
182 Ok(())
183}
184
185fn find_cached_model(cache_dir: &PathBuf, model_id: &str) -> Option<ResolvedModelSource> {
186 let hub_dir = cache_dir.join("hub");
188 let model_dir_name = format!("models--{}", model_id.replace("/", "--"));
189 let model_dir = hub_dir.join(&model_dir_name);
190
191 if model_dir.exists() {
192 let snapshots_dir = model_dir.join("snapshots");
194 if snapshots_dir.exists() {
195 if let Ok(entries) = std::fs::read_dir(&snapshots_dir) {
196 for entry in entries.filter_map(|e| e.ok()) {
197 let snapshot_path = entry.path();
198 if snapshot_path.is_dir() && snapshot_path.join("config.json").exists() {
199 let format = detect_format(&snapshot_path);
200 if format != ModelFormat::Unknown {
201 return Some(ResolvedModelSource {
202 original: model_id.to_string(),
203 local_path: snapshot_path,
204 format,
205 from_cache: true,
206 });
207 }
208 }
209 }
210 }
211 }
212 }
213
214 let direct = cache_dir.join(model_id);
216 if direct.exists() && direct.join("config.json").exists() {
217 let format = detect_format(&direct);
218 if format != ModelFormat::Unknown {
219 return Some(ResolvedModelSource {
220 original: model_id.to_string(),
221 local_path: direct,
222 format,
223 from_cache: true,
224 });
225 }
226 }
227
228 None
229}
230
231fn get_hf_cache_dir(config: &CliConfig) -> PathBuf {
232 if let Ok(hf_home) = std::env::var("HF_HOME") {
233 return PathBuf::from(hf_home);
234 }
235 let configured = shellexpand::tilde(&config.models.download.hf_cache_dir).to_string();
236 PathBuf::from(configured)
237}
238
239fn detect_format(path: &PathBuf) -> ModelFormat {
240 if path.join("model.safetensors").exists() {
241 ModelFormat::SafeTensors
242 } else if std::fs::read_dir(path)
243 .map(|d| {
244 d.filter_map(|e| e.ok()).any(|e| {
245 e.path()
246 .extension()
247 .map_or(false, |ext| ext == "safetensors")
248 })
249 })
250 .unwrap_or(false)
251 {
252 ModelFormat::SafeTensors
253 } else if path.join("pytorch_model.bin").exists() {
254 ModelFormat::PyTorchBin
255 } else {
256 ModelFormat::Unknown
257 }
258}