Skip to main content

obj_alloc/
id_map.rs

1//! 极简版 IdMap:自动生成递增 Id + Id 透明序列化 + 无条件编译
2//! 核心特性:插入值自动返回递增 Id、Id 浅包装 u64、无任何条件编译
3
4use core::fmt;
5use std::collections::HashMap;
6use std::marker::PhantomData;
7use std::ops::{Index, IndexMut};
8use serde::{Deserialize, Serialize};
9
10// ============================ 核心 Id 定义 ============================
11/// Id 基础 trait,所有自定义 Id 需实现此 trait
12pub trait Id: Copy + Clone + Eq + PartialEq + fmt::Debug + Into<u64> + From<u64> {
13    const NONE: Self;
14
15    fn none() -> Self {
16        Self::NONE
17    }
18    
19    fn is_none(&self) -> bool {
20        self.eq(&Self::NONE)
21    }
22    
23    /// 快速转换为 u64
24    fn as_u64(&self) -> u64 {
25        (*self).into()
26    }
27    
28    /// 从 u64 构建 Id
29    fn from_u64(val: u64) -> Self {
30        Self::from(val)
31    }
32}
33
34
35// ============================ 自定义 Id 生成宏 ============================
36/// 生成自定义 Id 类型的极简宏
37#[macro_export]
38macro_rules! new_id_type {
39    // 递归终止条件:无剩余参数时结束
40    () => {};
41
42    // 核心匹配模式:单个 ID 结构体定义(带可选 vis + 属性 + 名称)
43    (
44        $(#[$meta:meta])*
45        $vis:vis struct $name:ident;
46        $($rest:tt)*
47    ) => {
48        // 生成单个 ID 结构体的完整定义
49        $(#[$meta])*
50        #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
51        $vis struct $name(pub u64);
52
53        impl From<u64> for $name {
54            #[inline]
55            fn from(val: u64) -> Self {
56                Self(val)
57            }
58        }
59
60        impl From<$name> for u64 {
61            #[inline]
62            fn from(id: $name) -> Self {
63                id.0
64            }
65        }
66
67        impl $crate::Id for $name {
68            const NONE: Self = Self(0);
69        }
70
71        impl serde::Serialize for $name {
72            #[inline]
73            fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
74            where
75                S: serde::Serializer,
76            {
77                self.0.serialize(serializer)
78            }
79        }
80
81        impl<'de> serde::Deserialize<'de> for $name {
82            #[inline]
83            fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
84            where
85                D: serde::Deserializer<'de>,
86            {
87                let val = u64::deserialize(deserializer)?;
88                Ok(Self(val))
89            }
90        }
91
92        $crate::new_id_type!($($rest)*);
93    };
94}
95
96new_id_type!{
97    pub struct DefaultId;
98}
99
100
101// ============================ IdMap 核心实现(自动生成递增 Id) ============================
102/// 极简版 IdMap:自动生成递增 Id + HashMap 存储 + 无条件编译
103#[derive(Debug, Clone)]
104#[derive(Serialize, Deserialize)]
105#[serde(transparent)]
106pub struct IdMap<K: Id, V> {
107    pub(crate) inner: HashMap<u64, V>, // 底层存储:u64 -> V
108    #[serde(skip)]
109    max_id: u64,            // 记录最大 Id,用于生成递增 Id
110    #[serde(skip)]
111    _marker: PhantomData<K>,
112}
113
114impl<V> IdMap<DefaultId, V> {
115    /// 创建空的 IdMap(初始 max_id = 0)
116    pub fn new() -> Self { Self::with_id_capacity(0) }
117    
118    /// 创建指定初始容量的 IdMap
119    pub fn with_capacity(capacity: usize) -> Self { Self::with_id_capacity(capacity) }
120}
121
122impl<K: Id ,V> Default for IdMap<K, V> {
123    fn default() -> Self {
124        Self::with_id()
125    }
126}
127
128impl<K: Id, V> IdMap<K, V> {
129    /// 为自定义 Id 类型创建空 IdMap
130    pub fn with_id() -> Self {
131        Self {
132            inner: HashMap::new(),
133            max_id: 0,
134            _marker: PhantomData,
135        }
136    }
137    
138    /// 自定义 Id 类型创建指定初始容量的 IdMap
139    pub fn with_id_capacity(capacity: usize) -> Self {
140        Self {
141            inner: HashMap::with_capacity(capacity),
142            max_id: 0,
143            _marker: PhantomData,
144        }
145    }
146    
147    /// 插入值,自动生成递增 Id 并返回
148    pub fn insert(&mut self, value: V) -> K {
149        self.max_id += 1; // 递增生成新 Id(从 1 开始,避免 0 作为初始值)
150        let id_u64 = self.max_id;
151        self.inner.insert(id_u64, value); // 存储值
152        K::from_u64(id_u64) // 转换为指定 Id 类型并返回
153    }
154    
155    
156    /// 【手动指定 Id】插入键值对,返回旧值(若存在)
157    ///
158    /// 注意:若手动传入的 Id 大于当前 max_id,会更新 max_id 以保证自动生成的 Id 不重复
159    pub fn insert_with_id(&mut self, id: K, value: V) -> Option<V> {
160        let id_u64 = id.as_u64();
161        // 若手动传入的 Id 更大,更新 max_id,避免自动生成 Id 重复
162        if id_u64 > self.max_id {
163            self.max_id = id_u64;
164        }
165        self.inner.insert(id_u64, value)
166    }
167    
168    /// 从 Vec<V> 批量插入值,自动生成递增 Id,返回对应的 Id 列表
169    /// 生成的 Id 从当前 max_id + 1 开始连续递增
170    pub fn from_vec(values: Vec<V>) -> (Self, Vec<K>) {
171        let mut map = Self {
172            inner: HashMap::with_capacity(values.len()),
173            max_id: 0,
174            _marker: PhantomData,
175        };
176        let ids = values
177            .into_iter()
178            .map(|val| {
179                map.max_id += 1;
180                let id_u64 = map.max_id;
181                map.inner.insert(id_u64, val);
182                K::from_u64(id_u64)
183            })
184            .collect();
185        (map, ids)
186    }
187    
188    /// 循环插入:先生成递增 Id,再通过闭包(Id → V)生成值并插入
189    /// 适用于值需要依赖自身 Id 的场景(如循环引用/关联 Id 的场景)
190    pub fn insert_cyclic<F>(&mut self, f: F) -> K
191    where
192        F: FnOnce(K) -> V,
193    {
194        self.max_id += 1;
195        let new_id = K::from_u64(self.max_id);
196        let value = f(new_id);
197        self.inner.insert(self.max_id, value);
198        new_id
199    }
200    
201    /// 根据 Id 查询值
202    pub fn get(&self, id: K) -> Option<&V> {
203        self.inner.get(&id.as_u64())
204    }
205    
206    /// 根据 Id 查询可变值
207    pub fn get_mut(&mut self, id: K) -> Option<&mut V> {
208        self.inner.get_mut(&id.as_u64())
209    }
210    
211    /// 根据 Id 删除值
212    pub fn remove(&mut self, id: K) -> Option<V> {
213        self.inner.remove(&id.as_u64())
214    }
215    
216    /// 判断是否包含指定 Id
217    pub fn contains_id(&self, id: K) -> bool {
218        self.inner.contains_key(&id.as_u64())
219    }
220    
221    /// 获取当前最大 Id(仅用于参考,删除 Id 后不会回退)
222    pub fn max_id(&self) -> K {
223        K::from_u64(self.max_id)
224    }
225    
226    /// 获取元素数量
227    pub fn len(&self) -> usize {
228        self.inner.len()
229    }
230    
231    /// 判断是否为空
232    pub fn is_empty(&self) -> bool {
233        self.inner.is_empty()
234    }
235    
236    /// 清空所有元素(保留 max_id 不变,避免 Id 重复)
237    pub fn clear(&mut self) {
238        self.inner.clear();
239    }
240}
241
242// ============================ Index/IndexMut 实现 ============================
243impl<K: Id, V> Index<K> for IdMap<K, V> {
244    type Output = V;
245    
246    fn index(&self, id: K) -> &Self::Output {
247        self.get(id).expect("invalid IdMap id")
248    }
249}
250
251impl<K: Id, V> IndexMut<K> for IdMap<K, V> {
252    fn index_mut(&mut self, id: K) -> &mut Self::Output {
253        self.get_mut(id).expect("invalid IdMap id")
254    }
255}
256
257// ============================ 测试用例 ============================
258#[cfg(test)]
259mod tests {
260    use super::*;
261    use serde_json;
262    
263    // 测试默认 Id + 自动递增生成
264    #[test]
265    fn test_default_id_auto_generate() {
266        let mut map = IdMap::new();
267        
268        // 插入值,自动返回递增 Id
269        let id1 = map.insert("hello");
270        let id2 = map.insert("world");
271        let id3 = map.insert("rust");
272        
273        // 验证 Id 递增(从 1 开始)
274        assert_eq!(id1, DefaultId(1));
275        assert_eq!(id2, DefaultId(2));
276        assert_eq!(id3, DefaultId(3));
277        
278        // 验证值查询
279        assert_eq!(map.get(id1), Some(&"hello"));
280        assert_eq!(map[id2], "world");
281        assert_eq!(map.max_id(), DefaultId(3));
282        
283        // 删除值后,max_id 不回退
284        map.remove(id2);
285        assert_eq!(map.max_id(), DefaultId(3));
286        let id4 = map.insert("new value");
287        assert_eq!(id4, DefaultId(4)); // 继续递增
288        
289        // 数量/空判断
290        assert_eq!(map.len(), 3);
291        map.clear();
292        assert!(map.is_empty());
293    }
294    
295    // 测试自定义 Id
296    new_id_type! {
297        struct MyId;
298    }
299    
300    #[test]
301    fn test_custom_id() {
302        let mut map = IdMap::<MyId, u32>::with_id();
303        
304        let id1 = map.insert(42);
305        let id2 = map.insert(100);
306        
307        assert_eq!(id1, MyId(1));
308        assert_eq!(id2, MyId(2));
309        assert_eq!(map.get(id1), Some(&42));
310        
311        // 删除测试
312        map.remove(id1);
313        assert!(!map.contains_id(id1));
314    }
315    
316    // 测试 Id 透明序列化
317    #[test]
318    fn test_id_serde() {
319        // 测试默认 Id
320        let id = DefaultId(123456789);
321        let json = serde_json::to_string(&id).unwrap();
322        assert_eq!(json, "123456789"); // 直接输出 u64 字符串
323        let id2: DefaultId = serde_json::from_str(&json).unwrap();
324        assert_eq!(id2, id);
325        
326        // 测试自定义 Id
327        let my_id = MyId(987654321);
328        let json = serde_json::to_string(&my_id).unwrap();
329        let my_id2: MyId = serde_json::from_str(&json).unwrap();
330        assert_eq!(my_id2, my_id);
331    }
332}