Skip to main content

sqz_engine/
engine.rs

1use std::path::Path;
2use std::sync::{Arc, Mutex};
3
4use crate::ast_parser::AstParser;
5use crate::budget_tracker::{BudgetTracker, UsageReport};
6use crate::cache_manager::CacheManager;
7use crate::confidence_router::ConfidenceRouter;
8use crate::cost_calculator::{CostCalculator, SessionCostSummary};
9use crate::ctx_format::CtxFormat;
10use crate::error::{Result, SqzError};
11use crate::model_router::ModelRouter;
12use crate::pin_manager::PinManager;
13use crate::pipeline::CompressionPipeline;
14use crate::plugin_api::PluginLoader;
15use crate::preset::{Preset, PresetParser};
16use crate::session_store::{SessionStore, SessionSummary};
17use crate::terse_mode::TerseMode;
18use crate::types::{CompressedContent, PinEntry, Provenance, SessionId};
19use crate::verifier::Verifier;
20
21/// Top-level facade that wires all sqz_engine modules together.
22///
23/// # Concurrency design
24///
25/// `SqzEngine` is designed for single-threaded use on the main thread.
26/// The only cross-thread sharing happens during preset hot-reload: the
27/// file-watcher callback runs on a background thread and needs to update
28/// the preset, pipeline, and model router. These three fields are wrapped
29/// in `Arc<Mutex<>>` specifically for that purpose. All other fields are
30/// owned directly — no unnecessary synchronization.
31pub struct SqzEngine {
32    // --- Hot-reloadable state (shared with file-watcher thread) ---
33    preset: Arc<Mutex<Preset>>,
34    pipeline: Arc<Mutex<CompressionPipeline>>,
35    model_router: Arc<Mutex<ModelRouter>>,
36
37    // --- Single-owner state (no cross-thread sharing needed) ---
38    session_store: SessionStore,
39    cache_manager: CacheManager,
40    budget_tracker: BudgetTracker,
41    cost_calculator: CostCalculator,
42    ast_parser: AstParser,
43    terse_mode: TerseMode,
44    pin_manager: PinManager,
45    confidence_router: ConfidenceRouter,
46    _plugin_loader: PluginLoader,
47}
48
49impl SqzEngine {
50    /// Create a new engine with the default preset and a persistent session store.
51    ///
52    /// Sessions are stored in `~/.sqz/sessions.db` for cross-session continuity.
53    /// Falls back to a temp-file store if the home directory is unavailable.
54    pub fn new() -> Result<Self> {
55        let preset = Preset::default();
56        let store_path = Self::default_store_path();
57        Self::with_preset_and_store(preset, &store_path)
58    }
59
60    /// Resolve the default session store path: `~/.sqz/sessions.db`.
61    /// Falls back to a temp-file path if home dir is unavailable.
62    fn default_store_path() -> std::path::PathBuf {
63        if let Some(home) = dirs_next::home_dir() {
64            let sqz_dir = home.join(".sqz");
65            if std::fs::create_dir_all(&sqz_dir).is_ok() {
66                // Harden permissions on Unix: ~/.sqz/ contains session data
67                // and cached content that may include sensitive output.
68                #[cfg(unix)]
69                {
70                    use std::os::unix::fs::PermissionsExt;
71                    let _ = std::fs::set_permissions(
72                        &sqz_dir,
73                        std::fs::Permissions::from_mode(0o700),
74                    );
75                }
76                return sqz_dir.join("sessions.db");
77            }
78        }
79        // Fallback: temp dir with unique name
80        let dir = std::env::temp_dir();
81        dir.join(format!(
82            "sqz_session_{}_{}.db",
83            std::process::id(),
84            std::time::SystemTime::now()
85                .duration_since(std::time::UNIX_EPOCH)
86                .map(|d| d.as_nanos())
87                .unwrap_or(0)
88        ))
89    }
90
91    /// Create with a custom preset and a file-backed session store.
92    ///
93    /// Opens a single SQLite connection for the session store. The cache
94    /// manager and pin manager share the same store via separate connections
95    /// (SQLite WAL mode supports concurrent readers).
96    pub fn with_preset_and_store(preset: Preset, store_path: &Path) -> Result<Self> {
97        let pipeline = CompressionPipeline::new(&preset);
98        let window_size = preset.budget.default_window_size;
99
100        // One connection per consumer. SQLite WAL mode handles concurrency.
101        let session_store = SessionStore::open_or_create(store_path)?;
102        let cache_store = SessionStore::open_or_create(store_path)?;
103        let pin_store = SessionStore::open_or_create(store_path)?;
104
105        Ok(SqzEngine {
106            preset: Arc::new(Mutex::new(preset.clone())),
107            pipeline: Arc::new(Mutex::new(pipeline)),
108            model_router: Arc::new(Mutex::new(ModelRouter::new(&preset))),
109            session_store,
110            cache_manager: CacheManager::new(cache_store, 512 * 1024 * 1024),
111            budget_tracker: BudgetTracker::new(window_size, &preset),
112            cost_calculator: CostCalculator::with_defaults(),
113            ast_parser: AstParser::new(),
114            terse_mode: TerseMode,
115            pin_manager: PinManager::new(pin_store),
116            confidence_router: ConfidenceRouter::new(),
117            _plugin_loader: PluginLoader::new(Path::new("plugins")),
118        })
119    }
120
121    /// Compress input text using the current preset.
122    ///
123    /// Two-pass pipeline:
124    /// 1. Route to compression mode based on content entropy and risk patterns.
125    /// 2. Compress using the pipeline (safe preset for Safe mode, default otherwise).
126    /// 3. Verify invariants (error lines, JSON keys, diff hunks, etc.).
127    /// 4. If verification confidence is low, fall back to safe mode and re-compress.
128    pub fn compress(&self, input: &str) -> Result<CompressedContent> {
129        let preset = self.preset.lock()
130            .map_err(|_| SqzError::Other("preset lock poisoned".into()))?;
131        let pipeline = self.pipeline.lock()
132            .map_err(|_| SqzError::Other("pipeline lock poisoned".into()))?;
133        let ctx = crate::pipeline::SessionContext {
134            session_id: "engine".to_string(),
135        };
136
137        // Step 1: Route — check content risk before compressing
138        let mode = self.confidence_router.route(input);
139
140        // Step 2: If Safe mode, skip aggressive pipeline and go straight to safe compress
141        if mode == crate::confidence_router::CompressionMode::Safe {
142            eprintln!("[sqz] fallback: safe mode — content classified as high-risk (stack trace / migration / secret)");
143            return self.compress_safe(input, &pipeline, &ctx);
144        }
145
146        // Step 3: Compress with the configured pipeline
147        let mut result = pipeline.compress(input, &ctx, &preset)?;
148
149        // Step 4: Verify invariants
150        let verify = Verifier::verify(input, &result.data);
151        let fallback = verify.fallback_triggered;
152        result.verify = Some(verify);
153
154        // Step 5: If verifier signals low confidence, re-compress with safe settings
155        if fallback && result.data != input {
156            eprintln!("[sqz] fallback: verifier confidence {:.2} below threshold — re-compressing in safe mode",
157                result.verify.as_ref().map(|v| v.confidence).unwrap_or(0.0));
158            let safe_result = self.compress_safe(input, &pipeline, &ctx)?;
159            return Ok(safe_result);
160        }
161
162        Ok(result)
163    }
164
165    /// Defensive compression: any input in, `CompressedContent` out, guaranteed.
166    ///
167    /// Unlike `compress()` which returns `Result`, this method never returns
168    /// an error. On any internal failure it returns the original input
169    /// unchanged with a 1.0 compression ratio. This makes it safe to call
170    /// from contexts where error handling is impractical (e.g. shell hooks,
171    /// browser extension bridges).
172    pub fn compress_or_passthrough(&self, input: &str) -> CompressedContent {
173        match self.compress(input) {
174            Ok(result) => result,
175            Err(_) => {
176                let tokens = (input.len() as u32 + 3) / 4;
177                CompressedContent {
178                    data: input.to_string(),
179                    tokens_compressed: tokens,
180                    tokens_original: tokens,
181                    stages_applied: vec![],
182                    compression_ratio: 1.0,
183                    provenance: crate::types::Provenance::default(),
184                    verify: None,
185                }
186            }
187        }
188    }
189
190    /// Compress with explicit mode override, bypassing the confidence router.
191    ///
192    /// - `CompressionMode::Safe` → safe pipeline only (ANSI strip + condense)
193    /// - `CompressionMode::Default` → standard pipeline
194    /// - `CompressionMode::Aggressive` → standard pipeline (aggressive preset TBD)
195    pub fn compress_with_mode(&self, input: &str, mode: crate::confidence_router::CompressionMode) -> Result<CompressedContent> {
196        let pipeline = self.pipeline.lock()
197            .map_err(|_| SqzError::Other("pipeline lock poisoned".into()))?;
198        let ctx = crate::pipeline::SessionContext {
199            session_id: "engine".to_string(),
200        };
201
202        match mode {
203            crate::confidence_router::CompressionMode::Safe => {
204                self.compress_safe(input, &pipeline, &ctx)
205            }
206            _ => {
207                // Default and Aggressive: run normal pipeline + verify
208                drop(pipeline); // release lock before calling compress()
209                self.compress(input)
210            }
211        }
212    }
213
214    /// Safe-mode compression: minimal transforms only (ANSI strip + condense).
215    fn compress_safe(
216        &self,
217        input: &str,
218        pipeline: &crate::pipeline::CompressionPipeline,
219        ctx: &crate::pipeline::SessionContext,
220    ) -> Result<CompressedContent> {
221        use crate::preset::{
222            CompressionConfig, CondenseConfig, CustomTransformsConfig, BudgetConfig,
223            ModelConfig, PresetMeta, TerseModeConfig, TerseLevel, ToolSelectionConfig,
224        };
225
226        let safe_preset = Preset {
227            preset: PresetMeta {
228                name: "safe".to_string(),
229                version: "1.0".to_string(),
230                description: "Safe fallback — minimal compression".to_string(),
231            },
232            compression: CompressionConfig {
233                stages: vec!["condense".to_string()],
234                keep_fields: None,
235                strip_fields: None,
236                condense: Some(CondenseConfig { enabled: true, max_repeated_lines: 3 }),
237                git_diff_fold: None,
238                strip_nulls: None,
239                flatten: None,
240                truncate_strings: None,
241                collapse_arrays: None,
242                custom_transforms: Some(CustomTransformsConfig { enabled: false }),
243            },
244            tool_selection: ToolSelectionConfig {
245                max_tools: 5,
246                similarity_threshold: 0.7,
247                default_tools: vec![],
248            },
249            budget: BudgetConfig {
250                warning_threshold: 0.70,
251                ceiling_threshold: 0.85,
252                default_window_size: 200_000,
253                agents: Default::default(),
254            },
255            terse_mode: TerseModeConfig { enabled: false, level: TerseLevel::Moderate },
256            model: ModelConfig {
257                family: "anthropic".to_string(),
258                primary: String::new(),
259                local: String::new(),
260                complexity_threshold: 0.4,
261                pricing: None,
262            },
263        };
264
265        let mut result = pipeline.compress(input, ctx, &safe_preset)?;
266        let verify = Verifier::verify(input, &result.data);
267        result.verify = Some(verify);
268        result.provenance = Provenance {
269            label: Some("safe-fallback".to_string()),
270            ..Default::default()
271        };
272        Ok(result)
273    }
274
275    /// Compress with explicit provenance metadata attached to the result.
276    pub fn compress_with_provenance(
277        &self,
278        input: &str,
279        provenance: Provenance,
280    ) -> Result<CompressedContent> {
281        let mut result = self.compress(input)?;
282        result.provenance = provenance;
283        Ok(result)
284    }
285
286    /// Export a session to CTX format.
287    pub fn export_ctx(&self, session_id: &str) -> Result<String> {
288        let session = self.session_store.load_session(session_id.to_string())?;
289        CtxFormat::serialize(&session)
290    }
291
292    /// Import a CTX string and save as a new session.
293    pub fn import_ctx(&self, ctx: &str) -> Result<SessionId> {
294        let session = CtxFormat::deserialize(ctx)?;
295        self.session_store.save_session(&session)
296    }
297
298    /// Pin a conversation turn.
299    pub fn pin(&self, session_id: &str, turn_index: usize, reason: &str, tokens: u32) -> Result<PinEntry> {
300        self.pin_manager.pin(session_id, turn_index, reason, tokens)
301    }
302
303    /// Unpin a conversation turn.
304    pub fn unpin(&self, session_id: &str, turn_index: usize) -> Result<()> {
305        self.pin_manager.unpin(session_id, turn_index)
306    }
307
308    /// Search sessions by keyword.
309    pub fn search_sessions(&self, query: &str) -> Result<Vec<SessionSummary>> {
310        self.session_store.search(query)
311    }
312
313    /// Get usage report for an agent.
314    pub fn usage_report(&self, agent_id: &str) -> UsageReport {
315        self.budget_tracker.usage_report(agent_id.to_string())
316    }
317
318    /// Get cost summary for a session.
319    pub fn cost_summary(&self, session_id: &str) -> Result<SessionCostSummary> {
320        let session = self.session_store.load_session(session_id.to_string())?;
321        Ok(self.cost_calculator.session_summary(&session))
322    }
323
324    /// Reload the preset from a TOML string (hot-reload support).
325    pub fn reload_preset(&mut self, toml: &str) -> Result<()> {
326        let new_preset = PresetParser::parse(toml)?;
327        if let Ok(mut pipeline) = self.pipeline.lock() {
328            pipeline.reload_preset(&new_preset)?;
329        }
330        if let Ok(mut router) = self.model_router.lock() {
331            *router = ModelRouter::new(&new_preset);
332        }
333        if let Ok(mut preset) = self.preset.lock() {
334            *preset = new_preset;
335        }
336        Ok(())
337    }
338
339    /// Spawn a background thread that watches `path` for preset file changes.
340    ///
341    /// Only the preset, pipeline, and model_router are shared with the watcher
342    /// thread (via `Arc<Mutex<>>`). All other engine state stays on the main thread.
343    pub fn watch_preset_file(&self, path: &Path) -> Result<notify::RecommendedWatcher> {
344        use notify::{Event, EventKind, RecursiveMode, Watcher};
345
346        let preset_arc = Arc::clone(&self.preset);
347        let pipeline_arc = Arc::clone(&self.pipeline);
348        let router_arc = Arc::clone(&self.model_router);
349        let watched_path = path.to_owned();
350
351        let mut watcher = notify::recommended_watcher(move |res: notify::Result<Event>| {
352            if let Ok(event) = res {
353                if matches!(event.kind, EventKind::Modify(_) | EventKind::Create(_)) {
354                    match std::fs::read_to_string(&watched_path) {
355                        Ok(toml_str) => match PresetParser::parse(&toml_str) {
356                            Ok(new_preset) => {
357                                if let Ok(mut p) = pipeline_arc.lock() {
358                                    let _ = p.reload_preset(&new_preset);
359                                }
360                                if let Ok(mut r) = router_arc.lock() {
361                                    *r = ModelRouter::new(&new_preset);
362                                }
363                                if let Ok(mut pr) = preset_arc.lock() {
364                                    *pr = new_preset;
365                                }
366                            }
367                            Err(e) => eprintln!("[sqz] invalid preset: {e}"),
368                        },
369                        Err(e) => eprintln!("[sqz] preset read error: {e}"),
370                    }
371                }
372            }
373        })
374        .map_err(|e| SqzError::Other(format!("watcher error: {e}")))?;
375
376        watcher
377            .watch(path, RecursiveMode::NonRecursive)
378            .map_err(|e| SqzError::Other(format!("watch error: {e}")))?;
379
380        Ok(watcher)
381    }
382
383    /// Access the underlying `SessionStore`.
384    pub fn session_store(&self) -> &SessionStore {
385        &self.session_store
386    }
387
388    /// Access the `CacheManager` for persistent dedup.
389    pub fn cache_manager(&self) -> &CacheManager {
390        &self.cache_manager
391    }
392
393    /// Access the `AstParser`.
394    pub fn ast_parser(&self) -> &AstParser {
395        &self.ast_parser
396    }
397
398    /// Access the `TerseMode` helper.
399    pub fn terse_mode(&self) -> &TerseMode {
400        &self.terse_mode
401    }
402
403    /// Reorder context sections using the LITM positioner to mitigate
404    /// the "Lost In The Middle" attention bias in long-context models.
405    ///
406    /// Places highest-priority sections at the beginning and end of the
407    /// context window, lowest-priority in the middle.
408    pub fn reorder_context(
409        &self,
410        sections: &mut Vec<crate::litm_positioner::ContextSection>,
411        strategy: crate::litm_positioner::LitmStrategy,
412    ) {
413        let positioner = crate::litm_positioner::LitmPositioner::new(strategy);
414        positioner.reorder(sections);
415    }
416
417    /// Route content to the appropriate compression mode based on entropy
418    /// and risk pattern analysis.
419    pub fn route_compression_mode(&self, content: &str) -> crate::confidence_router::CompressionMode {
420        self.confidence_router.route(content)
421    }
422}
423
424
425#[cfg(test)]
426mod tests {
427    use super::*;
428    use crate::types::{BudgetState, CorrectionLog, ModelFamily, SessionState};
429    use chrono::Utc;
430    use std::path::PathBuf;
431
432    fn make_session(id: &str) -> SessionState {
433        let now = Utc::now();
434        SessionState {
435            id: id.to_string(),
436            project_dir: PathBuf::from("/tmp/test"),
437            conversation: vec![],
438            corrections: CorrectionLog::default(),
439            pins: vec![],
440            learnings: vec![],
441            compressed_summary: "test session".to_string(),
442            budget: BudgetState {
443                window_size: 200_000,
444                consumed: 0,
445                pinned: 0,
446                model_family: ModelFamily::AnthropicClaude,
447            },
448            tool_usage: vec![],
449            created_at: now,
450            updated_at: now,
451        }
452    }
453
454    #[test]
455    fn test_engine_new() {
456        let engine = SqzEngine::new();
457        assert!(engine.is_ok(), "SqzEngine::new() should succeed");
458    }
459
460    #[test]
461    fn test_compress_or_passthrough_returns_result_on_valid_input() {
462        let engine = SqzEngine::new().unwrap();
463        let result = engine.compress_or_passthrough("hello world");
464        assert_eq!(result.data, "hello world");
465        assert!(result.tokens_original > 0);
466    }
467
468    #[test]
469    fn test_compress_or_passthrough_never_panics_on_empty() {
470        let engine = SqzEngine::new().unwrap();
471        let result = engine.compress_or_passthrough("");
472        assert_eq!(result.data, "");
473        assert_eq!(result.compression_ratio, 1.0);
474    }
475
476    #[test]
477    fn test_compress_or_passthrough_handles_json() {
478        let engine = SqzEngine::new().unwrap();
479        let result = engine.compress_or_passthrough(r#"{"key":"value"}"#);
480        // Should compress successfully — data may be TOON-encoded
481        assert!(!result.data.is_empty());
482    }
483
484    #[test]
485    fn test_compress_or_passthrough_handles_binary_garbage() {
486        let engine = SqzEngine::new().unwrap();
487        // Feed it something weird — should never panic, always return something
488        let garbage = "\x00\x01\x02\x7f invalid control chars \t\n\r";
489        let result = engine.compress_or_passthrough(garbage);
490        assert!(!result.data.is_empty());
491    }
492
493    #[test]
494    fn test_compress_plain_text() {
495        let engine = SqzEngine::new().unwrap();
496        let result = engine.compress("hello world");
497        assert!(result.is_ok());
498        assert_eq!(result.unwrap().data, "hello world");
499    }
500
501    #[test]
502    fn test_compress_json_applies_toon() {
503        let engine = SqzEngine::new().unwrap();
504        let result = engine.compress(r#"{"name":"Alice","age":30}"#).unwrap();
505        assert!(result.data.starts_with("TOON:"), "JSON should be TOON-encoded");
506    }
507
508    #[test]
509    fn test_export_import_ctx_round_trip() {
510        let dir = tempfile::tempdir().unwrap();
511        let store_path = dir.path().join("store.db");
512        let engine = SqzEngine::with_preset_and_store(Preset::default(), &store_path).unwrap();
513
514        let session = make_session("sess-rt");
515        engine.session_store().save_session(&session).unwrap();
516
517        let ctx = engine.export_ctx("sess-rt").unwrap();
518        let imported_id = engine.import_ctx(&ctx).unwrap();
519        assert_eq!(imported_id, "sess-rt");
520    }
521
522    #[test]
523    fn test_search_sessions() {
524        let dir = tempfile::tempdir().unwrap();
525        let store_path = dir.path().join("store.db");
526        let engine = SqzEngine::with_preset_and_store(Preset::default(), &store_path).unwrap();
527
528        let mut session = make_session("sess-search");
529        session.compressed_summary = "authentication refactor".to_string();
530        engine.session_store().save_session(&session).unwrap();
531
532        let results = engine.search_sessions("authentication").unwrap();
533        assert_eq!(results.len(), 1);
534        assert_eq!(results[0].id, "sess-search");
535    }
536
537    #[test]
538    fn test_usage_report_starts_at_zero() {
539        let engine = SqzEngine::new().unwrap();
540        let report = engine.usage_report("default");
541        assert_eq!(report.consumed, 0);
542        assert_eq!(report.available, report.allocated);
543    }
544
545    #[test]
546    fn test_cost_summary() {
547        let dir = tempfile::tempdir().unwrap();
548        let store_path = dir.path().join("store.db");
549        let engine = SqzEngine::with_preset_and_store(Preset::default(), &store_path).unwrap();
550
551        let session = make_session("sess-cost");
552        engine.session_store().save_session(&session).unwrap();
553
554        let summary = engine.cost_summary("sess-cost").unwrap();
555        assert_eq!(summary.total_tokens, 0);
556        assert!((summary.total_usd - 0.0).abs() < f64::EPSILON);
557    }
558
559    #[test]
560    fn test_reload_preset_updates_state() {
561        let mut engine = SqzEngine::new().unwrap();
562        let toml = r#"
563[preset]
564name = "reloaded"
565version = "2.0"
566
567[compression]
568stages = []
569
570[tool_selection]
571max_tools = 5
572similarity_threshold = 0.7
573
574[budget]
575warning_threshold = 0.70
576ceiling_threshold = 0.85
577default_window_size = 200000
578
579[terse_mode]
580enabled = false
581level = "moderate"
582
583[model]
584family = "anthropic"
585primary = "claude-sonnet-4-20250514"
586complexity_threshold = 0.4
587"#;
588        assert!(engine.reload_preset(toml).is_ok());
589        // Verify the preset was actually updated
590        let preset = engine.preset.lock().unwrap();
591        assert_eq!(preset.preset.name, "reloaded");
592    }
593
594    #[test]
595    fn test_reload_invalid_preset_returns_error() {
596        let mut engine = SqzEngine::new().unwrap();
597        let result = engine.reload_preset("not valid toml [[[");
598        assert!(result.is_err(), "invalid TOML should return error");
599    }
600
601    #[test]
602    fn test_export_nonexistent_session_returns_error() {
603        let engine = SqzEngine::new().unwrap();
604        let result = engine.export_ctx("does-not-exist");
605        assert!(result.is_err());
606    }
607
608    #[test]
609    fn test_import_invalid_ctx_returns_error() {
610        let engine = SqzEngine::new().unwrap();
611        let result = engine.import_ctx("not valid json {{{");
612        assert!(result.is_err());
613    }
614}