Skip to main content

haagenti_serverless/
state.rs

1//! Function state management for serverless hibernation
2
3use crate::{Result, ServerlessError};
4use serde::{Deserialize, Serialize};
5use std::collections::HashMap;
6use std::time::Instant;
7
8/// Function state for hibernation/resume
9#[derive(Debug, Clone, Serialize, Deserialize)]
10pub struct FunctionState {
11    /// State version
12    pub version: u32,
13    /// Function name
14    pub function_name: String,
15    /// Created timestamp (unix ms)
16    pub created_at: u64,
17    /// Last modified timestamp (unix ms)
18    pub modified_at: u64,
19    /// Model state
20    pub model_state: ModelState,
21    /// Cache state
22    pub cache_state: CacheState,
23    /// Execution state
24    pub execution_state: ExecutionState,
25    /// Custom metadata
26    pub metadata: HashMap<String, String>,
27}
28
29/// Model state
30#[derive(Debug, Clone, Default, Serialize, Deserialize)]
31pub struct ModelState {
32    /// Model name
33    pub model_name: String,
34    /// Model version
35    pub model_version: String,
36    /// Loaded layers
37    pub loaded_layers: Vec<String>,
38    /// Weights hash
39    pub weights_hash: String,
40    /// Quantization info
41    pub quantization: Option<QuantizationInfo>,
42}
43
44/// Quantization info
45#[derive(Debug, Clone, Serialize, Deserialize)]
46pub struct QuantizationInfo {
47    /// Quantization type
48    pub qtype: String,
49    /// Bits per weight
50    pub bits: u8,
51    /// Group size
52    pub group_size: usize,
53}
54
55/// Cache state
56#[derive(Debug, Clone, Default, Serialize, Deserialize)]
57pub struct CacheState {
58    /// KV cache size
59    pub kv_cache_size: u64,
60    /// KV cache entries
61    pub kv_entries: usize,
62    /// Fragment cache size
63    pub fragment_cache_size: u64,
64    /// Fragment entries
65    pub fragment_entries: usize,
66    /// Cache hit rate
67    pub hit_rate: f64,
68}
69
70/// Execution state
71#[derive(Debug, Clone, Default, Serialize, Deserialize)]
72pub struct ExecutionState {
73    /// Total requests processed
74    pub total_requests: u64,
75    /// Successful requests
76    pub successful_requests: u64,
77    /// Failed requests
78    pub failed_requests: u64,
79    /// Average latency ms
80    pub avg_latency_ms: f64,
81    /// Last request timestamp
82    pub last_request_at: Option<u64>,
83}
84
85impl FunctionState {
86    /// Create new function state
87    pub fn new(function_name: impl Into<String>) -> Self {
88        let now = std::time::SystemTime::now()
89            .duration_since(std::time::UNIX_EPOCH)
90            .unwrap_or_default()
91            .as_millis() as u64;
92
93        Self {
94            version: 1,
95            function_name: function_name.into(),
96            created_at: now,
97            modified_at: now,
98            model_state: ModelState::default(),
99            cache_state: CacheState::default(),
100            execution_state: ExecutionState::default(),
101            metadata: HashMap::new(),
102        }
103    }
104
105    /// Update modified timestamp
106    pub fn touch(&mut self) {
107        self.modified_at = std::time::SystemTime::now()
108            .duration_since(std::time::UNIX_EPOCH)
109            .unwrap_or_default()
110            .as_millis() as u64;
111    }
112
113    /// Set metadata
114    pub fn set_metadata(&mut self, key: impl Into<String>, value: impl Into<String>) {
115        self.metadata.insert(key.into(), value.into());
116        self.touch();
117    }
118
119    /// Get metadata
120    pub fn get_metadata(&self, key: &str) -> Option<&String> {
121        self.metadata.get(key)
122    }
123
124    /// Record successful request
125    pub fn record_success(&mut self, latency_ms: f64) {
126        self.execution_state.total_requests += 1;
127        self.execution_state.successful_requests += 1;
128
129        let count = self.execution_state.successful_requests as f64;
130        self.execution_state.avg_latency_ms =
131            (self.execution_state.avg_latency_ms * (count - 1.0) + latency_ms) / count;
132
133        self.execution_state.last_request_at = Some(
134            std::time::SystemTime::now()
135                .duration_since(std::time::UNIX_EPOCH)
136                .unwrap_or_default()
137                .as_millis() as u64,
138        );
139
140        self.touch();
141    }
142
143    /// Record failed request
144    pub fn record_failure(&mut self) {
145        self.execution_state.total_requests += 1;
146        self.execution_state.failed_requests += 1;
147        self.touch();
148    }
149
150    /// Success rate
151    pub fn success_rate(&self) -> f64 {
152        if self.execution_state.total_requests == 0 {
153            1.0
154        } else {
155            self.execution_state.successful_requests as f64
156                / self.execution_state.total_requests as f64
157        }
158    }
159}
160
161/// State serializer
162#[derive(Debug)]
163pub struct StateSerializer {
164    /// Use compression
165    compression: bool,
166    /// Compression level
167    compression_level: i32,
168}
169
170impl StateSerializer {
171    /// Create new serializer
172    pub fn new(compression: bool) -> Self {
173        Self {
174            compression,
175            compression_level: 3,
176        }
177    }
178
179    /// Serialize state to bytes
180    pub fn serialize(&self, state: &FunctionState) -> Result<Vec<u8>> {
181        let json = serde_json::to_vec(state)
182            .map_err(|e| ServerlessError::SerializationError(e.to_string()))?;
183
184        if self.compression {
185            // In real implementation, use zstd compression
186            Ok(json)
187        } else {
188            Ok(json)
189        }
190    }
191
192    /// Deserialize state from bytes
193    pub fn deserialize(&self, data: &[u8]) -> Result<FunctionState> {
194        let json_data = if self.compression {
195            // In real implementation, decompress with zstd
196            data.to_vec()
197        } else {
198            data.to_vec()
199        };
200
201        serde_json::from_slice(&json_data)
202            .map_err(|e| ServerlessError::DeserializationError(e.to_string()))
203    }
204
205    /// Set compression level
206    pub fn with_compression_level(mut self, level: i32) -> Self {
207        self.compression_level = level;
208        self
209    }
210}
211
212impl Default for StateSerializer {
213    fn default() -> Self {
214        Self::new(true)
215    }
216}
217
218/// State manager for persisting function state
219#[derive(Debug)]
220pub struct StateManager {
221    /// Current state
222    state: FunctionState,
223    /// Serializer
224    serializer: StateSerializer,
225    /// Auto-save interval (seconds)
226    auto_save_interval: u64,
227    /// Last save time
228    last_save: Instant,
229    /// State changed since last save
230    dirty: bool,
231}
232
233impl StateManager {
234    /// Create new state manager
235    pub fn new(function_name: impl Into<String>) -> Self {
236        Self {
237            state: FunctionState::new(function_name),
238            serializer: StateSerializer::default(),
239            auto_save_interval: 60,
240            last_save: Instant::now(),
241            dirty: false,
242        }
243    }
244
245    /// Load state from bytes
246    pub fn load(data: &[u8]) -> Result<Self> {
247        let serializer = StateSerializer::default();
248        let state = serializer.deserialize(data)?;
249
250        Ok(Self {
251            state,
252            serializer,
253            auto_save_interval: 60,
254            last_save: Instant::now(),
255            dirty: false,
256        })
257    }
258
259    /// Get current state
260    pub fn state(&self) -> &FunctionState {
261        &self.state
262    }
263
264    /// Get mutable state
265    pub fn state_mut(&mut self) -> &mut FunctionState {
266        self.dirty = true;
267        &mut self.state
268    }
269
270    /// Save state to bytes
271    pub fn save(&mut self) -> Result<Vec<u8>> {
272        let data = self.serializer.serialize(&self.state)?;
273        self.last_save = Instant::now();
274        self.dirty = false;
275        Ok(data)
276    }
277
278    /// Check if should auto-save
279    pub fn should_auto_save(&self) -> bool {
280        self.dirty && self.last_save.elapsed().as_secs() >= self.auto_save_interval
281    }
282
283    /// Set auto-save interval
284    pub fn set_auto_save_interval(&mut self, seconds: u64) {
285        self.auto_save_interval = seconds;
286    }
287
288    /// Update model state
289    pub fn update_model(&mut self, model_state: ModelState) {
290        self.state.model_state = model_state;
291        self.state.touch();
292        self.dirty = true;
293    }
294
295    /// Update cache state
296    pub fn update_cache(&mut self, cache_state: CacheState) {
297        self.state.cache_state = cache_state;
298        self.state.touch();
299        self.dirty = true;
300    }
301
302    /// Record request
303    pub fn record_request(&mut self, success: bool, latency_ms: Option<f64>) {
304        if success {
305            self.state.record_success(latency_ms.unwrap_or(0.0));
306        } else {
307            self.state.record_failure();
308        }
309        self.dirty = true;
310    }
311}
312
313/// State diff for incremental updates
314#[derive(Debug, Clone, Serialize, Deserialize)]
315pub struct StateDiff {
316    /// Changed fields
317    pub changes: HashMap<String, serde_json::Value>,
318    /// Timestamp
319    pub timestamp: u64,
320}
321
322impl StateDiff {
323    /// Create new diff
324    pub fn new() -> Self {
325        Self {
326            changes: HashMap::new(),
327            timestamp: std::time::SystemTime::now()
328                .duration_since(std::time::UNIX_EPOCH)
329                .unwrap_or_default()
330                .as_millis() as u64,
331        }
332    }
333
334    /// Add change
335    pub fn add(&mut self, field: impl Into<String>, value: serde_json::Value) {
336        self.changes.insert(field.into(), value);
337    }
338
339    /// Is empty
340    pub fn is_empty(&self) -> bool {
341        self.changes.is_empty()
342    }
343}
344
345impl Default for StateDiff {
346    fn default() -> Self {
347        Self::new()
348    }
349}
350
351#[cfg(test)]
352mod tests {
353    use super::*;
354
355    #[test]
356    fn test_function_state_creation() {
357        let state = FunctionState::new("test-function");
358
359        assert_eq!(state.function_name, "test-function");
360        assert_eq!(state.version, 1);
361        assert!(state.created_at > 0);
362    }
363
364    #[test]
365    fn test_record_requests() {
366        let mut state = FunctionState::new("test");
367
368        state.record_success(100.0);
369        state.record_success(200.0);
370        state.record_failure();
371
372        assert_eq!(state.execution_state.total_requests, 3);
373        assert_eq!(state.execution_state.successful_requests, 2);
374        assert_eq!(state.execution_state.failed_requests, 1);
375        assert_eq!(state.execution_state.avg_latency_ms, 150.0);
376    }
377
378    #[test]
379    fn test_success_rate() {
380        let mut state = FunctionState::new("test");
381
382        state.record_success(100.0);
383        state.record_success(100.0);
384        state.record_failure();
385
386        assert!((state.success_rate() - 0.666666).abs() < 0.01);
387    }
388
389    #[test]
390    fn test_serialization() {
391        let state = FunctionState::new("test");
392        let serializer = StateSerializer::new(false);
393
394        let data = serializer.serialize(&state).unwrap();
395        let restored = serializer.deserialize(&data).unwrap();
396
397        assert_eq!(state.function_name, restored.function_name);
398    }
399
400    #[test]
401    fn test_state_manager() {
402        let mut manager = StateManager::new("test");
403
404        manager.record_request(true, Some(50.0));
405        manager.record_request(true, Some(100.0));
406
407        let state = manager.state();
408        assert_eq!(state.execution_state.successful_requests, 2);
409    }
410
411    #[test]
412    fn test_state_diff() {
413        let mut diff = StateDiff::new();
414
415        diff.add("field1", serde_json::json!(42));
416        diff.add("field2", serde_json::json!("value"));
417
418        assert!(!diff.is_empty());
419        assert_eq!(diff.changes.len(), 2);
420    }
421}