1use crate::{Result, ServerlessError};
4use serde::{Deserialize, Serialize};
5use std::collections::HashMap;
6use std::time::Instant;
7
8#[derive(Debug, Clone, Serialize, Deserialize)]
10pub struct FunctionState {
11 pub version: u32,
13 pub function_name: String,
15 pub created_at: u64,
17 pub modified_at: u64,
19 pub model_state: ModelState,
21 pub cache_state: CacheState,
23 pub execution_state: ExecutionState,
25 pub metadata: HashMap<String, String>,
27}
28
29#[derive(Debug, Clone, Default, Serialize, Deserialize)]
31pub struct ModelState {
32 pub model_name: String,
34 pub model_version: String,
36 pub loaded_layers: Vec<String>,
38 pub weights_hash: String,
40 pub quantization: Option<QuantizationInfo>,
42}
43
44#[derive(Debug, Clone, Serialize, Deserialize)]
46pub struct QuantizationInfo {
47 pub qtype: String,
49 pub bits: u8,
51 pub group_size: usize,
53}
54
55#[derive(Debug, Clone, Default, Serialize, Deserialize)]
57pub struct CacheState {
58 pub kv_cache_size: u64,
60 pub kv_entries: usize,
62 pub fragment_cache_size: u64,
64 pub fragment_entries: usize,
66 pub hit_rate: f64,
68}
69
70#[derive(Debug, Clone, Default, Serialize, Deserialize)]
72pub struct ExecutionState {
73 pub total_requests: u64,
75 pub successful_requests: u64,
77 pub failed_requests: u64,
79 pub avg_latency_ms: f64,
81 pub last_request_at: Option<u64>,
83}
84
85impl FunctionState {
86 pub fn new(function_name: impl Into<String>) -> Self {
88 let now = std::time::SystemTime::now()
89 .duration_since(std::time::UNIX_EPOCH)
90 .unwrap_or_default()
91 .as_millis() as u64;
92
93 Self {
94 version: 1,
95 function_name: function_name.into(),
96 created_at: now,
97 modified_at: now,
98 model_state: ModelState::default(),
99 cache_state: CacheState::default(),
100 execution_state: ExecutionState::default(),
101 metadata: HashMap::new(),
102 }
103 }
104
105 pub fn touch(&mut self) {
107 self.modified_at = std::time::SystemTime::now()
108 .duration_since(std::time::UNIX_EPOCH)
109 .unwrap_or_default()
110 .as_millis() as u64;
111 }
112
113 pub fn set_metadata(&mut self, key: impl Into<String>, value: impl Into<String>) {
115 self.metadata.insert(key.into(), value.into());
116 self.touch();
117 }
118
119 pub fn get_metadata(&self, key: &str) -> Option<&String> {
121 self.metadata.get(key)
122 }
123
124 pub fn record_success(&mut self, latency_ms: f64) {
126 self.execution_state.total_requests += 1;
127 self.execution_state.successful_requests += 1;
128
129 let count = self.execution_state.successful_requests as f64;
130 self.execution_state.avg_latency_ms =
131 (self.execution_state.avg_latency_ms * (count - 1.0) + latency_ms) / count;
132
133 self.execution_state.last_request_at = Some(
134 std::time::SystemTime::now()
135 .duration_since(std::time::UNIX_EPOCH)
136 .unwrap_or_default()
137 .as_millis() as u64,
138 );
139
140 self.touch();
141 }
142
143 pub fn record_failure(&mut self) {
145 self.execution_state.total_requests += 1;
146 self.execution_state.failed_requests += 1;
147 self.touch();
148 }
149
150 pub fn success_rate(&self) -> f64 {
152 if self.execution_state.total_requests == 0 {
153 1.0
154 } else {
155 self.execution_state.successful_requests as f64
156 / self.execution_state.total_requests as f64
157 }
158 }
159}
160
161#[derive(Debug)]
163pub struct StateSerializer {
164 compression: bool,
166 compression_level: i32,
168}
169
170impl StateSerializer {
171 pub fn new(compression: bool) -> Self {
173 Self {
174 compression,
175 compression_level: 3,
176 }
177 }
178
179 pub fn serialize(&self, state: &FunctionState) -> Result<Vec<u8>> {
181 let json = serde_json::to_vec(state)
182 .map_err(|e| ServerlessError::SerializationError(e.to_string()))?;
183
184 if self.compression {
185 Ok(json)
187 } else {
188 Ok(json)
189 }
190 }
191
192 pub fn deserialize(&self, data: &[u8]) -> Result<FunctionState> {
194 let json_data = if self.compression {
195 data.to_vec()
197 } else {
198 data.to_vec()
199 };
200
201 serde_json::from_slice(&json_data)
202 .map_err(|e| ServerlessError::DeserializationError(e.to_string()))
203 }
204
205 pub fn with_compression_level(mut self, level: i32) -> Self {
207 self.compression_level = level;
208 self
209 }
210}
211
212impl Default for StateSerializer {
213 fn default() -> Self {
214 Self::new(true)
215 }
216}
217
218#[derive(Debug)]
220pub struct StateManager {
221 state: FunctionState,
223 serializer: StateSerializer,
225 auto_save_interval: u64,
227 last_save: Instant,
229 dirty: bool,
231}
232
233impl StateManager {
234 pub fn new(function_name: impl Into<String>) -> Self {
236 Self {
237 state: FunctionState::new(function_name),
238 serializer: StateSerializer::default(),
239 auto_save_interval: 60,
240 last_save: Instant::now(),
241 dirty: false,
242 }
243 }
244
245 pub fn load(data: &[u8]) -> Result<Self> {
247 let serializer = StateSerializer::default();
248 let state = serializer.deserialize(data)?;
249
250 Ok(Self {
251 state,
252 serializer,
253 auto_save_interval: 60,
254 last_save: Instant::now(),
255 dirty: false,
256 })
257 }
258
259 pub fn state(&self) -> &FunctionState {
261 &self.state
262 }
263
264 pub fn state_mut(&mut self) -> &mut FunctionState {
266 self.dirty = true;
267 &mut self.state
268 }
269
270 pub fn save(&mut self) -> Result<Vec<u8>> {
272 let data = self.serializer.serialize(&self.state)?;
273 self.last_save = Instant::now();
274 self.dirty = false;
275 Ok(data)
276 }
277
278 pub fn should_auto_save(&self) -> bool {
280 self.dirty && self.last_save.elapsed().as_secs() >= self.auto_save_interval
281 }
282
283 pub fn set_auto_save_interval(&mut self, seconds: u64) {
285 self.auto_save_interval = seconds;
286 }
287
288 pub fn update_model(&mut self, model_state: ModelState) {
290 self.state.model_state = model_state;
291 self.state.touch();
292 self.dirty = true;
293 }
294
295 pub fn update_cache(&mut self, cache_state: CacheState) {
297 self.state.cache_state = cache_state;
298 self.state.touch();
299 self.dirty = true;
300 }
301
302 pub fn record_request(&mut self, success: bool, latency_ms: Option<f64>) {
304 if success {
305 self.state.record_success(latency_ms.unwrap_or(0.0));
306 } else {
307 self.state.record_failure();
308 }
309 self.dirty = true;
310 }
311}
312
313#[derive(Debug, Clone, Serialize, Deserialize)]
315pub struct StateDiff {
316 pub changes: HashMap<String, serde_json::Value>,
318 pub timestamp: u64,
320}
321
322impl StateDiff {
323 pub fn new() -> Self {
325 Self {
326 changes: HashMap::new(),
327 timestamp: std::time::SystemTime::now()
328 .duration_since(std::time::UNIX_EPOCH)
329 .unwrap_or_default()
330 .as_millis() as u64,
331 }
332 }
333
334 pub fn add(&mut self, field: impl Into<String>, value: serde_json::Value) {
336 self.changes.insert(field.into(), value);
337 }
338
339 pub fn is_empty(&self) -> bool {
341 self.changes.is_empty()
342 }
343}
344
345impl Default for StateDiff {
346 fn default() -> Self {
347 Self::new()
348 }
349}
350
351#[cfg(test)]
352mod tests {
353 use super::*;
354
355 #[test]
356 fn test_function_state_creation() {
357 let state = FunctionState::new("test-function");
358
359 assert_eq!(state.function_name, "test-function");
360 assert_eq!(state.version, 1);
361 assert!(state.created_at > 0);
362 }
363
364 #[test]
365 fn test_record_requests() {
366 let mut state = FunctionState::new("test");
367
368 state.record_success(100.0);
369 state.record_success(200.0);
370 state.record_failure();
371
372 assert_eq!(state.execution_state.total_requests, 3);
373 assert_eq!(state.execution_state.successful_requests, 2);
374 assert_eq!(state.execution_state.failed_requests, 1);
375 assert_eq!(state.execution_state.avg_latency_ms, 150.0);
376 }
377
378 #[test]
379 fn test_success_rate() {
380 let mut state = FunctionState::new("test");
381
382 state.record_success(100.0);
383 state.record_success(100.0);
384 state.record_failure();
385
386 assert!((state.success_rate() - 0.666666).abs() < 0.01);
387 }
388
389 #[test]
390 fn test_serialization() {
391 let state = FunctionState::new("test");
392 let serializer = StateSerializer::new(false);
393
394 let data = serializer.serialize(&state).unwrap();
395 let restored = serializer.deserialize(&data).unwrap();
396
397 assert_eq!(state.function_name, restored.function_name);
398 }
399
400 #[test]
401 fn test_state_manager() {
402 let mut manager = StateManager::new("test");
403
404 manager.record_request(true, Some(50.0));
405 manager.record_request(true, Some(100.0));
406
407 let state = manager.state();
408 assert_eq!(state.execution_state.successful_requests, 2);
409 }
410
411 #[test]
412 fn test_state_diff() {
413 let mut diff = StateDiff::new();
414
415 diff.add("field1", serde_json::json!(42));
416 diff.add("field2", serde_json::json!("value"));
417
418 assert!(!diff.is_empty());
419 assert_eq!(diff.changes.len(), 2);
420 }
421}