Skip to main content

sqz_engine/
preset.rs

1use serde::{Deserialize, Serialize};
2use std::collections::HashMap;
3
4use crate::error::{Result, SqzError};
5
6/// Parses, validates, and serializes TOML Preset files.
7pub struct PresetParser;
8
9impl PresetParser {
10    /// Parse a TOML string into a validated `Preset`.
11    pub fn parse(toml_str: &str) -> Result<Preset> {
12        let preset: Preset = toml::from_str(toml_str)?;
13        Self::validate(&preset)?;
14        Ok(preset)
15    }
16
17    /// Serialize a `Preset` back to a pretty-printed TOML string.
18    pub fn to_toml(preset: &Preset) -> Result<String> {
19        Ok(toml::to_string_pretty(preset)?)
20    }
21
22    /// Validate all fields of a `Preset`, returning descriptive errors.
23    pub fn validate(preset: &Preset) -> Result<()> {
24        if preset.preset.name.is_empty() {
25            return Err(SqzError::PresetValidation {
26                field: "preset.name".to_string(),
27                message: "must not be empty".to_string(),
28            });
29        }
30
31        if preset.preset.version.is_empty() {
32            return Err(SqzError::PresetValidation {
33                field: "preset.version".to_string(),
34                message: "must not be empty".to_string(),
35            });
36        }
37
38        let wt = preset.budget.warning_threshold;
39        if !(wt > 0.0 && wt < 1.0) {
40            return Err(SqzError::PresetValidation {
41                field: "budget.warning_threshold".to_string(),
42                message: "must be between 0.0 and 1.0".to_string(),
43            });
44        }
45
46        let ct = preset.budget.ceiling_threshold;
47        if !(ct > 0.0 && ct < 1.0) || ct <= wt {
48            return Err(SqzError::PresetValidation {
49                field: "budget.ceiling_threshold".to_string(),
50                message: "must be between 0.0 and 1.0 and greater than warning_threshold"
51                    .to_string(),
52            });
53        }
54
55        let max_tools = preset.tool_selection.max_tools;
56        if !(1..=50).contains(&max_tools) {
57            return Err(SqzError::PresetValidation {
58                field: "tool_selection.max_tools".to_string(),
59                message: "must be between 1 and 50".to_string(),
60            });
61        }
62
63        let st = preset.tool_selection.similarity_threshold;
64        if !(st > 0.0 && st < 1.0) {
65            return Err(SqzError::PresetValidation {
66                field: "tool_selection.similarity_threshold".to_string(),
67                message: "must be between 0.0 and 1.0".to_string(),
68            });
69        }
70
71        let cxt = preset.model.complexity_threshold;
72        if !(cxt > 0.0 && cxt < 1.0) {
73            return Err(SqzError::PresetValidation {
74                field: "model.complexity_threshold".to_string(),
75                message: "must be between 0.0 and 1.0".to_string(),
76            });
77        }
78
79        Ok(())
80    }
81}
82
83#[derive(Debug, Clone, Serialize, Deserialize)]
84pub struct Preset {
85    pub preset: PresetHeader,
86    pub compression: CompressionConfig,
87    pub tool_selection: ToolSelectionConfig,
88    pub budget: BudgetConfig,
89    pub terse_mode: TerseModeConfig,
90    pub model: ModelConfig,
91}
92
93/// Identity block at the top of every `.toml` preset file.
94#[derive(Debug, Clone, Serialize, Deserialize)]
95pub struct PresetHeader {
96    /// Short human-readable name, e.g. `"code-review"`.
97    pub name: String,
98    /// Semver string, e.g. `"1.0"`.
99    pub version: String,
100    /// Optional one-line description shown in `sqz preset list`.
101    #[serde(default)]
102    pub description: String,
103}
104
105// Keep PresetMeta as an alias so existing code compiles
106pub type PresetMeta = PresetHeader;
107
108// --- Compression ---
109
110#[derive(Debug, Clone, Serialize, Deserialize)]
111pub struct CompressionConfig {
112    #[serde(default)]
113    pub stages: Vec<String>,
114    pub keep_fields: Option<KeepFieldsConfig>,
115    pub strip_fields: Option<StripFieldsConfig>,
116    pub condense: Option<CondenseConfig>,
117    pub git_diff_fold: Option<GitDiffFoldConfig>,
118    pub strip_nulls: Option<StripNullsConfig>,
119    pub flatten: Option<FlattenConfig>,
120    pub truncate_strings: Option<TruncateStringsConfig>,
121    pub collapse_arrays: Option<CollapseArraysConfig>,
122    pub custom_transforms: Option<CustomTransformsConfig>,
123}
124
125#[derive(Debug, Clone, Serialize, Deserialize)]
126pub struct GitDiffFoldConfig {
127    pub enabled: bool,
128    #[serde(default = "default_max_context_lines")]
129    pub max_context_lines: u32,
130}
131
132fn default_max_context_lines() -> u32 {
133    2
134}
135
136#[derive(Debug, Clone, Serialize, Deserialize)]
137pub struct KeepFieldsConfig {
138    pub enabled: bool,
139    #[serde(default)]
140    pub fields: Vec<String>,
141}
142
143#[derive(Debug, Clone, Serialize, Deserialize)]
144pub struct StripFieldsConfig {
145    pub enabled: bool,
146    #[serde(default)]
147    pub fields: Vec<String>,
148}
149
150#[derive(Debug, Clone, Serialize, Deserialize)]
151pub struct CondenseConfig {
152    pub enabled: bool,
153    #[serde(default = "default_max_repeated_lines")]
154    pub max_repeated_lines: u32,
155}
156
157fn default_max_repeated_lines() -> u32 {
158    3
159}
160
161#[derive(Debug, Clone, Serialize, Deserialize)]
162pub struct StripNullsConfig {
163    pub enabled: bool,
164}
165
166#[derive(Debug, Clone, Serialize, Deserialize)]
167pub struct FlattenConfig {
168    pub enabled: bool,
169    #[serde(default = "default_max_depth")]
170    pub max_depth: u32,
171}
172
173fn default_max_depth() -> u32 {
174    3
175}
176
177#[derive(Debug, Clone, Serialize, Deserialize)]
178pub struct TruncateStringsConfig {
179    pub enabled: bool,
180    #[serde(default = "default_max_length")]
181    pub max_length: u32,
182}
183
184fn default_max_length() -> u32 {
185    500
186}
187
188#[derive(Debug, Clone, Serialize, Deserialize)]
189pub struct CollapseArraysConfig {
190    pub enabled: bool,
191    #[serde(default = "default_max_items")]
192    pub max_items: u32,
193    #[serde(default)]
194    pub summary_template: String,
195}
196
197fn default_max_items() -> u32 {
198    5
199}
200
201#[derive(Debug, Clone, Serialize, Deserialize)]
202pub struct CustomTransformsConfig {
203    pub enabled: bool,
204}
205
206// --- Tool selection ---
207
208#[derive(Debug, Clone, Serialize, Deserialize)]
209pub struct ToolSelectionConfig {
210    #[serde(default = "default_max_tools")]
211    pub max_tools: usize,
212    #[serde(default = "default_similarity_threshold")]
213    pub similarity_threshold: f64,
214    #[serde(default)]
215    pub default_tools: Vec<String>,
216}
217
218fn default_max_tools() -> usize {
219    5
220}
221
222fn default_similarity_threshold() -> f64 {
223    0.7
224}
225
226// --- Budget ---
227
228#[derive(Debug, Clone, Serialize, Deserialize)]
229pub struct BudgetConfig {
230    #[serde(default = "default_warning_threshold")]
231    pub warning_threshold: f64,
232    #[serde(default = "default_ceiling_threshold")]
233    pub ceiling_threshold: f64,
234    #[serde(default = "default_window_size")]
235    pub default_window_size: u32,
236    #[serde(default)]
237    pub agents: HashMap<String, f64>,
238}
239
240fn default_warning_threshold() -> f64 {
241    0.70
242}
243
244fn default_ceiling_threshold() -> f64 {
245    0.85
246}
247
248fn default_window_size() -> u32 {
249    200_000
250}
251
252// --- Terse mode ---
253
254#[derive(Debug, Clone, Serialize, Deserialize)]
255pub struct TerseModeConfig {
256    pub enabled: bool,
257    #[serde(default = "default_terse_level")]
258    pub level: TerseLevel,
259}
260
261#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
262#[serde(rename_all = "lowercase")]
263pub enum TerseLevel {
264    Minimal,
265    Moderate,
266    Verbose,
267}
268
269fn default_terse_level() -> TerseLevel {
270    TerseLevel::Moderate
271}
272
273// --- Model ---
274
275#[derive(Debug, Clone, Serialize, Deserialize)]
276pub struct ModelConfig {
277    pub family: String,
278    #[serde(default)]
279    pub primary: String,
280    #[serde(default)]
281    pub local: String,
282    #[serde(default = "default_complexity_threshold")]
283    pub complexity_threshold: f64,
284    pub pricing: Option<ModelPricingConfig>,
285}
286
287fn default_complexity_threshold() -> f64 {
288    0.4
289}
290
291#[derive(Debug, Clone, Serialize, Deserialize)]
292pub struct ModelPricingConfig {
293    pub input_per_1k: f64,
294    pub output_per_1k: f64,
295    #[serde(default)]
296    pub cache_read_discount: f64,
297}
298
299impl Default for Preset {
300    fn default() -> Self {
301        Preset {
302            preset: PresetMeta {
303                name: "default".to_string(),
304                version: "1.0".to_string(),
305                description: "Default compression preset for general development".to_string(),
306            },
307            compression: CompressionConfig {
308                stages: vec![
309                    "keep_fields".to_string(),
310                    "strip_fields".to_string(),
311                    "condense".to_string(),
312                    "strip_nulls".to_string(),
313                    "flatten".to_string(),
314                    "truncate_strings".to_string(),
315                    "collapse_arrays".to_string(),
316                    "custom_transforms".to_string(),
317                ],
318                keep_fields: Some(KeepFieldsConfig {
319                    enabled: false,
320                    fields: vec![
321                        "id".to_string(),
322                        "name".to_string(),
323                        "type".to_string(),
324                        "status".to_string(),
325                        "error".to_string(),
326                        "message".to_string(),
327                    ],
328                }),
329                strip_fields: Some(StripFieldsConfig {
330                    enabled: true,
331                    fields: vec![
332                        "metadata.internal_id".to_string(),
333                        "debug_info".to_string(),
334                        "trace_id".to_string(),
335                    ],
336                }),
337                condense: Some(CondenseConfig {
338                    enabled: true,
339                    max_repeated_lines: 3,
340                }),
341                git_diff_fold: Some(GitDiffFoldConfig {
342                    enabled: true,
343                    max_context_lines: 2,
344                }),
345                strip_nulls: Some(StripNullsConfig { enabled: true }),
346                flatten: Some(FlattenConfig {
347                    enabled: true,
348                    max_depth: 3,
349                }),
350                truncate_strings: Some(TruncateStringsConfig {
351                    enabled: true,
352                    max_length: 500,
353                }),
354                collapse_arrays: Some(CollapseArraysConfig {
355                    enabled: true,
356                    max_items: 5,
357                    summary_template: "... and {remaining} more items".to_string(),
358                }),
359                custom_transforms: Some(CustomTransformsConfig { enabled: true }),
360            },
361            tool_selection: ToolSelectionConfig {
362                max_tools: 5,
363                similarity_threshold: 0.7,
364                default_tools: vec![
365                    "read_file".to_string(),
366                    "write_file".to_string(),
367                    "search".to_string(),
368                ],
369            },
370            budget: BudgetConfig {
371                warning_threshold: 0.70,
372                ceiling_threshold: 0.85,
373                default_window_size: 200_000,
374                agents: {
375                    let mut m = HashMap::new();
376                    m.insert("parent".to_string(), 0.60);
377                    m.insert("child".to_string(), 0.20);
378                    m
379                },
380            },
381            terse_mode: TerseModeConfig {
382                enabled: true,
383                level: TerseLevel::Moderate,
384            },
385            model: ModelConfig {
386                family: "anthropic".to_string(),
387                primary: "claude-sonnet-4-20250514".to_string(),
388                local: "llama-3.1-8b".to_string(),
389                complexity_threshold: 0.4,
390                pricing: Some(ModelPricingConfig {
391                    input_per_1k: 0.003,
392                    output_per_1k: 0.015,
393                    cache_read_discount: 0.9,
394                }),
395            },
396        }
397    }
398}
399
400// ---------------------------------------------------------------------------
401// Tests
402// ---------------------------------------------------------------------------
403
404#[cfg(test)]
405mod tests {
406    use super::*;
407    use proptest::prelude::*;
408
409    // ---------------------------------------------------------------------------
410    // Strategies for generating valid Preset objects
411    // ---------------------------------------------------------------------------
412
413    /// Non-empty string strategy (printable ASCII, no control chars).
414    fn arb_nonempty_string() -> impl Strategy<Value = String> {
415        "[a-zA-Z0-9_\\-\\.]{1,32}".prop_map(|s| s)
416    }
417
418    /// f64 strictly in (0.0, 1.0), exclusive.
419    fn arb_open_unit() -> impl Strategy<Value = f64> {
420        (1u32..=9999u32).prop_map(|n| n as f64 / 10_000.0)
421    }
422
423    /// Strategy for a valid BudgetConfig: ceiling > warning, both in (0, 1).
424    fn arb_budget_config() -> impl Strategy<Value = BudgetConfig> {
425        // Pick warning in (0, 0.9), then ceiling in (warning, 1.0).
426        (1u32..=8999u32).prop_flat_map(|w_raw| {
427            let warning = w_raw as f64 / 10_000.0; // in (0.0001, 0.8999)
428            // ceiling must be > warning and < 1.0
429            let c_min = (w_raw + 1) as f64 / 10_000.0;
430            let c_max = 9999.0_f64 / 10_000.0;
431            // Map a u32 in [c_min_int, 9999] to a f64
432            let c_min_int = w_raw + 1;
433            (c_min_int..=9999u32).prop_map(move |c_raw| {
434                let ceiling = c_raw as f64 / 10_000.0;
435                let _ = (c_min, c_max); // suppress unused warnings
436                BudgetConfig {
437                    warning_threshold: warning,
438                    ceiling_threshold: ceiling,
439                    default_window_size: 200_000,
440                    agents: Default::default(),
441                }
442            })
443        })
444    }
445
446    /// Strategy for a valid ToolSelectionConfig.
447    fn arb_tool_selection_config() -> impl Strategy<Value = ToolSelectionConfig> {
448        (1usize..=50usize, arb_open_unit()).prop_map(|(max_tools, similarity_threshold)| {
449            ToolSelectionConfig {
450                max_tools,
451                similarity_threshold,
452                default_tools: vec![],
453            }
454        })
455    }
456
457    /// Strategy for a valid ModelConfig.
458    fn arb_model_config() -> impl Strategy<Value = ModelConfig> {
459        (arb_nonempty_string(), arb_open_unit()).prop_map(|(family, complexity_threshold)| {
460            ModelConfig {
461                family,
462                primary: String::new(),
463                local: String::new(),
464                complexity_threshold,
465                pricing: None,
466            }
467        })
468    }
469
470    /// Strategy for a valid Preset.
471    fn arb_preset() -> impl Strategy<Value = Preset> {
472        (
473            arb_nonempty_string(), // name
474            arb_nonempty_string(), // version
475            arb_budget_config(),
476            arb_tool_selection_config(),
477            arb_model_config(),
478        )
479            .prop_map(|(name, version, budget, tool_selection, model)| Preset {
480                preset: PresetMeta {
481                    name,
482                    version,
483                    description: String::new(),
484                },
485                compression: CompressionConfig {
486                    stages: vec![],
487                    keep_fields: None,
488                    strip_fields: None,
489                    condense: None,
490                    git_diff_fold: None,
491                    strip_nulls: None,
492                    flatten: None,
493                    truncate_strings: None,
494                    collapse_arrays: None,
495                    custom_transforms: None,
496                },
497                tool_selection,
498                budget,
499                terse_mode: TerseModeConfig {
500                    enabled: false,
501                    level: TerseLevel::Moderate,
502                },
503                model,
504            })
505    }
506
507    // ---------------------------------------------------------------------------
508    // Property 31: TOML Preset round-trip
509    // Validates: Requirements 29.1, 29.2, 29.3
510    // ---------------------------------------------------------------------------
511
512    proptest! {
513        /// **Validates: Requirements 29.1, 29.2, 29.3**
514        ///
515        /// Property 31: TOML Preset round-trip.
516        ///
517        /// For all valid Preset objects, serializing to TOML then deserializing
518        /// SHALL produce an equivalent Preset object.
519        ///
520        /// We compare by double-serializing: serialize the original to TOML,
521        /// parse it back, serialize again, and assert the two TOML strings are
522        /// identical. This avoids f64 direct comparison issues while still
523        /// verifying full fidelity.
524        #[test]
525        fn prop_preset_toml_round_trip(preset in arb_preset()) {
526            // First serialize
527            let toml1 = PresetParser::to_toml(&preset)
528                .expect("to_toml should not fail on a valid preset");
529
530            // Parse back
531            let parsed = PresetParser::parse(&toml1)
532                .expect("parse should not fail on a valid TOML string");
533
534            // Second serialize
535            let toml2 = PresetParser::to_toml(&parsed)
536                .expect("to_toml should not fail on re-parsed preset");
537
538            // The two TOML strings must be identical
539            prop_assert_eq!(
540                &toml1,
541                &toml2,
542                "TOML round-trip mismatch:\nfirst:  {}\nsecond: {}",
543                toml1,
544                toml2
545            );
546        }
547    }
548
549    // ---------------------------------------------------------------------------
550    // Property 32: Preset validation error descriptiveness
551    // Validates: Requirements 24.5, 29.4
552    // ---------------------------------------------------------------------------
553
554    /// Strategy for invalid warning_threshold values: 0.0, 1.0, negative, or >1.0.
555    fn arb_invalid_warning_threshold() -> impl Strategy<Value = f64> {
556        prop_oneof![
557            Just(0.0_f64),
558            Just(1.0_f64),
559            // negative values: -1.0 to -0.0001
560            (1u32..=10000u32).prop_map(|n| -(n as f64 / 10_000.0)),
561            // values > 1.0: 1.0001 to 2.0
562            (10001u32..=20000u32).prop_map(|n| n as f64 / 10_000.0),
563        ]
564    }
565
566    /// Strategy for invalid ceiling_threshold values: 0.0, 1.0, negative, or >1.0.
567    fn arb_invalid_ceiling_threshold() -> impl Strategy<Value = f64> {
568        prop_oneof![
569            Just(0.0_f64),
570            Just(1.0_f64),
571            (1u32..=10000u32).prop_map(|n| -(n as f64 / 10_000.0)),
572            (10001u32..=20000u32).prop_map(|n| n as f64 / 10_000.0),
573        ]
574    }
575
576    /// Strategy for invalid max_tools values: 0 or >50.
577    fn arb_invalid_max_tools() -> impl Strategy<Value = usize> {
578        prop_oneof![
579            Just(0usize),
580            (51usize..=200usize),
581        ]
582    }
583
584    /// Strategy for invalid complexity_threshold values: 0.0, 1.0, negative, or >1.0.
585    fn arb_invalid_complexity_threshold() -> impl Strategy<Value = f64> {
586        prop_oneof![
587            Just(0.0_f64),
588            Just(1.0_f64),
589            (1u32..=10000u32).prop_map(|n| -(n as f64 / 10_000.0)),
590            (10001u32..=20000u32).prop_map(|n| n as f64 / 10_000.0),
591        ]
592    }
593
594    proptest! {
595        /// **Validates: Requirements 24.5, 29.4**
596        ///
597        /// Property 32a: Invalid `budget.warning_threshold` produces a descriptive error
598        /// mentioning "budget.warning_threshold".
599        #[test]
600        fn prop_invalid_warning_threshold_error_mentions_field(
601            invalid_wt in arb_invalid_warning_threshold()
602        ) {
603            let mut preset = Preset::default();
604            preset.budget.warning_threshold = invalid_wt;
605            // Also ensure ceiling > warning to isolate the warning_threshold error.
606            // If invalid_wt >= 0.0 and < 1.0 but ceiling <= warning, we still want
607            // the warning_threshold error to fire first. The validate function checks
608            // warning_threshold before ceiling_threshold, so set ceiling to something
609            // that would be valid if warning were valid.
610            preset.budget.ceiling_threshold = 0.85;
611
612            let result = PresetParser::validate(&preset);
613            prop_assert!(result.is_err(), "expected validation error for warning_threshold={}", invalid_wt);
614            let err_msg = result.unwrap_err().to_string();
615            prop_assert!(
616                err_msg.contains("budget.warning_threshold"),
617                "error message '{}' does not mention 'budget.warning_threshold'",
618                err_msg
619            );
620        }
621
622        /// **Validates: Requirements 24.5, 29.4**
623        ///
624        /// Property 32b: Invalid `budget.ceiling_threshold` produces a descriptive error
625        /// mentioning "budget.ceiling_threshold".
626        #[test]
627        fn prop_invalid_ceiling_threshold_error_mentions_field(
628            invalid_ct in arb_invalid_ceiling_threshold()
629        ) {
630            let mut preset = Preset::default();
631            // Keep warning_threshold valid so ceiling_threshold error fires.
632            preset.budget.warning_threshold = 0.70;
633            preset.budget.ceiling_threshold = invalid_ct;
634
635            let result = PresetParser::validate(&preset);
636            prop_assert!(result.is_err(), "expected validation error for ceiling_threshold={}", invalid_ct);
637            let err_msg = result.unwrap_err().to_string();
638            prop_assert!(
639                err_msg.contains("budget.ceiling_threshold"),
640                "error message '{}' does not mention 'budget.ceiling_threshold'",
641                err_msg
642            );
643        }
644
645        /// **Validates: Requirements 24.5, 29.4**
646        ///
647        /// Property 32c: Empty `preset.name` produces a descriptive error
648        /// mentioning "preset.name".
649        #[test]
650        fn prop_empty_preset_name_error_mentions_field(_dummy in 0u32..1u32) {
651            let mut preset = Preset::default();
652            preset.preset.name = String::new();
653
654            let result = PresetParser::validate(&preset);
655            prop_assert!(result.is_err(), "expected validation error for empty preset.name");
656            let err_msg = result.unwrap_err().to_string();
657            prop_assert!(
658                err_msg.contains("preset.name"),
659                "error message '{}' does not mention 'preset.name'",
660                err_msg
661            );
662        }
663
664        /// **Validates: Requirements 24.5, 29.4**
665        ///
666        /// Property 32d: Invalid `tool_selection.max_tools` (0 or >50) produces a
667        /// descriptive error mentioning "tool_selection.max_tools".
668        #[test]
669        fn prop_invalid_max_tools_error_mentions_field(
670            invalid_mt in arb_invalid_max_tools()
671        ) {
672            let mut preset = Preset::default();
673            preset.tool_selection.max_tools = invalid_mt;
674
675            let result = PresetParser::validate(&preset);
676            prop_assert!(result.is_err(), "expected validation error for max_tools={}", invalid_mt);
677            let err_msg = result.unwrap_err().to_string();
678            prop_assert!(
679                err_msg.contains("tool_selection.max_tools"),
680                "error message '{}' does not mention 'tool_selection.max_tools'",
681                err_msg
682            );
683        }
684
685        /// **Validates: Requirements 24.5, 29.4**
686        ///
687        /// Property 32e: Invalid `model.complexity_threshold` produces a descriptive
688        /// error mentioning "model.complexity_threshold".
689        #[test]
690        fn prop_invalid_complexity_threshold_error_mentions_field(
691            invalid_cxt in arb_invalid_complexity_threshold()
692        ) {
693            let mut preset = Preset::default();
694            preset.model.complexity_threshold = invalid_cxt;
695
696            let result = PresetParser::validate(&preset);
697            prop_assert!(result.is_err(), "expected validation error for complexity_threshold={}", invalid_cxt);
698            let err_msg = result.unwrap_err().to_string();
699            prop_assert!(
700                err_msg.contains("model.complexity_threshold"),
701                "error message '{}' does not mention 'model.complexity_threshold'",
702                err_msg
703            );
704        }
705    }
706}