Skip to main content

latch_core/
config.rs

1use serde::{Deserialize, Serialize};
2
3use crate::TokenEstimator;
4
5#[derive(Clone, Debug, Serialize, Deserialize, Default)]
6pub struct LatchConfig {
7    pub compression: Option<CompressionConfig>,
8    pub cache: Option<CacheConfig>,
9    pub router: Option<RouterConfig>,
10    pub retry: Option<RetryConfig>,
11    pub meter: Option<MeterConfig>,
12}
13
14#[derive(Clone, Serialize, Deserialize)]
15pub struct CompressionConfig {
16    /// Below this token count, no compression is applied
17    pub min_tokens_to_compress: usize,
18    /// SlidingWindow: maximum number of turns to keep
19    pub max_turns: usize,
20    /// SlidingWindow: maximum system prompt tokens (0 = no limit)
21    pub max_system_tokens: usize,
22    /// Compression strategy to use
23    pub strategy: CompressionStrategy,
24    /// Dedup: maximum merged chars for adjacent same-role messages (0 = no limit)
25    pub dedup_max_merged_chars: usize,
26    /// Injectable token estimator (skipped during serialization)
27    #[serde(skip, default = "super::default_token_estimator")]
28    pub token_estimator: TokenEstimator,
29}
30
31impl std::fmt::Debug for CompressionConfig {
32    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
33        f.debug_struct("CompressionConfig")
34            .field("min_tokens_to_compress", &self.min_tokens_to_compress)
35            .field("max_turns", &self.max_turns)
36            .field("max_system_tokens", &self.max_system_tokens)
37            .field("strategy", &self.strategy)
38            .field("dedup_max_merged_chars", &self.dedup_max_merged_chars)
39            .field("token_estimator", &"<TokenEstimator>")
40            .finish()
41    }
42}
43
44#[derive(Clone, Debug, Serialize, Deserialize)]
45pub enum CompressionStrategy {
46    None,
47    SlidingWindow,
48    Dedup,
49    DedupThenWindow,
50}
51
52#[derive(Clone, Debug, Serialize, Deserialize)]
53pub struct CacheConfig {
54    pub prompt_caching: PromptCachingConfig,
55}
56
57#[derive(Clone, Debug, Serialize, Deserialize)]
58#[serde(default = "PromptCachingConfig::default")]
59pub struct PromptCachingConfig {
60    pub enabled: bool,
61    pub provider: PromptCacheProvider,
62    /// Roles that should receive cache markers (default: ["system"])
63    pub cache_roles: Vec<String>,
64    /// Minimum content length for caching to be meaningful (default: 0)
65    pub min_content_chars: usize,
66}
67
68impl Default for PromptCachingConfig {
69    fn default() -> Self {
70        Self {
71            enabled: false,
72            provider: PromptCacheProvider::default(), // None
73            cache_roles: vec!["system".to_string()],
74            min_content_chars: 0,
75        }
76    }
77}
78
79#[derive(Clone, Debug, Default, Serialize, Deserialize, PartialEq, Eq)]
80pub enum PromptCacheProvider {
81    #[default]
82    None,
83    Anthropic,
84    OpenAiCompatible,
85}
86
87#[derive(Clone, Serialize, Deserialize)]
88pub struct RouterConfig {
89    /// All available pool definitions
90    pub pools: Vec<PoolRoute>,
91    /// Return Uncertain if confidence is below this threshold
92    pub confidence_threshold: f32,
93    /// Requests with more tokens than this get Premium preference
94    pub long_request_tokens: usize,
95    /// Penalty multiplier per failure: 1 failure → 0.8, 2 → 0.64, etc.
96    #[serde(default = "default_penalty_per_failure")]
97    pub penalty_per_failure: f32,
98    /// Injectable token estimator (skipped during serialization)
99    #[serde(skip, default = "super::default_token_estimator")]
100    pub token_estimator: TokenEstimator,
101}
102
103fn default_penalty_per_failure() -> f32 {
104    0.8
105}
106
107impl std::fmt::Debug for RouterConfig {
108    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
109        f.debug_struct("RouterConfig")
110            .field("pools", &self.pools)
111            .field("confidence_threshold", &self.confidence_threshold)
112            .field("long_request_tokens", &self.long_request_tokens)
113            .field("token_estimator", &"<TokenEstimator>")
114            .finish()
115    }
116}
117
118#[derive(Clone, Debug, Serialize, Deserialize)]
119pub struct PoolRoute {
120    /// Pool identifier, corresponds to ugate's pool_id
121    pub pool_id: String,
122    /// Pool tier (affects scoring)
123    pub tier: PoolTier,
124    /// Priority weight when matching (0.0-1.0)
125    pub weight: f32,
126    /// Keywords that increase this pool's score when matched
127    pub match_keywords: Vec<String>,
128    /// Score bonus per matched keyword
129    pub keyword_score: f32,
130    /// Whether this pool is suitable for images (None = no restriction)
131    pub images: Option<bool>,
132}
133
134#[derive(Clone, Debug, Serialize, Deserialize, PartialEq, Eq)]
135pub enum PoolTier {
136    Fast,
137    Standard,
138    Premium,
139    Backup,
140}
141
142#[derive(Clone, Debug, Serialize, Deserialize)]
143pub struct PoolFeedback {
144    /// Number of recent failures for each pool
145    #[serde(default)]
146    pub recent_failures: std::collections::HashMap<String, u32>,
147    /// Endpoint scores from latch-score (endpoint_id -> score in 0.0..=100.0)
148    #[serde(default)]
149    pub endpoint_scores: std::collections::HashMap<String, f64>,
150}
151
152impl Default for PoolFeedback {
153    fn default() -> Self {
154        PoolFeedback {
155            recent_failures: std::collections::HashMap::new(),
156            endpoint_scores: std::collections::HashMap::new(),
157        }
158    }
159}
160
161#[derive(Clone, Debug, Serialize, Deserialize)]
162pub struct RetryConfig {
163    pub max_attempts: usize,
164    pub backoff_ms: u64,
165    pub max_backoff_ms: Option<u64>,
166    pub fallback_provider: Option<String>,
167    pub circuit_breaker: Option<CircuitBreakerConfig>,
168    /// Jitter range ratio 0.0-1.0. 0.2 means random factor in [0.8, 1.2]
169    #[serde(default = "default_jitter_ratio")]
170    pub jitter_ratio: f64,
171}
172
173fn default_jitter_ratio() -> f64 {
174    0.2
175}
176
177#[derive(Clone, Debug, Serialize, Deserialize)]
178pub struct CircuitBreakerConfig {
179    pub failure_threshold: usize,
180    pub open_ms: u64,
181    pub half_open_max_attempts: usize,
182}
183
184#[derive(Clone, Debug, Serialize, Deserialize)]
185pub struct MeterConfig {
186    pub session_token_limit: Option<u64>,
187    pub session_request_limit: Option<u64>,
188    pub price_per_1k_input_tokens: f64,
189    pub price_per_1k_output_tokens: f64,
190    pub currency: String,
191}
192