Skip to main content

sqz_engine/
preset.rs

1use serde::{Deserialize, Serialize};
2use std::collections::HashMap;
3
4use crate::error::{Result, SqzError};
5
6/// Parses, validates, and serializes TOML Preset files.
7pub struct PresetParser;
8
9impl PresetParser {
10    /// Parse a TOML string into a validated `Preset`.
11    pub fn parse(toml_str: &str) -> Result<Preset> {
12        let preset: Preset = toml::from_str(toml_str)?;
13        Self::validate(&preset)?;
14        Ok(preset)
15    }
16
17    /// Serialize a `Preset` back to a pretty-printed TOML string.
18    pub fn to_toml(preset: &Preset) -> Result<String> {
19        Ok(toml::to_string_pretty(preset)?)
20    }
21
22    /// Validate all fields of a `Preset`, returning descriptive errors.
23    pub fn validate(preset: &Preset) -> Result<()> {
24        if preset.preset.name.is_empty() {
25            return Err(SqzError::PresetValidation {
26                field: "preset.name".to_string(),
27                message: "must not be empty".to_string(),
28            });
29        }
30
31        if preset.preset.version.is_empty() {
32            return Err(SqzError::PresetValidation {
33                field: "preset.version".to_string(),
34                message: "must not be empty".to_string(),
35            });
36        }
37
38        let wt = preset.budget.warning_threshold;
39        if !(wt > 0.0 && wt < 1.0) {
40            return Err(SqzError::PresetValidation {
41                field: "budget.warning_threshold".to_string(),
42                message: "must be between 0.0 and 1.0".to_string(),
43            });
44        }
45
46        let ct = preset.budget.ceiling_threshold;
47        if !(ct > 0.0 && ct < 1.0) || ct <= wt {
48            return Err(SqzError::PresetValidation {
49                field: "budget.ceiling_threshold".to_string(),
50                message: "must be between 0.0 and 1.0 and greater than warning_threshold"
51                    .to_string(),
52            });
53        }
54
55        let max_tools = preset.tool_selection.max_tools;
56        if !(1..=50).contains(&max_tools) {
57            return Err(SqzError::PresetValidation {
58                field: "tool_selection.max_tools".to_string(),
59                message: "must be between 1 and 50".to_string(),
60            });
61        }
62
63        let st = preset.tool_selection.similarity_threshold;
64        if !(st > 0.0 && st < 1.0) {
65            return Err(SqzError::PresetValidation {
66                field: "tool_selection.similarity_threshold".to_string(),
67                message: "must be between 0.0 and 1.0".to_string(),
68            });
69        }
70
71        let cxt = preset.model.complexity_threshold;
72        if !(cxt > 0.0 && cxt < 1.0) {
73            return Err(SqzError::PresetValidation {
74                field: "model.complexity_threshold".to_string(),
75                message: "must be between 0.0 and 1.0".to_string(),
76            });
77        }
78
79        Ok(())
80    }
81}
82
83#[derive(Debug, Clone, Serialize, Deserialize)]
84pub struct Preset {
85    pub preset: PresetMeta,
86    pub compression: CompressionConfig,
87    pub tool_selection: ToolSelectionConfig,
88    pub budget: BudgetConfig,
89    pub terse_mode: TerseModeConfig,
90    pub model: ModelConfig,
91}
92
93#[derive(Debug, Clone, Serialize, Deserialize)]
94pub struct PresetMeta {
95    pub name: String,
96    pub version: String,
97    #[serde(default)]
98    pub description: String,
99}
100
101// --- Compression ---
102
103#[derive(Debug, Clone, Serialize, Deserialize)]
104pub struct CompressionConfig {
105    #[serde(default)]
106    pub stages: Vec<String>,
107    pub keep_fields: Option<KeepFieldsConfig>,
108    pub strip_fields: Option<StripFieldsConfig>,
109    pub condense: Option<CondenseConfig>,
110    pub strip_nulls: Option<StripNullsConfig>,
111    pub flatten: Option<FlattenConfig>,
112    pub truncate_strings: Option<TruncateStringsConfig>,
113    pub collapse_arrays: Option<CollapseArraysConfig>,
114    pub custom_transforms: Option<CustomTransformsConfig>,
115}
116
117#[derive(Debug, Clone, Serialize, Deserialize)]
118pub struct KeepFieldsConfig {
119    pub enabled: bool,
120    #[serde(default)]
121    pub fields: Vec<String>,
122}
123
124#[derive(Debug, Clone, Serialize, Deserialize)]
125pub struct StripFieldsConfig {
126    pub enabled: bool,
127    #[serde(default)]
128    pub fields: Vec<String>,
129}
130
131#[derive(Debug, Clone, Serialize, Deserialize)]
132pub struct CondenseConfig {
133    pub enabled: bool,
134    #[serde(default = "default_max_repeated_lines")]
135    pub max_repeated_lines: u32,
136}
137
138fn default_max_repeated_lines() -> u32 {
139    3
140}
141
142#[derive(Debug, Clone, Serialize, Deserialize)]
143pub struct StripNullsConfig {
144    pub enabled: bool,
145}
146
147#[derive(Debug, Clone, Serialize, Deserialize)]
148pub struct FlattenConfig {
149    pub enabled: bool,
150    #[serde(default = "default_max_depth")]
151    pub max_depth: u32,
152}
153
154fn default_max_depth() -> u32 {
155    3
156}
157
158#[derive(Debug, Clone, Serialize, Deserialize)]
159pub struct TruncateStringsConfig {
160    pub enabled: bool,
161    #[serde(default = "default_max_length")]
162    pub max_length: u32,
163}
164
165fn default_max_length() -> u32 {
166    500
167}
168
169#[derive(Debug, Clone, Serialize, Deserialize)]
170pub struct CollapseArraysConfig {
171    pub enabled: bool,
172    #[serde(default = "default_max_items")]
173    pub max_items: u32,
174    #[serde(default)]
175    pub summary_template: String,
176}
177
178fn default_max_items() -> u32 {
179    5
180}
181
182#[derive(Debug, Clone, Serialize, Deserialize)]
183pub struct CustomTransformsConfig {
184    pub enabled: bool,
185}
186
187// --- Tool selection ---
188
189#[derive(Debug, Clone, Serialize, Deserialize)]
190pub struct ToolSelectionConfig {
191    #[serde(default = "default_max_tools")]
192    pub max_tools: usize,
193    #[serde(default = "default_similarity_threshold")]
194    pub similarity_threshold: f64,
195    #[serde(default)]
196    pub default_tools: Vec<String>,
197}
198
199fn default_max_tools() -> usize {
200    5
201}
202
203fn default_similarity_threshold() -> f64 {
204    0.7
205}
206
207// --- Budget ---
208
209#[derive(Debug, Clone, Serialize, Deserialize)]
210pub struct BudgetConfig {
211    #[serde(default = "default_warning_threshold")]
212    pub warning_threshold: f64,
213    #[serde(default = "default_ceiling_threshold")]
214    pub ceiling_threshold: f64,
215    #[serde(default = "default_window_size")]
216    pub default_window_size: u32,
217    #[serde(default)]
218    pub agents: HashMap<String, f64>,
219}
220
221fn default_warning_threshold() -> f64 {
222    0.70
223}
224
225fn default_ceiling_threshold() -> f64 {
226    0.85
227}
228
229fn default_window_size() -> u32 {
230    200_000
231}
232
233// --- Terse mode ---
234
235#[derive(Debug, Clone, Serialize, Deserialize)]
236pub struct TerseModeConfig {
237    pub enabled: bool,
238    #[serde(default = "default_terse_level")]
239    pub level: TerseLevel,
240}
241
242#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
243#[serde(rename_all = "lowercase")]
244pub enum TerseLevel {
245    Minimal,
246    Moderate,
247    Verbose,
248}
249
250fn default_terse_level() -> TerseLevel {
251    TerseLevel::Moderate
252}
253
254// --- Model ---
255
256#[derive(Debug, Clone, Serialize, Deserialize)]
257pub struct ModelConfig {
258    pub family: String,
259    #[serde(default)]
260    pub primary: String,
261    #[serde(default)]
262    pub local: String,
263    #[serde(default = "default_complexity_threshold")]
264    pub complexity_threshold: f64,
265    pub pricing: Option<ModelPricingConfig>,
266}
267
268fn default_complexity_threshold() -> f64 {
269    0.4
270}
271
272#[derive(Debug, Clone, Serialize, Deserialize)]
273pub struct ModelPricingConfig {
274    pub input_per_1k: f64,
275    pub output_per_1k: f64,
276    #[serde(default)]
277    pub cache_read_discount: f64,
278}
279
280impl Default for Preset {
281    fn default() -> Self {
282        Preset {
283            preset: PresetMeta {
284                name: "default".to_string(),
285                version: "1.0".to_string(),
286                description: "Default compression preset for general development".to_string(),
287            },
288            compression: CompressionConfig {
289                stages: vec![
290                    "keep_fields".to_string(),
291                    "strip_fields".to_string(),
292                    "condense".to_string(),
293                    "strip_nulls".to_string(),
294                    "flatten".to_string(),
295                    "truncate_strings".to_string(),
296                    "collapse_arrays".to_string(),
297                    "custom_transforms".to_string(),
298                ],
299                keep_fields: Some(KeepFieldsConfig {
300                    enabled: false,
301                    fields: vec![
302                        "id".to_string(),
303                        "name".to_string(),
304                        "type".to_string(),
305                        "status".to_string(),
306                        "error".to_string(),
307                        "message".to_string(),
308                    ],
309                }),
310                strip_fields: Some(StripFieldsConfig {
311                    enabled: true,
312                    fields: vec![
313                        "metadata.internal_id".to_string(),
314                        "debug_info".to_string(),
315                        "trace_id".to_string(),
316                    ],
317                }),
318                condense: Some(CondenseConfig {
319                    enabled: true,
320                    max_repeated_lines: 3,
321                }),
322                strip_nulls: Some(StripNullsConfig { enabled: true }),
323                flatten: Some(FlattenConfig {
324                    enabled: true,
325                    max_depth: 3,
326                }),
327                truncate_strings: Some(TruncateStringsConfig {
328                    enabled: true,
329                    max_length: 500,
330                }),
331                collapse_arrays: Some(CollapseArraysConfig {
332                    enabled: true,
333                    max_items: 5,
334                    summary_template: "... and {remaining} more items".to_string(),
335                }),
336                custom_transforms: Some(CustomTransformsConfig { enabled: true }),
337            },
338            tool_selection: ToolSelectionConfig {
339                max_tools: 5,
340                similarity_threshold: 0.7,
341                default_tools: vec![
342                    "read_file".to_string(),
343                    "write_file".to_string(),
344                    "search".to_string(),
345                ],
346            },
347            budget: BudgetConfig {
348                warning_threshold: 0.70,
349                ceiling_threshold: 0.85,
350                default_window_size: 200_000,
351                agents: {
352                    let mut m = HashMap::new();
353                    m.insert("parent".to_string(), 0.60);
354                    m.insert("child".to_string(), 0.20);
355                    m
356                },
357            },
358            terse_mode: TerseModeConfig {
359                enabled: true,
360                level: TerseLevel::Moderate,
361            },
362            model: ModelConfig {
363                family: "anthropic".to_string(),
364                primary: "claude-sonnet-4-20250514".to_string(),
365                local: "llama-3.1-8b".to_string(),
366                complexity_threshold: 0.4,
367                pricing: Some(ModelPricingConfig {
368                    input_per_1k: 0.003,
369                    output_per_1k: 0.015,
370                    cache_read_discount: 0.9,
371                }),
372            },
373        }
374    }
375}
376
377// ---------------------------------------------------------------------------
378// Tests
379// ---------------------------------------------------------------------------
380
381#[cfg(test)]
382mod tests {
383    use super::*;
384    use proptest::prelude::*;
385
386    // ---------------------------------------------------------------------------
387    // Strategies for generating valid Preset objects
388    // ---------------------------------------------------------------------------
389
390    /// Non-empty string strategy (printable ASCII, no control chars).
391    fn arb_nonempty_string() -> impl Strategy<Value = String> {
392        "[a-zA-Z0-9_\\-\\.]{1,32}".prop_map(|s| s)
393    }
394
395    /// f64 strictly in (0.0, 1.0), exclusive.
396    fn arb_open_unit() -> impl Strategy<Value = f64> {
397        (1u32..=9999u32).prop_map(|n| n as f64 / 10_000.0)
398    }
399
400    /// Strategy for a valid BudgetConfig: ceiling > warning, both in (0, 1).
401    fn arb_budget_config() -> impl Strategy<Value = BudgetConfig> {
402        // Pick warning in (0, 0.9), then ceiling in (warning, 1.0).
403        (1u32..=8999u32).prop_flat_map(|w_raw| {
404            let warning = w_raw as f64 / 10_000.0; // in (0.0001, 0.8999)
405            // ceiling must be > warning and < 1.0
406            let c_min = (w_raw + 1) as f64 / 10_000.0;
407            let c_max = 9999.0_f64 / 10_000.0;
408            // Map a u32 in [c_min_int, 9999] to a f64
409            let c_min_int = w_raw + 1;
410            (c_min_int..=9999u32).prop_map(move |c_raw| {
411                let ceiling = c_raw as f64 / 10_000.0;
412                let _ = (c_min, c_max); // suppress unused warnings
413                BudgetConfig {
414                    warning_threshold: warning,
415                    ceiling_threshold: ceiling,
416                    default_window_size: 200_000,
417                    agents: Default::default(),
418                }
419            })
420        })
421    }
422
423    /// Strategy for a valid ToolSelectionConfig.
424    fn arb_tool_selection_config() -> impl Strategy<Value = ToolSelectionConfig> {
425        (1usize..=50usize, arb_open_unit()).prop_map(|(max_tools, similarity_threshold)| {
426            ToolSelectionConfig {
427                max_tools,
428                similarity_threshold,
429                default_tools: vec![],
430            }
431        })
432    }
433
434    /// Strategy for a valid ModelConfig.
435    fn arb_model_config() -> impl Strategy<Value = ModelConfig> {
436        (arb_nonempty_string(), arb_open_unit()).prop_map(|(family, complexity_threshold)| {
437            ModelConfig {
438                family,
439                primary: String::new(),
440                local: String::new(),
441                complexity_threshold,
442                pricing: None,
443            }
444        })
445    }
446
447    /// Strategy for a valid Preset.
448    fn arb_preset() -> impl Strategy<Value = Preset> {
449        (
450            arb_nonempty_string(), // name
451            arb_nonempty_string(), // version
452            arb_budget_config(),
453            arb_tool_selection_config(),
454            arb_model_config(),
455        )
456            .prop_map(|(name, version, budget, tool_selection, model)| Preset {
457                preset: PresetMeta {
458                    name,
459                    version,
460                    description: String::new(),
461                },
462                compression: CompressionConfig {
463                    stages: vec![],
464                    keep_fields: None,
465                    strip_fields: None,
466                    condense: None,
467                    strip_nulls: None,
468                    flatten: None,
469                    truncate_strings: None,
470                    collapse_arrays: None,
471                    custom_transforms: None,
472                },
473                tool_selection,
474                budget,
475                terse_mode: TerseModeConfig {
476                    enabled: false,
477                    level: TerseLevel::Moderate,
478                },
479                model,
480            })
481    }
482
483    // ---------------------------------------------------------------------------
484    // Property 31: TOML Preset round-trip
485    // Validates: Requirements 29.1, 29.2, 29.3
486    // ---------------------------------------------------------------------------
487
488    proptest! {
489        /// **Validates: Requirements 29.1, 29.2, 29.3**
490        ///
491        /// Property 31: TOML Preset round-trip.
492        ///
493        /// For all valid Preset objects, serializing to TOML then deserializing
494        /// SHALL produce an equivalent Preset object.
495        ///
496        /// We compare by double-serializing: serialize the original to TOML,
497        /// parse it back, serialize again, and assert the two TOML strings are
498        /// identical. This avoids f64 direct comparison issues while still
499        /// verifying full fidelity.
500        #[test]
501        fn prop_preset_toml_round_trip(preset in arb_preset()) {
502            // First serialize
503            let toml1 = PresetParser::to_toml(&preset)
504                .expect("to_toml should not fail on a valid preset");
505
506            // Parse back
507            let parsed = PresetParser::parse(&toml1)
508                .expect("parse should not fail on a valid TOML string");
509
510            // Second serialize
511            let toml2 = PresetParser::to_toml(&parsed)
512                .expect("to_toml should not fail on re-parsed preset");
513
514            // The two TOML strings must be identical
515            prop_assert_eq!(
516                &toml1,
517                &toml2,
518                "TOML round-trip mismatch:\nfirst:  {}\nsecond: {}",
519                toml1,
520                toml2
521            );
522        }
523    }
524
525    // ---------------------------------------------------------------------------
526    // Property 32: Preset validation error descriptiveness
527    // Validates: Requirements 24.5, 29.4
528    // ---------------------------------------------------------------------------
529
530    /// Strategy for invalid warning_threshold values: 0.0, 1.0, negative, or >1.0.
531    fn arb_invalid_warning_threshold() -> impl Strategy<Value = f64> {
532        prop_oneof![
533            Just(0.0_f64),
534            Just(1.0_f64),
535            // negative values: -1.0 to -0.0001
536            (1u32..=10000u32).prop_map(|n| -(n as f64 / 10_000.0)),
537            // values > 1.0: 1.0001 to 2.0
538            (10001u32..=20000u32).prop_map(|n| n as f64 / 10_000.0),
539        ]
540    }
541
542    /// Strategy for invalid ceiling_threshold values: 0.0, 1.0, negative, or >1.0.
543    fn arb_invalid_ceiling_threshold() -> impl Strategy<Value = f64> {
544        prop_oneof![
545            Just(0.0_f64),
546            Just(1.0_f64),
547            (1u32..=10000u32).prop_map(|n| -(n as f64 / 10_000.0)),
548            (10001u32..=20000u32).prop_map(|n| n as f64 / 10_000.0),
549        ]
550    }
551
552    /// Strategy for invalid max_tools values: 0 or >50.
553    fn arb_invalid_max_tools() -> impl Strategy<Value = usize> {
554        prop_oneof![
555            Just(0usize),
556            (51usize..=200usize),
557        ]
558    }
559
560    /// Strategy for invalid complexity_threshold values: 0.0, 1.0, negative, or >1.0.
561    fn arb_invalid_complexity_threshold() -> impl Strategy<Value = f64> {
562        prop_oneof![
563            Just(0.0_f64),
564            Just(1.0_f64),
565            (1u32..=10000u32).prop_map(|n| -(n as f64 / 10_000.0)),
566            (10001u32..=20000u32).prop_map(|n| n as f64 / 10_000.0),
567        ]
568    }
569
570    proptest! {
571        /// **Validates: Requirements 24.5, 29.4**
572        ///
573        /// Property 32a: Invalid `budget.warning_threshold` produces a descriptive error
574        /// mentioning "budget.warning_threshold".
575        #[test]
576        fn prop_invalid_warning_threshold_error_mentions_field(
577            invalid_wt in arb_invalid_warning_threshold()
578        ) {
579            let mut preset = Preset::default();
580            preset.budget.warning_threshold = invalid_wt;
581            // Also ensure ceiling > warning to isolate the warning_threshold error.
582            // If invalid_wt >= 0.0 and < 1.0 but ceiling <= warning, we still want
583            // the warning_threshold error to fire first. The validate function checks
584            // warning_threshold before ceiling_threshold, so set ceiling to something
585            // that would be valid if warning were valid.
586            preset.budget.ceiling_threshold = 0.85;
587
588            let result = PresetParser::validate(&preset);
589            prop_assert!(result.is_err(), "expected validation error for warning_threshold={}", invalid_wt);
590            let err_msg = result.unwrap_err().to_string();
591            prop_assert!(
592                err_msg.contains("budget.warning_threshold"),
593                "error message '{}' does not mention 'budget.warning_threshold'",
594                err_msg
595            );
596        }
597
598        /// **Validates: Requirements 24.5, 29.4**
599        ///
600        /// Property 32b: Invalid `budget.ceiling_threshold` produces a descriptive error
601        /// mentioning "budget.ceiling_threshold".
602        #[test]
603        fn prop_invalid_ceiling_threshold_error_mentions_field(
604            invalid_ct in arb_invalid_ceiling_threshold()
605        ) {
606            let mut preset = Preset::default();
607            // Keep warning_threshold valid so ceiling_threshold error fires.
608            preset.budget.warning_threshold = 0.70;
609            preset.budget.ceiling_threshold = invalid_ct;
610
611            let result = PresetParser::validate(&preset);
612            prop_assert!(result.is_err(), "expected validation error for ceiling_threshold={}", invalid_ct);
613            let err_msg = result.unwrap_err().to_string();
614            prop_assert!(
615                err_msg.contains("budget.ceiling_threshold"),
616                "error message '{}' does not mention 'budget.ceiling_threshold'",
617                err_msg
618            );
619        }
620
621        /// **Validates: Requirements 24.5, 29.4**
622        ///
623        /// Property 32c: Empty `preset.name` produces a descriptive error
624        /// mentioning "preset.name".
625        #[test]
626        fn prop_empty_preset_name_error_mentions_field(_dummy in 0u32..1u32) {
627            let mut preset = Preset::default();
628            preset.preset.name = String::new();
629
630            let result = PresetParser::validate(&preset);
631            prop_assert!(result.is_err(), "expected validation error for empty preset.name");
632            let err_msg = result.unwrap_err().to_string();
633            prop_assert!(
634                err_msg.contains("preset.name"),
635                "error message '{}' does not mention 'preset.name'",
636                err_msg
637            );
638        }
639
640        /// **Validates: Requirements 24.5, 29.4**
641        ///
642        /// Property 32d: Invalid `tool_selection.max_tools` (0 or >50) produces a
643        /// descriptive error mentioning "tool_selection.max_tools".
644        #[test]
645        fn prop_invalid_max_tools_error_mentions_field(
646            invalid_mt in arb_invalid_max_tools()
647        ) {
648            let mut preset = Preset::default();
649            preset.tool_selection.max_tools = invalid_mt;
650
651            let result = PresetParser::validate(&preset);
652            prop_assert!(result.is_err(), "expected validation error for max_tools={}", invalid_mt);
653            let err_msg = result.unwrap_err().to_string();
654            prop_assert!(
655                err_msg.contains("tool_selection.max_tools"),
656                "error message '{}' does not mention 'tool_selection.max_tools'",
657                err_msg
658            );
659        }
660
661        /// **Validates: Requirements 24.5, 29.4**
662        ///
663        /// Property 32e: Invalid `model.complexity_threshold` produces a descriptive
664        /// error mentioning "model.complexity_threshold".
665        #[test]
666        fn prop_invalid_complexity_threshold_error_mentions_field(
667            invalid_cxt in arb_invalid_complexity_threshold()
668        ) {
669            let mut preset = Preset::default();
670            preset.model.complexity_threshold = invalid_cxt;
671
672            let result = PresetParser::validate(&preset);
673            prop_assert!(result.is_err(), "expected validation error for complexity_threshold={}", invalid_cxt);
674            let err_msg = result.unwrap_err().to_string();
675            prop_assert!(
676                err_msg.contains("model.complexity_threshold"),
677                "error message '{}' does not mention 'model.complexity_threshold'",
678                err_msg
679            );
680        }
681    }
682}