1use std::collections::HashMap;
8
9use anyhow::Result;
10use serde::{Deserialize, Serialize};
11
12use crate::expand_prompts::{build_batch_messages, build_single_messages};
13
14pub const MAX_VARIATIONS: usize = 10;
16
17pub const DISCORD_MAX_VARIATIONS: usize = 5;
19
20#[derive(Debug, Clone, Deserialize, Serialize)]
22pub struct FamilyOverride {
23 pub word_limit: Option<u32>,
25 pub style_notes: Option<String>,
27}
28
29#[derive(Debug, Clone)]
31pub struct ExpandConfig {
32 pub model_family: String,
34 pub variations: usize,
36 pub temperature: f64,
38 pub top_p: f64,
40 pub max_tokens: u32,
42 pub thinking: bool,
44 pub system_prompt: Option<String>,
47 pub batch_prompt: Option<String>,
50 pub family_overrides: HashMap<String, FamilyOverride>,
52}
53
54impl Default for ExpandConfig {
55 fn default() -> Self {
56 Self {
57 model_family: "flux".to_string(),
58 variations: 1,
59 temperature: 0.7,
60 top_p: 0.9,
61 max_tokens: 300,
62 thinking: false,
63 system_prompt: None,
64 batch_prompt: None,
65 family_overrides: HashMap::new(),
66 }
67 }
68}
69
70#[derive(Debug, Clone)]
72pub struct ExpandResult {
73 pub original: String,
75 pub expanded: Vec<String>,
77}
78
79pub trait PromptExpander: Send + Sync {
81 fn expand(&self, prompt: &str, config: &ExpandConfig) -> Result<ExpandResult>;
83}
84
85#[derive(Debug, Serialize, Deserialize)]
89struct ChatMessage {
90 role: String,
91 content: String,
92}
93
94#[derive(Debug, Serialize)]
96struct ChatCompletionRequest {
97 model: String,
98 messages: Vec<ChatMessage>,
99 temperature: f64,
100 top_p: f64,
101 max_tokens: u32,
102 #[serde(skip_serializing_if = "std::ops::Not::not")]
107 enable_thinking: bool,
108}
109
110#[derive(Debug, Deserialize)]
112struct ChatCompletionResponse {
113 choices: Vec<ChatChoice>,
114}
115
116#[derive(Debug, Deserialize)]
117struct ChatChoice {
118 message: ChatMessageResponse,
119}
120
121#[derive(Debug, Deserialize)]
122struct ChatMessageResponse {
123 content: String,
124}
125
126pub struct ApiExpander {
128 endpoint: String,
129 model: String,
130}
131
132impl ApiExpander {
133 pub fn new(endpoint: &str, model: &str) -> Self {
134 let endpoint = endpoint.trim_end_matches('/').to_string();
136 Self {
137 endpoint,
138 model: model.to_string(),
139 }
140 }
141}
142
143impl PromptExpander for ApiExpander {
144 fn expand(&self, prompt: &str, config: &ExpandConfig) -> Result<ExpandResult> {
145 let family_override = config.family_overrides.get(&config.model_family);
146 let messages = if config.variations > 1 {
147 build_batch_messages(
148 prompt,
149 &config.model_family,
150 config.variations,
151 config.batch_prompt.as_deref(),
152 family_override,
153 )
154 } else {
155 build_single_messages(
156 prompt,
157 &config.model_family,
158 config.system_prompt.as_deref(),
159 family_override,
160 )
161 };
162
163 let chat_messages: Vec<ChatMessage> = messages
164 .into_iter()
165 .map(|(role, content)| ChatMessage { role, content })
166 .collect();
167
168 let req_body = ChatCompletionRequest {
169 model: self.model.clone(),
170 messages: chat_messages,
171 temperature: config.temperature,
172 top_p: config.top_p,
173 max_tokens: config.max_tokens,
174 enable_thinking: config.thinking,
175 };
176
177 let url = format!("{}/v1/chat/completions", self.endpoint);
178
179 let body = serde_json::to_string(&req_body)?;
183 let response_text: String = ureq::post(&url)
184 .header("Content-Type", "application/json")
185 .send(body.as_str())
186 .map_err(|e| anyhow::anyhow!("expand API request failed: {e}"))?
187 .body_mut()
188 .read_to_string()
189 .map_err(|e| anyhow::anyhow!("failed to read expand API response: {e}"))?;
190
191 let completion: ChatCompletionResponse = serde_json::from_str(&response_text)
192 .map_err(|e| anyhow::anyhow!("failed to parse expand API response: {e}"))?;
193
194 let content = completion
195 .choices
196 .first()
197 .map(|c| c.message.content.clone())
198 .filter(|c| !c.trim().is_empty())
199 .ok_or_else(|| {
200 anyhow::anyhow!("expand API returned empty response (no choices or empty content)")
201 })?;
202
203 let expanded = if config.variations > 1 {
204 parse_variations(&content, config.variations)
205 } else {
206 vec![clean_expanded_prompt(&content)]
207 };
208
209 Ok(ExpandResult {
210 original: prompt.to_string(),
211 expanded,
212 })
213 }
214}
215
216pub fn parse_variations_public(text: &str, expected: usize) -> Vec<String> {
218 parse_variations(text, expected)
219}
220
221pub fn clean_expanded_prompt_public(text: &str) -> String {
223 clean_expanded_prompt(text)
224}
225
226fn parse_variations(text: &str, expected: usize) -> Vec<String> {
229 let trimmed = text.trim();
230
231 if let Ok(arr) = serde_json::from_str::<Vec<String>>(trimmed) {
233 if !arr.is_empty() {
234 return arr.into_iter().map(|s| clean_expanded_prompt(&s)).collect();
235 }
236 }
237
238 if let Some(start) = trimmed.find('[') {
240 if let Some(end) = trimmed.rfind(']') {
241 if start < end {
242 let json_slice = &trimmed[start..=end];
243 if let Ok(arr) = serde_json::from_str::<Vec<String>>(json_slice) {
244 if !arr.is_empty() {
245 return arr.into_iter().map(|s| clean_expanded_prompt(&s)).collect();
246 }
247 }
248 }
249 }
250 }
251
252 let lines: Vec<String> = trimmed
254 .lines()
255 .map(|l| l.trim())
256 .filter(|l| !l.is_empty())
257 .map(|l| {
258 let stripped = l
260 .trim_start_matches(|c: char| c.is_ascii_digit())
261 .trim_start_matches(['.', ')', ':', '-'])
262 .trim_start_matches('"')
263 .trim_end_matches('"')
264 .trim();
265 clean_expanded_prompt(stripped)
266 })
267 .filter(|l| !l.is_empty())
268 .collect();
269
270 if lines.len() >= expected {
271 return lines;
272 }
273
274 let paragraphs: Vec<String> = trimmed
276 .split("\n\n")
277 .map(|p| clean_expanded_prompt(p.trim()))
278 .filter(|p| !p.is_empty())
279 .collect();
280
281 if !paragraphs.is_empty() {
282 return paragraphs;
283 }
284
285 vec![clean_expanded_prompt(trimmed)]
287}
288
289fn clean_expanded_prompt(text: &str) -> String {
291 let trimmed = text.trim().trim_matches('"').trim_matches('\'').trim();
292
293 let cleaned = if let Some(end_idx) = trimmed.find("</think>") {
295 trimmed[end_idx + "</think>".len()..].trim()
296 } else {
297 trimmed
298 };
299
300 cleaned.split_whitespace().collect::<Vec<_>>().join(" ")
302}
303
304#[derive(Debug, Clone, Deserialize, Serialize)]
306pub struct ExpandSettings {
307 #[serde(default)]
309 pub enabled: bool,
310 #[serde(default = "default_backend")]
312 pub backend: String,
313 #[serde(default = "default_expand_model")]
315 pub model: String,
316 #[serde(default = "default_api_model")]
318 pub api_model: String,
319 #[serde(default = "default_temperature")]
321 pub temperature: f64,
322 #[serde(default = "default_top_p")]
324 pub top_p: f64,
325 #[serde(default = "default_max_tokens")]
327 pub max_tokens: u32,
328 #[serde(default)]
330 pub thinking: bool,
331 #[serde(default)]
334 pub system_prompt: Option<String>,
335 #[serde(default)]
338 pub batch_prompt: Option<String>,
339 #[serde(default)]
341 pub families: HashMap<String, FamilyOverride>,
342}
343
344fn default_backend() -> String {
345 "local".to_string()
346}
347
348fn default_expand_model() -> String {
349 "qwen3-expand:q8".to_string()
350}
351
352fn default_api_model() -> String {
353 "qwen2.5:3b".to_string()
354}
355
356fn default_temperature() -> f64 {
357 0.7
358}
359
360fn default_top_p() -> f64 {
361 0.9
362}
363
364fn default_max_tokens() -> u32 {
365 300
366}
367
368impl Default for ExpandSettings {
369 fn default() -> Self {
370 Self {
371 enabled: false,
372 backend: default_backend(),
373 model: default_expand_model(),
374 api_model: default_api_model(),
375 temperature: default_temperature(),
376 top_p: default_top_p(),
377 max_tokens: default_max_tokens(),
378 thinking: false,
379 system_prompt: None,
380 batch_prompt: None,
381 families: HashMap::new(),
382 }
383 }
384}
385
386impl ExpandSettings {
387 pub fn with_env_overrides(mut self) -> Self {
389 if let Ok(v) = std::env::var("MOLD_EXPAND") {
390 self.enabled = matches!(v.trim().to_lowercase().as_str(), "1" | "true" | "yes");
391 }
392 if let Ok(v) = std::env::var("MOLD_EXPAND_BACKEND") {
393 if !v.is_empty() {
394 self.backend = v;
395 }
396 }
397 if let Ok(v) = std::env::var("MOLD_EXPAND_MODEL") {
398 if !v.is_empty() {
399 if self.is_local() {
400 self.model = v;
401 } else {
402 self.api_model = v;
403 }
404 }
405 }
406 if let Ok(v) = std::env::var("MOLD_EXPAND_TEMPERATURE") {
407 if let Ok(t) = v.parse::<f64>() {
408 self.temperature = t;
409 }
410 }
411 if let Ok(v) = std::env::var("MOLD_EXPAND_THINKING") {
412 self.thinking = matches!(v.trim().to_lowercase().as_str(), "1" | "true" | "yes");
413 }
414 if let Ok(v) = std::env::var("MOLD_EXPAND_SYSTEM_PROMPT") {
415 if !v.is_empty() {
416 self.system_prompt = Some(v);
417 }
418 }
419 if let Ok(v) = std::env::var("MOLD_EXPAND_BATCH_PROMPT") {
420 if !v.is_empty() {
421 self.batch_prompt = Some(v);
422 }
423 }
424 self
425 }
426
427 pub fn to_expand_config(&self, model_family: &str, variations: usize) -> ExpandConfig {
429 ExpandConfig {
430 model_family: model_family.to_string(),
431 variations,
432 temperature: self.temperature,
433 top_p: self.top_p,
434 max_tokens: self.max_tokens,
435 thinking: self.thinking,
436 system_prompt: self.system_prompt.clone(),
437 batch_prompt: self.batch_prompt.clone(),
438 family_overrides: self.families.clone(),
439 }
440 }
441
442 pub fn validate_templates(&self) -> Vec<String> {
446 let mut warnings = Vec::new();
447 if let Some(ref tmpl) = self.system_prompt {
448 for placeholder in ["{WORD_LIMIT}", "{MODEL_NOTES}"] {
449 if !tmpl.contains(placeholder) {
450 warnings.push(format!(
451 "system_prompt is missing placeholder {placeholder} — it won't be substituted"
452 ));
453 }
454 }
455 }
456 if let Some(ref tmpl) = self.batch_prompt {
457 for placeholder in ["{N}", "{WORD_LIMIT}", "{MODEL_NOTES}"] {
458 if !tmpl.contains(placeholder) {
459 warnings.push(format!(
460 "batch_prompt is missing placeholder {placeholder} — it won't be substituted"
461 ));
462 }
463 }
464 }
465 warnings
466 }
467
468 pub fn create_api_expander(&self) -> Option<ApiExpander> {
471 if self.backend == "local" {
472 None
473 } else {
474 Some(ApiExpander::new(&self.backend, &self.api_model))
475 }
476 }
477
478 pub fn is_local(&self) -> bool {
480 self.backend == "local"
481 }
482}
483
484#[cfg(test)]
485mod tests {
486 use super::*;
487
488 #[test]
491 fn clean_prompt_strips_quotes() {
492 assert_eq!(clean_expanded_prompt("\"a cat on mars\""), "a cat on mars");
493 }
494
495 #[test]
496 fn clean_prompt_strips_single_quotes() {
497 assert_eq!(clean_expanded_prompt("'a cat on mars'"), "a cat on mars");
498 }
499
500 #[test]
501 fn clean_prompt_strips_thinking() {
502 let input = "<think>hmm let me think</think>\n\na cat on mars";
503 assert_eq!(clean_expanded_prompt(input), "a cat on mars");
504 }
505
506 #[test]
507 fn clean_prompt_strips_multiline_thinking() {
508 let input = "<think>\nstep 1: analyze\nstep 2: expand\n</think>\n\ndetailed prompt here";
509 assert_eq!(clean_expanded_prompt(input), "detailed prompt here");
510 }
511
512 #[test]
513 fn clean_prompt_collapses_whitespace() {
514 assert_eq!(
515 clean_expanded_prompt("a cat\n\non mars"),
516 "a cat on mars"
517 );
518 }
519
520 #[test]
521 fn clean_prompt_empty_input() {
522 assert_eq!(clean_expanded_prompt(""), "");
523 assert_eq!(clean_expanded_prompt(" "), "");
524 }
525
526 #[test]
527 fn clean_prompt_only_thinking_block() {
528 let input = "<think>some reasoning</think>";
529 assert_eq!(clean_expanded_prompt(input), "");
530 }
531
532 #[test]
533 fn clean_prompt_preserves_content_without_thinking() {
534 let input = "a beautiful sunset over the ocean, golden light, dramatic clouds";
535 assert_eq!(clean_expanded_prompt(input), input);
536 }
537
538 #[test]
541 fn parse_variations_json_array() {
542 let input = r#"["a cat", "a dog", "a bird"]"#;
543 let result = parse_variations(input, 3);
544 assert_eq!(result, vec!["a cat", "a dog", "a bird"]);
545 }
546
547 #[test]
548 fn parse_variations_embedded_json() {
549 let input = "Here are 3 prompts:\n[\"a cat\", \"a dog\", \"a bird\"]";
550 let result = parse_variations(input, 3);
551 assert_eq!(result, vec!["a cat", "a dog", "a bird"]);
552 }
553
554 #[test]
555 fn parse_variations_json_with_thinking() {
556 let input =
557 "<think>let me think</think>\n\n[\"expanded cat\", \"expanded dog\", \"expanded bird\"]";
558 let result = parse_variations(input, 3);
561 assert_eq!(result.len(), 3);
562 }
563
564 #[test]
565 fn parse_variations_numbered_list() {
566 let input = "1. a cat on mars\n2. a dog in space\n3. a bird underwater";
567 let result = parse_variations(input, 3);
568 assert_eq!(result.len(), 3);
569 assert!(result[0].contains("cat"));
570 assert!(result[1].contains("dog"));
571 assert!(result[2].contains("bird"));
572 }
573
574 #[test]
575 fn parse_variations_numbered_with_parens() {
576 let input = "1) a cat\n2) a dog\n3) a bird";
577 let result = parse_variations(input, 3);
578 assert_eq!(result.len(), 3);
579 assert!(result[0].contains("cat"));
580 }
581
582 #[test]
583 fn parse_variations_numbered_with_quotes() {
584 let input = "1. \"a cat on mars\"\n2. \"a dog in space\"";
585 let result = parse_variations(input, 2);
586 assert_eq!(result.len(), 2);
587 assert!(!result[0].starts_with('"'));
589 assert!(result[0].contains("cat"));
590 }
591
592 #[test]
593 fn parse_variations_paragraph_fallback() {
594 let input = "A majestic cat sitting on mars\n\nA playful dog floating in space";
595 let result = parse_variations(input, 2);
596 assert_eq!(result.len(), 2);
597 assert!(result[0].contains("cat"));
598 assert!(result[1].contains("dog"));
599 }
600
601 #[test]
602 fn parse_variations_single_text_fallback() {
603 let input = "just a single prompt with no structure";
605 let result = parse_variations(input, 3);
606 assert!(!result.is_empty());
607 assert!(result[0].contains("single prompt"));
608 }
609
610 #[test]
611 fn parse_variations_empty_json_array_falls_through() {
612 let input = "[]";
614 let result = parse_variations(input, 3);
615 assert!(!result.is_empty());
617 }
618
619 #[test]
620 fn parse_variations_cleans_each_item() {
621 let input = r#"[" a cat ", " a dog "]"#;
622 let result = parse_variations(input, 2);
623 assert_eq!(result[0], "a cat");
624 assert_eq!(result[1], "a dog");
625 }
626
627 #[test]
630 fn expand_settings_defaults() {
631 let settings = ExpandSettings::default();
632 assert!(!settings.enabled);
633 assert_eq!(settings.backend, "local");
634 assert_eq!(settings.model, "qwen3-expand:q8");
635 assert_eq!(settings.api_model, "qwen2.5:3b");
636 assert_eq!(settings.temperature, 0.7);
637 assert_eq!(settings.top_p, 0.9);
638 assert_eq!(settings.max_tokens, 300);
639 assert!(!settings.thinking);
640 assert!(settings.system_prompt.is_none());
641 assert!(settings.batch_prompt.is_none());
642 assert!(settings.families.is_empty());
643 }
644
645 #[test]
646 fn expand_settings_is_local() {
647 let settings = ExpandSettings::default();
648 assert!(settings.is_local());
649
650 let api_settings = ExpandSettings {
651 backend: "http://localhost:11434".to_string(),
652 ..Default::default()
653 };
654 assert!(!api_settings.is_local());
655 }
656
657 #[test]
658 fn expand_settings_create_api_expander_none_for_local() {
659 let settings = ExpandSettings::default();
660 assert!(settings.create_api_expander().is_none());
661 }
662
663 #[test]
664 fn expand_settings_create_api_expander_some_for_url() {
665 let settings = ExpandSettings {
666 backend: "http://localhost:11434".to_string(),
667 api_model: "llama3:8b".to_string(),
668 ..Default::default()
669 };
670 let expander = settings.create_api_expander();
671 assert!(expander.is_some());
672 }
673
674 #[test]
675 fn expand_settings_to_expand_config() {
676 let settings = ExpandSettings {
677 temperature: 0.5,
678 top_p: 0.8,
679 max_tokens: 200,
680 thinking: true,
681 ..Default::default()
682 };
683 let config = settings.to_expand_config("sdxl", 3);
684 assert_eq!(config.model_family, "sdxl");
685 assert_eq!(config.variations, 3);
686 assert_eq!(config.temperature, 0.5);
687 assert_eq!(config.top_p, 0.8);
688 assert_eq!(config.max_tokens, 200);
689 assert!(config.thinking);
690 }
691
692 #[test]
693 fn expand_settings_serde_roundtrip() {
694 let mut families = HashMap::new();
695 families.insert(
696 "sd15".to_string(),
697 FamilyOverride {
698 word_limit: Some(80),
699 style_notes: Some("Custom SD1.5 notes.".to_string()),
700 },
701 );
702 let settings = ExpandSettings {
703 enabled: true,
704 backend: "http://example.com".to_string(),
705 model: "qwen3-expand-small:q8".to_string(),
706 api_model: "gpt-4".to_string(),
707 temperature: 1.2,
708 top_p: 0.95,
709 max_tokens: 500,
710 thinking: true,
711 system_prompt: Some("Custom system prompt {WORD_LIMIT} {MODEL_NOTES}".to_string()),
712 batch_prompt: Some("Custom batch {N} {WORD_LIMIT} {MODEL_NOTES}".to_string()),
713 families,
714 };
715 let toml_str = toml::to_string(&settings).unwrap();
716 let deserialized: ExpandSettings = toml::from_str(&toml_str).unwrap();
717 assert_eq!(deserialized.enabled, settings.enabled);
718 assert_eq!(deserialized.backend, settings.backend);
719 assert_eq!(deserialized.model, settings.model);
720 assert_eq!(deserialized.api_model, settings.api_model);
721 assert_eq!(deserialized.temperature, settings.temperature);
722 assert_eq!(deserialized.max_tokens, settings.max_tokens);
723 assert_eq!(deserialized.thinking, settings.thinking);
724 assert_eq!(deserialized.system_prompt, settings.system_prompt);
725 assert_eq!(deserialized.batch_prompt, settings.batch_prompt);
726 assert_eq!(deserialized.families.len(), 1);
727 let sd15 = deserialized.families.get("sd15").unwrap();
728 assert_eq!(sd15.word_limit, Some(80));
729 assert_eq!(sd15.style_notes.as_deref(), Some("Custom SD1.5 notes."));
730 }
731
732 #[test]
733 fn expand_settings_serde_defaults_on_empty() {
734 let deserialized: ExpandSettings = toml::from_str("").unwrap();
736 let defaults = ExpandSettings::default();
737 assert_eq!(deserialized.enabled, defaults.enabled);
738 assert_eq!(deserialized.backend, defaults.backend);
739 assert_eq!(deserialized.model, defaults.model);
740 assert_eq!(deserialized.temperature, defaults.temperature);
741 }
742
743 #[test]
746 fn api_expander_strips_trailing_slash() {
747 let expander = ApiExpander::new("http://localhost:11434/", "qwen2.5:3b");
748 assert_eq!(expander.endpoint, "http://localhost:11434");
749 }
750
751 #[test]
752 fn api_expander_no_slash_unchanged() {
753 let expander = ApiExpander::new("http://localhost:11434", "qwen2.5:3b");
754 assert_eq!(expander.endpoint, "http://localhost:11434");
755 }
756
757 #[test]
760 fn expand_config_default() {
761 let config = ExpandConfig::default();
762 assert_eq!(config.model_family, "flux");
763 assert_eq!(config.variations, 1);
764 assert_eq!(config.temperature, 0.7);
765 assert_eq!(config.max_tokens, 300);
766 assert!(!config.thinking);
767 }
768
769 #[test]
773 fn env_override_model_routes_to_local() {
774 let settings = ExpandSettings::default();
776 assert!(settings.is_local());
777 let mut s = settings;
780 let v = "qwen3-expand-small:q8".to_string();
781 if s.is_local() {
782 s.model = v.clone();
783 } else {
784 s.api_model = v.clone();
785 }
786 assert_eq!(s.model, "qwen3-expand-small:q8");
787 assert_eq!(s.api_model, "qwen2.5:3b"); }
789
790 #[test]
791 fn env_override_model_routes_to_api() {
792 let mut s = ExpandSettings {
794 backend: "http://localhost:11434".to_string(),
795 ..Default::default()
796 };
797 assert!(!s.is_local());
798 let v = "llama3:70b".to_string();
799 if s.is_local() {
800 s.model = v.clone();
801 } else {
802 s.api_model = v.clone();
803 }
804 assert_eq!(s.api_model, "llama3:70b");
805 assert_eq!(s.model, "qwen3-expand:q8"); }
807
808 #[test]
811 fn to_expand_config_passes_overrides() {
812 let mut families = HashMap::new();
813 families.insert(
814 "flux".to_string(),
815 FamilyOverride {
816 word_limit: Some(200),
817 style_notes: None,
818 },
819 );
820 let settings = ExpandSettings {
821 system_prompt: Some("Custom {WORD_LIMIT} {MODEL_NOTES}".to_string()),
822 batch_prompt: Some("Batch {N} {WORD_LIMIT} {MODEL_NOTES}".to_string()),
823 families,
824 ..Default::default()
825 };
826 let config = settings.to_expand_config("flux", 3);
827 assert_eq!(
828 config.system_prompt.as_deref(),
829 Some("Custom {WORD_LIMIT} {MODEL_NOTES}")
830 );
831 assert_eq!(
832 config.batch_prompt.as_deref(),
833 Some("Batch {N} {WORD_LIMIT} {MODEL_NOTES}")
834 );
835 assert_eq!(config.family_overrides.len(), 1);
836 assert_eq!(
837 config.family_overrides.get("flux").unwrap().word_limit,
838 Some(200)
839 );
840 }
841
842 #[test]
843 fn expand_config_default_has_no_overrides() {
844 let config = ExpandConfig::default();
845 assert!(config.system_prompt.is_none());
846 assert!(config.batch_prompt.is_none());
847 assert!(config.family_overrides.is_empty());
848 }
849
850 #[test]
853 fn validate_templates_valid() {
854 let settings = ExpandSettings {
855 system_prompt: Some("You are a writer. {WORD_LIMIT} words. {MODEL_NOTES}".to_string()),
856 batch_prompt: Some(
857 "Generate {N} prompts. {WORD_LIMIT} words. {MODEL_NOTES}".to_string(),
858 ),
859 ..Default::default()
860 };
861 assert!(settings.validate_templates().is_empty());
862 }
863
864 #[test]
865 fn validate_templates_none_is_valid() {
866 let settings = ExpandSettings::default();
867 assert!(settings.validate_templates().is_empty());
868 }
869
870 #[test]
871 fn validate_templates_missing_word_limit() {
872 let settings = ExpandSettings {
873 system_prompt: Some("You are a writer. {MODEL_NOTES}".to_string()),
874 ..Default::default()
875 };
876 let errors = settings.validate_templates();
877 assert_eq!(errors.len(), 1);
878 assert!(errors[0].contains("{WORD_LIMIT}"));
879 }
880
881 #[test]
882 fn validate_templates_missing_model_notes() {
883 let settings = ExpandSettings {
884 system_prompt: Some("You are a writer. {WORD_LIMIT} words.".to_string()),
885 ..Default::default()
886 };
887 let errors = settings.validate_templates();
888 assert_eq!(errors.len(), 1);
889 assert!(errors[0].contains("{MODEL_NOTES}"));
890 }
891
892 #[test]
893 fn validate_templates_batch_missing_n() {
894 let settings = ExpandSettings {
895 batch_prompt: Some("Generate prompts. {WORD_LIMIT} {MODEL_NOTES}".to_string()),
896 ..Default::default()
897 };
898 let errors = settings.validate_templates();
899 assert_eq!(errors.len(), 1);
900 assert!(errors[0].contains("{N}"));
901 }
902
903 #[test]
904 fn validate_templates_batch_missing_all() {
905 let settings = ExpandSettings {
906 batch_prompt: Some("Generate prompts.".to_string()),
907 ..Default::default()
908 };
909 let errors = settings.validate_templates();
910 assert_eq!(errors.len(), 3);
911 }
912
913 #[test]
916 fn family_override_serde_roundtrip() {
917 let ov = FamilyOverride {
918 word_limit: Some(100),
919 style_notes: Some("Be creative.".to_string()),
920 };
921 let json = serde_json::to_string(&ov).unwrap();
922 let deserialized: FamilyOverride = serde_json::from_str(&json).unwrap();
923 assert_eq!(deserialized.word_limit, Some(100));
924 assert_eq!(deserialized.style_notes.as_deref(), Some("Be creative."));
925 }
926
927 #[test]
928 fn family_override_partial_toml() {
929 let toml_str = "word_limit = 75\n";
930 let ov: FamilyOverride = toml::from_str(toml_str).unwrap();
931 assert_eq!(ov.word_limit, Some(75));
932 assert!(ov.style_notes.is_none());
933 }
934
935 #[test]
938 fn expand_settings_toml_with_families() {
939 let toml_str = r#"
940enabled = true
941system_prompt = "Custom prompt. {WORD_LIMIT} words. {MODEL_NOTES}"
942
943[families.sd15]
944word_limit = 40
945style_notes = "Short keywords only."
946
947[families.flux]
948word_limit = 250
949"#;
950 let settings: ExpandSettings = toml::from_str(toml_str).unwrap();
951 assert!(settings.enabled);
952 assert!(settings.system_prompt.is_some());
953 assert_eq!(settings.families.len(), 2);
954 let sd15 = settings.families.get("sd15").unwrap();
955 assert_eq!(sd15.word_limit, Some(40));
956 assert_eq!(sd15.style_notes.as_deref(), Some("Short keywords only."));
957 let flux = settings.families.get("flux").unwrap();
958 assert_eq!(flux.word_limit, Some(250));
959 assert!(flux.style_notes.is_none());
960 }
961}