1use crate::hash::CryptoHash;
14use crate::types::{AccountId, NumShards};
15use borsh::{BorshDeserialize, BorshSerialize};
16use itertools::Itertools;
17use near_primitives_core::types::{ShardId, ShardIndex};
18use near_schema_checker_lib::ProtocolSchema;
19use std::collections::{BTreeMap, BTreeSet};
20use std::{fmt, str};
21
22pub type ShardVersion = u32;
29
30#[derive(
39 BorshSerialize,
40 BorshDeserialize,
41 serde::Serialize,
42 serde::Deserialize,
43 Clone,
44 Debug,
45 PartialEq,
46 Eq,
47 ProtocolSchema,
48)]
49#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))]
50pub enum ShardLayout {
51 V0(ShardLayoutV0),
52 V1(ShardLayoutV1),
53 V2(ShardLayoutV2),
54}
55
56#[derive(
62 BorshSerialize,
63 BorshDeserialize,
64 serde::Serialize,
65 serde::Deserialize,
66 Clone,
67 Debug,
68 PartialEq,
69 Eq,
70 ProtocolSchema,
71)]
72#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))]
73pub struct ShardLayoutV0 {
74 num_shards: NumShards,
76 version: ShardVersion,
78}
79
80type ShardsSplitMap = Vec<Vec<ShardId>>;
87
88type ShardsSplitMapV2 = BTreeMap<ShardId, Vec<ShardId>>;
95
96type ShardsParentMapV2 = BTreeMap<ShardId, ShardId>;
98
99pub fn shard_uids_to_ids(shard_uids: &[ShardUId]) -> Vec<ShardId> {
100 shard_uids.iter().map(|shard_uid| shard_uid.shard_id()).collect_vec()
101}
102
103#[derive(
104 BorshSerialize,
105 BorshDeserialize,
106 serde::Serialize,
107 serde::Deserialize,
108 Clone,
109 Debug,
110 PartialEq,
111 Eq,
112 ProtocolSchema,
113)]
114#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))]
115pub struct ShardLayoutV1 {
116 boundary_accounts: Vec<AccountId>,
121 shards_split_map: Option<ShardsSplitMap>,
125 to_parent_shard_map: Option<Vec<ShardId>>,
128 version: ShardVersion,
130}
131
132impl ShardLayoutV1 {
133 fn account_id_to_shard_id(&self, account_id: &AccountId) -> ShardId {
136 let mut shard_id: u64 = 0;
137 for boundary_account in &self.boundary_accounts {
138 if account_id < boundary_account {
139 break;
140 }
141 shard_id += 1;
142 }
143 shard_id.into()
144 }
145}
146
147#[derive(BorshSerialize, BorshDeserialize, Clone, Debug, PartialEq, Eq, ProtocolSchema)]
149pub struct ShardLayoutV2 {
150 boundary_accounts: Vec<AccountId>,
157
158 shard_ids: Vec<ShardId>,
165
166 id_to_index_map: BTreeMap<ShardId, ShardIndex>,
168
169 index_to_id_map: BTreeMap<ShardIndex, ShardId>,
172
173 shards_split_map: Option<ShardsSplitMapV2>,
176 shards_parent_map: Option<ShardsParentMapV2>,
179
180 version: ShardVersion,
184}
185
186#[derive(serde::Serialize, serde::Deserialize)]
189#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))]
190struct SerdeShardLayoutV2 {
191 boundary_accounts: Vec<AccountId>,
192 shard_ids: Vec<ShardId>,
193 id_to_index_map: BTreeMap<String, ShardIndex>,
194 index_to_id_map: BTreeMap<String, ShardId>,
195 shards_split_map: Option<BTreeMap<String, Vec<ShardId>>>,
196 shards_parent_map: Option<BTreeMap<String, ShardId>>,
197 version: ShardVersion,
198}
199
200impl From<&ShardLayoutV2> for SerdeShardLayoutV2 {
201 fn from(layout: &ShardLayoutV2) -> Self {
202 fn key_to_string<K, V>(map: &BTreeMap<K, V>) -> BTreeMap<String, V>
203 where
204 K: std::fmt::Display,
205 V: Clone,
206 {
207 map.iter().map(|(k, v)| (k.to_string(), v.clone())).collect()
208 }
209
210 Self {
211 boundary_accounts: layout.boundary_accounts.clone(),
212 shard_ids: layout.shard_ids.clone(),
213 id_to_index_map: key_to_string(&layout.id_to_index_map),
214 index_to_id_map: key_to_string(&layout.index_to_id_map),
215 shards_split_map: layout.shards_split_map.as_ref().map(key_to_string),
216 shards_parent_map: layout.shards_parent_map.as_ref().map(key_to_string),
217 version: layout.version,
218 }
219 }
220}
221
222impl TryFrom<SerdeShardLayoutV2> for ShardLayoutV2 {
223 type Error = Box<dyn std::error::Error + Send + Sync>;
224
225 fn try_from(layout: SerdeShardLayoutV2) -> Result<Self, Self::Error> {
226 fn key_to_shard_id<V>(
227 map: BTreeMap<String, V>,
228 ) -> Result<BTreeMap<ShardId, V>, Box<dyn std::error::Error + Send + Sync>> {
229 map.into_iter().map(|(k, v)| Ok((k.parse::<u64>()?.into(), v))).collect()
230 }
231
232 let SerdeShardLayoutV2 {
233 boundary_accounts,
234 shard_ids,
235 id_to_index_map,
236 index_to_id_map,
237 shards_split_map,
238 shards_parent_map,
239 version,
240 } = layout;
241
242 let id_to_index_map = key_to_shard_id(id_to_index_map)?;
243 let shards_split_map = shards_split_map.map(key_to_shard_id).transpose()?;
244 let shards_parent_map = shards_parent_map.map(key_to_shard_id).transpose()?;
245 let index_to_id_map = index_to_id_map
246 .into_iter()
247 .map(|(k, v)| Ok((k.parse()?, v)))
248 .collect::<Result<_, Self::Error>>()?;
249
250 match (&shards_split_map, &shards_parent_map) {
251 (None, None) => {}
252 (Some(shard_split_map), Some(shards_parent_map)) => {
253 let expected_shards_parent_map =
254 validate_and_derive_shard_parent_map_v2(&shard_ids, &shard_split_map);
255 if &expected_shards_parent_map != shards_parent_map {
256 return Err("shards_parent_map does not match the expected value".into());
257 }
258 }
259 _ => {
260 return Err(
261 "shards_split_map and shards_parent_map must be both present or both absent"
262 .into(),
263 );
264 }
265 }
266
267 Ok(Self {
268 boundary_accounts,
269 shard_ids,
270 id_to_index_map,
271 index_to_id_map,
272 shards_split_map,
273 shards_parent_map,
274 version,
275 })
276 }
277}
278
279impl serde::Serialize for ShardLayoutV2 {
280 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
281 where
282 S: serde::Serializer,
283 {
284 SerdeShardLayoutV2::from(self).serialize(serializer)
285 }
286}
287
288impl<'de> serde::Deserialize<'de> for ShardLayoutV2 {
289 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
290 where
291 D: serde::Deserializer<'de>,
292 {
293 let serde_layout = SerdeShardLayoutV2::deserialize(deserializer)?;
294 ShardLayoutV2::try_from(serde_layout).map_err(serde::de::Error::custom)
295 }
296}
297
298#[cfg(feature = "schemars")]
299impl schemars::JsonSchema for ShardLayoutV2 {
300 fn schema_name() -> std::borrow::Cow<'static, str> {
301 "ShardLayoutV2".to_string().into()
302 }
303
304 fn json_schema(generator: &mut schemars::SchemaGenerator) -> schemars::Schema {
305 SerdeShardLayoutV2::json_schema(generator)
306 }
307}
308
309impl ShardLayoutV2 {
310 pub fn account_id_to_shard_id(&self, account_id: &AccountId) -> ShardId {
311 let mut shard_id_index = 0;
314 for boundary_account in &self.boundary_accounts {
315 if account_id < boundary_account {
316 break;
317 }
318 shard_id_index += 1;
319 }
320 self.shard_ids[shard_id_index]
321 }
322
323 pub fn shards_split_map(&self) -> &Option<ShardsSplitMapV2> {
324 &self.shards_split_map
325 }
326
327 pub fn boundary_accounts(&self) -> &Vec<AccountId> {
328 &self.boundary_accounts
329 }
330}
331
332#[derive(Debug)]
333pub enum ShardLayoutError {
334 InvalidShardIdError { shard_id: ShardId },
335 InvalidShardIndexError { shard_index: ShardIndex },
336 NoParentError { shard_id: ShardId },
337}
338
339impl fmt::Display for ShardLayoutError {
340 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
341 write!(f, "{:?}", self)
342 }
343}
344
345impl std::error::Error for ShardLayoutError {}
346
347impl ShardLayout {
348 pub fn single_shard() -> Self {
350 let shard_id = ShardId::new(0);
351 Self::V2(ShardLayoutV2 {
352 boundary_accounts: vec![],
353 shard_ids: vec![shard_id],
354 id_to_index_map: [(shard_id, 0)].into(),
355 index_to_id_map: [(0, shard_id)].into(),
356 shards_split_map: None,
357 shards_parent_map: None,
358 version: 0,
359 })
360 }
361
362 #[cfg(all(feature = "test_utils", feature = "rand"))]
367 pub fn multi_shard(num_shards: NumShards, version: ShardVersion) -> Self {
368 assert!(num_shards > 0, "at least 1 shard is required");
369
370 let boundary_accounts = (1..num_shards)
371 .map(|i| format!("test{}", i).parse().unwrap())
372 .collect::<Vec<AccountId>>();
373
374 Self::multi_shard_custom(boundary_accounts, version)
375 }
376
377 #[cfg(all(feature = "test_utils", feature = "rand"))]
382 pub fn multi_shard_custom(boundary_accounts: Vec<AccountId>, version: ShardVersion) -> Self {
383 use rand::{SeedableRng, rngs::StdRng, seq::SliceRandom};
384
385 let num_shards = (boundary_accounts.len() + 1) as u64;
386
387 let mut rng = StdRng::seed_from_u64(42);
390 let mut shard_ids = (0..num_shards).map(ShardId::new).collect::<Vec<ShardId>>();
391 shard_ids.shuffle(&mut rng);
392
393 let (id_to_index_map, index_to_id_map) = shard_ids
394 .iter()
395 .enumerate()
396 .map(|(i, &shard_id)| ((shard_id, i), (i, shard_id)))
397 .unzip();
398
399 Self::V2(ShardLayoutV2 {
400 boundary_accounts,
401 shard_ids,
402 id_to_index_map,
403 index_to_id_map,
404 shards_split_map: None,
405 shards_parent_map: None,
406 version,
407 })
408 }
409
410 #[deprecated(note = "Use multi_shard() instead")]
412 pub fn v0(num_shards: NumShards, version: ShardVersion) -> Self {
413 Self::V0(ShardLayoutV0 { num_shards, version })
414 }
415
416 #[deprecated(note = "Use multi_shard() instead")]
418 pub fn v1(
419 boundary_accounts: Vec<AccountId>,
420 shards_split_map: Option<ShardsSplitMap>,
421 version: ShardVersion,
422 ) -> Self {
423 let to_parent_shard_map = if let Some(shards_split_map) = &shards_split_map {
424 let mut to_parent_shard_map = BTreeMap::new();
425 let num_shards = (boundary_accounts.len() + 1) as NumShards;
426 for (parent_shard_id, shard_ids) in shards_split_map.iter().enumerate() {
427 let parent_shard_id = ShardId::new(parent_shard_id as u64);
428 for &shard_id in shard_ids {
429 let prev = to_parent_shard_map.insert(shard_id, parent_shard_id);
430 assert!(prev.is_none(), "no shard should appear in the map twice");
431 let shard_id: u64 = shard_id.into();
432 assert!(shard_id < num_shards, "shard id should be valid");
433 }
434 }
435 Some((0..num_shards).map(|shard_id| to_parent_shard_map[&shard_id.into()]).collect())
436 } else {
437 None
438 };
439 Self::V1(ShardLayoutV1 {
440 boundary_accounts,
441 shards_split_map,
442 to_parent_shard_map,
443 version,
444 })
445 }
446
447 pub fn v2(
449 boundary_accounts: Vec<AccountId>,
450 shard_ids: Vec<ShardId>,
451 shards_split_map: Option<ShardsSplitMapV2>,
452 ) -> Self {
453 const VERSION: ShardVersion = 3;
455
456 assert_eq!(boundary_accounts.len() + 1, shard_ids.len());
457 assert_eq!(boundary_accounts, boundary_accounts.iter().sorted().cloned().collect_vec());
458
459 let mut id_to_index_map = BTreeMap::new();
460 let mut index_to_id_map = BTreeMap::new();
461 for (shard_index, &shard_id) in shard_ids.iter().enumerate() {
462 id_to_index_map.insert(shard_id, shard_index);
463 index_to_id_map.insert(shard_index, shard_id);
464 }
465
466 let shards_parent_map = shards_split_map.as_ref().map(|shards_split_map| {
467 validate_and_derive_shard_parent_map_v2(&shard_ids, &shards_split_map)
468 });
469
470 Self::V2(ShardLayoutV2 {
471 boundary_accounts,
472 shard_ids,
473 id_to_index_map,
474 index_to_id_map,
475 shards_split_map,
476 shards_parent_map,
477 version: VERSION,
478 })
479 }
480
481 pub fn account_id_to_shard_id(&self, account_id: &AccountId) -> ShardId {
487 match self {
488 ShardLayout::V0(v0) => {
489 let hash = CryptoHash::hash_bytes(account_id.as_bytes());
490 let (bytes, _) = stdx::split_array::<32, 8, 24>(hash.as_bytes());
491 let shard_id = u64::from_le_bytes(*bytes) % v0.num_shards;
492 shard_id.into()
493 }
494 ShardLayout::V1(v1) => v1.account_id_to_shard_id(account_id),
495 ShardLayout::V2(v2) => v2.account_id_to_shard_id(account_id),
496 }
497 }
498
499 #[inline]
501 pub fn account_id_to_shard_uid(&self, account_id: &AccountId) -> ShardUId {
502 ShardUId::from_shard_id_and_layout(self.account_id_to_shard_id(account_id), self)
503 }
504
505 #[inline]
508 pub fn get_children_shards_uids(&self, parent_shard_id: ShardId) -> Option<Vec<ShardUId>> {
509 self.get_children_shards_ids(parent_shard_id).map(|shards| {
510 shards.into_iter().map(|id| ShardUId::from_shard_id_and_layout(id, self)).collect()
511 })
512 }
513
514 pub fn get_children_shards_ids(&self, parent_shard_id: ShardId) -> Option<Vec<ShardId>> {
517 match self {
518 Self::V0(_) => None,
519 Self::V1(v1) => match &v1.shards_split_map {
520 Some(shards_split_map) => {
521 let parent_shard_index: ShardIndex = parent_shard_id.into();
525 shards_split_map.get(parent_shard_index).cloned()
526 }
527 None => None,
528 },
529 Self::V2(v2) => match &v2.shards_split_map {
530 Some(shards_split_map) => shards_split_map.get(&parent_shard_id).cloned(),
531 None => None,
532 },
533 }
534 }
535
536 pub fn try_get_parent_shard_id(
540 &self,
541 shard_id: ShardId,
542 ) -> Result<Option<ShardId>, ShardLayoutError> {
543 if !self.shard_ids().any(|id| id == shard_id) {
544 return Err(ShardLayoutError::InvalidShardIdError { shard_id });
545 }
546 let parent_shard_id = match self {
547 Self::V0(_) => None,
548 Self::V1(v1) => match &v1.to_parent_shard_map {
549 Some(to_parent_shard_map) => {
552 let shard_index = self.get_shard_index(shard_id).unwrap();
553 let parent_shard_id = to_parent_shard_map.get(shard_index).unwrap();
554 Some(*parent_shard_id)
555 }
556 None => None,
557 },
558 Self::V2(v2) => match &v2.shards_parent_map {
559 Some(to_parent_shard_map) => {
562 let parent_shard_id = to_parent_shard_map.get(&shard_id).unwrap();
563 Some(*parent_shard_id)
564 }
565 None => None,
566 },
567 };
568 Ok(parent_shard_id)
569 }
570
571 pub fn get_parent_shard_id(&self, shard_id: ShardId) -> Result<ShardId, ShardLayoutError> {
576 let parent_shard_id = self.try_get_parent_shard_id(shard_id)?;
577 parent_shard_id.ok_or(ShardLayoutError::NoParentError { shard_id })
578 }
579
580 pub fn derive_shard_layout(
582 base_shard_layout: &ShardLayout,
583 new_boundary_account: AccountId,
584 ) -> ShardLayout {
585 let mut boundary_accounts = base_shard_layout.boundary_accounts().clone();
586 let mut shard_ids = base_shard_layout.shard_ids().collect::<Vec<_>>();
587 let mut shards_split_map = shard_ids
588 .iter()
589 .map(|id| (*id, vec![*id]))
590 .collect::<BTreeMap<ShardId, Vec<ShardId>>>();
591
592 assert!(!boundary_accounts.contains(&new_boundary_account), "duplicated boundary account");
593
594 boundary_accounts.push(new_boundary_account.clone());
596 boundary_accounts.sort();
597 let new_boundary_account_index = boundary_accounts
598 .iter()
599 .position(|acc| acc == &new_boundary_account)
600 .expect("account should be guaranteed to exist at this point");
601
602 let max_shard_id =
604 *shard_ids.iter().max().expect("there should always be at least one shard");
605 let new_shards = vec![max_shard_id + 1, max_shard_id + 2];
606 let parent_shard_id = shard_ids
607 .splice(new_boundary_account_index..new_boundary_account_index + 1, new_shards.clone())
608 .collect::<Vec<_>>();
609 let [parent_shard_id] = parent_shard_id.as_slice() else {
610 panic!("should only splice one shard");
611 };
612 shards_split_map.insert(*parent_shard_id, new_shards);
613
614 ShardLayout::v2(boundary_accounts, shard_ids, Some(shards_split_map))
615 }
616
617 #[inline]
618 pub fn version(&self) -> ShardVersion {
619 match self {
620 Self::V0(v0) => v0.version,
621 Self::V1(v1) => v1.version,
622 Self::V2(v2) => v2.version,
623 }
624 }
625
626 pub fn boundary_accounts(&self) -> &Vec<AccountId> {
627 match self {
628 Self::V1(v1) => &v1.boundary_accounts,
629 Self::V2(v2) => &v2.boundary_accounts,
630 _ => panic!("ShardLayout::V0 doesn't have boundary accounts"),
631 }
632 }
633
634 pub fn num_shards(&self) -> NumShards {
635 match self {
636 Self::V0(v0) => v0.num_shards,
637 Self::V1(v1) => (v1.boundary_accounts.len() + 1) as NumShards,
638 Self::V2(v2) => v2.shard_ids.len() as NumShards,
639 }
640 }
641
642 pub fn shard_ids(&self) -> impl Iterator<Item = ShardId> {
646 match self {
647 Self::V0(_) => (0..self.num_shards()).map(Into::into).collect_vec().into_iter(),
648 Self::V1(_) => (0..self.num_shards()).map(Into::into).collect_vec().into_iter(),
649 Self::V2(v2) => v2.shard_ids.clone().into_iter(),
650 }
651 }
652
653 pub fn shard_uids(&self) -> impl Iterator<Item = ShardUId> + '_ {
656 self.shard_ids().map(|shard_id| ShardUId::from_shard_id_and_layout(shard_id, self))
657 }
658
659 pub fn shard_indexes(&self) -> impl Iterator<Item = ShardIndex> + 'static {
660 let num_shards: usize =
661 self.num_shards().try_into().expect("Number of shards doesn't fit in usize");
662 match self {
663 Self::V0(_) | Self::V1(_) | Self::V2(_) => (0..num_shards).into_iter(),
664 }
665 }
666
667 pub fn shard_infos(&self) -> impl Iterator<Item = ShardInfo> + '_ {
672 self.shard_uids()
673 .enumerate()
674 .map(|(shard_index, shard_uid)| ShardInfo { shard_index, shard_uid })
675 }
676
677 pub fn get_shard_index(&self, shard_id: ShardId) -> Result<ShardIndex, ShardLayoutError> {
680 match self {
681 Self::V0(_) => Ok(shard_id.into()),
683 Self::V1(_) => Ok(shard_id.into()),
685 Self::V2(v2) => v2
687 .id_to_index_map
688 .get(&shard_id)
689 .copied()
690 .ok_or(ShardLayoutError::InvalidShardIdError { shard_id }),
691 }
692 }
693
694 pub fn get_shard_id(&self, shard_index: ShardIndex) -> Result<ShardId, ShardLayoutError> {
697 let num_shards = self.num_shards() as usize;
698 match self {
699 Self::V0(_) | Self::V1(_) => {
700 if shard_index >= num_shards {
701 return Err(ShardLayoutError::InvalidShardIndexError { shard_index });
702 }
703 Ok(ShardId::new(shard_index as u64))
704 }
705 Self::V2(v2) => v2
706 .shard_ids
707 .get(shard_index)
708 .copied()
709 .ok_or(ShardLayoutError::InvalidShardIndexError { shard_index }),
710 }
711 }
712
713 pub fn get_shard_uid(&self, shard_index: ShardIndex) -> Result<ShardUId, ShardLayoutError> {
714 let shard_id = self.get_shard_id(shard_index)?;
715 Ok(ShardUId::from_shard_id_and_layout(shard_id, self))
716 }
717
718 pub fn get_split_parent_shard_ids(&self) -> BTreeSet<ShardId> {
721 let mut parent_shard_ids = BTreeSet::new();
722 for shard_id in self.shard_ids() {
723 let parent_shard_id = self
724 .try_get_parent_shard_id(shard_id)
725 .expect("shard_id belongs to the shard layout");
726 let Some(parent_shard_id) = parent_shard_id else {
727 continue;
728 };
729 if parent_shard_id == shard_id {
730 continue;
731 }
732 parent_shard_ids.insert(parent_shard_id);
733 }
734 parent_shard_ids
735 }
736
737 pub fn get_split_parent_shard_uids(&self) -> BTreeSet<ShardUId> {
740 let parent_shard_ids = self.get_split_parent_shard_ids();
741 parent_shard_ids
742 .into_iter()
743 .map(|shard_id| ShardUId::new(self.version(), shard_id))
744 .collect()
745 }
746}
747
748fn validate_and_derive_shard_parent_map_v2(
750 shard_ids: &Vec<ShardId>,
751 shards_split_map: &ShardsSplitMapV2,
752) -> ShardsParentMapV2 {
753 let mut shards_parent_map = ShardsParentMapV2::new();
754 for (&parent_shard_id, child_shard_ids) in shards_split_map {
755 for &child_shard_id in child_shard_ids {
756 let prev = shards_parent_map.insert(child_shard_id, parent_shard_id);
757 assert!(prev.is_none(), "no shard should appear in the map twice");
758 }
759 if let &[child_shard_id] = child_shard_ids.as_slice() {
760 assert_eq!(parent_shard_id, child_shard_id);
763 } else {
764 assert!(!shard_ids.contains(&parent_shard_id));
767 }
768 }
769
770 assert_eq!(
771 shard_ids.iter().copied().sorted().collect_vec(),
772 shards_parent_map.keys().copied().collect_vec()
773 );
774 shards_parent_map
775}
776
777#[derive(
787 BorshSerialize,
788 BorshDeserialize,
789 Hash,
790 Clone,
791 Copy,
792 PartialEq,
793 Eq,
794 PartialOrd,
795 Ord,
796 ProtocolSchema,
797)]
798pub struct ShardUId {
799 pub version: ShardVersion,
800 pub shard_id: u32,
801}
802
803impl ShardUId {
804 pub fn new(version: ShardVersion, shard_id: ShardId) -> Self {
805 Self { version, shard_id: shard_id.into() }
806 }
807
808 #[cfg(feature = "test_utils")]
811 pub fn single_shard() -> Self {
812 ShardLayout::single_shard().shard_uids().next().unwrap()
813 }
814
815 pub fn to_bytes(&self) -> [u8; 8] {
817 let mut res = [0; 8];
818 res[0..4].copy_from_slice(&u32::to_le_bytes(self.version));
819 res[4..].copy_from_slice(&u32::to_le_bytes(self.shard_id));
820 res
821 }
822
823 pub fn get_upper_bound_db_key(shard_uid_bytes: &[u8; 8]) -> [u8; 8] {
829 let mut result = *shard_uid_bytes;
830 for i in (0..8).rev() {
831 if result[i] == u8::MAX {
832 result[i] = 0;
833 } else {
834 result[i] += 1;
835 return result;
836 }
837 }
838 panic!("Next shard prefix for shard bytes {shard_uid_bytes:?} does not exist");
839 }
840
841 pub fn from_shard_id_and_layout(shard_id: ShardId, shard_layout: &ShardLayout) -> Self {
843 assert!(shard_layout.shard_ids().any(|i| i == shard_id));
844 Self::new(shard_layout.version(), shard_id)
845 }
846
847 pub fn shard_id(&self) -> ShardId {
849 self.shard_id.into()
850 }
851}
852
853impl TryFrom<&[u8]> for ShardUId {
854 type Error = Box<dyn std::error::Error + Send + Sync>;
855
856 fn try_from(bytes: &[u8]) -> Result<Self, Self::Error> {
858 if bytes.len() != 8 {
859 return Err("incorrect length for ShardUId".into());
860 }
861 let version = u32::from_le_bytes(bytes[0..4].try_into().unwrap());
862 let shard_id = u32::from_le_bytes(bytes[4..8].try_into().unwrap());
863 Ok(Self { version, shard_id })
864 }
865}
866
867pub fn get_block_shard_uid(block_hash: &CryptoHash, shard_uid: &ShardUId) -> Vec<u8> {
869 let mut res = Vec::with_capacity(40);
870 res.extend_from_slice(block_hash.as_ref());
871 res.extend_from_slice(&shard_uid.to_bytes());
872 res
873}
874
875pub fn get_block_shard_uid_rev(
877 key: &[u8],
878) -> Result<(CryptoHash, ShardUId), Box<dyn std::error::Error + Send + Sync>> {
879 if key.len() != 40 {
880 return Err(
881 std::io::Error::new(std::io::ErrorKind::InvalidInput, "Invalid key length").into()
882 );
883 }
884 let block_hash = CryptoHash::try_from(&key[..32])?;
885 let shard_id = ShardUId::try_from(&key[32..])?;
886 Ok((block_hash, shard_id))
887}
888
889impl fmt::Display for ShardUId {
890 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
891 write!(f, "s{}.v{}", self.shard_id, self.version)
892 }
893}
894
895impl fmt::Debug for ShardUId {
896 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
897 fmt::Display::fmt(self, f)
898 }
899}
900
901impl str::FromStr for ShardUId {
902 type Err = String;
903
904 fn from_str(s: &str) -> Result<Self, Self::Err> {
905 let (shard_str, version_str) = s
906 .split_once(".")
907 .ok_or_else(|| "shard version and number must be separated by \".\"".to_string())?;
908
909 let version = version_str
910 .strip_prefix("v")
911 .ok_or_else(|| "shard version must start with \"v\"".to_string())?
912 .parse::<ShardVersion>()
913 .map_err(|e| format!("shard version after \"v\" must be a number, {e}"))?;
914
915 let shard_str = shard_str
916 .strip_prefix("s")
917 .ok_or_else(|| "shard id must start with \"s\"".to_string())?;
918 let shard_id = shard_str
919 .parse::<u32>()
920 .map_err(|e| format!("shard id after \"s\" must be a number, {e}"))?;
921
922 Ok(ShardUId { shard_id, version })
923 }
924}
925
926impl<'de> serde::Deserialize<'de> for ShardUId {
927 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
928 where
929 D: serde::Deserializer<'de>,
930 {
931 deserializer.deserialize_any(ShardUIdVisitor)
932 }
933}
934
935impl serde::Serialize for ShardUId {
936 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
937 where
938 S: serde::Serializer,
939 {
940 serializer.serialize_str(&self.to_string())
941 }
942}
943
944struct ShardUIdVisitor;
945impl<'de> serde::de::Visitor<'de> for ShardUIdVisitor {
946 type Value = ShardUId;
947
948 fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
949 write!(
950 formatter,
951 "either string format of `ShardUId` like 's0.v3' for shard 0 version 3, or a map"
952 )
953 }
954
955 fn visit_str<E>(self, v: &str) -> Result<Self::Value, E>
956 where
957 E: serde::de::Error,
958 {
959 v.parse().map_err(|e| E::custom(e))
960 }
961
962 fn visit_map<A>(self, mut map: A) -> Result<Self::Value, A::Error>
963 where
964 A: serde::de::MapAccess<'de>,
965 {
966 let mut version = None;
970 let mut shard_id = None;
971
972 while let Some((field, value)) = map.next_entry()? {
973 match field {
974 "version" => version = Some(value),
975 "shard_id" => shard_id = Some(value),
976 _ => return Err(serde::de::Error::unknown_field(field, &["version", "shard_id"])),
977 }
978 }
979
980 match (version, shard_id) {
981 (None, _) => Err(serde::de::Error::missing_field("version")),
982 (_, None) => Err(serde::de::Error::missing_field("shard_id")),
983 (Some(version), Some(shard_id)) => Ok(ShardUId { version, shard_id }),
984 }
985 }
986}
987
988pub struct ShardInfo {
989 pub shard_index: ShardIndex,
990 pub shard_uid: ShardUId,
991}
992
993impl ShardInfo {
994 pub fn shard_index(&self) -> ShardIndex {
995 self.shard_index
996 }
997
998 pub fn shard_id(&self) -> ShardId {
999 self.shard_uid.shard_id()
1000 }
1001
1002 pub fn shard_uid(&self) -> ShardUId {
1003 self.shard_uid
1004 }
1005}
1006
1007#[cfg(test)]
1008mod tests {
1009 use crate::epoch_manager::EpochConfigStore;
1010 use crate::shard_layout::{ShardLayout, ShardUId};
1011 use itertools::Itertools;
1012 use near_primitives_core::types::ProtocolVersion;
1013 use near_primitives_core::types::{AccountId, ShardId};
1014 use rand::distributions::Alphanumeric;
1015 use rand::rngs::StdRng;
1016 use rand::{Rng, SeedableRng};
1017 use std::collections::{BTreeMap, HashMap};
1018
1019 use super::{ShardsSplitMap, ShardsSplitMapV2};
1020
1021 fn new_shard_ids_vec(shard_ids: Vec<u64>) -> Vec<ShardId> {
1022 shard_ids.into_iter().map(Into::into).collect()
1023 }
1024
1025 fn new_shards_split_map(shards_split_map: Vec<Vec<u64>>) -> ShardsSplitMap {
1026 shards_split_map.into_iter().map(new_shard_ids_vec).collect()
1027 }
1028
1029 fn new_shards_split_map_v2(shards_split_map: BTreeMap<u64, Vec<u64>>) -> ShardsSplitMapV2 {
1030 shards_split_map.into_iter().map(|(k, v)| (k.into(), new_shard_ids_vec(v))).collect()
1031 }
1032
1033 impl ShardLayout {
1034 pub fn for_protocol_version(protocol_version: ProtocolVersion) -> Self {
1036 let config_store = EpochConfigStore::for_chain_id("mainnet", None).unwrap();
1037 config_store.get_config(protocol_version).shard_layout.clone()
1038 }
1039 }
1040
1041 #[test]
1042 fn test_shard_layout_v0() {
1043 let num_shards = 4;
1044 #[allow(deprecated)]
1045 let shard_layout = ShardLayout::v0(num_shards, 0);
1046 let mut shard_id_distribution: HashMap<ShardId, _> =
1047 shard_layout.shard_ids().map(|shard_id| (shard_id.into(), 0)).collect();
1048 let mut rng = StdRng::from_seed([0; 32]);
1049 for _i in 0..1000 {
1050 let s: Vec<u8> = (&mut rng).sample_iter(&Alphanumeric).take(10).collect();
1051 let s = String::from_utf8(s).unwrap();
1052 let account_id = s.to_lowercase().parse().unwrap();
1053 let shard_id = shard_layout.account_id_to_shard_id(&account_id);
1054 *shard_id_distribution.get_mut(&shard_id).unwrap() += 1;
1055
1056 let shard_id: u64 = shard_id.into();
1057 assert!(shard_id < num_shards);
1058 }
1059 let expected_distribution: HashMap<ShardId, _> = [
1060 (ShardId::new(0), 247),
1061 (ShardId::new(1), 268),
1062 (ShardId::new(2), 233),
1063 (ShardId::new(3), 252),
1064 ]
1065 .into_iter()
1066 .collect();
1067 assert_eq!(shard_id_distribution, expected_distribution);
1068 }
1069
1070 #[test]
1071 fn test_shard_layout_v1() {
1072 let aid = |s: &str| s.parse().unwrap();
1073 let sid = |s: u64| ShardId::new(s);
1074
1075 let boundary_accounts =
1076 ["aurora", "bar", "foo", "foo.baz", "paz"].iter().map(|a| a.parse().unwrap()).collect();
1077 #[allow(deprecated)]
1078 let shard_layout = ShardLayout::v1(
1079 boundary_accounts,
1080 Some(new_shards_split_map(vec![vec![0, 1, 2], vec![3, 4, 5]])),
1081 1,
1082 );
1083 assert_eq!(
1084 shard_layout.get_children_shards_uids(ShardId::new(0)).unwrap(),
1085 (0..3).map(|x| ShardUId { version: 1, shard_id: x }).collect::<Vec<_>>()
1086 );
1087 assert_eq!(
1088 shard_layout.get_children_shards_uids(ShardId::new(1)).unwrap(),
1089 (3..6).map(|x| ShardUId { version: 1, shard_id: x }).collect::<Vec<_>>()
1090 );
1091 for x in 0..3 {
1092 assert_eq!(shard_layout.get_parent_shard_id(ShardId::new(x)).unwrap(), sid(0));
1093 assert_eq!(shard_layout.get_parent_shard_id(ShardId::new(x + 3)).unwrap(), sid(1));
1094 }
1095
1096 assert_eq!(shard_layout.account_id_to_shard_id(&aid("aurora")), sid(1));
1097 assert_eq!(shard_layout.account_id_to_shard_id(&aid("foo.aurora")), sid(3));
1098 assert_eq!(shard_layout.account_id_to_shard_id(&aid("bar.foo.aurora")), sid(2));
1099 assert_eq!(shard_layout.account_id_to_shard_id(&aid("bar")), sid(2));
1100 assert_eq!(shard_layout.account_id_to_shard_id(&aid("bar.bar")), sid(2));
1101 assert_eq!(shard_layout.account_id_to_shard_id(&aid("foo")), sid(3));
1102 assert_eq!(shard_layout.account_id_to_shard_id(&aid("baz.foo")), sid(2));
1103 assert_eq!(shard_layout.account_id_to_shard_id(&aid("foo.baz")), sid(4));
1104 assert_eq!(shard_layout.account_id_to_shard_id(&aid("a.foo.baz")), sid(0));
1105
1106 assert_eq!(shard_layout.account_id_to_shard_id(&aid("aaa")), sid(0));
1107 assert_eq!(shard_layout.account_id_to_shard_id(&aid("abc")), sid(0));
1108 assert_eq!(shard_layout.account_id_to_shard_id(&aid("bbb")), sid(2));
1109 assert_eq!(shard_layout.account_id_to_shard_id(&aid("foo.goo")), sid(4));
1110 assert_eq!(shard_layout.account_id_to_shard_id(&aid("goo")), sid(4));
1111 assert_eq!(shard_layout.account_id_to_shard_id(&aid("zoo")), sid(5));
1112 }
1113
1114 #[test]
1115 fn test_shard_layout_v2() {
1116 let sid = |s: u64| ShardId::new(s);
1117 let shard_layout = get_test_shard_layout_v2();
1118
1119 assert_eq!(shard_layout.account_id_to_shard_id(&"aaa".parse().unwrap()), sid(3));
1121 assert_eq!(shard_layout.account_id_to_shard_id(&"ddd".parse().unwrap()), sid(8));
1122 assert_eq!(shard_layout.account_id_to_shard_id(&"mmm".parse().unwrap()), sid(4));
1123 assert_eq!(shard_layout.account_id_to_shard_id(&"rrr".parse().unwrap()), sid(7));
1124
1125 assert_eq!(shard_layout.account_id_to_shard_id(&"ccc".parse().unwrap()), sid(8));
1127 assert_eq!(shard_layout.account_id_to_shard_id(&"kkk".parse().unwrap()), sid(4));
1128 assert_eq!(shard_layout.account_id_to_shard_id(&"ppp".parse().unwrap()), sid(7));
1129
1130 assert_eq!(shard_layout.shard_ids().collect_vec(), new_shard_ids_vec(vec![3, 8, 4, 7]));
1132
1133 let version = 3;
1135 let u = |shard_id| ShardUId { shard_id, version };
1136 assert_eq!(shard_layout.shard_uids().collect_vec(), vec![u(3), u(8), u(4), u(7)]);
1137
1138 assert_eq!(shard_layout.get_parent_shard_id(ShardId::new(3)).unwrap(), sid(3));
1140 assert_eq!(shard_layout.get_parent_shard_id(ShardId::new(8)).unwrap(), sid(1));
1141 assert_eq!(shard_layout.get_parent_shard_id(ShardId::new(4)).unwrap(), sid(4));
1142 assert_eq!(shard_layout.get_parent_shard_id(ShardId::new(7)).unwrap(), sid(1));
1143
1144 assert_eq!(
1146 shard_layout.get_children_shards_ids(ShardId::new(1)).unwrap(),
1147 new_shard_ids_vec(vec![7, 8])
1148 );
1149 assert_eq!(
1150 shard_layout.get_children_shards_ids(ShardId::new(3)).unwrap(),
1151 new_shard_ids_vec(vec![3])
1152 );
1153 assert_eq!(
1154 shard_layout.get_children_shards_ids(ShardId::new(4)).unwrap(),
1155 new_shard_ids_vec(vec![4])
1156 );
1157 }
1158
1159 fn get_test_shard_layout_v2() -> ShardLayout {
1160 let b0 = "ccc".parse().unwrap();
1161 let b1 = "kkk".parse().unwrap();
1162 let b2 = "ppp".parse().unwrap();
1163
1164 let boundary_accounts = vec![b0, b1, b2];
1165 let shard_ids = vec![3, 8, 4, 7];
1166 let shard_ids = new_shard_ids_vec(shard_ids);
1167
1168 let shards_split_map = BTreeMap::from([(1, vec![7, 8]), (3, vec![3]), (4, vec![4])]);
1171 let shards_split_map = new_shards_split_map_v2(shards_split_map);
1172 let shards_split_map = Some(shards_split_map);
1173
1174 ShardLayout::v2(boundary_accounts, shard_ids, shards_split_map)
1175 }
1176
1177 #[test]
1178 fn test_deriving_shard_layout() {
1179 fn to_boundary_accounts<const N: usize>(accounts: [&str; N]) -> Vec<AccountId> {
1180 accounts.into_iter().map(|a| a.parse().unwrap()).collect()
1181 }
1182
1183 fn to_shard_ids<const N: usize>(ids: [u32; N]) -> Vec<ShardId> {
1184 ids.into_iter().map(|id| ShardId::new(id as u64)).collect()
1185 }
1186
1187 fn to_shards_split_map<const N: usize>(
1188 xs: [(u32, Vec<u32>); N],
1189 ) -> BTreeMap<ShardId, Vec<ShardId>> {
1190 xs.into_iter()
1191 .map(|(k, xs)| {
1192 (
1193 ShardId::new(k as u64),
1194 xs.into_iter().map(|x| ShardId::new(x as u64)).collect(),
1195 )
1196 })
1197 .collect()
1198 }
1199
1200 let base_layout = ShardLayout::v2(vec![], vec![ShardId::new(0)], None);
1204 let derived_layout =
1205 ShardLayout::derive_shard_layout(&base_layout, "test1.near".parse().unwrap());
1206 assert_eq!(
1207 derived_layout,
1208 ShardLayout::v2(
1209 to_boundary_accounts(["test1.near"]),
1210 to_shard_ids([1, 2]),
1211 Some(to_shards_split_map([(0, vec![1, 2])])),
1212 ),
1213 );
1214
1215 let base_layout = derived_layout;
1219 let derived_layout =
1220 ShardLayout::derive_shard_layout(&base_layout, "test3.near".parse().unwrap());
1221 assert_eq!(
1222 derived_layout,
1223 ShardLayout::v2(
1224 to_boundary_accounts(["test1.near", "test3.near"]),
1225 to_shard_ids([1, 3, 4]),
1226 Some(to_shards_split_map([(1, vec![1]), (2, vec![3, 4])])),
1227 ),
1228 );
1229
1230 let base_layout = derived_layout;
1234 let derived_layout =
1235 ShardLayout::derive_shard_layout(&base_layout, "test0.near".parse().unwrap());
1236 assert_eq!(
1237 derived_layout,
1238 ShardLayout::v2(
1239 to_boundary_accounts(["test0.near", "test1.near", "test3.near"]),
1240 to_shard_ids([5, 6, 3, 4]),
1241 Some(to_shards_split_map([(1, vec![5, 6]), (3, vec![3]), (4, vec![4]),])),
1242 ),
1243 );
1244
1245 let base_layout = derived_layout;
1249 let derived_layout =
1250 ShardLayout::derive_shard_layout(&base_layout, "test2.near".parse().unwrap());
1251 assert_eq!(
1252 derived_layout,
1253 ShardLayout::v2(
1254 to_boundary_accounts(["test0.near", "test1.near", "test2.near", "test3.near"]),
1255 to_shard_ids([5, 6, 7, 8, 4]),
1256 Some(to_shards_split_map([
1257 (5, vec![5]),
1258 (6, vec![6]),
1259 (3, vec![7, 8]),
1260 (4, vec![4]),
1261 ])),
1262 )
1263 );
1264
1265 assert_eq!(base_layout.version(), 3);
1269 assert_eq!(base_layout.version(), derived_layout.version());
1270 }
1271
1272 #[test]
1276 fn test_multi_shard_non_contiguous() {
1277 for n in 2..10 {
1278 let shard_layout = ShardLayout::multi_shard(n, 0);
1279 assert!(!shard_layout.shard_ids().is_sorted());
1280 }
1281 }
1282}