canic_core/ops/storage/sharding/
registry.rs

1use crate::{
2    Error, ThisError,
3    cdk::{types::Principal, utils::time::now_secs},
4    ids::CanisterRole,
5    model::memory::sharding::{ShardEntry, ShardKey, ShardingRegistry},
6    ops::storage::StorageOpsError,
7};
8
9///
10/// ShardingRegistryOps
11///
12
13pub struct ShardingRegistryOps;
14
15///
16/// ShardingRegistryOpsError
17/// Storage-layer errors for sharding registry CRUD and consistency checks.
18///
19
20#[derive(Debug, ThisError)]
21pub enum ShardingRegistryOpsError {
22    #[error("shard not found: {0}")]
23    ShardNotFound(Principal),
24
25    #[error("invalid sharding key: {0}")]
26    InvalidKey(String),
27
28    #[error("shard {pid} belongs to pool '{actual}', not '{expected}'")]
29    PoolMismatch {
30        pid: Principal,
31        expected: String,
32        actual: String,
33    },
34
35    #[error("slot {slot} in pool '{pool}' already assigned to shard {pid}")]
36    SlotOccupied {
37        pool: String,
38        slot: u32,
39        pid: Principal,
40    },
41}
42
43impl From<ShardingRegistryOpsError> for Error {
44    fn from(err: ShardingRegistryOpsError) -> Self {
45        StorageOpsError::from(err).into()
46    }
47}
48
49impl ShardingRegistryOps {
50    /// Create a new shard entry in the registry.
51    pub fn create(
52        pid: Principal,
53        pool: &str,
54        slot: u32,
55        canister_role: &CanisterRole,
56        capacity: u32,
57    ) -> Result<(), Error> {
58        ShardingRegistry::with_mut(|core| {
59            if slot != ShardEntry::UNASSIGNED_SLOT {
60                for (other_pid, other_entry) in core.all_entries() {
61                    if other_pid != pid
62                        && other_entry.pool.as_ref() == pool
63                        && other_entry.slot == slot
64                    {
65                        return Err(ShardingRegistryOpsError::SlotOccupied {
66                            pool: pool.to_string(),
67                            slot,
68                            pid: other_pid,
69                        }
70                        .into());
71                    }
72                }
73            }
74
75            let entry =
76                ShardEntry::try_new(pool, slot, canister_role.clone(), capacity, now_secs())
77                    .map_err(ShardingRegistryOpsError::InvalidKey)?;
78            core.insert_entry(pid, entry);
79
80            Ok(())
81        })
82    }
83
84    /// Fetch a shard entry by principal.
85    #[must_use]
86    pub fn get(pid: Principal) -> Option<ShardEntry> {
87        ShardingRegistry::with(|core| core.get_entry(&pid))
88    }
89
90    /// Export all shard entries.
91    #[must_use]
92    pub fn export() -> Vec<(Principal, ShardEntry)> {
93        ShardingRegistry::export()
94    }
95
96    /// Returns the shard assigned to the given tenant (if any).
97    #[must_use]
98    pub fn tenant_shard(pool: &str, tenant: &str) -> Option<Principal> {
99        ShardingRegistry::tenant_shard(pool, tenant)
100    }
101
102    /// Lookup the slot index for a given shard principal.
103    #[must_use]
104    pub fn slot_for_shard(pool: &str, shard: Principal) -> Option<u32> {
105        ShardingRegistry::slot_for_shard(pool, shard)
106    }
107
108    /// Lists all tenants currently assigned to the specified shard.
109    #[must_use]
110    pub fn tenants_in_shard(pool: &str, shard: Principal) -> Vec<String> {
111        ShardingRegistry::tenants_in_shard(pool, shard)
112    }
113
114    /// Assign (or reassign) a tenant to a shard.
115    ///
116    /// Storage responsibilities:
117    /// - enforce referential integrity (target shard must exist)
118    /// - enforce pool consistency (assignment pool must match shard entry pool)
119    /// - maintain derived counters (`ShardEntry.count`)
120    pub fn assign(pool: &str, tenant: &str, shard: Principal) -> Result<(), Error> {
121        ShardingRegistry::with_mut(|core| {
122            let mut entry = core
123                .get_entry(&shard)
124                .ok_or(ShardingRegistryOpsError::ShardNotFound(shard))?;
125
126            if entry.pool.as_ref() != pool {
127                return Err(ShardingRegistryOpsError::PoolMismatch {
128                    pid: shard,
129                    expected: pool.to_string(),
130                    actual: entry.pool.to_string(),
131                }
132                .into());
133            }
134
135            let key =
136                ShardKey::try_new(pool, tenant).map_err(ShardingRegistryOpsError::InvalidKey)?;
137
138            // If tenant is already assigned, decrement the old shard count.
139            if let Some(current) = core.get_assignment(&key) {
140                if current == shard {
141                    return Ok(());
142                }
143
144                if let Some(mut old_entry) = core.get_entry(&current) {
145                    old_entry.count = old_entry.count.saturating_sub(1);
146                    core.insert_entry(current, old_entry);
147                }
148            }
149
150            // Overwrite the assignment and increment the new shard count.
151            core.insert_assignment(key, shard);
152            entry.count = entry.count.saturating_add(1);
153            core.insert_entry(shard, entry);
154
155            Ok(())
156        })
157    }
158
159    /// Remove a tenant assignment, if present.
160    ///
161    /// Returns the shard principal that previously held the assignment.
162    pub fn unassign(pool: &str, tenant: &str) -> Result<Option<Principal>, Error> {
163        ShardingRegistry::with_mut(|core| {
164            let key =
165                ShardKey::try_new(pool, tenant).map_err(ShardingRegistryOpsError::InvalidKey)?;
166            let Some(shard) = core.remove_assignment(&key) else {
167                return Ok(None);
168            };
169
170            if let Some(mut entry) = core.get_entry(&shard) {
171                entry.count = entry.count.saturating_sub(1);
172                core.insert_entry(shard, entry);
173            }
174
175            Ok(Some(shard))
176        })
177    }
178
179    /// Update the logical slot index for a shard entry.
180    pub fn set_slot(pid: Principal, slot: u32) -> Result<(), Error> {
181        ShardingRegistry::with_mut(|core| {
182            let mut entry = core
183                .get_entry(&pid)
184                .ok_or(ShardingRegistryOpsError::ShardNotFound(pid))?;
185
186            if slot != ShardEntry::UNASSIGNED_SLOT {
187                for (other_pid, other_entry) in core.all_entries() {
188                    if other_pid != pid
189                        && other_entry.pool == entry.pool
190                        && other_entry.slot == slot
191                    {
192                        return Err(ShardingRegistryOpsError::SlotOccupied {
193                            pool: entry.pool.to_string(),
194                            slot,
195                            pid: other_pid,
196                        }
197                        .into());
198                    }
199                }
200            }
201
202            entry.slot = slot;
203            core.insert_entry(pid, entry);
204
205            Ok(())
206        })
207    }
208
209    #[cfg(test)]
210    pub(crate) fn clear_for_test() {
211        ShardingRegistry::clear();
212    }
213}
214
215#[cfg(test)]
216mod tests {
217    use super::*;
218
219    fn p(id: u8) -> Principal {
220        Principal::from_slice(&[id; 29])
221    }
222
223    #[test]
224    fn assign_and_unassign_updates_count() {
225        ShardingRegistryOps::clear_for_test();
226        let role = CanisterRole::new("alpha");
227        let shard_pid = p(1);
228
229        ShardingRegistryOps::create(shard_pid, "poolA", 0, &role, 2).unwrap();
230        ShardingRegistryOps::assign("poolA", "tenant1", shard_pid).unwrap();
231        let count_after = ShardingRegistryOps::get(shard_pid).unwrap().count;
232        assert_eq!(count_after, 1);
233
234        assert_eq!(
235            ShardingRegistryOps::unassign("poolA", "tenant1").unwrap(),
236            Some(shard_pid)
237        );
238        let count_final = ShardingRegistryOps::get(shard_pid).unwrap().count;
239        assert_eq!(count_final, 0);
240    }
241}