Skip to main content

mold_core/
expand.rs

1//! LLM-powered prompt expansion.
2//!
3//! Provides a `PromptExpander` trait with two backends:
4//! - `ApiExpander`: calls any OpenAI-compatible `/v1/chat/completions` endpoint
5//! - Local GGUF inference (in `mold-inference`, behind the `expand` feature flag)
6
7use std::collections::HashMap;
8
9use anyhow::Result;
10use serde::{Deserialize, Serialize};
11
12use crate::expand_prompts::{build_batch_messages, build_single_messages};
13
14/// Maximum number of prompt variations the server API accepts.
15pub const MAX_VARIATIONS: usize = 10;
16
17/// Maximum number of prompt variations for Discord (embed character limit).
18pub const DISCORD_MAX_VARIATIONS: usize = 5;
19
20/// Per-family word limit and style notes override.
21#[derive(Debug, Clone, Deserialize, Serialize)]
22pub struct FamilyOverride {
23    /// Word limit for expanded prompts (overrides built-in default for this family).
24    pub word_limit: Option<u32>,
25    /// Style notes injected into the system prompt (overrides built-in default).
26    pub style_notes: Option<String>,
27}
28
29/// Configuration for a prompt expansion request.
30#[derive(Debug, Clone)]
31pub struct ExpandConfig {
32    /// Diffusion model family (e.g. "flux", "sd15", "sdxl").
33    pub model_family: String,
34    /// Number of prompt variations to generate (1 = single expansion).
35    pub variations: usize,
36    /// Sampling temperature (0.0-2.0). Higher = more creative.
37    pub temperature: f64,
38    /// Nucleus sampling threshold.
39    pub top_p: f64,
40    /// Maximum tokens for the LLM response.
41    pub max_tokens: u32,
42    /// Enable Qwen3 thinking mode for higher quality (slower).
43    pub thinking: bool,
44    /// Custom single-expansion system prompt template (overrides built-in).
45    /// Placeholders: `{WORD_LIMIT}`, `{MODEL_NOTES}`
46    pub system_prompt: Option<String>,
47    /// Custom batch-variation system prompt template (overrides built-in).
48    /// Placeholders: `{N}`, `{WORD_LIMIT}`, `{MODEL_NOTES}`
49    pub batch_prompt: Option<String>,
50    /// Per-family overrides for word limits and style notes.
51    pub family_overrides: HashMap<String, FamilyOverride>,
52}
53
54impl Default for ExpandConfig {
55    fn default() -> Self {
56        Self {
57            model_family: "flux".to_string(),
58            variations: 1,
59            temperature: 0.7,
60            top_p: 0.9,
61            max_tokens: 300,
62            thinking: false,
63            system_prompt: None,
64            batch_prompt: None,
65            family_overrides: HashMap::new(),
66        }
67    }
68}
69
70/// Result of a prompt expansion.
71#[derive(Debug, Clone)]
72pub struct ExpandResult {
73    /// The original user prompt.
74    pub original: String,
75    /// Expanded prompt(s). Length equals `ExpandConfig::variations`.
76    pub expanded: Vec<String>,
77}
78
79/// Trait for prompt expansion backends.
80pub trait PromptExpander: Send + Sync {
81    /// Expand a user prompt into one or more detailed image generation prompts.
82    fn expand(&self, prompt: &str, config: &ExpandConfig) -> Result<ExpandResult>;
83}
84
85// ── API expander ─────────────────────────────────────────────────────────────
86
87/// OpenAI-compatible chat completion message.
88#[derive(Debug, Serialize, Deserialize)]
89struct ChatMessage {
90    role: String,
91    content: String,
92}
93
94/// Request body for `/v1/chat/completions`.
95#[derive(Debug, Serialize)]
96struct ChatCompletionRequest {
97    model: String,
98    messages: Vec<ChatMessage>,
99    temperature: f64,
100    top_p: f64,
101    max_tokens: u32,
102    /// Qwen3/vLLM thinking mode — only serialized when `true`.
103    /// NOTE: this is a non-standard extension; strict OpenAI-compatible
104    /// endpoints may reject it. Only enable when the backend supports it
105    /// (e.g. Ollama, vLLM with Qwen3 models).
106    #[serde(skip_serializing_if = "std::ops::Not::not")]
107    enable_thinking: bool,
108}
109
110/// Response from `/v1/chat/completions`.
111#[derive(Debug, Deserialize)]
112struct ChatCompletionResponse {
113    choices: Vec<ChatChoice>,
114}
115
116#[derive(Debug, Deserialize)]
117struct ChatChoice {
118    message: ChatMessageResponse,
119}
120
121#[derive(Debug, Deserialize)]
122struct ChatMessageResponse {
123    content: String,
124}
125
126/// Expander that calls an OpenAI-compatible API endpoint.
127pub struct ApiExpander {
128    endpoint: String,
129    model: String,
130}
131
132impl ApiExpander {
133    pub fn new(endpoint: &str, model: &str) -> Self {
134        // Strip trailing slash for consistent URL building
135        let endpoint = endpoint.trim_end_matches('/').to_string();
136        Self {
137            endpoint,
138            model: model.to_string(),
139        }
140    }
141}
142
143impl PromptExpander for ApiExpander {
144    fn expand(&self, prompt: &str, config: &ExpandConfig) -> Result<ExpandResult> {
145        let family_override = config.family_overrides.get(&config.model_family);
146        let messages = if config.variations > 1 {
147            build_batch_messages(
148                prompt,
149                &config.model_family,
150                config.variations,
151                config.batch_prompt.as_deref(),
152                family_override,
153            )
154        } else {
155            build_single_messages(
156                prompt,
157                &config.model_family,
158                config.system_prompt.as_deref(),
159                family_override,
160            )
161        };
162
163        let chat_messages: Vec<ChatMessage> = messages
164            .into_iter()
165            .map(|(role, content)| ChatMessage { role, content })
166            .collect();
167
168        let req_body = ChatCompletionRequest {
169            model: self.model.clone(),
170            messages: chat_messages,
171            temperature: config.temperature,
172            top_p: config.top_p,
173            max_tokens: config.max_tokens,
174            enable_thinking: config.thinking,
175        };
176
177        let url = format!("{}/v1/chat/completions", self.endpoint);
178
179        // Use ureq (blocking HTTP) — this trait method is sync and may be
180        // called from within a tokio runtime via spawn_blocking, so we cannot
181        // use async reqwest or Handle::block_on (which panics inside a runtime).
182        let body = serde_json::to_string(&req_body)?;
183        let response_text: String = ureq::post(&url)
184            .header("Content-Type", "application/json")
185            .send(body.as_str())
186            .map_err(|e| anyhow::anyhow!("expand API request failed: {e}"))?
187            .body_mut()
188            .read_to_string()
189            .map_err(|e| anyhow::anyhow!("failed to read expand API response: {e}"))?;
190
191        let completion: ChatCompletionResponse = serde_json::from_str(&response_text)
192            .map_err(|e| anyhow::anyhow!("failed to parse expand API response: {e}"))?;
193
194        let content = completion
195            .choices
196            .first()
197            .map(|c| c.message.content.clone())
198            .filter(|c| !c.trim().is_empty())
199            .ok_or_else(|| {
200                anyhow::anyhow!("expand API returned empty response (no choices or empty content)")
201            })?;
202
203        let expanded = if config.variations > 1 {
204            parse_variations(&content, config.variations)
205        } else {
206            vec![clean_expanded_prompt(&content)]
207        };
208
209        Ok(ExpandResult {
210            original: prompt.to_string(),
211            expanded,
212        })
213    }
214}
215
216/// Public wrapper for `parse_variations` (used by mold-inference local expander).
217pub fn parse_variations_public(text: &str, expected: usize) -> Vec<String> {
218    parse_variations(text, expected)
219}
220
221/// Public wrapper for `clean_expanded_prompt` (used by mold-inference local expander).
222pub fn clean_expanded_prompt_public(text: &str) -> String {
223    clean_expanded_prompt(text)
224}
225
226/// Parse multiple variations from LLM output.
227/// Tries JSON array first, then numbered list, then line-separated.
228fn parse_variations(text: &str, expected: usize) -> Vec<String> {
229    let trimmed = text.trim();
230
231    // Try JSON array
232    if let Ok(arr) = serde_json::from_str::<Vec<String>>(trimmed) {
233        if !arr.is_empty() {
234            return arr.into_iter().map(|s| clean_expanded_prompt(&s)).collect();
235        }
236    }
237
238    // Try to find a JSON array embedded in the text (LLM may include preamble)
239    if let Some(start) = trimmed.find('[') {
240        if let Some(end) = trimmed.rfind(']') {
241            if start < end {
242                let json_slice = &trimmed[start..=end];
243                if let Ok(arr) = serde_json::from_str::<Vec<String>>(json_slice) {
244                    if !arr.is_empty() {
245                        return arr.into_iter().map(|s| clean_expanded_prompt(&s)).collect();
246                    }
247                }
248            }
249        }
250    }
251
252    // Fall back to numbered list parsing (1. ... 2. ... etc.)
253    let lines: Vec<String> = trimmed
254        .lines()
255        .map(|l| l.trim())
256        .filter(|l| !l.is_empty())
257        .map(|l| {
258            // Strip numbered prefix: "1. ", "2) ", etc.
259            let stripped = l
260                .trim_start_matches(|c: char| c.is_ascii_digit())
261                .trim_start_matches(['.', ')', ':', '-'])
262                .trim_start_matches('"')
263                .trim_end_matches('"')
264                .trim();
265            clean_expanded_prompt(stripped)
266        })
267        .filter(|l| !l.is_empty())
268        .collect();
269
270    if lines.len() >= expected {
271        return lines;
272    }
273
274    // Last resort: split on double newlines
275    let paragraphs: Vec<String> = trimmed
276        .split("\n\n")
277        .map(|p| clean_expanded_prompt(p.trim()))
278        .filter(|p| !p.is_empty())
279        .collect();
280
281    if !paragraphs.is_empty() {
282        return paragraphs;
283    }
284
285    // Ultimate fallback: return the whole text as a single variation
286    vec![clean_expanded_prompt(trimmed)]
287}
288
289/// Clean up an expanded prompt: trim whitespace, remove quotes, collapse whitespace.
290fn clean_expanded_prompt(text: &str) -> String {
291    let trimmed = text.trim().trim_matches('"').trim_matches('\'').trim();
292
293    // Strip any thinking block if present
294    let cleaned = if let Some(end_idx) = trimmed.find("</think>") {
295        trimmed[end_idx + "</think>".len()..].trim()
296    } else {
297        trimmed
298    };
299
300    // Collapse multiple whitespace/newlines into single spaces
301    cleaned.split_whitespace().collect::<Vec<_>>().join(" ")
302}
303
304/// Expand configuration from the mold config file.
305#[derive(Debug, Clone, Deserialize, Serialize)]
306pub struct ExpandSettings {
307    /// Enable prompt expansion by default (overridden by --expand/--no-expand).
308    #[serde(default)]
309    pub enabled: bool,
310    /// Backend: "local" for built-in GGUF inference, or a URL for OpenAI-compatible API.
311    #[serde(default = "default_backend")]
312    pub backend: String,
313    /// Model name for local GGUF expansion.
314    #[serde(default = "default_expand_model")]
315    pub model: String,
316    /// Model name when using API backend (e.g. "qwen2.5:3b" for Ollama).
317    #[serde(default = "default_api_model")]
318    pub api_model: String,
319    /// Sampling temperature.
320    #[serde(default = "default_temperature")]
321    pub temperature: f64,
322    /// Nucleus sampling threshold.
323    #[serde(default = "default_top_p")]
324    pub top_p: f64,
325    /// Maximum tokens for the LLM response.
326    #[serde(default = "default_max_tokens")]
327    pub max_tokens: u32,
328    /// Enable thinking mode for Qwen3 (higher quality, slower).
329    #[serde(default)]
330    pub thinking: bool,
331    /// Custom single-expansion system prompt template.
332    /// Available placeholders: `{WORD_LIMIT}`, `{MODEL_NOTES}`
333    #[serde(default)]
334    pub system_prompt: Option<String>,
335    /// Custom batch-variation system prompt template.
336    /// Available placeholders: `{N}`, `{WORD_LIMIT}`, `{MODEL_NOTES}`
337    #[serde(default)]
338    pub batch_prompt: Option<String>,
339    /// Per-family word limit and style notes overrides.
340    #[serde(default)]
341    pub families: HashMap<String, FamilyOverride>,
342}
343
344fn default_backend() -> String {
345    "local".to_string()
346}
347
348fn default_expand_model() -> String {
349    "qwen3-expand:q8".to_string()
350}
351
352fn default_api_model() -> String {
353    "qwen2.5:3b".to_string()
354}
355
356fn default_temperature() -> f64 {
357    0.7
358}
359
360fn default_top_p() -> f64 {
361    0.9
362}
363
364fn default_max_tokens() -> u32 {
365    300
366}
367
368impl Default for ExpandSettings {
369    fn default() -> Self {
370        Self {
371            enabled: false,
372            backend: default_backend(),
373            model: default_expand_model(),
374            api_model: default_api_model(),
375            temperature: default_temperature(),
376            top_p: default_top_p(),
377            max_tokens: default_max_tokens(),
378            thinking: false,
379            system_prompt: None,
380            batch_prompt: None,
381            families: HashMap::new(),
382        }
383    }
384}
385
386impl ExpandSettings {
387    /// Load from environment variables, falling back to provided defaults.
388    pub fn with_env_overrides(mut self) -> Self {
389        if let Ok(v) = std::env::var("MOLD_EXPAND") {
390            self.enabled = matches!(v.trim().to_lowercase().as_str(), "1" | "true" | "yes");
391        }
392        if let Ok(v) = std::env::var("MOLD_EXPAND_BACKEND") {
393            if !v.is_empty() {
394                self.backend = v;
395            }
396        }
397        if let Ok(v) = std::env::var("MOLD_EXPAND_MODEL") {
398            if !v.is_empty() {
399                if self.is_local() {
400                    self.model = v;
401                } else {
402                    self.api_model = v;
403                }
404            }
405        }
406        if let Ok(v) = std::env::var("MOLD_EXPAND_TEMPERATURE") {
407            if let Ok(t) = v.parse::<f64>() {
408                self.temperature = t;
409            }
410        }
411        if let Ok(v) = std::env::var("MOLD_EXPAND_THINKING") {
412            self.thinking = matches!(v.trim().to_lowercase().as_str(), "1" | "true" | "yes");
413        }
414        if let Ok(v) = std::env::var("MOLD_EXPAND_SYSTEM_PROMPT") {
415            if !v.is_empty() {
416                self.system_prompt = Some(v);
417            }
418        }
419        if let Ok(v) = std::env::var("MOLD_EXPAND_BATCH_PROMPT") {
420            if !v.is_empty() {
421                self.batch_prompt = Some(v);
422            }
423        }
424        self
425    }
426
427    /// Build an `ExpandConfig` for a specific request.
428    pub fn to_expand_config(&self, model_family: &str, variations: usize) -> ExpandConfig {
429        ExpandConfig {
430            model_family: model_family.to_string(),
431            variations,
432            temperature: self.temperature,
433            top_p: self.top_p,
434            max_tokens: self.max_tokens,
435            thinking: self.thinking,
436            system_prompt: self.system_prompt.clone(),
437            batch_prompt: self.batch_prompt.clone(),
438            family_overrides: self.families.clone(),
439        }
440    }
441
442    /// Validate that custom templates contain expected placeholders.
443    /// Returns a list of warnings (empty = valid). Callers should treat
444    /// these as non-fatal hints — expansion still runs with partial templates.
445    pub fn validate_templates(&self) -> Vec<String> {
446        let mut warnings = Vec::new();
447        if let Some(ref tmpl) = self.system_prompt {
448            for placeholder in ["{WORD_LIMIT}", "{MODEL_NOTES}"] {
449                if !tmpl.contains(placeholder) {
450                    warnings.push(format!(
451                        "system_prompt is missing placeholder {placeholder} — it won't be substituted"
452                    ));
453                }
454            }
455        }
456        if let Some(ref tmpl) = self.batch_prompt {
457            for placeholder in ["{N}", "{WORD_LIMIT}", "{MODEL_NOTES}"] {
458                if !tmpl.contains(placeholder) {
459                    warnings.push(format!(
460                        "batch_prompt is missing placeholder {placeholder} — it won't be substituted"
461                    ));
462                }
463            }
464        }
465        warnings
466    }
467
468    /// Create the appropriate expander backend.
469    /// Returns `None` if the backend is "local" (handled by mold-inference).
470    pub fn create_api_expander(&self) -> Option<ApiExpander> {
471        if self.backend == "local" {
472            None
473        } else {
474            Some(ApiExpander::new(&self.backend, &self.api_model))
475        }
476    }
477
478    /// Check if this is configured for local (GGUF) expansion.
479    pub fn is_local(&self) -> bool {
480        self.backend == "local"
481    }
482}
483
484#[cfg(test)]
485mod tests {
486    use super::*;
487
488    // ── clean_expanded_prompt ────────────────────────────────────────────
489
490    #[test]
491    fn clean_prompt_strips_quotes() {
492        assert_eq!(clean_expanded_prompt("\"a cat on mars\""), "a cat on mars");
493    }
494
495    #[test]
496    fn clean_prompt_strips_single_quotes() {
497        assert_eq!(clean_expanded_prompt("'a cat on mars'"), "a cat on mars");
498    }
499
500    #[test]
501    fn clean_prompt_strips_thinking() {
502        let input = "<think>hmm let me think</think>\n\na cat on mars";
503        assert_eq!(clean_expanded_prompt(input), "a cat on mars");
504    }
505
506    #[test]
507    fn clean_prompt_strips_multiline_thinking() {
508        let input = "<think>\nstep 1: analyze\nstep 2: expand\n</think>\n\ndetailed prompt here";
509        assert_eq!(clean_expanded_prompt(input), "detailed prompt here");
510    }
511
512    #[test]
513    fn clean_prompt_collapses_whitespace() {
514        assert_eq!(
515            clean_expanded_prompt("a  cat\n\non   mars"),
516            "a cat on mars"
517        );
518    }
519
520    #[test]
521    fn clean_prompt_empty_input() {
522        assert_eq!(clean_expanded_prompt(""), "");
523        assert_eq!(clean_expanded_prompt("   "), "");
524    }
525
526    #[test]
527    fn clean_prompt_only_thinking_block() {
528        let input = "<think>some reasoning</think>";
529        assert_eq!(clean_expanded_prompt(input), "");
530    }
531
532    #[test]
533    fn clean_prompt_preserves_content_without_thinking() {
534        let input = "a beautiful sunset over the ocean, golden light, dramatic clouds";
535        assert_eq!(clean_expanded_prompt(input), input);
536    }
537
538    // ── parse_variations ─────────────────────────────────────────────────
539
540    #[test]
541    fn parse_variations_json_array() {
542        let input = r#"["a cat", "a dog", "a bird"]"#;
543        let result = parse_variations(input, 3);
544        assert_eq!(result, vec!["a cat", "a dog", "a bird"]);
545    }
546
547    #[test]
548    fn parse_variations_embedded_json() {
549        let input = "Here are 3 prompts:\n[\"a cat\", \"a dog\", \"a bird\"]";
550        let result = parse_variations(input, 3);
551        assert_eq!(result, vec!["a cat", "a dog", "a bird"]);
552    }
553
554    #[test]
555    fn parse_variations_json_with_thinking() {
556        let input =
557            "<think>let me think</think>\n\n[\"expanded cat\", \"expanded dog\", \"expanded bird\"]";
558        // The thinking block is inside individual items, not wrapping the JSON.
559        // parse_variations should find the embedded JSON array.
560        let result = parse_variations(input, 3);
561        assert_eq!(result.len(), 3);
562    }
563
564    #[test]
565    fn parse_variations_numbered_list() {
566        let input = "1. a cat on mars\n2. a dog in space\n3. a bird underwater";
567        let result = parse_variations(input, 3);
568        assert_eq!(result.len(), 3);
569        assert!(result[0].contains("cat"));
570        assert!(result[1].contains("dog"));
571        assert!(result[2].contains("bird"));
572    }
573
574    #[test]
575    fn parse_variations_numbered_with_parens() {
576        let input = "1) a cat\n2) a dog\n3) a bird";
577        let result = parse_variations(input, 3);
578        assert_eq!(result.len(), 3);
579        assert!(result[0].contains("cat"));
580    }
581
582    #[test]
583    fn parse_variations_numbered_with_quotes() {
584        let input = "1. \"a cat on mars\"\n2. \"a dog in space\"";
585        let result = parse_variations(input, 2);
586        assert_eq!(result.len(), 2);
587        // Quotes should be stripped by clean_expanded_prompt
588        assert!(!result[0].starts_with('"'));
589        assert!(result[0].contains("cat"));
590    }
591
592    #[test]
593    fn parse_variations_paragraph_fallback() {
594        let input = "A majestic cat sitting on mars\n\nA playful dog floating in space";
595        let result = parse_variations(input, 2);
596        assert_eq!(result.len(), 2);
597        assert!(result[0].contains("cat"));
598        assert!(result[1].contains("dog"));
599    }
600
601    #[test]
602    fn parse_variations_single_text_fallback() {
603        // When nothing else matches, return the whole text as one variation
604        let input = "just a single prompt with no structure";
605        let result = parse_variations(input, 3);
606        assert!(!result.is_empty());
607        assert!(result[0].contains("single prompt"));
608    }
609
610    #[test]
611    fn parse_variations_empty_json_array_falls_through() {
612        // Empty JSON array should fall through to other parsers
613        let input = "[]";
614        let result = parse_variations(input, 3);
615        // Should not panic; falls through to numbered list / paragraph / fallback
616        assert!(!result.is_empty());
617    }
618
619    #[test]
620    fn parse_variations_cleans_each_item() {
621        let input = r#"["  a cat  ", "  a dog  "]"#;
622        let result = parse_variations(input, 2);
623        assert_eq!(result[0], "a cat");
624        assert_eq!(result[1], "a dog");
625    }
626
627    // ── ExpandSettings ───────────────────────────────────────────────────
628
629    #[test]
630    fn expand_settings_defaults() {
631        let settings = ExpandSettings::default();
632        assert!(!settings.enabled);
633        assert_eq!(settings.backend, "local");
634        assert_eq!(settings.model, "qwen3-expand:q8");
635        assert_eq!(settings.api_model, "qwen2.5:3b");
636        assert_eq!(settings.temperature, 0.7);
637        assert_eq!(settings.top_p, 0.9);
638        assert_eq!(settings.max_tokens, 300);
639        assert!(!settings.thinking);
640        assert!(settings.system_prompt.is_none());
641        assert!(settings.batch_prompt.is_none());
642        assert!(settings.families.is_empty());
643    }
644
645    #[test]
646    fn expand_settings_is_local() {
647        let settings = ExpandSettings::default();
648        assert!(settings.is_local());
649
650        let api_settings = ExpandSettings {
651            backend: "http://localhost:11434".to_string(),
652            ..Default::default()
653        };
654        assert!(!api_settings.is_local());
655    }
656
657    #[test]
658    fn expand_settings_create_api_expander_none_for_local() {
659        let settings = ExpandSettings::default();
660        assert!(settings.create_api_expander().is_none());
661    }
662
663    #[test]
664    fn expand_settings_create_api_expander_some_for_url() {
665        let settings = ExpandSettings {
666            backend: "http://localhost:11434".to_string(),
667            api_model: "llama3:8b".to_string(),
668            ..Default::default()
669        };
670        let expander = settings.create_api_expander();
671        assert!(expander.is_some());
672    }
673
674    #[test]
675    fn expand_settings_to_expand_config() {
676        let settings = ExpandSettings {
677            temperature: 0.5,
678            top_p: 0.8,
679            max_tokens: 200,
680            thinking: true,
681            ..Default::default()
682        };
683        let config = settings.to_expand_config("sdxl", 3);
684        assert_eq!(config.model_family, "sdxl");
685        assert_eq!(config.variations, 3);
686        assert_eq!(config.temperature, 0.5);
687        assert_eq!(config.top_p, 0.8);
688        assert_eq!(config.max_tokens, 200);
689        assert!(config.thinking);
690    }
691
692    #[test]
693    fn expand_settings_serde_roundtrip() {
694        let mut families = HashMap::new();
695        families.insert(
696            "sd15".to_string(),
697            FamilyOverride {
698                word_limit: Some(80),
699                style_notes: Some("Custom SD1.5 notes.".to_string()),
700            },
701        );
702        let settings = ExpandSettings {
703            enabled: true,
704            backend: "http://example.com".to_string(),
705            model: "qwen3-expand-small:q8".to_string(),
706            api_model: "gpt-4".to_string(),
707            temperature: 1.2,
708            top_p: 0.95,
709            max_tokens: 500,
710            thinking: true,
711            system_prompt: Some("Custom system prompt {WORD_LIMIT} {MODEL_NOTES}".to_string()),
712            batch_prompt: Some("Custom batch {N} {WORD_LIMIT} {MODEL_NOTES}".to_string()),
713            families,
714        };
715        let toml_str = toml::to_string(&settings).unwrap();
716        let deserialized: ExpandSettings = toml::from_str(&toml_str).unwrap();
717        assert_eq!(deserialized.enabled, settings.enabled);
718        assert_eq!(deserialized.backend, settings.backend);
719        assert_eq!(deserialized.model, settings.model);
720        assert_eq!(deserialized.api_model, settings.api_model);
721        assert_eq!(deserialized.temperature, settings.temperature);
722        assert_eq!(deserialized.max_tokens, settings.max_tokens);
723        assert_eq!(deserialized.thinking, settings.thinking);
724        assert_eq!(deserialized.system_prompt, settings.system_prompt);
725        assert_eq!(deserialized.batch_prompt, settings.batch_prompt);
726        assert_eq!(deserialized.families.len(), 1);
727        let sd15 = deserialized.families.get("sd15").unwrap();
728        assert_eq!(sd15.word_limit, Some(80));
729        assert_eq!(sd15.style_notes.as_deref(), Some("Custom SD1.5 notes."));
730    }
731
732    #[test]
733    fn expand_settings_serde_defaults_on_empty() {
734        // Deserializing an empty table should produce all defaults
735        let deserialized: ExpandSettings = toml::from_str("").unwrap();
736        let defaults = ExpandSettings::default();
737        assert_eq!(deserialized.enabled, defaults.enabled);
738        assert_eq!(deserialized.backend, defaults.backend);
739        assert_eq!(deserialized.model, defaults.model);
740        assert_eq!(deserialized.temperature, defaults.temperature);
741    }
742
743    // ── ApiExpander ──────────────────────────────────────────────────────
744
745    #[test]
746    fn api_expander_strips_trailing_slash() {
747        let expander = ApiExpander::new("http://localhost:11434/", "qwen2.5:3b");
748        assert_eq!(expander.endpoint, "http://localhost:11434");
749    }
750
751    #[test]
752    fn api_expander_no_slash_unchanged() {
753        let expander = ApiExpander::new("http://localhost:11434", "qwen2.5:3b");
754        assert_eq!(expander.endpoint, "http://localhost:11434");
755    }
756
757    // ── ExpandConfig ─────────────────────────────────────────────────────
758
759    #[test]
760    fn expand_config_default() {
761        let config = ExpandConfig::default();
762        assert_eq!(config.model_family, "flux");
763        assert_eq!(config.variations, 1);
764        assert_eq!(config.temperature, 0.7);
765        assert_eq!(config.max_tokens, 300);
766        assert!(!config.thinking);
767    }
768
769    // ── env overrides ────────────────────────────────────────────────────
770    // These tests use a serial approach to avoid env var races.
771
772    #[test]
773    fn env_override_model_routes_to_local() {
774        // When backend is "local", MOLD_EXPAND_MODEL should set self.model
775        let settings = ExpandSettings::default();
776        assert!(settings.is_local());
777        // We can't safely set env vars in parallel tests, but we can test
778        // the routing logic directly:
779        let mut s = settings;
780        let v = "qwen3-expand-small:q8".to_string();
781        if s.is_local() {
782            s.model = v.clone();
783        } else {
784            s.api_model = v.clone();
785        }
786        assert_eq!(s.model, "qwen3-expand-small:q8");
787        assert_eq!(s.api_model, "qwen2.5:3b"); // unchanged
788    }
789
790    #[test]
791    fn env_override_model_routes_to_api() {
792        // When backend is a URL, MOLD_EXPAND_MODEL should set self.api_model
793        let mut s = ExpandSettings {
794            backend: "http://localhost:11434".to_string(),
795            ..Default::default()
796        };
797        assert!(!s.is_local());
798        let v = "llama3:70b".to_string();
799        if s.is_local() {
800            s.model = v.clone();
801        } else {
802            s.api_model = v.clone();
803        }
804        assert_eq!(s.api_model, "llama3:70b");
805        assert_eq!(s.model, "qwen3-expand:q8"); // unchanged
806    }
807
808    // ── template overrides in ExpandConfig ───────────────────────────────
809
810    #[test]
811    fn to_expand_config_passes_overrides() {
812        let mut families = HashMap::new();
813        families.insert(
814            "flux".to_string(),
815            FamilyOverride {
816                word_limit: Some(200),
817                style_notes: None,
818            },
819        );
820        let settings = ExpandSettings {
821            system_prompt: Some("Custom {WORD_LIMIT} {MODEL_NOTES}".to_string()),
822            batch_prompt: Some("Batch {N} {WORD_LIMIT} {MODEL_NOTES}".to_string()),
823            families,
824            ..Default::default()
825        };
826        let config = settings.to_expand_config("flux", 3);
827        assert_eq!(
828            config.system_prompt.as_deref(),
829            Some("Custom {WORD_LIMIT} {MODEL_NOTES}")
830        );
831        assert_eq!(
832            config.batch_prompt.as_deref(),
833            Some("Batch {N} {WORD_LIMIT} {MODEL_NOTES}")
834        );
835        assert_eq!(config.family_overrides.len(), 1);
836        assert_eq!(
837            config.family_overrides.get("flux").unwrap().word_limit,
838            Some(200)
839        );
840    }
841
842    #[test]
843    fn expand_config_default_has_no_overrides() {
844        let config = ExpandConfig::default();
845        assert!(config.system_prompt.is_none());
846        assert!(config.batch_prompt.is_none());
847        assert!(config.family_overrides.is_empty());
848    }
849
850    // ── validate_templates ──────────────────────────────────────────────
851
852    #[test]
853    fn validate_templates_valid() {
854        let settings = ExpandSettings {
855            system_prompt: Some("You are a writer. {WORD_LIMIT} words. {MODEL_NOTES}".to_string()),
856            batch_prompt: Some(
857                "Generate {N} prompts. {WORD_LIMIT} words. {MODEL_NOTES}".to_string(),
858            ),
859            ..Default::default()
860        };
861        assert!(settings.validate_templates().is_empty());
862    }
863
864    #[test]
865    fn validate_templates_none_is_valid() {
866        let settings = ExpandSettings::default();
867        assert!(settings.validate_templates().is_empty());
868    }
869
870    #[test]
871    fn validate_templates_missing_word_limit() {
872        let settings = ExpandSettings {
873            system_prompt: Some("You are a writer. {MODEL_NOTES}".to_string()),
874            ..Default::default()
875        };
876        let errors = settings.validate_templates();
877        assert_eq!(errors.len(), 1);
878        assert!(errors[0].contains("{WORD_LIMIT}"));
879    }
880
881    #[test]
882    fn validate_templates_missing_model_notes() {
883        let settings = ExpandSettings {
884            system_prompt: Some("You are a writer. {WORD_LIMIT} words.".to_string()),
885            ..Default::default()
886        };
887        let errors = settings.validate_templates();
888        assert_eq!(errors.len(), 1);
889        assert!(errors[0].contains("{MODEL_NOTES}"));
890    }
891
892    #[test]
893    fn validate_templates_batch_missing_n() {
894        let settings = ExpandSettings {
895            batch_prompt: Some("Generate prompts. {WORD_LIMIT} {MODEL_NOTES}".to_string()),
896            ..Default::default()
897        };
898        let errors = settings.validate_templates();
899        assert_eq!(errors.len(), 1);
900        assert!(errors[0].contains("{N}"));
901    }
902
903    #[test]
904    fn validate_templates_batch_missing_all() {
905        let settings = ExpandSettings {
906            batch_prompt: Some("Generate prompts.".to_string()),
907            ..Default::default()
908        };
909        let errors = settings.validate_templates();
910        assert_eq!(errors.len(), 3);
911    }
912
913    // ── FamilyOverride serde ────────────────────────────────────────────
914
915    #[test]
916    fn family_override_serde_roundtrip() {
917        let ov = FamilyOverride {
918            word_limit: Some(100),
919            style_notes: Some("Be creative.".to_string()),
920        };
921        let json = serde_json::to_string(&ov).unwrap();
922        let deserialized: FamilyOverride = serde_json::from_str(&json).unwrap();
923        assert_eq!(deserialized.word_limit, Some(100));
924        assert_eq!(deserialized.style_notes.as_deref(), Some("Be creative."));
925    }
926
927    #[test]
928    fn family_override_partial_toml() {
929        let toml_str = "word_limit = 75\n";
930        let ov: FamilyOverride = toml::from_str(toml_str).unwrap();
931        assert_eq!(ov.word_limit, Some(75));
932        assert!(ov.style_notes.is_none());
933    }
934
935    // ── full config with families in TOML ───────────────────────────────
936
937    #[test]
938    fn expand_settings_toml_with_families() {
939        let toml_str = r#"
940enabled = true
941system_prompt = "Custom prompt. {WORD_LIMIT} words. {MODEL_NOTES}"
942
943[families.sd15]
944word_limit = 40
945style_notes = "Short keywords only."
946
947[families.flux]
948word_limit = 250
949"#;
950        let settings: ExpandSettings = toml::from_str(toml_str).unwrap();
951        assert!(settings.enabled);
952        assert!(settings.system_prompt.is_some());
953        assert_eq!(settings.families.len(), 2);
954        let sd15 = settings.families.get("sd15").unwrap();
955        assert_eq!(sd15.word_limit, Some(40));
956        assert_eq!(sd15.style_notes.as_deref(), Some("Short keywords only."));
957        let flux = settings.families.get("flux").unwrap();
958        assert_eq!(flux.word_limit, Some(250));
959        assert!(flux.style_notes.is_none());
960    }
961}