1use serde::{Deserialize, Serialize};
9use std::hash::{Hash, Hasher};
10
11#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
17pub enum PartitionKey {
18 String(String),
20 Int(i64),
22 Composite(Vec<PartitionKey>),
24 Uuid(String),
26 Bytes(Vec<u8>),
28}
29
30impl PartitionKey {
31 pub fn string(value: impl Into<String>) -> Self {
33 Self::String(value.into())
34 }
35
36 pub fn int(value: i64) -> Self {
38 Self::Int(value)
39 }
40
41 pub fn composite(keys: Vec<PartitionKey>) -> Self {
43 Self::Composite(keys)
44 }
45
46 pub fn uuid(value: impl Into<String>) -> Self {
48 Self::Uuid(value.into())
49 }
50
51 pub fn hash_value(&self) -> u64 {
53 let mut hasher = PartitionHasher::new();
54 self.hash(&mut hasher);
55 hasher.finish()
56 }
57
58 pub fn to_bytes(&self) -> Vec<u8> {
60 match self {
61 Self::String(s) => s.as_bytes().to_vec(),
62 Self::Int(i) => i.to_le_bytes().to_vec(),
63 Self::Composite(keys) => {
64 let mut bytes = Vec::new();
65 for key in keys {
66 bytes.extend(key.to_bytes());
67 bytes.push(0); }
69 bytes
70 }
71 Self::Uuid(u) => u.as_bytes().to_vec(),
72 Self::Bytes(b) => b.clone(),
73 }
74 }
75}
76
77impl From<String> for PartitionKey {
78 fn from(value: String) -> Self {
79 Self::String(value)
80 }
81}
82
83impl From<&str> for PartitionKey {
84 fn from(value: &str) -> Self {
85 Self::String(value.to_string())
86 }
87}
88
89impl From<i64> for PartitionKey {
90 fn from(value: i64) -> Self {
91 Self::Int(value)
92 }
93}
94
95impl From<i32> for PartitionKey {
96 fn from(value: i32) -> Self {
97 Self::Int(value as i64)
98 }
99}
100
101#[derive(Debug, Clone, Serialize, Deserialize)]
107pub enum PartitionStrategy {
108 Hash {
110 columns: Vec<String>,
112 num_partitions: u32,
114 },
115 Range {
117 column: String,
119 boundaries: Vec<RangeBoundary>,
121 },
122 List {
124 column: String,
126 mappings: Vec<ListMapping>,
128 },
129 RoundRobin {
131 num_partitions: u32,
133 },
134 Time {
136 column: String,
138 interval: TimeInterval,
140 },
141}
142
143impl PartitionStrategy {
144 pub fn hash(columns: Vec<String>, num_partitions: u32) -> Self {
146 Self::Hash {
147 columns,
148 num_partitions,
149 }
150 }
151
152 pub fn range(column: String, boundaries: Vec<RangeBoundary>) -> Self {
154 Self::Range { column, boundaries }
155 }
156
157 pub fn list(column: String, mappings: Vec<ListMapping>) -> Self {
159 Self::List { column, mappings }
160 }
161
162 pub fn round_robin(num_partitions: u32) -> Self {
164 Self::RoundRobin { num_partitions }
165 }
166
167 pub fn time(column: String, interval: TimeInterval) -> Self {
169 Self::Time { column, interval }
170 }
171
172 pub fn partition_for_key(&self, key: &PartitionKey) -> u32 {
174 match self {
175 Self::Hash { num_partitions, .. } => {
176 (key.hash_value() % *num_partitions as u64) as u32
177 }
178 Self::Range { boundaries, .. } => {
179 let hash = key.hash_value();
180 for (i, boundary) in boundaries.iter().enumerate() {
181 if hash < boundary.upper_bound {
182 return i as u32;
183 }
184 }
185 boundaries.len() as u32
186 }
187 Self::List { mappings, .. } => {
188 let key_str = match key {
189 PartitionKey::String(s) => s.clone(),
190 _ => format!("{:?}", key),
191 };
192 for mapping in mappings {
193 if mapping.values.contains(&key_str) {
194 return mapping.partition;
195 }
196 }
197 0 }
199 Self::RoundRobin { num_partitions } => {
200 (key.hash_value() % *num_partitions as u64) as u32
202 }
203 Self::Time { interval, .. } => {
204 if let PartitionKey::Int(ts) = key {
205 (*ts as u64 / interval.to_seconds()) as u32
206 } else {
207 0
208 }
209 }
210 }
211 }
212}
213
214#[derive(Debug, Clone, Serialize, Deserialize)]
220pub struct RangeBoundary {
221 pub partition_name: String,
222 pub upper_bound: u64,
223}
224
225impl RangeBoundary {
226 pub fn new(name: impl Into<String>, upper_bound: u64) -> Self {
227 Self {
228 partition_name: name.into(),
229 upper_bound,
230 }
231 }
232}
233
234#[derive(Debug, Clone, Serialize, Deserialize)]
240pub struct ListMapping {
241 pub partition: u32,
242 pub values: Vec<String>,
243}
244
245impl ListMapping {
246 pub fn new(partition: u32, values: Vec<String>) -> Self {
247 Self { partition, values }
248 }
249}
250
251#[derive(Debug, Clone, Copy, Serialize, Deserialize)]
257pub enum TimeInterval {
258 Hour,
259 Day,
260 Week,
261 Month,
262 Year,
263 Custom(u64),
264}
265
266impl TimeInterval {
267 pub fn to_seconds(&self) -> u64 {
269 match self {
270 Self::Hour => 3600,
271 Self::Day => 86400,
272 Self::Week => 604800,
273 Self::Month => 2592000, Self::Year => 31536000, Self::Custom(s) => *s,
276 }
277 }
278}
279
280#[derive(Debug, Clone, Serialize, Deserialize)]
286pub struct PartitionRange {
287 pub start: u64,
288 pub end: u64,
289 pub inclusive_start: bool,
290 pub inclusive_end: bool,
291}
292
293impl PartitionRange {
294 pub fn new(start: u64, end: u64) -> Self {
296 Self {
297 start,
298 end,
299 inclusive_start: true,
300 inclusive_end: false,
301 }
302 }
303
304 pub fn full() -> Self {
306 Self::new(0, u64::MAX)
307 }
308
309 pub fn contains(&self, hash: u64) -> bool {
311 let start_check = if self.inclusive_start {
312 hash >= self.start
313 } else {
314 hash > self.start
315 };
316
317 let end_check = if self.inclusive_end {
318 hash <= self.end
319 } else {
320 hash < self.end
321 };
322
323 start_check && end_check
324 }
325
326 pub fn split(&self, num_parts: u32) -> Vec<PartitionRange> {
328 if num_parts == 0 {
329 return vec![self.clone()];
330 }
331
332 let range_size = (self.end - self.start) / num_parts as u64;
333 let mut ranges = Vec::with_capacity(num_parts as usize);
334
335 for i in 0..num_parts {
336 let start = self.start + (i as u64 * range_size);
337 let end = if i == num_parts - 1 {
338 self.end
339 } else {
340 self.start + ((i as u64 + 1) * range_size)
341 };
342
343 ranges.push(PartitionRange {
344 start,
345 end,
346 inclusive_start: true,
347 inclusive_end: i == num_parts - 1,
348 });
349 }
350
351 ranges
352 }
353
354 pub fn merge(&self, other: &PartitionRange) -> Option<PartitionRange> {
356 if self.end == other.start {
357 Some(PartitionRange {
358 start: self.start,
359 end: other.end,
360 inclusive_start: self.inclusive_start,
361 inclusive_end: other.inclusive_end,
362 })
363 } else if other.end == self.start {
364 Some(PartitionRange {
365 start: other.start,
366 end: self.end,
367 inclusive_start: other.inclusive_start,
368 inclusive_end: self.inclusive_end,
369 })
370 } else {
371 None
372 }
373 }
374
375 pub fn size(&self) -> u64 {
377 self.end - self.start
378 }
379}
380
381struct PartitionHasher {
387 state: u64,
388}
389
390impl PartitionHasher {
391 fn new() -> Self {
392 Self {
393 state: 0x517cc1b727220a95,
394 }
395 }
396}
397
398impl Hasher for PartitionHasher {
399 fn finish(&self) -> u64 {
400 self.state
401 }
402
403 fn write(&mut self, bytes: &[u8]) {
404 for byte in bytes {
405 self.state ^= *byte as u64;
406 self.state = self.state.wrapping_mul(0x5851f42d4c957f2d);
407 }
408 }
409}
410
411pub struct KeyExtractor {
417 columns: Vec<String>,
418}
419
420impl KeyExtractor {
421 pub fn new(columns: Vec<String>) -> Self {
423 Self { columns }
424 }
425
426 pub fn extract(&self, values: &std::collections::HashMap<String, String>) -> Option<PartitionKey> {
428 if self.columns.len() == 1 {
429 values
430 .get(&self.columns[0])
431 .map(|v| PartitionKey::String(v.clone()))
432 } else {
433 let mut keys = Vec::new();
434 for col in &self.columns {
435 if let Some(v) = values.get(col) {
436 keys.push(PartitionKey::String(v.clone()));
437 } else {
438 return None;
439 }
440 }
441 Some(PartitionKey::Composite(keys))
442 }
443 }
444
445 pub fn columns(&self) -> &[String] {
447 &self.columns
448 }
449}
450
451#[cfg(test)]
456mod tests {
457 use super::*;
458
459 #[test]
460 fn test_partition_key_string() {
461 let key = PartitionKey::string("user_123");
462 let hash = key.hash_value();
463 assert!(hash > 0);
464
465 let key2 = PartitionKey::string("user_123");
467 assert_eq!(key.hash_value(), key2.hash_value());
468 }
469
470 #[test]
471 fn test_partition_key_int() {
472 let key = PartitionKey::int(12345);
473 let hash = key.hash_value();
474 assert!(hash > 0);
475 }
476
477 #[test]
478 fn test_partition_key_composite() {
479 let key = PartitionKey::composite(vec![
480 PartitionKey::string("tenant_1"),
481 PartitionKey::int(100),
482 ]);
483 let hash = key.hash_value();
484 assert!(hash > 0);
485 }
486
487 #[test]
488 fn test_partition_strategy_hash() {
489 let strategy = PartitionStrategy::hash(vec!["id".to_string()], 16);
490
491 let key1 = PartitionKey::string("key1");
492 let key2 = PartitionKey::string("key2");
493
494 let p1 = strategy.partition_for_key(&key1);
495 let p2 = strategy.partition_for_key(&key2);
496
497 assert!(p1 < 16);
498 assert!(p2 < 16);
499
500 assert_eq!(p1, strategy.partition_for_key(&key1));
502 }
503
504 #[test]
505 fn test_partition_strategy_list() {
506 let strategy = PartitionStrategy::list(
507 "region".to_string(),
508 vec![
509 ListMapping::new(0, vec!["us-east".to_string(), "us-west".to_string()]),
510 ListMapping::new(1, vec!["eu-west".to_string()]),
511 ],
512 );
513
514 let key_us = PartitionKey::string("us-east");
515 let key_eu = PartitionKey::string("eu-west");
516
517 assert_eq!(strategy.partition_for_key(&key_us), 0);
518 assert_eq!(strategy.partition_for_key(&key_eu), 1);
519 }
520
521 #[test]
522 fn test_partition_range() {
523 let range = PartitionRange::new(100, 200);
524
525 assert!(range.contains(100));
526 assert!(range.contains(150));
527 assert!(!range.contains(200));
528 assert!(!range.contains(50));
529 }
530
531 #[test]
532 fn test_partition_range_split() {
533 let range = PartitionRange::new(0, 1000);
534 let parts = range.split(4);
535
536 assert_eq!(parts.len(), 4);
537 assert_eq!(parts[0].start, 0);
538 assert_eq!(parts[0].end, 250);
539 assert_eq!(parts[3].end, 1000);
540 }
541
542 #[test]
543 fn test_partition_range_merge() {
544 let r1 = PartitionRange::new(0, 100);
545 let r2 = PartitionRange::new(100, 200);
546
547 let merged = r1.merge(&r2).unwrap();
548 assert_eq!(merged.start, 0);
549 assert_eq!(merged.end, 200);
550 }
551
552 #[test]
553 fn test_key_extractor() {
554 let extractor = KeyExtractor::new(vec!["user_id".to_string()]);
555
556 let mut values = std::collections::HashMap::new();
557 values.insert("user_id".to_string(), "123".to_string());
558 values.insert("name".to_string(), "Alice".to_string());
559
560 let key = extractor.extract(&values).unwrap();
561 assert_eq!(key, PartitionKey::String("123".to_string()));
562 }
563
564 #[test]
565 fn test_key_extractor_composite() {
566 let extractor = KeyExtractor::new(vec![
567 "tenant_id".to_string(),
568 "user_id".to_string(),
569 ]);
570
571 let mut values = std::collections::HashMap::new();
572 values.insert("tenant_id".to_string(), "t1".to_string());
573 values.insert("user_id".to_string(), "u1".to_string());
574
575 let key = extractor.extract(&values).unwrap();
576 match key {
577 PartitionKey::Composite(keys) => {
578 assert_eq!(keys.len(), 2);
579 }
580 _ => panic!("Expected composite key"),
581 }
582 }
583
584 #[test]
585 fn test_time_interval() {
586 assert_eq!(TimeInterval::Hour.to_seconds(), 3600);
587 assert_eq!(TimeInterval::Day.to_seconds(), 86400);
588 assert_eq!(TimeInterval::Custom(7200).to_seconds(), 7200);
589 }
590}