1use std::collections::HashMap;
7
8#[derive(Debug, Clone, PartialEq)]
12pub enum PartitionStateValue {
13 Integer(i64),
14 Float(f64),
15 Bytes(Vec<u8>),
16 StringVal(String),
17 Counter(u64),
18 Gauge { value: f64, timestamp: i64 },
19}
20
21#[derive(Debug, Clone)]
25pub struct StatePartition {
26 pub partition_id: u32,
27 pub state: HashMap<String, PartitionStateValue>,
28 pub version: u64,
29 pub last_checkpointed: i64,
30}
31
32impl StatePartition {
33 pub fn new(partition_id: u32) -> Self {
34 Self {
35 partition_id,
36 state: HashMap::new(),
37 version: 0,
38 last_checkpointed: 0,
39 }
40 }
41
42 fn bump_version(&mut self) -> u64 {
43 self.version += 1;
44 self.version
45 }
46}
47
48#[derive(Debug, Clone)]
52pub struct StateCoordinator {
53 pub node_id: String,
54 pub peers: Vec<String>,
55}
56
57impl StateCoordinator {
58 pub fn new(node_id: impl Into<String>) -> Self {
59 Self {
60 node_id: node_id.into(),
61 peers: Vec::new(),
62 }
63 }
64
65 pub fn add_peer(&mut self, peer: impl Into<String>) {
66 self.peers.push(peer.into());
67 }
68}
69
70pub struct DistributedStateStore {
77 pub(crate) partitions: Vec<StatePartition>,
78 replication_factor: usize,
79 coordinator: StateCoordinator,
80}
81
82impl DistributedStateStore {
83 pub fn new(partition_count: u32, replication_factor: usize) -> Self {
85 let partitions = (0..partition_count).map(StatePartition::new).collect();
86 Self {
87 partitions,
88 replication_factor,
89 coordinator: StateCoordinator::new("local"),
90 }
91 }
92
93 fn fnv_hash(key: &str) -> u64 {
95 const FNV_OFFSET: u64 = 14_695_981_039_346_656_037;
96 const FNV_PRIME: u64 = 1_099_511_628_211;
97 let mut hash = FNV_OFFSET;
98 for byte in key.as_bytes() {
99 hash ^= *byte as u64;
100 hash = hash.wrapping_mul(FNV_PRIME);
101 }
102 hash
103 }
104
105 pub fn partition_for(&self, key: &str) -> u32 {
107 let count = self.partitions.len() as u64;
108 if count == 0 {
109 return 0;
110 }
111 (Self::fnv_hash(key) % count) as u32
112 }
113
114 pub fn get(&self, key: &str) -> Option<&PartitionStateValue> {
116 let pid = self.partition_for(key) as usize;
117 self.partitions.get(pid)?.state.get(key)
118 }
119
120 pub fn set(&mut self, key: &str, value: PartitionStateValue) -> u64 {
122 let pid = self.partition_for(key) as usize;
123 let partition = &mut self.partitions[pid];
124 partition.state.insert(key.to_string(), value);
125 partition.bump_version()
126 }
127
128 pub fn delete(&mut self, key: &str) -> bool {
130 let pid = self.partition_for(key) as usize;
131 match self.partitions.get_mut(pid) {
132 Some(partition) => partition.state.remove(key).is_some(),
133 None => false,
134 }
135 }
136
137 pub fn replicate_to(
139 &self,
140 _peer: &str,
141 partition_id: u32,
142 ) -> Vec<(String, PartitionStateValue)> {
143 self.partitions
144 .iter()
145 .find(|p| p.partition_id == partition_id)
146 .map(|p| {
147 p.state
148 .iter()
149 .map(|(k, v)| (k.clone(), v.clone()))
150 .collect()
151 })
152 .unwrap_or_default()
153 }
154
155 pub fn checkpoint_partition(&mut self, partition_id: u32) -> StatePartition {
157 let now_ms = std::time::SystemTime::now()
158 .duration_since(std::time::UNIX_EPOCH)
159 .map(|d| d.as_millis() as i64)
160 .unwrap_or(0);
161 let partition = self
162 .partitions
163 .iter_mut()
164 .find(|p| p.partition_id == partition_id)
165 .expect("partition_id out of range");
166 partition.last_checkpointed = now_ms;
167 partition.clone()
168 }
169
170 pub fn restore_partition(&mut self, partition: StatePartition) {
172 if let Some(p) = self
173 .partitions
174 .iter_mut()
175 .find(|p| p.partition_id == partition.partition_id)
176 {
177 *p = partition;
178 }
179 }
180
181 pub fn partition_count(&self) -> u32 {
183 self.partitions.len() as u32
184 }
185
186 pub fn total_keys(&self) -> usize {
188 self.partitions.iter().map(|p| p.state.len()).sum()
189 }
190
191 pub fn replication_factor(&self) -> usize {
193 self.replication_factor
194 }
195
196 pub fn coordinator(&self) -> &StateCoordinator {
198 &self.coordinator
199 }
200
201 pub fn coordinator_mut(&mut self) -> &mut StateCoordinator {
203 &mut self.coordinator
204 }
205}
206
207pub struct StateAggregator {
214 store: DistributedStateStore,
215}
216
217impl StateAggregator {
218 pub fn new(partition_count: u32) -> Self {
220 Self {
221 store: DistributedStateStore::new(partition_count, 1),
222 }
223 }
224
225 pub fn increment(&mut self, key: &str, by: i64) -> i64 {
227 let current = match self.store.get(key) {
228 Some(PartitionStateValue::Integer(v)) => *v,
229 Some(PartitionStateValue::Counter(v)) => *v as i64,
230 _ => 0,
231 };
232 let next = current + by;
233 self.store.set(key, PartitionStateValue::Integer(next));
234 next
235 }
236
237 pub fn accumulate(&mut self, key: &str, value: f64) -> f64 {
239 let current = match self.store.get(key) {
240 Some(PartitionStateValue::Float(v)) => *v,
241 _ => 0.0,
242 };
243 let next = current + value;
244 self.store.set(key, PartitionStateValue::Float(next));
245 next
246 }
247
248 pub fn update_gauge(&mut self, key: &str, value: f64) {
250 let timestamp = std::time::SystemTime::now()
251 .duration_since(std::time::UNIX_EPOCH)
252 .map(|d| d.as_millis() as i64)
253 .unwrap_or(0);
254 self.store
255 .set(key, PartitionStateValue::Gauge { value, timestamp });
256 }
257
258 pub fn window_count(&mut self, window_key: &str, event_key: &str) -> u64 {
261 let key = format!("{window_key}:{event_key}");
262 let current = match self.store.get(&key) {
263 Some(PartitionStateValue::Counter(v)) => *v,
264 _ => 0,
265 };
266 let next = current + 1;
267 self.store.set(&key, PartitionStateValue::Counter(next));
268 next
269 }
270
271 pub fn merge_from(&mut self, other: &DistributedStateStore) {
273 for partition in &other.partitions {
274 for (key, value) in &partition.state {
275 self.store.set(key, value.clone());
276 }
277 }
278 }
279
280 pub fn store(&self) -> &DistributedStateStore {
282 &self.store
283 }
284}
285
286#[cfg(test)]
289mod tests {
290 use super::*;
291
292 #[test]
295 fn test_new_store_empty() {
296 let store = DistributedStateStore::new(4, 1);
297 assert_eq!(store.partition_count(), 4);
298 assert_eq!(store.total_keys(), 0);
299 assert_eq!(store.replication_factor(), 1);
300 }
301
302 #[test]
303 fn test_set_and_get_string() {
304 let mut store = DistributedStateStore::new(4, 1);
305 store.set("hello", PartitionStateValue::StringVal("world".to_string()));
306 match store.get("hello") {
307 Some(PartitionStateValue::StringVal(s)) => assert_eq!(s, "world"),
308 other => panic!("unexpected: {other:?}"),
309 }
310 }
311
312 #[test]
313 fn test_set_returns_version_increases() {
314 let mut store = DistributedStateStore::new(4, 1);
315 let v1 = store.set("k", PartitionStateValue::Integer(1));
316 let v2 = store.set("k", PartitionStateValue::Integer(2));
317 assert!(v2 > v1, "version must increase on each write");
318 }
319
320 #[test]
321 fn test_delete_existing_key() {
322 let mut store = DistributedStateStore::new(4, 1);
323 store.set("k", PartitionStateValue::Counter(10));
324 assert!(
325 store.delete("k"),
326 "delete should return true for existing key"
327 );
328 assert!(store.get("k").is_none());
329 }
330
331 #[test]
332 fn test_delete_missing_key() {
333 let mut store = DistributedStateStore::new(4, 1);
334 assert!(!store.delete("nonexistent"));
335 }
336
337 #[test]
338 fn test_partition_for_deterministic() {
339 let store = DistributedStateStore::new(8, 1);
340 let p1 = store.partition_for("my_key");
341 let p2 = store.partition_for("my_key");
342 assert_eq!(p1, p2, "same key must always map to same partition");
343 assert!(p1 < 8);
344 }
345
346 #[test]
347 fn test_partition_for_distributes_across_partitions() {
348 let store = DistributedStateStore::new(8, 1);
349 let keys = [
350 "alpha", "beta", "gamma", "delta", "epsilon", "zeta", "eta", "theta",
351 ];
352 let partitions: std::collections::HashSet<u32> =
353 keys.iter().map(|k| store.partition_for(k)).collect();
354 assert!(
355 partitions.len() >= 2,
356 "8 keys over 8 partitions must use at least 2 different partitions"
357 );
358 }
359
360 #[test]
361 fn test_total_keys_after_operations() {
362 let mut store = DistributedStateStore::new(4, 1);
363 store.set("a", PartitionStateValue::Integer(1));
364 store.set("b", PartitionStateValue::Integer(2));
365 store.set("c", PartitionStateValue::Integer(3));
366 assert_eq!(store.total_keys(), 3);
367 store.delete("b");
368 assert_eq!(store.total_keys(), 2);
369 }
370
371 #[test]
372 fn test_replicate_to_returns_partition_contents() {
373 let mut store = DistributedStateStore::new(4, 2);
374 store.set("key1", PartitionStateValue::Integer(42));
375 let pid = store.partition_for("key1");
376 let replica = store.replicate_to("peer-node", pid);
377 assert!(!replica.is_empty());
378 assert!(replica.iter().any(|(k, _)| k == "key1"));
379 }
380
381 #[test]
382 fn test_replicate_to_nonexistent_partition() {
383 let store = DistributedStateStore::new(4, 1);
384 let replica = store.replicate_to("peer", 99);
385 assert!(replica.is_empty());
386 }
387
388 #[test]
389 fn test_checkpoint_and_restore() {
390 let mut store = DistributedStateStore::new(4, 1);
391 let expected_val = 42.5_f64;
392 store.set("x", PartitionStateValue::Float(expected_val));
393 let pid = store.partition_for("x");
394
395 let checkpoint = store.checkpoint_partition(pid);
396 assert!(
397 checkpoint.last_checkpointed > 0,
398 "last_checkpointed must be set"
399 );
400
401 store.set("x", PartitionStateValue::Float(0.0));
403
404 store.restore_partition(checkpoint);
406 match store.get("x") {
407 Some(PartitionStateValue::Float(v)) => {
408 assert!((v - expected_val).abs() < 1e-9);
409 }
410 other => panic!("unexpected after restore: {other:?}"),
411 }
412 }
413
414 #[test]
415 fn test_coordinator_default_node_id() {
416 let store = DistributedStateStore::new(2, 1);
417 assert_eq!(store.coordinator().node_id, "local");
418 assert!(store.coordinator().peers.is_empty());
419 }
420
421 #[test]
422 fn test_coordinator_add_peers() {
423 let mut store = DistributedStateStore::new(2, 1);
424 store.coordinator_mut().add_peer("node-2");
425 store.coordinator_mut().add_peer("node-3");
426 assert_eq!(store.coordinator().peers.len(), 2);
427 }
428
429 #[test]
430 fn test_all_value_variants() {
431 let mut store = DistributedStateStore::new(8, 1);
432 store.set("int_k", PartitionStateValue::Integer(-10));
433 store.set("float_k", PartitionStateValue::Float(2.5));
434 store.set("bytes_k", PartitionStateValue::Bytes(vec![1, 2, 3]));
435 store.set("str_k", PartitionStateValue::StringVal("hi".to_string()));
436 store.set("ctr_k", PartitionStateValue::Counter(99));
437 store.set(
438 "gauge_k",
439 PartitionStateValue::Gauge {
440 value: 1.0,
441 timestamp: 1000,
442 },
443 );
444 assert_eq!(store.total_keys(), 6);
445 }
446
447 #[test]
448 fn test_single_partition_all_keys_same_partition() {
449 let store = DistributedStateStore::new(1, 1);
450 assert_eq!(store.partition_for("anything"), 0);
451 assert_eq!(store.partition_for("other_key"), 0);
452 }
453
454 #[test]
455 fn test_overwrite_value() {
456 let mut store = DistributedStateStore::new(4, 1);
457 store.set("key", PartitionStateValue::Integer(1));
458 store.set("key", PartitionStateValue::Integer(2));
459 match store.get("key") {
460 Some(PartitionStateValue::Integer(v)) => assert_eq!(*v, 2),
461 other => panic!("unexpected: {other:?}"),
462 }
463 }
464
465 #[test]
466 fn test_state_partition_new() {
467 let p = StatePartition::new(5);
468 assert_eq!(p.partition_id, 5);
469 assert_eq!(p.version, 0);
470 assert!(p.state.is_empty());
471 assert_eq!(p.last_checkpointed, 0);
472 }
473
474 #[test]
477 fn test_aggregator_increment_positive() {
478 let mut agg = StateAggregator::new(4);
479 assert_eq!(agg.increment("counter", 5), 5);
480 assert_eq!(agg.increment("counter", 3), 8);
481 }
482
483 #[test]
484 fn test_aggregator_increment_negative() {
485 let mut agg = StateAggregator::new(4);
486 agg.increment("counter", 10);
487 assert_eq!(agg.increment("counter", -2), 8);
488 }
489
490 #[test]
491 fn test_aggregator_accumulate_floats() {
492 let mut agg = StateAggregator::new(4);
493 let v1 = agg.accumulate("sum", 1.5);
494 let v2 = agg.accumulate("sum", 2.5);
495 assert!((v1 - 1.5).abs() < 1e-9);
496 assert!((v2 - 4.0).abs() < 1e-9);
497 }
498
499 #[test]
500 fn test_aggregator_update_gauge() {
501 let mut agg = StateAggregator::new(4);
502 agg.update_gauge("temperature", 98.6);
503 match agg.store().get("temperature") {
504 Some(PartitionStateValue::Gauge { value, .. }) => {
505 assert!((value - 98.6).abs() < 1e-9);
506 }
507 other => panic!("unexpected: {other:?}"),
508 }
509 }
510
511 #[test]
512 fn test_aggregator_window_count_isolated() {
513 let mut agg = StateAggregator::new(4);
514 assert_eq!(agg.window_count("win-1", "click"), 1);
515 assert_eq!(agg.window_count("win-1", "click"), 2);
516 assert_eq!(agg.window_count("win-1", "view"), 1);
517 assert_eq!(agg.window_count("win-2", "click"), 1);
518 }
519
520 #[test]
521 fn test_aggregator_merge_from() {
522 let mut store2 = DistributedStateStore::new(4, 1);
523 store2.set("shared_key", PartitionStateValue::Integer(100));
524
525 let mut agg = StateAggregator::new(4);
526 agg.merge_from(&store2);
527
528 match agg.store().get("shared_key") {
529 Some(PartitionStateValue::Integer(v)) => assert_eq!(*v, 100),
530 other => panic!("unexpected: {other:?}"),
531 }
532 }
533
534 #[test]
535 fn test_aggregator_store_accessor() {
536 let agg = StateAggregator::new(4);
537 assert_eq!(agg.store().partition_count(), 4);
538 assert_eq!(agg.store().total_keys(), 0);
539 }
540}