lc/cli/
image.rs

1//! Image generation commands
2
3use anyhow::Result;
4use colored::*;
5use std::fs;
6use std::io::{self, Write};
7use std::path::Path;
8
9/// Handle image generation command
10pub async fn handle(
11    prompt: Vec<String>,
12    model: Option<String>,
13    provider: Option<String>,
14    size: Option<String>,
15    count: Option<u32>,
16    output: Option<String>,
17    debug: bool,
18) -> Result<()> {
19    // Set debug mode if requested
20    if debug {
21        crate::utils::cli_utils::set_debug_mode(true);
22    }
23
24    // Join prompt parts into a single string
25    let prompt_str = prompt.join(" ");
26    if prompt_str.is_empty() {
27        anyhow::bail!("No prompt provided for image generation");
28    }
29
30    let config = crate::config::Config::load()?;
31
32    // Default values
33    let size_str = size.unwrap_or_else(|| "1024x1024".to_string());
34    let count_val = count.unwrap_or(1);
35
36    // Resolve provider and model using the same logic as other commands
37    let (provider_name, model_name) = crate::utils::cli_utils::resolve_model_and_provider(
38        &config,
39        provider,
40        model,
41    )?;
42
43    // Get provider config with authentication from centralized keys
44    let provider_config = config.get_provider_with_auth(&provider_name)?;
45    
46    // Allow either API key or resolved custom auth headers
47    let header_has_resolved_key = provider_config.headers.iter().any(|(k, v)| {
48        let k_l = k.to_lowercase();
49        (k_l.contains("key") || k_l.contains("token") || k_l.contains("auth"))
50            && !v.trim().is_empty()
51            && !v.contains("${api_key}")
52    });
53    if provider_config.api_key.is_none() && !header_has_resolved_key {
54        anyhow::bail!(
55            "No API key configured for provider '{}'. Add one with 'lc keys add {}'",
56            provider_name,
57            provider_name
58        );
59    }
60
61    let mut config_mut = config.clone();
62    let client = crate::core::chat::create_authenticated_client(&mut config_mut, &provider_name).await?;
63
64    // Save config if tokens were updated
65    if config_mut.get_cached_token(&provider_name) != config.get_cached_token(&provider_name) {
66        config_mut.save()?;
67    }
68
69    println!(
70        "{} Generating {} image(s) with prompt: \"{}\"",
71        "🎨".blue(),
72        count_val,
73        prompt_str
74    );
75    println!("{} Model: {}", "🤖".blue(), model_name);
76    println!("{} Provider: {}", "🏭".blue(), provider_name);
77    println!("{} Size: {}", "📐".blue(), size_str);
78
79    // Create image generation request
80    let image_request = crate::core::provider::ImageGenerationRequest {
81        prompt: prompt_str.clone(),
82        model: Some(model_name.clone()),
83        n: Some(count_val),
84        size: Some(size_str.clone()),
85        quality: Some("standard".to_string()),
86        style: None,
87        response_format: Some("url".to_string()),
88    };
89
90    // Generate images
91    print!("{} ", "Generating...".dimmed());
92    io::stdout().flush()?;
93
94    match client.generate_images(&image_request).await {
95        Ok(response) => {
96            print!("\r{}\r", " ".repeat(20)); // Clear "Generating..."
97            println!(
98                "{} Successfully generated {} image(s)!",
99                "✅".green(),
100                response.data.len()
101            );
102
103            // Create output directory if specified
104            let output_dir = if let Some(dir) = output {
105                let path = Path::new(&dir);
106                if !path.exists() {
107                    fs::create_dir_all(path)?;
108                    println!("{} Created output directory: {}", "📁".blue(), dir);
109                }
110                Some(dir)
111            } else {
112                None
113            };
114
115            // Process each generated image
116            for (i, image_data) in response.data.iter().enumerate() {
117                let image_num = i + 1;
118
119                if let Some(url) = &image_data.url {
120                    println!(
121                        "\n{} Image {}/{}",
122                        "🖼️".blue(),
123                        image_num,
124                        response.data.len()
125                    );
126                    println!("   URL: {}", url);
127
128                    if let Some(revised_prompt) = &image_data.revised_prompt {
129                        if revised_prompt != &prompt_str {
130                            println!("   Revised prompt: {}", revised_prompt.dimmed());
131                        }
132                    }
133
134                    // Download image if output directory is specified
135                    if let Some(ref dir) = output_dir {
136                        let filename = format!(
137                            "image_{}_{}.png",
138                            chrono::Utc::now().format("%Y%m%d_%H%M%S"),
139                            image_num
140                        );
141                        let filepath = Path::new(dir).join(&filename);
142
143                        match download_image(url, &filepath).await {
144                            Ok(_) => {
145                                println!("   {} Saved to: {}", "💾".green(), filepath.display());
146                            }
147                            Err(e) => {
148                                eprintln!("   {} Failed to download image: {}", "❌".red(), e);
149                            }
150                        }
151                    }
152                } else if let Some(b64_data) = &image_data.b64_json {
153                    println!(
154                        "\n{} Image {}/{} (Base64)",
155                        "🖼️".blue(),
156                        image_num,
157                        response.data.len()
158                    );
159
160                    // For base64 data, always save to a file (either specified output dir or current dir)
161                    let save_dir = output_dir.as_deref().unwrap_or(".");
162                    let filename = format!(
163                        "image_{}_{}.png",
164                        chrono::Utc::now().format("%Y%m%d_%H%M%S"),
165                        image_num
166                    );
167                    let filepath = Path::new(save_dir).join(&filename);
168
169                    match save_base64_image(b64_data, &filepath) {
170                        Ok(_) => {
171                            println!("   {} Saved to: {}", "💾".green(), filepath.display());
172                        }
173                        Err(e) => {
174                            eprintln!("   {} Failed to save image: {}", "❌".red(), e);
175                        }
176                    }
177
178                    if let Some(revised_prompt) = &image_data.revised_prompt {
179                        if revised_prompt != &prompt_str {
180                            println!("   Revised prompt: {}", revised_prompt.dimmed());
181                        }
182                    }
183                }
184            }
185
186            if output_dir.is_none() {
187                // Check if we had any URL-based images that weren't downloaded
188                let has_url_images = response.data.iter().any(|img| img.url.is_some());
189                if has_url_images {
190                    println!(
191                        "\n{} Use --output <directory> to automatically download URL-based images",
192                        "💡".yellow()
193                    );
194                }
195            }
196        }
197        Err(e) => {
198            print!("\r{}\r", " ".repeat(20)); // Clear "Generating..."
199            anyhow::bail!("Failed to generate images: {}", e);
200        }
201    }
202
203    Ok(())
204}
205
206// Helper function to download image from URL
207async fn download_image(url: &str, filepath: &std::path::Path) -> Result<()> {
208    let response = reqwest::get(url).await?;
209
210    if !response.status().is_success() {
211        anyhow::bail!("Failed to download image: HTTP {}", response.status());
212    }
213
214    let bytes = response.bytes().await?;
215    std::fs::write(filepath, bytes)?;
216
217    Ok(())
218}
219
220// Helper function to save base64 image data
221fn save_base64_image(b64_data: &str, filepath: &std::path::Path) -> Result<()> {
222    use base64::{engine::general_purpose, Engine as _};
223
224    let image_bytes = general_purpose::STANDARD.decode(b64_data)?;
225    std::fs::write(filepath, image_bytes)?;
226
227    Ok(())
228}