1use crate::hash::CryptoHash;
14use crate::types::{AccountId, NumShards};
15use borsh::{BorshDeserialize, BorshSerialize};
16use itertools::Itertools;
17use near_primitives_core::types::{ShardId, ShardIndex};
18use near_schema_checker_lib::ProtocolSchema;
19use std::collections::{BTreeMap, BTreeSet};
20use std::{fmt, str};
21
22pub type ShardVersion = u32;
29
30#[derive(
39 BorshSerialize,
40 BorshDeserialize,
41 serde::Serialize,
42 serde::Deserialize,
43 Clone,
44 Debug,
45 PartialEq,
46 Eq,
47 ProtocolSchema,
48)]
49#[borsh(use_discriminant = true)]
50#[repr(u8)]
51#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))]
52pub enum ShardLayout {
53 V0(ShardLayoutV0) = 0,
54 V1(ShardLayoutV1) = 1,
55 V2(ShardLayoutV2) = 2,
56}
57
58#[derive(
64 BorshSerialize,
65 BorshDeserialize,
66 serde::Serialize,
67 serde::Deserialize,
68 Clone,
69 Debug,
70 PartialEq,
71 Eq,
72 ProtocolSchema,
73)]
74#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))]
75pub struct ShardLayoutV0 {
76 num_shards: NumShards,
78 version: ShardVersion,
80}
81
82type ShardsSplitMap = Vec<Vec<ShardId>>;
89
90type ShardsSplitMapV2 = BTreeMap<ShardId, Vec<ShardId>>;
97
98type ShardsParentMapV2 = BTreeMap<ShardId, ShardId>;
100
101pub fn shard_uids_to_ids(shard_uids: &[ShardUId]) -> Vec<ShardId> {
102 shard_uids.iter().map(|shard_uid| shard_uid.shard_id()).collect_vec()
103}
104
105#[derive(
106 BorshSerialize,
107 BorshDeserialize,
108 serde::Serialize,
109 serde::Deserialize,
110 Clone,
111 Debug,
112 PartialEq,
113 Eq,
114 ProtocolSchema,
115)]
116#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))]
117pub struct ShardLayoutV1 {
118 boundary_accounts: Vec<AccountId>,
123 shards_split_map: Option<ShardsSplitMap>,
127 to_parent_shard_map: Option<Vec<ShardId>>,
130 version: ShardVersion,
132}
133
134impl ShardLayoutV1 {
135 fn account_id_to_shard_id(&self, account_id: &AccountId) -> ShardId {
138 let mut shard_id: u64 = 0;
139 for boundary_account in &self.boundary_accounts {
140 if account_id < boundary_account {
141 break;
142 }
143 shard_id += 1;
144 }
145 shard_id.into()
146 }
147}
148
149#[derive(BorshSerialize, BorshDeserialize, Clone, Debug, PartialEq, Eq, ProtocolSchema)]
151pub struct ShardLayoutV2 {
152 boundary_accounts: Vec<AccountId>,
159
160 shard_ids: Vec<ShardId>,
167
168 id_to_index_map: BTreeMap<ShardId, ShardIndex>,
170
171 index_to_id_map: BTreeMap<ShardIndex, ShardId>,
174
175 shards_split_map: Option<ShardsSplitMapV2>,
178 shards_parent_map: Option<ShardsParentMapV2>,
181
182 version: ShardVersion,
186}
187
188#[derive(serde::Serialize, serde::Deserialize)]
191#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))]
192struct SerdeShardLayoutV2 {
193 boundary_accounts: Vec<AccountId>,
194 shard_ids: Vec<ShardId>,
195 id_to_index_map: BTreeMap<String, ShardIndex>,
196 index_to_id_map: BTreeMap<String, ShardId>,
197 shards_split_map: Option<BTreeMap<String, Vec<ShardId>>>,
198 shards_parent_map: Option<BTreeMap<String, ShardId>>,
199 version: ShardVersion,
200}
201
202impl From<&ShardLayoutV2> for SerdeShardLayoutV2 {
203 fn from(layout: &ShardLayoutV2) -> Self {
204 fn key_to_string<K, V>(map: &BTreeMap<K, V>) -> BTreeMap<String, V>
205 where
206 K: std::fmt::Display,
207 V: Clone,
208 {
209 map.iter().map(|(k, v)| (k.to_string(), v.clone())).collect()
210 }
211
212 Self {
213 boundary_accounts: layout.boundary_accounts.clone(),
214 shard_ids: layout.shard_ids.clone(),
215 id_to_index_map: key_to_string(&layout.id_to_index_map),
216 index_to_id_map: key_to_string(&layout.index_to_id_map),
217 shards_split_map: layout.shards_split_map.as_ref().map(key_to_string),
218 shards_parent_map: layout.shards_parent_map.as_ref().map(key_to_string),
219 version: layout.version,
220 }
221 }
222}
223
224impl TryFrom<SerdeShardLayoutV2> for ShardLayoutV2 {
225 type Error = Box<dyn std::error::Error + Send + Sync>;
226
227 fn try_from(layout: SerdeShardLayoutV2) -> Result<Self, Self::Error> {
228 fn key_to_shard_id<V>(
229 map: BTreeMap<String, V>,
230 ) -> Result<BTreeMap<ShardId, V>, Box<dyn std::error::Error + Send + Sync>> {
231 map.into_iter().map(|(k, v)| Ok((k.parse::<u64>()?.into(), v))).collect()
232 }
233
234 let SerdeShardLayoutV2 {
235 boundary_accounts,
236 shard_ids,
237 id_to_index_map,
238 index_to_id_map,
239 shards_split_map,
240 shards_parent_map,
241 version,
242 } = layout;
243
244 let id_to_index_map = key_to_shard_id(id_to_index_map)?;
245 let shards_split_map = shards_split_map.map(key_to_shard_id).transpose()?;
246 let shards_parent_map = shards_parent_map.map(key_to_shard_id).transpose()?;
247 let index_to_id_map = index_to_id_map
248 .into_iter()
249 .map(|(k, v)| Ok((k.parse()?, v)))
250 .collect::<Result<_, Self::Error>>()?;
251
252 match (&shards_split_map, &shards_parent_map) {
253 (None, None) => {}
254 (Some(shard_split_map), Some(shards_parent_map)) => {
255 let expected_shards_parent_map =
256 validate_and_derive_shard_parent_map_v2(&shard_ids, &shard_split_map);
257 if &expected_shards_parent_map != shards_parent_map {
258 return Err("shards_parent_map does not match the expected value".into());
259 }
260 }
261 _ => {
262 return Err(
263 "shards_split_map and shards_parent_map must be both present or both absent"
264 .into(),
265 );
266 }
267 }
268
269 Ok(Self {
270 boundary_accounts,
271 shard_ids,
272 id_to_index_map,
273 index_to_id_map,
274 shards_split_map,
275 shards_parent_map,
276 version,
277 })
278 }
279}
280
281impl serde::Serialize for ShardLayoutV2 {
282 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
283 where
284 S: serde::Serializer,
285 {
286 SerdeShardLayoutV2::from(self).serialize(serializer)
287 }
288}
289
290impl<'de> serde::Deserialize<'de> for ShardLayoutV2 {
291 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
292 where
293 D: serde::Deserializer<'de>,
294 {
295 let serde_layout = SerdeShardLayoutV2::deserialize(deserializer)?;
296 ShardLayoutV2::try_from(serde_layout).map_err(serde::de::Error::custom)
297 }
298}
299
300#[cfg(feature = "schemars")]
301impl schemars::JsonSchema for ShardLayoutV2 {
302 fn schema_name() -> std::borrow::Cow<'static, str> {
303 "ShardLayoutV2".to_string().into()
304 }
305
306 fn json_schema(generator: &mut schemars::SchemaGenerator) -> schemars::Schema {
307 SerdeShardLayoutV2::json_schema(generator)
308 }
309}
310
311impl ShardLayoutV2 {
312 pub fn account_id_to_shard_id(&self, account_id: &AccountId) -> ShardId {
313 let mut shard_id_index = 0;
316 for boundary_account in &self.boundary_accounts {
317 if account_id < boundary_account {
318 break;
319 }
320 shard_id_index += 1;
321 }
322 self.shard_ids[shard_id_index]
323 }
324
325 pub fn shards_split_map(&self) -> &Option<ShardsSplitMapV2> {
326 &self.shards_split_map
327 }
328
329 pub fn boundary_accounts(&self) -> &Vec<AccountId> {
330 &self.boundary_accounts
331 }
332}
333
334#[derive(Debug)]
335pub enum ShardLayoutError {
336 InvalidShardIdError { shard_id: ShardId },
337 InvalidShardIndexError { shard_index: ShardIndex },
338 NoParentError { shard_id: ShardId },
339}
340
341impl fmt::Display for ShardLayoutError {
342 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
343 write!(f, "{:?}", self)
344 }
345}
346
347impl std::error::Error for ShardLayoutError {}
348
349impl ShardLayout {
350 pub fn single_shard() -> Self {
352 let shard_id = ShardId::new(0);
353 Self::V2(ShardLayoutV2 {
354 boundary_accounts: vec![],
355 shard_ids: vec![shard_id],
356 id_to_index_map: [(shard_id, 0)].into(),
357 index_to_id_map: [(0, shard_id)].into(),
358 shards_split_map: None,
359 shards_parent_map: None,
360 version: 0,
361 })
362 }
363
364 #[cfg(all(feature = "test_utils", feature = "rand"))]
369 pub fn multi_shard(num_shards: NumShards, version: ShardVersion) -> Self {
370 assert!(num_shards > 0, "at least 1 shard is required");
371
372 let boundary_accounts = (1..num_shards)
373 .map(|i| format!("test{}", i).parse().unwrap())
374 .collect::<Vec<AccountId>>();
375
376 Self::multi_shard_custom(boundary_accounts, version)
377 }
378
379 #[cfg(all(feature = "test_utils", feature = "rand"))]
384 pub fn multi_shard_custom(boundary_accounts: Vec<AccountId>, version: ShardVersion) -> Self {
385 use rand::{SeedableRng, rngs::StdRng, seq::SliceRandom};
386
387 let num_shards = (boundary_accounts.len() + 1) as u64;
388
389 let mut rng = StdRng::seed_from_u64(42);
392 let mut shard_ids = (0..num_shards).map(ShardId::new).collect::<Vec<ShardId>>();
393 shard_ids.shuffle(&mut rng);
394
395 let (id_to_index_map, index_to_id_map) = shard_ids
396 .iter()
397 .enumerate()
398 .map(|(i, &shard_id)| ((shard_id, i), (i, shard_id)))
399 .unzip();
400
401 Self::V2(ShardLayoutV2 {
402 boundary_accounts,
403 shard_ids,
404 id_to_index_map,
405 index_to_id_map,
406 shards_split_map: None,
407 shards_parent_map: None,
408 version,
409 })
410 }
411
412 #[deprecated(note = "Use multi_shard() instead")]
414 pub fn v0(num_shards: NumShards, version: ShardVersion) -> Self {
415 Self::V0(ShardLayoutV0 { num_shards, version })
416 }
417
418 #[deprecated(note = "Use multi_shard() instead")]
420 pub fn v1(
421 boundary_accounts: Vec<AccountId>,
422 shards_split_map: Option<ShardsSplitMap>,
423 version: ShardVersion,
424 ) -> Self {
425 let to_parent_shard_map = if let Some(shards_split_map) = &shards_split_map {
426 let mut to_parent_shard_map = BTreeMap::new();
427 let num_shards = (boundary_accounts.len() + 1) as NumShards;
428 for (parent_shard_id, shard_ids) in shards_split_map.iter().enumerate() {
429 let parent_shard_id = ShardId::new(parent_shard_id as u64);
430 for &shard_id in shard_ids {
431 let prev = to_parent_shard_map.insert(shard_id, parent_shard_id);
432 assert!(prev.is_none(), "no shard should appear in the map twice");
433 let shard_id: u64 = shard_id.into();
434 assert!(shard_id < num_shards, "shard id should be valid");
435 }
436 }
437 Some((0..num_shards).map(|shard_id| to_parent_shard_map[&shard_id.into()]).collect())
438 } else {
439 None
440 };
441 Self::V1(ShardLayoutV1 {
442 boundary_accounts,
443 shards_split_map,
444 to_parent_shard_map,
445 version,
446 })
447 }
448
449 pub fn v2(
451 boundary_accounts: Vec<AccountId>,
452 shard_ids: Vec<ShardId>,
453 shards_split_map: Option<ShardsSplitMapV2>,
454 ) -> Self {
455 const VERSION: ShardVersion = 3;
457
458 assert_eq!(boundary_accounts.len() + 1, shard_ids.len());
459 assert_eq!(boundary_accounts, boundary_accounts.iter().sorted().cloned().collect_vec());
460
461 let mut id_to_index_map = BTreeMap::new();
462 let mut index_to_id_map = BTreeMap::new();
463 for (shard_index, &shard_id) in shard_ids.iter().enumerate() {
464 id_to_index_map.insert(shard_id, shard_index);
465 index_to_id_map.insert(shard_index, shard_id);
466 }
467
468 let shards_parent_map = shards_split_map.as_ref().map(|shards_split_map| {
469 validate_and_derive_shard_parent_map_v2(&shard_ids, &shards_split_map)
470 });
471
472 Self::V2(ShardLayoutV2 {
473 boundary_accounts,
474 shard_ids,
475 id_to_index_map,
476 index_to_id_map,
477 shards_split_map,
478 shards_parent_map,
479 version: VERSION,
480 })
481 }
482
483 pub fn account_id_to_shard_id(&self, account_id: &AccountId) -> ShardId {
489 match self {
490 ShardLayout::V0(v0) => {
491 let hash = CryptoHash::hash_bytes(account_id.as_bytes());
492 let (bytes, _) = stdx::split_array::<32, 8, 24>(hash.as_bytes());
493 let shard_id = u64::from_le_bytes(*bytes) % v0.num_shards;
494 shard_id.into()
495 }
496 ShardLayout::V1(v1) => v1.account_id_to_shard_id(account_id),
497 ShardLayout::V2(v2) => v2.account_id_to_shard_id(account_id),
498 }
499 }
500
501 #[inline]
503 pub fn account_id_to_shard_uid(&self, account_id: &AccountId) -> ShardUId {
504 ShardUId::from_shard_id_and_layout(self.account_id_to_shard_id(account_id), self)
505 }
506
507 #[inline]
510 pub fn get_children_shards_uids(&self, parent_shard_id: ShardId) -> Option<Vec<ShardUId>> {
511 self.get_children_shards_ids(parent_shard_id).map(|shards| {
512 shards.into_iter().map(|id| ShardUId::from_shard_id_and_layout(id, self)).collect()
513 })
514 }
515
516 pub fn get_children_shards_ids(&self, parent_shard_id: ShardId) -> Option<Vec<ShardId>> {
519 match self {
520 Self::V0(_) => None,
521 Self::V1(v1) => match &v1.shards_split_map {
522 Some(shards_split_map) => {
523 let parent_shard_index: ShardIndex = parent_shard_id.into();
527 shards_split_map.get(parent_shard_index).cloned()
528 }
529 None => None,
530 },
531 Self::V2(v2) => match &v2.shards_split_map {
532 Some(shards_split_map) => shards_split_map.get(&parent_shard_id).cloned(),
533 None => None,
534 },
535 }
536 }
537
538 pub fn try_get_parent_shard_id(
542 &self,
543 shard_id: ShardId,
544 ) -> Result<Option<ShardId>, ShardLayoutError> {
545 if !self.shard_ids().any(|id| id == shard_id) {
546 return Err(ShardLayoutError::InvalidShardIdError { shard_id });
547 }
548 let parent_shard_id = match self {
549 Self::V0(_) => None,
550 Self::V1(v1) => match &v1.to_parent_shard_map {
551 Some(to_parent_shard_map) => {
554 let shard_index = self.get_shard_index(shard_id).unwrap();
555 let parent_shard_id = to_parent_shard_map.get(shard_index).unwrap();
556 Some(*parent_shard_id)
557 }
558 None => None,
559 },
560 Self::V2(v2) => match &v2.shards_parent_map {
561 Some(to_parent_shard_map) => {
564 let parent_shard_id = to_parent_shard_map.get(&shard_id).unwrap();
565 Some(*parent_shard_id)
566 }
567 None => None,
568 },
569 };
570 Ok(parent_shard_id)
571 }
572
573 pub fn get_parent_shard_id(&self, shard_id: ShardId) -> Result<ShardId, ShardLayoutError> {
578 let parent_shard_id = self.try_get_parent_shard_id(shard_id)?;
579 parent_shard_id.ok_or(ShardLayoutError::NoParentError { shard_id })
580 }
581
582 pub fn derive_shard_layout(
584 base_shard_layout: &ShardLayout,
585 new_boundary_account: AccountId,
586 ) -> ShardLayout {
587 let mut boundary_accounts = base_shard_layout.boundary_accounts().clone();
588 let mut shard_ids = base_shard_layout.shard_ids().collect::<Vec<_>>();
589 let mut shards_split_map = shard_ids
590 .iter()
591 .map(|id| (*id, vec![*id]))
592 .collect::<BTreeMap<ShardId, Vec<ShardId>>>();
593
594 assert!(!boundary_accounts.contains(&new_boundary_account), "duplicated boundary account");
595
596 boundary_accounts.push(new_boundary_account.clone());
598 boundary_accounts.sort();
599 let new_boundary_account_index = boundary_accounts
600 .iter()
601 .position(|acc| acc == &new_boundary_account)
602 .expect("account should be guaranteed to exist at this point");
603
604 let max_shard_id =
606 *shard_ids.iter().max().expect("there should always be at least one shard");
607 let new_shards = vec![max_shard_id + 1, max_shard_id + 2];
608 let parent_shard_id = shard_ids
609 .splice(new_boundary_account_index..new_boundary_account_index + 1, new_shards.clone())
610 .collect::<Vec<_>>();
611 let [parent_shard_id] = parent_shard_id.as_slice() else {
612 panic!("should only splice one shard");
613 };
614 shards_split_map.insert(*parent_shard_id, new_shards);
615
616 ShardLayout::v2(boundary_accounts, shard_ids, Some(shards_split_map))
617 }
618
619 #[inline]
620 pub fn version(&self) -> ShardVersion {
621 match self {
622 Self::V0(v0) => v0.version,
623 Self::V1(v1) => v1.version,
624 Self::V2(v2) => v2.version,
625 }
626 }
627
628 pub fn boundary_accounts(&self) -> &Vec<AccountId> {
629 match self {
630 Self::V1(v1) => &v1.boundary_accounts,
631 Self::V2(v2) => &v2.boundary_accounts,
632 _ => panic!("ShardLayout::V0 doesn't have boundary accounts"),
633 }
634 }
635
636 pub fn num_shards(&self) -> NumShards {
637 match self {
638 Self::V0(v0) => v0.num_shards,
639 Self::V1(v1) => (v1.boundary_accounts.len() + 1) as NumShards,
640 Self::V2(v2) => v2.shard_ids.len() as NumShards,
641 }
642 }
643
644 pub fn shard_ids(&self) -> impl Iterator<Item = ShardId> {
648 match self {
649 Self::V0(_) => (0..self.num_shards()).map(Into::into).collect_vec().into_iter(),
650 Self::V1(_) => (0..self.num_shards()).map(Into::into).collect_vec().into_iter(),
651 Self::V2(v2) => v2.shard_ids.clone().into_iter(),
652 }
653 }
654
655 pub fn shard_uids(&self) -> impl Iterator<Item = ShardUId> + '_ {
658 self.shard_ids().map(|shard_id| ShardUId::from_shard_id_and_layout(shard_id, self))
659 }
660
661 pub fn shard_indexes(&self) -> impl Iterator<Item = ShardIndex> + 'static {
662 let num_shards: usize =
663 self.num_shards().try_into().expect("Number of shards doesn't fit in usize");
664 match self {
665 Self::V0(_) | Self::V1(_) | Self::V2(_) => (0..num_shards).into_iter(),
666 }
667 }
668
669 pub fn shard_infos(&self) -> impl Iterator<Item = ShardInfo> + '_ {
674 self.shard_uids()
675 .enumerate()
676 .map(|(shard_index, shard_uid)| ShardInfo { shard_index, shard_uid })
677 }
678
679 pub fn get_shard_index(&self, shard_id: ShardId) -> Result<ShardIndex, ShardLayoutError> {
682 match self {
683 Self::V0(_) => Ok(shard_id.into()),
685 Self::V1(_) => Ok(shard_id.into()),
687 Self::V2(v2) => v2
689 .id_to_index_map
690 .get(&shard_id)
691 .copied()
692 .ok_or(ShardLayoutError::InvalidShardIdError { shard_id }),
693 }
694 }
695
696 pub fn get_shard_id(&self, shard_index: ShardIndex) -> Result<ShardId, ShardLayoutError> {
699 let num_shards = self.num_shards() as usize;
700 match self {
701 Self::V0(_) | Self::V1(_) => {
702 if shard_index >= num_shards {
703 return Err(ShardLayoutError::InvalidShardIndexError { shard_index });
704 }
705 Ok(ShardId::new(shard_index as u64))
706 }
707 Self::V2(v2) => v2
708 .shard_ids
709 .get(shard_index)
710 .copied()
711 .ok_or(ShardLayoutError::InvalidShardIndexError { shard_index }),
712 }
713 }
714
715 pub fn get_shard_uid(&self, shard_index: ShardIndex) -> Result<ShardUId, ShardLayoutError> {
716 let shard_id = self.get_shard_id(shard_index)?;
717 Ok(ShardUId::from_shard_id_and_layout(shard_id, self))
718 }
719
720 pub fn get_split_parent_shard_ids(&self) -> BTreeSet<ShardId> {
723 let mut parent_shard_ids = BTreeSet::new();
724 for shard_id in self.shard_ids() {
725 let parent_shard_id = self
726 .try_get_parent_shard_id(shard_id)
727 .expect("shard_id belongs to the shard layout");
728 let Some(parent_shard_id) = parent_shard_id else {
729 continue;
730 };
731 if parent_shard_id == shard_id {
732 continue;
733 }
734 parent_shard_ids.insert(parent_shard_id);
735 }
736 parent_shard_ids
737 }
738
739 pub fn get_split_parent_shard_uids(&self) -> BTreeSet<ShardUId> {
742 let parent_shard_ids = self.get_split_parent_shard_ids();
743 parent_shard_ids
744 .into_iter()
745 .map(|shard_id| ShardUId::new(self.version(), shard_id))
746 .collect()
747 }
748}
749
750fn validate_and_derive_shard_parent_map_v2(
752 shard_ids: &Vec<ShardId>,
753 shards_split_map: &ShardsSplitMapV2,
754) -> ShardsParentMapV2 {
755 let mut shards_parent_map = ShardsParentMapV2::new();
756 for (&parent_shard_id, child_shard_ids) in shards_split_map {
757 for &child_shard_id in child_shard_ids {
758 let prev = shards_parent_map.insert(child_shard_id, parent_shard_id);
759 assert!(prev.is_none(), "no shard should appear in the map twice");
760 }
761 if let &[child_shard_id] = child_shard_ids.as_slice() {
762 assert_eq!(parent_shard_id, child_shard_id);
765 } else {
766 assert!(!shard_ids.contains(&parent_shard_id));
769 }
770 }
771
772 assert_eq!(
773 shard_ids.iter().copied().sorted().collect_vec(),
774 shards_parent_map.keys().copied().collect_vec()
775 );
776 shards_parent_map
777}
778
779#[derive(
789 BorshSerialize,
790 BorshDeserialize,
791 Hash,
792 Clone,
793 Copy,
794 PartialEq,
795 Eq,
796 PartialOrd,
797 Ord,
798 ProtocolSchema,
799)]
800#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))]
801pub struct ShardUId {
802 pub version: ShardVersion,
803 pub shard_id: u32,
804}
805
806impl ShardUId {
807 pub fn new(version: ShardVersion, shard_id: ShardId) -> Self {
808 Self { version, shard_id: shard_id.into() }
809 }
810
811 #[cfg(feature = "test_utils")]
814 pub fn single_shard() -> Self {
815 ShardLayout::single_shard().shard_uids().next().unwrap()
816 }
817
818 pub fn to_bytes(&self) -> [u8; 8] {
820 let mut res = [0; 8];
821 res[0..4].copy_from_slice(&u32::to_le_bytes(self.version));
822 res[4..].copy_from_slice(&u32::to_le_bytes(self.shard_id));
823 res
824 }
825
826 pub fn get_upper_bound_db_key(shard_uid_bytes: &[u8; 8]) -> [u8; 8] {
832 let mut result = *shard_uid_bytes;
833 for i in (0..8).rev() {
834 if result[i] == u8::MAX {
835 result[i] = 0;
836 } else {
837 result[i] += 1;
838 return result;
839 }
840 }
841 panic!("Next shard prefix for shard bytes {shard_uid_bytes:?} does not exist");
842 }
843
844 pub fn from_shard_id_and_layout(shard_id: ShardId, shard_layout: &ShardLayout) -> Self {
846 assert!(shard_layout.shard_ids().any(|i| i == shard_id));
847 Self::new(shard_layout.version(), shard_id)
848 }
849
850 pub fn shard_id(&self) -> ShardId {
852 self.shard_id.into()
853 }
854}
855
856impl TryFrom<&[u8]> for ShardUId {
857 type Error = Box<dyn std::error::Error + Send + Sync>;
858
859 fn try_from(bytes: &[u8]) -> Result<Self, Self::Error> {
861 if bytes.len() != 8 {
862 return Err("incorrect length for ShardUId".into());
863 }
864 let version = u32::from_le_bytes(bytes[0..4].try_into().unwrap());
865 let shard_id = u32::from_le_bytes(bytes[4..8].try_into().unwrap());
866 Ok(Self { version, shard_id })
867 }
868}
869
870pub fn get_block_shard_uid(block_hash: &CryptoHash, shard_uid: &ShardUId) -> Vec<u8> {
872 let mut res = Vec::with_capacity(40);
873 res.extend_from_slice(block_hash.as_ref());
874 res.extend_from_slice(&shard_uid.to_bytes());
875 res
876}
877
878pub fn get_block_shard_uid_rev(
880 key: &[u8],
881) -> Result<(CryptoHash, ShardUId), Box<dyn std::error::Error + Send + Sync>> {
882 if key.len() != 40 {
883 return Err(
884 std::io::Error::new(std::io::ErrorKind::InvalidInput, "Invalid key length").into()
885 );
886 }
887 let block_hash = CryptoHash::try_from(&key[..32])?;
888 let shard_id = ShardUId::try_from(&key[32..])?;
889 Ok((block_hash, shard_id))
890}
891
892impl fmt::Display for ShardUId {
893 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
894 write!(f, "s{}.v{}", self.shard_id, self.version)
895 }
896}
897
898impl fmt::Debug for ShardUId {
899 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
900 fmt::Display::fmt(self, f)
901 }
902}
903
904impl str::FromStr for ShardUId {
905 type Err = String;
906
907 fn from_str(s: &str) -> Result<Self, Self::Err> {
908 let (shard_str, version_str) = s
909 .split_once(".")
910 .ok_or_else(|| "shard version and number must be separated by \".\"".to_string())?;
911
912 let version = version_str
913 .strip_prefix("v")
914 .ok_or_else(|| "shard version must start with \"v\"".to_string())?
915 .parse::<ShardVersion>()
916 .map_err(|e| format!("shard version after \"v\" must be a number, {e}"))?;
917
918 let shard_str = shard_str
919 .strip_prefix("s")
920 .ok_or_else(|| "shard id must start with \"s\"".to_string())?;
921 let shard_id = shard_str
922 .parse::<u32>()
923 .map_err(|e| format!("shard id after \"s\" must be a number, {e}"))?;
924
925 Ok(ShardUId { shard_id, version })
926 }
927}
928
929impl<'de> serde::Deserialize<'de> for ShardUId {
930 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
931 where
932 D: serde::Deserializer<'de>,
933 {
934 deserializer.deserialize_any(ShardUIdVisitor)
935 }
936}
937
938impl serde::Serialize for ShardUId {
939 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
940 where
941 S: serde::Serializer,
942 {
943 serializer.serialize_str(&self.to_string())
944 }
945}
946
947struct ShardUIdVisitor;
948impl<'de> serde::de::Visitor<'de> for ShardUIdVisitor {
949 type Value = ShardUId;
950
951 fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
952 write!(
953 formatter,
954 "either string format of `ShardUId` like 's0.v3' for shard 0 version 3, or a map"
955 )
956 }
957
958 fn visit_str<E>(self, v: &str) -> Result<Self::Value, E>
959 where
960 E: serde::de::Error,
961 {
962 v.parse().map_err(|e| E::custom(e))
963 }
964
965 fn visit_map<A>(self, mut map: A) -> Result<Self::Value, A::Error>
966 where
967 A: serde::de::MapAccess<'de>,
968 {
969 let mut version = None;
973 let mut shard_id = None;
974
975 while let Some((field, value)) = map.next_entry()? {
976 match field {
977 "version" => version = Some(value),
978 "shard_id" => shard_id = Some(value),
979 _ => return Err(serde::de::Error::unknown_field(field, &["version", "shard_id"])),
980 }
981 }
982
983 match (version, shard_id) {
984 (None, _) => Err(serde::de::Error::missing_field("version")),
985 (_, None) => Err(serde::de::Error::missing_field("shard_id")),
986 (Some(version), Some(shard_id)) => Ok(ShardUId { version, shard_id }),
987 }
988 }
989}
990
991#[derive(Clone, Copy)]
992pub struct ShardInfo {
993 pub shard_index: ShardIndex,
994 pub shard_uid: ShardUId,
995}
996
997impl ShardInfo {
998 pub fn shard_index(&self) -> ShardIndex {
999 self.shard_index
1000 }
1001
1002 pub fn shard_id(&self) -> ShardId {
1003 self.shard_uid.shard_id()
1004 }
1005
1006 pub fn shard_uid(&self) -> ShardUId {
1007 self.shard_uid
1008 }
1009}
1010
1011#[cfg(test)]
1012mod tests {
1013 use crate::epoch_manager::EpochConfigStore;
1014 use crate::shard_layout::{ShardLayout, ShardUId};
1015 use itertools::Itertools;
1016 use near_primitives_core::types::ProtocolVersion;
1017 use near_primitives_core::types::{AccountId, ShardId};
1018 use rand::distributions::Alphanumeric;
1019 use rand::rngs::StdRng;
1020 use rand::{Rng, SeedableRng};
1021 use std::collections::{BTreeMap, HashMap};
1022
1023 use super::{ShardsSplitMap, ShardsSplitMapV2};
1024
1025 fn new_shard_ids_vec(shard_ids: Vec<u64>) -> Vec<ShardId> {
1026 shard_ids.into_iter().map(Into::into).collect()
1027 }
1028
1029 fn new_shards_split_map(shards_split_map: Vec<Vec<u64>>) -> ShardsSplitMap {
1030 shards_split_map.into_iter().map(new_shard_ids_vec).collect()
1031 }
1032
1033 fn new_shards_split_map_v2(shards_split_map: BTreeMap<u64, Vec<u64>>) -> ShardsSplitMapV2 {
1034 shards_split_map.into_iter().map(|(k, v)| (k.into(), new_shard_ids_vec(v))).collect()
1035 }
1036
1037 impl ShardLayout {
1038 pub fn for_protocol_version(protocol_version: ProtocolVersion) -> Self {
1040 let config_store = EpochConfigStore::for_chain_id("mainnet", None).unwrap();
1041 config_store.get_config(protocol_version).shard_layout.clone()
1042 }
1043 }
1044
1045 #[test]
1046 fn test_shard_layout_v0() {
1047 let num_shards = 4;
1048 #[allow(deprecated)]
1049 let shard_layout = ShardLayout::v0(num_shards, 0);
1050 let mut shard_id_distribution: HashMap<ShardId, _> =
1051 shard_layout.shard_ids().map(|shard_id| (shard_id.into(), 0)).collect();
1052 let mut rng = StdRng::from_seed([0; 32]);
1053 for _i in 0..1000 {
1054 let s: Vec<u8> = (&mut rng).sample_iter(&Alphanumeric).take(10).collect();
1055 let s = String::from_utf8(s).unwrap();
1056 let account_id = s.to_lowercase().parse().unwrap();
1057 let shard_id = shard_layout.account_id_to_shard_id(&account_id);
1058 *shard_id_distribution.get_mut(&shard_id).unwrap() += 1;
1059
1060 let shard_id: u64 = shard_id.into();
1061 assert!(shard_id < num_shards);
1062 }
1063 let expected_distribution: HashMap<ShardId, _> = [
1064 (ShardId::new(0), 247),
1065 (ShardId::new(1), 268),
1066 (ShardId::new(2), 233),
1067 (ShardId::new(3), 252),
1068 ]
1069 .into_iter()
1070 .collect();
1071 assert_eq!(shard_id_distribution, expected_distribution);
1072 }
1073
1074 #[test]
1075 fn test_shard_layout_v1() {
1076 let aid = |s: &str| s.parse().unwrap();
1077 let sid = |s: u64| ShardId::new(s);
1078
1079 let boundary_accounts =
1080 ["aurora", "bar", "foo", "foo.baz", "paz"].iter().map(|a| a.parse().unwrap()).collect();
1081 #[allow(deprecated)]
1082 let shard_layout = ShardLayout::v1(
1083 boundary_accounts,
1084 Some(new_shards_split_map(vec![vec![0, 1, 2], vec![3, 4, 5]])),
1085 1,
1086 );
1087 assert_eq!(
1088 shard_layout.get_children_shards_uids(ShardId::new(0)).unwrap(),
1089 (0..3).map(|x| ShardUId { version: 1, shard_id: x }).collect::<Vec<_>>()
1090 );
1091 assert_eq!(
1092 shard_layout.get_children_shards_uids(ShardId::new(1)).unwrap(),
1093 (3..6).map(|x| ShardUId { version: 1, shard_id: x }).collect::<Vec<_>>()
1094 );
1095 for x in 0..3 {
1096 assert_eq!(shard_layout.get_parent_shard_id(ShardId::new(x)).unwrap(), sid(0));
1097 assert_eq!(shard_layout.get_parent_shard_id(ShardId::new(x + 3)).unwrap(), sid(1));
1098 }
1099
1100 assert_eq!(shard_layout.account_id_to_shard_id(&aid("aurora")), sid(1));
1101 assert_eq!(shard_layout.account_id_to_shard_id(&aid("foo.aurora")), sid(3));
1102 assert_eq!(shard_layout.account_id_to_shard_id(&aid("bar.foo.aurora")), sid(2));
1103 assert_eq!(shard_layout.account_id_to_shard_id(&aid("bar")), sid(2));
1104 assert_eq!(shard_layout.account_id_to_shard_id(&aid("bar.bar")), sid(2));
1105 assert_eq!(shard_layout.account_id_to_shard_id(&aid("foo")), sid(3));
1106 assert_eq!(shard_layout.account_id_to_shard_id(&aid("baz.foo")), sid(2));
1107 assert_eq!(shard_layout.account_id_to_shard_id(&aid("foo.baz")), sid(4));
1108 assert_eq!(shard_layout.account_id_to_shard_id(&aid("a.foo.baz")), sid(0));
1109
1110 assert_eq!(shard_layout.account_id_to_shard_id(&aid("aaa")), sid(0));
1111 assert_eq!(shard_layout.account_id_to_shard_id(&aid("abc")), sid(0));
1112 assert_eq!(shard_layout.account_id_to_shard_id(&aid("bbb")), sid(2));
1113 assert_eq!(shard_layout.account_id_to_shard_id(&aid("foo.goo")), sid(4));
1114 assert_eq!(shard_layout.account_id_to_shard_id(&aid("goo")), sid(4));
1115 assert_eq!(shard_layout.account_id_to_shard_id(&aid("zoo")), sid(5));
1116 }
1117
1118 #[test]
1119 fn test_shard_layout_v2() {
1120 let sid = |s: u64| ShardId::new(s);
1121 let shard_layout = get_test_shard_layout_v2();
1122
1123 assert_eq!(shard_layout.account_id_to_shard_id(&"aaa".parse().unwrap()), sid(3));
1125 assert_eq!(shard_layout.account_id_to_shard_id(&"ddd".parse().unwrap()), sid(8));
1126 assert_eq!(shard_layout.account_id_to_shard_id(&"mmm".parse().unwrap()), sid(4));
1127 assert_eq!(shard_layout.account_id_to_shard_id(&"rrr".parse().unwrap()), sid(7));
1128
1129 assert_eq!(shard_layout.account_id_to_shard_id(&"ccc".parse().unwrap()), sid(8));
1131 assert_eq!(shard_layout.account_id_to_shard_id(&"kkk".parse().unwrap()), sid(4));
1132 assert_eq!(shard_layout.account_id_to_shard_id(&"ppp".parse().unwrap()), sid(7));
1133
1134 assert_eq!(shard_layout.shard_ids().collect_vec(), new_shard_ids_vec(vec![3, 8, 4, 7]));
1136
1137 let version = 3;
1139 let u = |shard_id| ShardUId { shard_id, version };
1140 assert_eq!(shard_layout.shard_uids().collect_vec(), vec![u(3), u(8), u(4), u(7)]);
1141
1142 assert_eq!(shard_layout.get_parent_shard_id(ShardId::new(3)).unwrap(), sid(3));
1144 assert_eq!(shard_layout.get_parent_shard_id(ShardId::new(8)).unwrap(), sid(1));
1145 assert_eq!(shard_layout.get_parent_shard_id(ShardId::new(4)).unwrap(), sid(4));
1146 assert_eq!(shard_layout.get_parent_shard_id(ShardId::new(7)).unwrap(), sid(1));
1147
1148 assert_eq!(
1150 shard_layout.get_children_shards_ids(ShardId::new(1)).unwrap(),
1151 new_shard_ids_vec(vec![7, 8])
1152 );
1153 assert_eq!(
1154 shard_layout.get_children_shards_ids(ShardId::new(3)).unwrap(),
1155 new_shard_ids_vec(vec![3])
1156 );
1157 assert_eq!(
1158 shard_layout.get_children_shards_ids(ShardId::new(4)).unwrap(),
1159 new_shard_ids_vec(vec![4])
1160 );
1161 }
1162
1163 fn get_test_shard_layout_v2() -> ShardLayout {
1164 let b0 = "ccc".parse().unwrap();
1165 let b1 = "kkk".parse().unwrap();
1166 let b2 = "ppp".parse().unwrap();
1167
1168 let boundary_accounts = vec![b0, b1, b2];
1169 let shard_ids = vec![3, 8, 4, 7];
1170 let shard_ids = new_shard_ids_vec(shard_ids);
1171
1172 let shards_split_map = BTreeMap::from([(1, vec![7, 8]), (3, vec![3]), (4, vec![4])]);
1175 let shards_split_map = new_shards_split_map_v2(shards_split_map);
1176 let shards_split_map = Some(shards_split_map);
1177
1178 ShardLayout::v2(boundary_accounts, shard_ids, shards_split_map)
1179 }
1180
1181 #[test]
1182 fn test_deriving_shard_layout() {
1183 fn to_boundary_accounts<const N: usize>(accounts: [&str; N]) -> Vec<AccountId> {
1184 accounts.into_iter().map(|a| a.parse().unwrap()).collect()
1185 }
1186
1187 fn to_shard_ids<const N: usize>(ids: [u32; N]) -> Vec<ShardId> {
1188 ids.into_iter().map(|id| ShardId::new(id as u64)).collect()
1189 }
1190
1191 fn to_shards_split_map<const N: usize>(
1192 xs: [(u32, Vec<u32>); N],
1193 ) -> BTreeMap<ShardId, Vec<ShardId>> {
1194 xs.into_iter()
1195 .map(|(k, xs)| {
1196 (
1197 ShardId::new(k as u64),
1198 xs.into_iter().map(|x| ShardId::new(x as u64)).collect(),
1199 )
1200 })
1201 .collect()
1202 }
1203
1204 let base_layout = ShardLayout::v2(vec![], vec![ShardId::new(0)], None);
1208 let derived_layout =
1209 ShardLayout::derive_shard_layout(&base_layout, "test1.near".parse().unwrap());
1210 assert_eq!(
1211 derived_layout,
1212 ShardLayout::v2(
1213 to_boundary_accounts(["test1.near"]),
1214 to_shard_ids([1, 2]),
1215 Some(to_shards_split_map([(0, vec![1, 2])])),
1216 ),
1217 );
1218
1219 let base_layout = derived_layout;
1223 let derived_layout =
1224 ShardLayout::derive_shard_layout(&base_layout, "test3.near".parse().unwrap());
1225 assert_eq!(
1226 derived_layout,
1227 ShardLayout::v2(
1228 to_boundary_accounts(["test1.near", "test3.near"]),
1229 to_shard_ids([1, 3, 4]),
1230 Some(to_shards_split_map([(1, vec![1]), (2, vec![3, 4])])),
1231 ),
1232 );
1233
1234 let base_layout = derived_layout;
1238 let derived_layout =
1239 ShardLayout::derive_shard_layout(&base_layout, "test0.near".parse().unwrap());
1240 assert_eq!(
1241 derived_layout,
1242 ShardLayout::v2(
1243 to_boundary_accounts(["test0.near", "test1.near", "test3.near"]),
1244 to_shard_ids([5, 6, 3, 4]),
1245 Some(to_shards_split_map([(1, vec![5, 6]), (3, vec![3]), (4, vec![4]),])),
1246 ),
1247 );
1248
1249 let base_layout = derived_layout;
1253 let derived_layout =
1254 ShardLayout::derive_shard_layout(&base_layout, "test2.near".parse().unwrap());
1255 assert_eq!(
1256 derived_layout,
1257 ShardLayout::v2(
1258 to_boundary_accounts(["test0.near", "test1.near", "test2.near", "test3.near"]),
1259 to_shard_ids([5, 6, 7, 8, 4]),
1260 Some(to_shards_split_map([
1261 (5, vec![5]),
1262 (6, vec![6]),
1263 (3, vec![7, 8]),
1264 (4, vec![4]),
1265 ])),
1266 )
1267 );
1268
1269 assert_eq!(base_layout.version(), 3);
1273 assert_eq!(base_layout.version(), derived_layout.version());
1274 }
1275
1276 #[test]
1280 fn test_multi_shard_non_contiguous() {
1281 for n in 2..10 {
1282 let shard_layout = ShardLayout::multi_shard(n, 0);
1283 assert!(!shard_layout.shard_ids().is_sorted());
1284 }
1285 }
1286}