1use serde::{Deserialize, Serialize};
9use std::hash::{Hash, Hasher};
10
11#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
17pub enum PartitionKey {
18 String(String),
20 Int(i64),
22 Composite(Vec<PartitionKey>),
24 Uuid(String),
26 Bytes(Vec<u8>),
28}
29
30impl PartitionKey {
31 pub fn string(value: impl Into<String>) -> Self {
33 Self::String(value.into())
34 }
35
36 pub fn int(value: i64) -> Self {
38 Self::Int(value)
39 }
40
41 pub fn composite(keys: Vec<PartitionKey>) -> Self {
43 Self::Composite(keys)
44 }
45
46 pub fn uuid(value: impl Into<String>) -> Self {
48 Self::Uuid(value.into())
49 }
50
51 pub fn hash_value(&self) -> u64 {
53 let mut hasher = PartitionHasher::new();
54 self.hash(&mut hasher);
55 hasher.finish()
56 }
57
58 pub fn to_bytes(&self) -> Vec<u8> {
60 match self {
61 Self::String(s) => s.as_bytes().to_vec(),
62 Self::Int(i) => i.to_le_bytes().to_vec(),
63 Self::Composite(keys) => {
64 let mut bytes = Vec::new();
65 for key in keys {
66 bytes.extend(key.to_bytes());
67 bytes.push(0); }
69 bytes
70 }
71 Self::Uuid(u) => u.as_bytes().to_vec(),
72 Self::Bytes(b) => b.clone(),
73 }
74 }
75}
76
77impl From<String> for PartitionKey {
78 fn from(value: String) -> Self {
79 Self::String(value)
80 }
81}
82
83impl From<&str> for PartitionKey {
84 fn from(value: &str) -> Self {
85 Self::String(value.to_string())
86 }
87}
88
89impl From<i64> for PartitionKey {
90 fn from(value: i64) -> Self {
91 Self::Int(value)
92 }
93}
94
95impl From<i32> for PartitionKey {
96 fn from(value: i32) -> Self {
97 Self::Int(value as i64)
98 }
99}
100
101#[derive(Debug, Clone, Serialize, Deserialize)]
107pub enum PartitionStrategy {
108 Hash {
110 columns: Vec<String>,
112 num_partitions: u32,
114 },
115 Range {
117 column: String,
119 boundaries: Vec<RangeBoundary>,
121 },
122 List {
124 column: String,
126 mappings: Vec<ListMapping>,
128 },
129 RoundRobin {
131 num_partitions: u32,
133 },
134 Time {
136 column: String,
138 interval: TimeInterval,
140 },
141}
142
143impl PartitionStrategy {
144 pub fn hash(columns: Vec<String>, num_partitions: u32) -> Self {
146 Self::Hash {
147 columns,
148 num_partitions,
149 }
150 }
151
152 pub fn range(column: String, boundaries: Vec<RangeBoundary>) -> Self {
154 Self::Range { column, boundaries }
155 }
156
157 pub fn list(column: String, mappings: Vec<ListMapping>) -> Self {
159 Self::List { column, mappings }
160 }
161
162 pub fn round_robin(num_partitions: u32) -> Self {
164 Self::RoundRobin { num_partitions }
165 }
166
167 pub fn time(column: String, interval: TimeInterval) -> Self {
169 Self::Time { column, interval }
170 }
171
172 pub fn partition_for_key(&self, key: &PartitionKey) -> u32 {
174 match self {
175 Self::Hash { num_partitions, .. } => (key.hash_value() % *num_partitions as u64) as u32,
176 Self::Range { boundaries, .. } => {
177 let hash = key.hash_value();
178 for (i, boundary) in boundaries.iter().enumerate() {
179 if hash < boundary.upper_bound {
180 return i as u32;
181 }
182 }
183 boundaries.len() as u32
184 }
185 Self::List { mappings, .. } => {
186 let key_str = match key {
187 PartitionKey::String(s) => s.clone(),
188 _ => format!("{:?}", key),
189 };
190 for mapping in mappings {
191 if mapping.values.contains(&key_str) {
192 return mapping.partition;
193 }
194 }
195 0 }
197 Self::RoundRobin { num_partitions } => {
198 (key.hash_value() % *num_partitions as u64) as u32
200 }
201 Self::Time { interval, .. } => {
202 if let PartitionKey::Int(ts) = key {
203 (*ts as u64 / interval.to_seconds()) as u32
204 } else {
205 0
206 }
207 }
208 }
209 }
210}
211
212#[derive(Debug, Clone, Serialize, Deserialize)]
218pub struct RangeBoundary {
219 pub partition_name: String,
220 pub upper_bound: u64,
221}
222
223impl RangeBoundary {
224 pub fn new(name: impl Into<String>, upper_bound: u64) -> Self {
225 Self {
226 partition_name: name.into(),
227 upper_bound,
228 }
229 }
230}
231
232#[derive(Debug, Clone, Serialize, Deserialize)]
238pub struct ListMapping {
239 pub partition: u32,
240 pub values: Vec<String>,
241}
242
243impl ListMapping {
244 pub fn new(partition: u32, values: Vec<String>) -> Self {
245 Self { partition, values }
246 }
247}
248
249#[derive(Debug, Clone, Copy, Serialize, Deserialize)]
255pub enum TimeInterval {
256 Hour,
257 Day,
258 Week,
259 Month,
260 Year,
261 Custom(u64),
262}
263
264impl TimeInterval {
265 pub fn to_seconds(&self) -> u64 {
267 match self {
268 Self::Hour => 3600,
269 Self::Day => 86400,
270 Self::Week => 604800,
271 Self::Month => 2592000, Self::Year => 31536000, Self::Custom(s) => *s,
274 }
275 }
276}
277
278#[derive(Debug, Clone, Serialize, Deserialize)]
284pub struct PartitionRange {
285 pub start: u64,
286 pub end: u64,
287 pub inclusive_start: bool,
288 pub inclusive_end: bool,
289}
290
291impl PartitionRange {
292 pub fn new(start: u64, end: u64) -> Self {
294 Self {
295 start,
296 end,
297 inclusive_start: true,
298 inclusive_end: false,
299 }
300 }
301
302 pub fn full() -> Self {
304 Self::new(0, u64::MAX)
305 }
306
307 pub fn contains(&self, hash: u64) -> bool {
309 let start_check = if self.inclusive_start {
310 hash >= self.start
311 } else {
312 hash > self.start
313 };
314
315 let end_check = if self.inclusive_end {
316 hash <= self.end
317 } else {
318 hash < self.end
319 };
320
321 start_check && end_check
322 }
323
324 pub fn split(&self, num_parts: u32) -> Vec<PartitionRange> {
326 if num_parts == 0 {
327 return vec![self.clone()];
328 }
329
330 let range_size = (self.end - self.start) / num_parts as u64;
331 let mut ranges = Vec::with_capacity(num_parts as usize);
332
333 for i in 0..num_parts {
334 let start = self.start + (i as u64 * range_size);
335 let end = if i == num_parts - 1 {
336 self.end
337 } else {
338 self.start + ((i as u64 + 1) * range_size)
339 };
340
341 ranges.push(PartitionRange {
342 start,
343 end,
344 inclusive_start: true,
345 inclusive_end: i == num_parts - 1,
346 });
347 }
348
349 ranges
350 }
351
352 pub fn merge(&self, other: &PartitionRange) -> Option<PartitionRange> {
354 if self.end == other.start {
355 Some(PartitionRange {
356 start: self.start,
357 end: other.end,
358 inclusive_start: self.inclusive_start,
359 inclusive_end: other.inclusive_end,
360 })
361 } else if other.end == self.start {
362 Some(PartitionRange {
363 start: other.start,
364 end: self.end,
365 inclusive_start: other.inclusive_start,
366 inclusive_end: self.inclusive_end,
367 })
368 } else {
369 None
370 }
371 }
372
373 pub fn size(&self) -> u64 {
375 self.end - self.start
376 }
377}
378
379struct PartitionHasher {
385 state: u64,
386}
387
388impl PartitionHasher {
389 fn new() -> Self {
390 Self {
391 state: 0x517cc1b727220a95,
392 }
393 }
394}
395
396impl Hasher for PartitionHasher {
397 fn finish(&self) -> u64 {
398 self.state
399 }
400
401 fn write(&mut self, bytes: &[u8]) {
402 for byte in bytes {
403 self.state ^= *byte as u64;
404 self.state = self.state.wrapping_mul(0x5851f42d4c957f2d);
405 }
406 }
407}
408
409pub struct KeyExtractor {
415 columns: Vec<String>,
416}
417
418impl KeyExtractor {
419 pub fn new(columns: Vec<String>) -> Self {
421 Self { columns }
422 }
423
424 pub fn extract(
426 &self,
427 values: &std::collections::HashMap<String, String>,
428 ) -> Option<PartitionKey> {
429 if self.columns.len() == 1 {
430 values
431 .get(&self.columns[0])
432 .map(|v| PartitionKey::String(v.clone()))
433 } else {
434 let mut keys = Vec::new();
435 for col in &self.columns {
436 if let Some(v) = values.get(col) {
437 keys.push(PartitionKey::String(v.clone()));
438 } else {
439 return None;
440 }
441 }
442 Some(PartitionKey::Composite(keys))
443 }
444 }
445
446 pub fn columns(&self) -> &[String] {
448 &self.columns
449 }
450}
451
452#[cfg(test)]
457mod tests {
458 use super::*;
459
460 #[test]
461 fn test_partition_key_string() {
462 let key = PartitionKey::string("user_123");
463 let hash = key.hash_value();
464 assert!(hash > 0);
465
466 let key2 = PartitionKey::string("user_123");
468 assert_eq!(key.hash_value(), key2.hash_value());
469 }
470
471 #[test]
472 fn test_partition_key_int() {
473 let key = PartitionKey::int(12345);
474 let hash = key.hash_value();
475 assert!(hash > 0);
476 }
477
478 #[test]
479 fn test_partition_key_composite() {
480 let key = PartitionKey::composite(vec![
481 PartitionKey::string("tenant_1"),
482 PartitionKey::int(100),
483 ]);
484 let hash = key.hash_value();
485 assert!(hash > 0);
486 }
487
488 #[test]
489 fn test_partition_strategy_hash() {
490 let strategy = PartitionStrategy::hash(vec!["id".to_string()], 16);
491
492 let key1 = PartitionKey::string("key1");
493 let key2 = PartitionKey::string("key2");
494
495 let p1 = strategy.partition_for_key(&key1);
496 let p2 = strategy.partition_for_key(&key2);
497
498 assert!(p1 < 16);
499 assert!(p2 < 16);
500
501 assert_eq!(p1, strategy.partition_for_key(&key1));
503 }
504
505 #[test]
506 fn test_partition_strategy_list() {
507 let strategy = PartitionStrategy::list(
508 "region".to_string(),
509 vec![
510 ListMapping::new(0, vec!["us-east".to_string(), "us-west".to_string()]),
511 ListMapping::new(1, vec!["eu-west".to_string()]),
512 ],
513 );
514
515 let key_us = PartitionKey::string("us-east");
516 let key_eu = PartitionKey::string("eu-west");
517
518 assert_eq!(strategy.partition_for_key(&key_us), 0);
519 assert_eq!(strategy.partition_for_key(&key_eu), 1);
520 }
521
522 #[test]
523 fn test_partition_range() {
524 let range = PartitionRange::new(100, 200);
525
526 assert!(range.contains(100));
527 assert!(range.contains(150));
528 assert!(!range.contains(200));
529 assert!(!range.contains(50));
530 }
531
532 #[test]
533 fn test_partition_range_split() {
534 let range = PartitionRange::new(0, 1000);
535 let parts = range.split(4);
536
537 assert_eq!(parts.len(), 4);
538 assert_eq!(parts[0].start, 0);
539 assert_eq!(parts[0].end, 250);
540 assert_eq!(parts[3].end, 1000);
541 }
542
543 #[test]
544 fn test_partition_range_merge() {
545 let r1 = PartitionRange::new(0, 100);
546 let r2 = PartitionRange::new(100, 200);
547
548 let merged = r1.merge(&r2).unwrap();
549 assert_eq!(merged.start, 0);
550 assert_eq!(merged.end, 200);
551 }
552
553 #[test]
554 fn test_key_extractor() {
555 let extractor = KeyExtractor::new(vec!["user_id".to_string()]);
556
557 let mut values = std::collections::HashMap::new();
558 values.insert("user_id".to_string(), "123".to_string());
559 values.insert("name".to_string(), "Alice".to_string());
560
561 let key = extractor.extract(&values).unwrap();
562 assert_eq!(key, PartitionKey::String("123".to_string()));
563 }
564
565 #[test]
566 fn test_key_extractor_composite() {
567 let extractor = KeyExtractor::new(vec!["tenant_id".to_string(), "user_id".to_string()]);
568
569 let mut values = std::collections::HashMap::new();
570 values.insert("tenant_id".to_string(), "t1".to_string());
571 values.insert("user_id".to_string(), "u1".to_string());
572
573 let key = extractor.extract(&values).unwrap();
574 match key {
575 PartitionKey::Composite(keys) => {
576 assert_eq!(keys.len(), 2);
577 }
578 _ => panic!("Expected composite key"),
579 }
580 }
581
582 #[test]
583 fn test_time_interval() {
584 assert_eq!(TimeInterval::Hour.to_seconds(), 3600);
585 assert_eq!(TimeInterval::Day.to_seconds(), 86400);
586 assert_eq!(TimeInterval::Custom(7200).to_seconds(), 7200);
587 }
588}