1use std::io::{Error, ErrorKind, Result, Write};
2use std::sync::atomic::{AtomicU64, Ordering};
3use std::{
4 hash::{BuildHasher, Hash},
5 io::BufWriter,
6};
7
8use serde::Deserialize;
9
10use crate::{
11 MassMap, MassMapBucketMeta, MassMapDefaultHashLoader, MassMapHashConfig, MassMapHashLoader,
12 MassMapHeader, MassMapInfo, MassMapMeta, MassMapReader, MassMapWriter,
13};
14
15#[derive(Debug)]
26pub struct MassMapBuilder<H: MassMapHashLoader = MassMapDefaultHashLoader> {
27 hash_config: MassMapHashConfig,
28 bucket_count: u64,
29 writer_buffer_size: usize,
30 field_names: bool,
31 bucket_size_limit: u32,
32 phantom: std::marker::PhantomData<H>,
33}
34
35impl<H: MassMapHashLoader> Default for MassMapBuilder<H> {
36 fn default() -> Self {
37 Self {
38 hash_config: MassMapHashConfig::default(),
39 bucket_count: 1024,
40 writer_buffer_size: 16 << 20, field_names: false,
42 bucket_size_limit: u32::MAX,
43 phantom: std::marker::PhantomData,
44 }
45 }
46}
47
48impl MassMapBuilder {
49 #[allow(clippy::should_implement_trait)]
51 pub fn default() -> Self {
52 <Self as Default>::default()
53 }
54}
55
56impl<H: MassMapHashLoader> MassMapBuilder<H> {
57 pub fn with_hash_config(mut self, config: MassMapHashConfig) -> Self {
66 self.hash_config = config;
67 self
68 }
69
70 pub fn with_hash_seed(mut self, seed: u64) -> Self {
72 self.hash_config.parameters["seed"] = serde_json::json!(seed);
73 self
74 }
75
76 pub fn with_bucket_count(mut self, count: u64) -> Self {
81 self.bucket_count = count;
82 self
83 }
84
85 pub fn with_writer_buffer_size(mut self, size: usize) -> Self {
87 self.writer_buffer_size = size;
88 self
89 }
90
91 pub fn with_field_names(mut self, value: bool) -> Self {
96 self.field_names = value;
97 self
98 }
99
100 pub fn with_bucket_size_limit(mut self, limit: u32) -> Self {
106 self.bucket_size_limit = limit;
107 self
108 }
109
110 pub fn build<W, K, V>(
136 self,
137 writer: &W,
138 entries: impl Iterator<Item = impl std::borrow::Borrow<(K, V)>>,
139 ) -> std::io::Result<MassMapInfo>
140 where
141 W: MassMapWriter,
142 K: serde::Serialize + Clone + std::hash::Hash + Eq,
143 V: serde::Serialize + Clone,
144 {
145 let build_hasher = H::load(&self.hash_config)?;
146
147 let mut buckets: Vec<Vec<(K, V)>> = vec![Vec::new(); self.bucket_count as usize];
148 let mut entry_count: u64 = 0;
149 for entry in entries {
150 let (key, value) = entry.borrow();
151 let bucket_index = build_hasher.hash_one(key) % self.bucket_count;
152 buckets[bucket_index as usize].push((key.clone(), value.clone()));
153 entry_count += 1;
154 }
155
156 let mut bucket_metas: Vec<MassMapBucketMeta> =
157 Vec::with_capacity(self.bucket_count as usize);
158
159 let offset = AtomicU64::new(MassMapHeader::SIZE as u64);
160 let mut buf_writer = BufWriter::with_capacity(
161 self.writer_buffer_size,
162 MassMapWriterWrapper {
163 inner: writer,
164 offset: &offset,
165 },
166 );
167 let mut occupied_bucket_count = 0;
168 let mut occupied_bucket_range = 0..0;
169 for (i, bucket) in buckets.into_iter().enumerate() {
170 if bucket.is_empty() {
171 bucket_metas.push(MassMapBucketMeta {
172 offset: 0,
173 length: 0,
174 count: 0,
175 });
176 continue;
177 }
178
179 occupied_bucket_count += 1;
180 if occupied_bucket_range.is_empty() {
181 occupied_bucket_range.start = i as u64;
182 }
183 occupied_bucket_range.end = i as u64 + 1;
184
185 let begin_offset = offset.load(Ordering::Relaxed) + buf_writer.buffer().len() as u64;
186
187 let result = if self.field_names {
188 rmp_serde::encode::write_named(&mut buf_writer, &bucket)
189 } else {
190 rmp_serde::encode::write(&mut buf_writer, &bucket)
191 };
192 result.map_err(|e| Error::other(format!("Fail to serialize bucket: {}", e)))?;
193
194 let end_offset = offset.load(Ordering::Relaxed) + buf_writer.buffer().len() as u64;
195 if end_offset - begin_offset > self.bucket_size_limit as u64 {
196 return Err(Error::new(
197 ErrorKind::InvalidData,
198 format!("bucket size exceeds {}", self.bucket_size_limit),
199 ));
200 }
201
202 bucket_metas.push(MassMapBucketMeta {
203 offset: begin_offset,
204 length: (end_offset - begin_offset) as u32,
205 count: bucket.len() as u32,
206 });
207 }
208
209 let meta = MassMapMeta {
210 hash_config: self.hash_config,
211 entry_count,
212 bucket_count: self.bucket_count,
213 occupied_bucket_count,
214 occupied_bucket_range,
215 key_type: std::any::type_name::<K>().to_string(),
216 value_type: std::any::type_name::<V>().to_string(),
217 };
218
219 let meta_offset = offset.load(Ordering::Relaxed) + buf_writer.buffer().len() as u64;
220 rmp_serde::encode::write(&mut buf_writer, &(meta.clone(), bucket_metas))
221 .map_err(|e| Error::other(format!("Fail to serialize meta: {}", e)))?;
222 let finished_offset = offset.load(Ordering::Relaxed) + buf_writer.buffer().len() as u64;
223 buf_writer.flush()?;
224
225 let meta_length = finished_offset - meta_offset;
226 let header = MassMapHeader {
227 meta_offset,
228 meta_length,
229 };
230 writer.write_all_at(&header.serialize(), 0)?;
231
232 Ok(MassMapInfo { header, meta })
233 }
234}
235
236pub struct MassMapWriterWrapper<'a, W: MassMapWriter> {
242 inner: &'a W,
243 offset: &'a AtomicU64,
244}
245
246impl<'a, W: MassMapWriter> std::io::Write for MassMapWriterWrapper<'a, W> {
247 fn write(&mut self, buf: &[u8]) -> Result<usize> {
248 let offset = self.offset.fetch_add(buf.len() as u64, Ordering::Relaxed);
249 self.inner.write_all_at(buf, offset)?;
250 Ok(buf.len())
251 }
252
253 fn flush(&mut self) -> Result<()> {
254 Ok(())
255 }
256}
257
258#[derive(Debug)]
259pub struct MassMapMerger {
260 writer_buffer_size: usize,
261}
262
263impl Default for MassMapMerger {
264 fn default() -> Self {
265 Self {
266 writer_buffer_size: 16 << 20, }
268 }
269}
270
271impl MassMapMerger {
272 pub fn with_writer_buffer_size(mut self, size: usize) -> Self {
274 self.writer_buffer_size = size;
275 self
276 }
277}
278
279impl MassMapMerger {
280 pub fn merge<W, K, V, R: MassMapReader, H: MassMapHashLoader>(
281 self,
282 writer: &W,
283 mut maps: Vec<MassMap<K, V, R, H>>,
284 ) -> Result<MassMapInfo>
285 where
286 W: MassMapWriter,
287 K: for<'de> Deserialize<'de> + Eq + Hash,
288 V: for<'de> Deserialize<'de> + Clone,
289 {
290 if maps.is_empty() {
291 return Err(Error::new(
292 ErrorKind::InvalidData,
293 "No massmaps provided for merging",
294 ));
295 }
296
297 maps.sort_by_key(|m| m.meta().occupied_bucket_range.start);
298
299 let mut entry_count = 0;
300 let mut bucket_metas =
301 vec![MassMapBucketMeta::default(); maps[0].meta().bucket_count as usize];
302 let hash_config = maps[0].meta().hash_config.clone();
303 let mut occupied_bucket_count = 0;
304 let mut occupied_bucket_range = 0..0;
305 let mut global_offset = 0u64;
306
307 for map in &maps {
308 if map.meta().hash_config != hash_config {
309 return Err(Error::new(
310 ErrorKind::InvalidData,
311 "Incompatible hash configurations between massmaps",
312 ));
313 }
314 if map.meta().bucket_count != bucket_metas.len() as u64 {
315 return Err(Error::new(
316 ErrorKind::InvalidData,
317 "Incompatible bucket counts between massmaps",
318 ));
319 }
320
321 if map.meta().entry_count == 0 {
322 continue;
323 }
324
325 occupied_bucket_count += map.meta().occupied_bucket_count;
326 if occupied_bucket_range.is_empty() {
327 occupied_bucket_range = map.meta().occupied_bucket_range.clone();
328 } else if occupied_bucket_range.end <= map.meta().occupied_bucket_range.start {
329 occupied_bucket_range.end = map.meta().occupied_bucket_range.end;
330 } else {
331 return Err(Error::new(
332 ErrorKind::InvalidData,
333 "Overlapping occupied bucket ranges between massmaps",
334 ));
335 }
336
337 for idx in map.meta().occupied_bucket_range.clone() {
339 let bucket_meta = &mut bucket_metas[idx as usize];
340 *bucket_meta = map.bucket_metas()[idx as usize];
341 if bucket_meta.count > 0 {
342 bucket_meta.offset += global_offset;
343 }
344 }
345 entry_count += map.meta().entry_count;
346
347 let mut current_offset = MassMapHeader::SIZE as u64;
349 let finished_offset = map.header().meta_offset;
350 while current_offset < finished_offset {
351 let chunk = std::cmp::min(
352 finished_offset - current_offset,
353 self.writer_buffer_size as u64,
354 );
355 map.reader().read_exact_at(current_offset, chunk, |data| {
356 writer.write_all_at(data, global_offset + MassMapHeader::SIZE as u64)?;
357 Ok(())
358 })?;
359 current_offset += chunk;
360 global_offset += chunk;
361 }
362 }
363
364 let meta = MassMapMeta {
365 hash_config,
366 entry_count,
367 bucket_count: bucket_metas.len() as u64,
368 occupied_bucket_count,
369 occupied_bucket_range,
370 key_type: std::any::type_name::<K>().to_string(),
371 value_type: std::any::type_name::<V>().to_string(),
372 };
373
374 let meta_offset = global_offset + MassMapHeader::SIZE as u64;
375 let offset = AtomicU64::new(meta_offset);
376 let mut buf_writer = BufWriter::with_capacity(
377 self.writer_buffer_size,
378 MassMapWriterWrapper {
379 inner: writer,
380 offset: &offset,
381 },
382 );
383
384 rmp_serde::encode::write(&mut buf_writer, &(meta.clone(), bucket_metas))
385 .map_err(|e| Error::other(format!("Fail to serialize meta: {}", e)))?;
386 buf_writer.flush()?;
387 let finished_offset = offset.load(Ordering::Relaxed);
388
389 let meta_length = finished_offset - meta_offset;
390 let header = MassMapHeader {
391 meta_offset,
392 meta_length,
393 };
394 writer.write_all_at(&header.serialize(), 0)?;
395
396 Ok(MassMapInfo { header, meta })
397 }
398}
399
400#[cfg(test)]
401mod tests {
402 use std::{fs::File, hash::Hasher, sync::Arc};
403
404 use crate::*;
405
406 #[derive(Debug)]
407 struct MemoryWriter {
408 data: std::sync::Mutex<Vec<u8>>,
409 limit: u64,
410 }
411
412 impl MemoryWriter {
413 fn new(limit: u64) -> Self {
414 Self {
415 data: std::sync::Mutex::new(Vec::new()),
416 limit,
417 }
418 }
419
420 fn read_at(&self, buf: &mut [u8], offset: u64) -> std::io::Result<usize> {
421 let data = self.data.lock().unwrap();
422 let available = data.len() - std::cmp::min(offset as usize, data.len());
423 let to_read = std::cmp::min(buf.len(), available);
424 buf[..to_read].copy_from_slice(&data[offset as usize..offset as usize + to_read]);
425 Ok(to_read)
426 }
427 }
428
429 #[cfg(unix)]
430 impl std::os::unix::fs::FileExt for MemoryWriter {
431 fn read_at(&self, buf: &mut [u8], offset: u64) -> std::io::Result<usize> {
432 self.read_at(buf, offset)
433 }
434 fn write_at(&self, mut buf: &[u8], offset: u64) -> std::io::Result<usize> {
435 if offset > self.limit {
436 return Err(std::io::Error::new(
437 std::io::ErrorKind::WriteZero,
438 "Write exceeds limit",
439 ));
440 }
441 if buf.len() as u64 + offset > self.limit {
442 buf = &buf[..(self.limit - offset) as usize];
443 }
444
445 let mut data = self.data.lock().unwrap();
446 if data.len() < (offset as usize + buf.len()) {
447 data.resize(offset as usize + buf.len(), 0);
448 }
449 data[offset as usize..offset as usize + buf.len()].copy_from_slice(buf);
450 Ok(buf.len())
451 }
452 }
453
454 #[cfg(windows)]
455 impl std::os::windows::fs::FileExt for MemoryWriter {
456 fn seek_read(&self, buf: &mut [u8], offset: u64) -> std::io::Result<usize> {
457 self.read_at(buf, offset)
458 }
459 fn seek_write(&self, mut buf: &[u8], offset: u64) -> std::io::Result<usize> {
460 if offset > self.limit {
461 return Err(std::io::Error::new(
462 std::io::ErrorKind::WriteZero,
463 "Write exceeds limit",
464 ));
465 }
466 if buf.len() as u64 + offset > self.limit {
467 buf = &buf[..(self.limit - offset) as usize];
468 }
469
470 let mut data = self.data.lock().unwrap();
471 if data.len() < (offset as usize + buf.len()) {
472 data.resize(offset as usize + buf.len(), 0);
473 }
474 data[offset as usize..offset as usize + buf.len()].copy_from_slice(buf);
475 Ok(buf.len())
476 }
477 }
478
479 #[test]
480 fn test_shorter_write() {
481 const SUFFICIENT_CAPACITY: u64 = 6400;
483 const INSUFFICIENT_CAPACITY: u64 = 6000;
484 const N: u64 = 1000;
485
486 let entries = (0..N).map(|i| (i, i));
487 let writer = MemoryWriter::new(SUFFICIENT_CAPACITY);
488 let hash_config = MassMapHashConfig {
489 name: "foldhash".to_string(),
490 parameters: serde_json::json!({ "seed": 42 }),
491 };
492 let builder = MassMapBuilder::default()
493 .with_bucket_count(1)
494 .with_hash_config(hash_config);
495 builder.build(&writer, entries).unwrap();
496
497 let map = MassMap::<u64, u64, _>::load(writer).unwrap();
498 for i in 0..N {
499 let value = map.get(&i).unwrap().unwrap();
500 assert_eq!(value, i);
501 }
502
503 let entries = (0..N).map(|i| (i, i));
504 let writer = MemoryWriter::new(INSUFFICIENT_CAPACITY);
505 let builder = MassMapBuilder::default().with_bucket_count(1);
506 builder.build(&writer, entries).unwrap_err();
507 }
508
509 pub struct SimpleHasher {
510 state: u64,
511 modulo: u64,
512 }
513
514 impl SimpleHasher {
515 pub fn new(modulo: u64) -> Self {
516 SimpleHasher { state: 0, modulo }
517 }
518 }
519
520 impl Hasher for SimpleHasher {
521 fn finish(&self) -> u64 {
522 self.state % self.modulo
523 }
524
525 fn write(&mut self, bytes: &[u8]) {
526 for &byte in bytes.iter().rev() {
527 self.state = self.state.wrapping_mul(256).wrapping_add(byte as u64);
528 }
529 }
530 }
531
532 struct SimpleBuildHasher {
533 modulo: u64,
534 }
535
536 impl std::hash::BuildHasher for SimpleBuildHasher {
537 type Hasher = SimpleHasher;
538
539 fn build_hasher(&self) -> Self::Hasher {
540 SimpleHasher::new(self.modulo)
541 }
542 }
543
544 struct SimpleHashLoader;
545
546 impl MassMapHashLoader for SimpleHashLoader {
547 type BuildHasher = SimpleBuildHasher;
548
549 fn load(config: &MassMapHashConfig) -> std::io::Result<Self::BuildHasher> {
550 let modulo = config
551 .parameters
552 .get("modulo")
553 .and_then(|v| v.as_u64())
554 .unwrap_or(10000);
555 Ok(SimpleBuildHasher { modulo })
556 }
557 }
558
559 fn create_simple_map(
560 entries: impl Iterator<Item = (u64, u64)>,
561 bucket_count: u64,
562 hash_modulo: u64,
563 ) -> MassMap<u64, u64, MemoryWriter, SimpleHashLoader> {
564 let writer = MemoryWriter::new(10 << 20); let hash_config = MassMapHashConfig {
566 name: "simplehash".to_string(),
567 parameters: serde_json::json!({
568 "modulo": hash_modulo
569 }),
570 };
571 let builder = MassMapBuilder::<SimpleHashLoader>::default()
572 .with_bucket_count(bucket_count)
573 .with_hash_config(hash_config);
574 builder.build(&writer, entries).unwrap();
575
576 MassMap::<u64, u64, _, SimpleHashLoader>::load(writer).unwrap()
577 }
578
579 #[test]
580 fn test_normal_merge() {
581 let dir = tempfile::tempdir().unwrap();
582
583 const M: u64 = 10000;
584 const N: u64 = 100_000;
585 const P: u64 = 10;
586
587 let mut threads = Vec::with_capacity(P as usize);
588 for i in 0..P {
589 threads.push(std::thread::spawn(move || {
590 let entries = (0..N).filter(|v| (v % M) / (M / P) == i).map(|v| (v, v));
591 let map = create_simple_map(entries, M, M);
592 assert_eq!(map.meta().occupied_bucket_count, M / P);
593 assert_eq!(map.meta().entry_count, N / P);
594 assert_eq!(map.meta().occupied_bucket_range.start, (M / P) * i);
595
596 for item in map.iter() {
597 let (k, v) = item.unwrap();
598 assert_eq!(k, v);
599 }
600 map
601 }));
602 }
603
604 let mut maps = threads
605 .into_iter()
606 .map(|t| t.join().unwrap())
607 .collect::<Vec<_>>();
608 maps.push(create_simple_map((0..0).map(|v| (v, v)), M, M));
609
610 let path = dir.path().join("merge.massmap");
611 let writer = std::fs::File::create(&path).unwrap();
612 MassMapMerger::default().merge(&writer, maps).unwrap();
613
614 let reader = std::fs::File::open(&path).unwrap();
615 let map = MassMap::<u64, u64, _, SimpleHashLoader>::load(reader).unwrap();
616 assert_eq!(map.len(), N);
617 let map = Arc::new(map);
618
619 let mut threads = Vec::with_capacity(P as usize);
620 for i in 0..P {
621 const CHUNK: u64 = N / P;
622 let range = CHUNK * i..CHUNK * (i + 1);
623 let map = map.clone();
624 threads.push(std::thread::spawn(move || {
625 for v in range {
626 assert_eq!(map.get(&v).unwrap().unwrap(), v);
627 }
628 }));
629 }
630
631 for thread in threads {
632 thread.join().unwrap();
633 }
634 }
635
636 #[test]
637 fn test_invalid_merge() {
638 {
640 let map1 = create_simple_map((0..1000).map(|i| (i, i)), 1024, 10000);
641 let map2 = create_simple_map((1000..2000).map(|i| (i, i)), 1024, 20000);
642 let writer = MemoryWriter::new(10 << 20); MassMapMerger::default()
644 .with_writer_buffer_size(1 << 20)
645 .merge(&writer, vec![map1, map2])
646 .unwrap_err();
647 }
648
649 {
651 let map1 = create_simple_map((0..1000).map(|i| (i, i)), 1024, 10000);
652 let map2 = create_simple_map((1000..2000).map(|i| (i, i)), 2048, 10000);
653 let writer = MemoryWriter::new(10 << 20); MassMapMerger::default()
655 .merge(&writer, vec![map1, map2])
656 .unwrap_err();
657 }
658
659 {
661 let map1 = create_simple_map((0..1000).map(|i| (i, i)), 1024, 10000);
662 let map2 = create_simple_map((500..1500).map(|i| (i, i)), 1024, 10000);
663 let writer = MemoryWriter::new(10 << 20); MassMapMerger::default()
665 .merge(&writer, vec![map1, map2])
666 .unwrap_err();
667 }
668
669 {
671 let writer = MemoryWriter::new(10 << 20); MassMapMerger::default()
673 .merge::<_, u64, u64, File, SimpleHashLoader>(&writer, vec![])
674 .unwrap_err();
675 }
676 }
677}