lc/cli/
embed.rs

1//! Embedding commands implementation
2
3use anyhow::Result;
4use colored::*;
5
6use crate::chat;
7use crate::config;
8use crate::data::vector_db::{FileProcessor, VectorDatabase};
9use crate::cli::set_debug_mode;
10use crate::provider::EmbeddingRequest;
11use crate::utils::resolve_model_and_provider;
12
13/// Handle embed command
14pub async fn handle_embed_command(
15    model: String,
16    provider: Option<String>,
17    database: Option<String>,
18    files: Vec<String>,
19    text: Option<String>,
20    debug: bool,
21) -> Result<()> {
22    // Set debug mode if requested
23    if debug {
24        set_debug_mode(true);
25    }
26
27    // Validate input: either text or files must be provided
28    if text.is_none() && files.is_empty() {
29        anyhow::bail!("Either text or files must be provided for embedding");
30    }
31
32    let config = config::Config::load()?;
33
34    // Resolve provider and model using the same logic as direct prompts
35    let (provider_name, resolved_model) =
36        resolve_model_and_provider(&config, provider, Some(model))?;
37
38    // Get provider config with authentication from centralized keys
39    let provider_config = config.get_provider_with_auth(&provider_name)?;
40
41    // Allow either API key or resolved custom auth headers (e.g., x-goog-api-key)
42    let header_has_resolved_key = provider_config.headers.iter().any(|(k, v)| {
43        let k_l = k.to_lowercase();
44        (k_l.contains("key") || k_l.contains("token") || k_l.contains("auth"))
45            && !v.trim().is_empty()
46            && !v.contains("${api_key}")
47    });
48    if provider_config.api_key.is_none() && !header_has_resolved_key {
49        anyhow::bail!(
50            "No API key configured for provider '{}'. Add one with 'lc keys add {}'",
51            provider_name,
52            provider_name
53        );
54    }
55
56    let mut config_mut = config.clone();
57    let client = chat::create_authenticated_client(&mut config_mut, &provider_name).await?;
58
59    // Save config if tokens were updated
60    if config_mut.get_cached_token(&provider_name) != config.get_cached_token(&provider_name) {
61        config_mut.save()?;
62    }
63
64    println!("{} Starting embedding process...", "🔄".blue());
65    println!("{} Model: {}", "📊".blue(), resolved_model);
66    println!("{} Provider: {}", "đŸĸ".blue(), provider_name);
67
68    let mut total_embeddings = 0;
69    let mut total_tokens = 0;
70
71    // Process files if provided
72    if !files.is_empty() {
73        println!("{} Processing files with glob patterns...", "📁".blue());
74
75        // Expand file patterns and filter for text files
76        let file_paths = FileProcessor::expand_file_patterns(&files)?;
77
78        if file_paths.is_empty() {
79            println!(
80                "{} No text files found matching the patterns",
81                "âš ī¸".yellow()
82            );
83        } else {
84            println!(
85                "{} Found {} text files to process",
86                "✅".green(),
87                file_paths.len()
88            );
89
90            for file_path in file_paths {
91                println!("\n{} Processing file: {}", "📄".blue(), file_path.display());
92
93                // Read and chunk the file
94                match FileProcessor::process_file(&file_path) {
95                    Ok(chunks) => {
96                        println!("{} Split into {} chunks", "âœ‚ī¸".blue(), chunks.len());
97
98                        // Process each chunk
99                        for (chunk_index, chunk) in chunks.iter().enumerate() {
100                            let embedding_request = EmbeddingRequest {
101                                model: resolved_model.clone(),
102                                input: chunk.clone(),
103                                encoding_format: Some("float".to_string()),
104                            };
105
106                            match client.embeddings(&embedding_request).await {
107                                Ok(response) => {
108                                    if let Some(embedding_data) = response.data.first() {
109                                        total_embeddings += 1;
110                                        total_tokens += response.usage.total_tokens;
111
112                                        // Store in vector database if specified
113                                        if let Some(db_name) = &database {
114                                            match VectorDatabase::new(db_name) {
115                                                Ok(vector_db) => {
116                                                    let file_path_str = file_path.to_string_lossy();
117                                                    match vector_db.add_vector_with_metadata(
118                                                        chunk,
119                                                        &embedding_data.embedding,
120                                                        &resolved_model,
121                                                        &provider_name,
122                                                        Some(&file_path_str),
123                                                        Some(chunk_index as i32),
124                                                        Some(chunks.len() as i32),
125                                                    ) {
126                                                        Ok(id) => {
127                                                            println!("  {} Chunk {}/{} stored with ID: {}",
128                                                                "💾".green(), chunk_index + 1, chunks.len(), id);
129                                                        }
130                                                        Err(e) => {
131                                                            eprintln!("  Warning: Failed to store chunk {}: {}", chunk_index + 1, e);
132                                                        }
133                                                    }
134                                                }
135                                                Err(e) => {
136                                                    eprintln!("  Warning: Failed to create/open vector database '{}': {}", db_name, e);
137                                                }
138                                            }
139                                        } else {
140                                            // Just show progress without storing
141                                            println!(
142                                                "  {} Chunk {}/{} embedded ({} dimensions)",
143                                                "✅".green(),
144                                                chunk_index + 1,
145                                                chunks.len(),
146                                                embedding_data.embedding.len()
147                                            );
148                                        }
149                                    }
150                                }
151                                Err(e) => {
152                                    eprintln!(
153                                        "  Warning: Failed to embed chunk {}: {}",
154                                        chunk_index + 1,
155                                        e
156                                    );
157                                }
158                            }
159                        }
160                    }
161                    Err(e) => {
162                        eprintln!(
163                            "Warning: Failed to process file '{}': {}",
164                            file_path.display(),
165                            e
166                        );
167                    }
168                }
169            }
170        }
171    }
172
173    // Process text if provided
174    if let Some(text_content) = text {
175        println!("\n{} Processing text input...", "📝".blue());
176        println!(
177            "{} Text: \"{}\"",
178            "📝".blue(),
179            if text_content.len() > 50 {
180                format!("{}...", &text_content[..50])
181            } else {
182                text_content.clone()
183            }
184        );
185
186        let embedding_request = EmbeddingRequest {
187            model: resolved_model.clone(),
188            input: text_content.clone(),
189            encoding_format: Some("float".to_string()),
190        };
191
192        match client.embeddings(&embedding_request).await {
193            Ok(response) => {
194                if let Some(embedding_data) = response.data.first() {
195                    total_embeddings += 1;
196                    total_tokens += response.usage.total_tokens;
197
198                    println!(
199                        "{} Vector dimensions: {}",
200                        "📏".blue(),
201                        embedding_data.embedding.len()
202                    );
203
204                    // Display vector preview
205                    let embedding = &embedding_data.embedding;
206                    if embedding.len() > 10 {
207                        println!("\n{} Vector preview:", "🔍".blue());
208                        print!("  [");
209                        for (i, val) in embedding.iter().take(5).enumerate() {
210                            if i > 0 {
211                                print!(", ");
212                            }
213                            print!("{:.6}", val);
214                        }
215                        print!(" ... ");
216                        for (i, val) in embedding.iter().skip(embedding.len() - 5).enumerate() {
217                            if i > 0 {
218                                print!(", ");
219                            }
220                            print!("{:.6}", val);
221                        }
222                        println!("]");
223                    }
224
225                    // Store in vector database if specified
226                    if let Some(db_name) = &database {
227                        match VectorDatabase::new(db_name) {
228                            Ok(vector_db) => {
229                                match vector_db.add_vector(
230                                    &text_content,
231                                    &embedding,
232                                    &resolved_model,
233                                    &provider_name,
234                                ) {
235                                    Ok(id) => {
236                                        println!(
237                                            "\n{} Stored in vector database '{}' with ID: {}",
238                                            "💾".green(),
239                                            db_name,
240                                            id
241                                        );
242                                    }
243                                    Err(e) => {
244                                        eprintln!(
245                                            "Warning: Failed to store in vector database: {}",
246                                            e
247                                        );
248                                    }
249                                }
250                            }
251                            Err(e) => {
252                                eprintln!(
253                                    "Warning: Failed to create/open vector database '{}': {}",
254                                    db_name, e
255                                );
256                            }
257                        }
258                    }
259
260                    // Output full vector as JSON for programmatic use
261                    if files.is_empty() {
262                        // Only show full vector for single text input
263                        println!("\n{} Full vector (JSON):", "📋".dimmed());
264                        println!("{}", serde_json::to_string(&embedding)?);
265                    }
266                }
267            }
268            Err(e) => {
269                anyhow::bail!("Failed to generate embeddings for text: {}", e);
270            }
271        }
272    }
273
274    // Summary
275    println!("\n{} Embedding process completed!", "🎉".green());
276    println!(
277        "{} Total embeddings generated: {}",
278        "📊".blue(),
279        total_embeddings
280    );
281    println!("{} Total tokens used: {}", "💰".yellow(), total_tokens);
282
283    if let Some(db_name) = &database {
284        println!(
285            "{} All embeddings stored in database: {}",
286            "💾".green(),
287            db_name
288        );
289    }
290
291    Ok(())
292}
293
294/// Handle similar command
295pub async fn handle_similar_command(
296    model: Option<String>,
297    provider: Option<String>,
298    database: String,
299    limit: usize,
300    query: String,
301) -> Result<()> {
302    // Open the vector database
303    let vector_db = VectorDatabase::new(&database)?;
304
305    // Check if database has any vectors
306    let count = vector_db.count()?;
307    if count == 0 {
308        anyhow::bail!(
309            "Vector database '{}' is empty. Add some vectors first using 'lc embed -d {}'",
310            database,
311            database
312        );
313    }
314
315    // Get model info from database if not provided
316    let (resolved_model, resolved_provider) = match (&model, &provider) {
317        (Some(m), Some(p)) => (m.clone(), p.clone()),
318        _ => {
319            if let Some((db_model, db_provider)) = vector_db.get_model_info()? {
320                if model.is_some() || provider.is_some() {
321                    println!(
322                        "{} Using model from database: {}:{}",
323                        "â„šī¸".blue(),
324                        db_provider,
325                        db_model
326                    );
327                }
328                (db_model, db_provider)
329            } else {
330                anyhow::bail!(
331                    "No model specified and database '{}' has no stored model info",
332                    database
333                );
334            }
335        }
336    };
337
338    let config = config::Config::load()?;
339
340    // Resolve provider and model
341    let (provider_name, model_name) =
342        resolve_model_and_provider(&config, Some(resolved_provider), Some(resolved_model))?;
343
344    // Get provider config with authentication from centralized keys
345    let provider_config = config.get_provider_with_auth(&provider_name)?;
346
347    // Allow either API key or resolved custom auth headers (e.g., x-goog-api-key)
348    let header_has_resolved_key = provider_config.headers.iter().any(|(k, v)| {
349        let k_l = k.to_lowercase();
350        (k_l.contains("key") || k_l.contains("token") || k_l.contains("auth"))
351            && !v.trim().is_empty()
352            && !v.contains("${api_key}")
353    });
354    if provider_config.api_key.is_none() && !header_has_resolved_key {
355        anyhow::bail!(
356            "No API key configured for provider '{}'. Add one with 'lc keys add {}'",
357            provider_name,
358            provider_name
359        );
360    }
361
362    let mut config_mut = config.clone();
363    let client = chat::create_authenticated_client(&mut config_mut, &provider_name).await?;
364
365    // Save config if tokens were updated
366    if config_mut.get_cached_token(&provider_name) != config.get_cached_token(&provider_name) {
367        config_mut.save()?;
368    }
369
370    // Generate embedding for query
371    let embedding_request = EmbeddingRequest {
372        model: model_name.clone(),
373        input: query.clone(),
374        encoding_format: Some("float".to_string()),
375    };
376
377    println!("{} Searching for similar content...", "🔍".blue());
378    println!("{} Database: {}", "📊".blue(), database);
379    println!(
380        "{} Query: \"{}\"",
381        "📝".blue(),
382        if query.len() > 50 {
383            format!("{}...", &query[..50])
384        } else {
385            query.clone()
386        }
387    );
388
389    match client.embeddings(&embedding_request).await {
390        Ok(response) => {
391            if let Some(embedding_data) = response.data.first() {
392                let query_vector = &embedding_data.embedding;
393
394                // Find similar vectors
395                let similar_results = vector_db.find_similar(query_vector, limit)?;
396
397                if similar_results.is_empty() {
398                    println!(
399                        "\n{} No similar content found in database '{}'",
400                        "❌".red(),
401                        database
402                    );
403                } else {
404                    println!(
405                        "\n{} Found {} similar results:",
406                        "✅".green(),
407                        similar_results.len()
408                    );
409
410                    for (i, (entry, similarity)) in similar_results.iter().enumerate() {
411                        let similarity_percent = (similarity * 100.0).round() as u32;
412                        let similarity_color = if similarity_percent >= 80 {
413                            format!("{}%", similarity_percent).green()
414                        } else if similarity_percent >= 60 {
415                            format!("{}%", similarity_percent).yellow()
416                        } else {
417                            format!("{}%", similarity_percent).red()
418                        };
419
420                        println!(
421                            "\n{} {} (Similarity: {})",
422                            format!("{}.", i + 1).bold(),
423                            similarity_color,
424                            format!("ID: {}", entry.id).dimmed()
425                        );
426                        println!("   {}", entry.text);
427                        println!(
428                            "   {}",
429                            format!(
430                                "Added: {}",
431                                entry.created_at.format("%Y-%m-%d %H:%M:%S UTC")
432                            )
433                            .dimmed()
434                        );
435                    }
436                }
437            } else {
438                anyhow::bail!("No embedding data in response");
439            }
440        }
441        Err(e) => {
442            anyhow::bail!("Failed to generate query embedding: {}", e);
443        }
444    }
445
446    Ok(())
447}
448
449/// RAG helper function to retrieve relevant context
450pub async fn retrieve_rag_context(
451    db_name: &str,
452    query: &str,
453    _client: &crate::chat::LLMClient,
454    _model: &str,
455    _provider: &str,
456) -> Result<String> {
457    crate::debug_log!(
458        "RAG: Starting context retrieval for database '{}' with query '{}'",
459        db_name,
460        query
461    );
462
463    // Open the vector database
464    let vector_db = VectorDatabase::new(db_name)?;
465    crate::debug_log!("RAG: Successfully opened vector database '{}'", db_name);
466
467    // Check if database has any vectors
468    let count = vector_db.count()?;
469    crate::debug_log!("RAG: Database '{}' contains {} vectors", db_name, count);
470    if count == 0 {
471        crate::debug_log!("RAG: Database is empty, returning empty context");
472        return Ok(String::new());
473    }
474
475    // Get model info from database
476    let (db_model, db_provider) = if let Some((m, p)) = vector_db.get_model_info()? {
477        crate::debug_log!("RAG: Using database model '{}' from provider '{}'", m, p);
478        (m, p)
479    } else {
480        crate::debug_log!("RAG: No model info in database, returning empty context");
481        return Ok(String::new());
482    };
483
484    // Create a client for the embedding provider (not the chat provider)
485    let config = config::Config::load()?;
486    let mut config_mut = config.clone();
487    let embedding_client = chat::create_authenticated_client(&mut config_mut, &db_provider).await?;
488    crate::debug_log!(
489        "RAG: Created embedding client for provider '{}'",
490        db_provider
491    );
492
493    // Use the database's embedding model for consistency
494    let embedding_request = EmbeddingRequest {
495        model: db_model.clone(),
496        input: query.to_string(),
497        encoding_format: Some("float".to_string()),
498    };
499
500    crate::debug_log!(
501        "RAG: Generating embedding for query using model '{}'",
502        db_model
503    );
504
505    // Generate embedding for query using the correct provider
506    let response = embedding_client.embeddings(&embedding_request).await?;
507    crate::debug_log!("RAG: Successfully generated embedding for query");
508
509    if let Some(embedding_data) = response.data.first() {
510        let query_vector = &embedding_data.embedding;
511        crate::debug_log!("RAG: Query vector has {} dimensions", query_vector.len());
512
513        // Find top 3 most similar vectors for context
514        let similar_results = vector_db.find_similar(query_vector, 3)?;
515        crate::debug_log!("RAG: Found {} similar results", similar_results.len());
516
517        if similar_results.is_empty() {
518            crate::debug_log!("RAG: No similar results found, returning empty context");
519            return Ok(String::new());
520        }
521
522        // Format context
523        let mut context = String::new();
524        let mut included_count = 0;
525        for (entry, similarity) in similar_results {
526            crate::debug_log!(
527                "RAG: Result similarity: {:.3} for text: '{}'",
528                similarity,
529                &entry.text[..50.min(entry.text.len())]
530            );
531            // Only include results with reasonable similarity (>0.3)
532            if similarity > 0.3 {
533                context.push_str(&format!("- {}\n", entry.text));
534                included_count += 1;
535            }
536        }
537
538        crate::debug_log!(
539            "RAG: Included {} results in context (similarity > 0.3)",
540            included_count
541        );
542        crate::debug_log!("RAG: Final context length: {} characters", context.len());
543
544        Ok(context)
545    } else {
546        crate::debug_log!("RAG: No embedding data in response, returning empty context");
547        Ok(String::new())
548    }
549}