Skip to main content

obj_alloc/
id_map.rs

1//! 极简版 IdMap:自动生成递增 Id + Id 透明序列化 + 无条件编译
2//! 核心特性:插入值自动返回递增 Id、Id 浅包装 u64、无任何条件编译
3
4use core::fmt;
5use std::collections::HashMap;
6use std::marker::PhantomData;
7use std::ops::{Index, IndexMut};
8
9// ============================ 核心 Id 定义 ============================
10/// Id 基础 trait,所有自定义 Id 需实现此 trait
11pub trait Id: Copy + Clone + Eq + PartialEq + fmt::Debug + Into<u64> + From<u64> {
12    /// 快速转换为 u64
13    fn as_u64(&self) -> u64 {
14        (*self).into()
15    }
16    
17    /// 从 u64 构建 Id
18    fn from_u64(val: u64) -> Self {
19        Self::from(val)
20    }
21}
22
23
24// ============================ 自定义 Id 生成宏 ============================
25/// 生成自定义 Id 类型的极简宏
26#[macro_export]
27macro_rules! new_id_type {
28    // 递归终止条件:无剩余参数时结束
29    () => {};
30
31    // 核心匹配模式:单个 ID 结构体定义(带可选 vis + 属性 + 名称)
32    (
33        $(#[$meta:meta])*
34        $vis:vis struct $name:ident;
35        $($rest:tt)*
36    ) => {
37        // 生成单个 ID 结构体的完整定义
38        $(#[$meta])*
39        #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
40        $vis struct $name(pub u64);
41
42        impl From<u64> for $name {
43            #[inline]
44            fn from(val: u64) -> Self {
45                Self(val)
46            }
47        }
48
49        impl From<$name> for u64 {
50            #[inline]
51            fn from(id: $name) -> Self {
52                id.0
53            }
54        }
55
56        impl $crate::Id for $name {}
57
58        impl serde::Serialize for $name {
59            #[inline]
60            fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
61            where
62                S: serde::Serializer,
63            {
64                self.0.serialize(serializer)
65            }
66        }
67
68        impl<'de> serde::Deserialize<'de> for $name {
69            #[inline]
70            fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
71            where
72                D: serde::Deserializer<'de>,
73            {
74                let val = u64::deserialize(deserializer)?;
75                Ok(Self(val))
76            }
77        }
78
79        $crate::new_id_type!($($rest)*);
80    };
81}
82
83new_id_type!{
84    pub struct DefaultId;
85}
86
87
88// ============================ IdMap 核心实现(自动生成递增 Id) ============================
89/// 极简版 IdMap:自动生成递增 Id + HashMap 存储 + 无条件编译
90#[derive(Debug, Clone)]
91pub struct IdMap<K: Id, V> {
92    pub(crate) inner: HashMap<u64, V>, // 底层存储:u64 -> V
93    max_id: u64,            // 记录最大 Id,用于生成递增 Id
94    _marker: PhantomData<K>,
95}
96
97impl<V> IdMap<DefaultId, V> {
98    /// 创建空的 IdMap(初始 max_id = 0)
99    pub fn new() -> Self { Self::with_id_capacity(0) }
100    
101    /// 创建指定初始容量的 IdMap
102    pub fn with_capacity(capacity: usize) -> Self { Self::with_id_capacity(capacity) }
103}
104
105impl<K: Id, V> IdMap<K, V> {
106    /// 为自定义 Id 类型创建空 IdMap
107    pub fn with_id() -> Self {
108        Self {
109            inner: HashMap::new(),
110            max_id: 0,
111            _marker: PhantomData,
112        }
113    }
114    
115    /// 自定义 Id 类型创建指定初始容量的 IdMap
116    pub fn with_id_capacity(capacity: usize) -> Self {
117        Self {
118            inner: HashMap::with_capacity(capacity),
119            max_id: 0,
120            _marker: PhantomData,
121        }
122    }
123    
124    /// 插入值,自动生成递增 Id 并返回
125    pub fn insert(&mut self, value: V) -> K {
126        self.max_id += 1; // 递增生成新 Id(从 1 开始,避免 0 作为初始值)
127        let id_u64 = self.max_id;
128        self.inner.insert(id_u64, value); // 存储值
129        K::from_u64(id_u64) // 转换为指定 Id 类型并返回
130    }
131    
132    
133    /// 【手动指定 Id】插入键值对,返回旧值(若存在)
134    ///
135    /// 注意:若手动传入的 Id 大于当前 max_id,会更新 max_id 以保证自动生成的 Id 不重复
136    pub fn insert_with_id(&mut self, id: K, value: V) -> Option<V> {
137        let id_u64 = id.as_u64();
138        // 若手动传入的 Id 更大,更新 max_id,避免自动生成 Id 重复
139        if id_u64 > self.max_id {
140            self.max_id = id_u64;
141        }
142        self.inner.insert(id_u64, value)
143    }
144    
145    /// 从 Vec<V> 批量插入值,自动生成递增 Id,返回对应的 Id 列表
146    /// 生成的 Id 从当前 max_id + 1 开始连续递增
147    pub fn from_vec(values: Vec<V>) -> (Self, Vec<K>) {
148        let mut map = Self {
149            inner: HashMap::with_capacity(values.len()),
150            max_id: 0,
151            _marker: PhantomData,
152        };
153        let ids = values
154            .into_iter()
155            .map(|val| {
156                map.max_id += 1;
157                let id_u64 = map.max_id;
158                map.inner.insert(id_u64, val);
159                K::from_u64(id_u64)
160            })
161            .collect();
162        (map, ids)
163    }
164    
165    /// 循环插入:先生成递增 Id,再通过闭包(Id → V)生成值并插入
166    /// 适用于值需要依赖自身 Id 的场景(如循环引用/关联 Id 的场景)
167    pub fn insert_cyclic<F>(&mut self, f: F) -> K
168    where
169        F: FnOnce(K) -> V,
170    {
171        self.max_id += 1;
172        let new_id = K::from_u64(self.max_id);
173        let value = f(new_id);
174        self.inner.insert(self.max_id, value);
175        new_id
176    }
177    
178    /// 根据 Id 查询值
179    pub fn get(&self, id: K) -> Option<&V> {
180        self.inner.get(&id.as_u64())
181    }
182    
183    /// 根据 Id 查询可变值
184    pub fn get_mut(&mut self, id: K) -> Option<&mut V> {
185        self.inner.get_mut(&id.as_u64())
186    }
187    
188    /// 根据 Id 删除值
189    pub fn remove(&mut self, id: K) -> Option<V> {
190        self.inner.remove(&id.as_u64())
191    }
192    
193    /// 判断是否包含指定 Id
194    pub fn contains_id(&self, id: K) -> bool {
195        self.inner.contains_key(&id.as_u64())
196    }
197    
198    /// 获取当前最大 Id(仅用于参考,删除 Id 后不会回退)
199    pub fn max_id(&self) -> K {
200        K::from_u64(self.max_id)
201    }
202    
203    /// 获取元素数量
204    pub fn len(&self) -> usize {
205        self.inner.len()
206    }
207    
208    /// 判断是否为空
209    pub fn is_empty(&self) -> bool {
210        self.inner.is_empty()
211    }
212    
213    /// 清空所有元素(保留 max_id 不变,避免 Id 重复)
214    pub fn clear(&mut self) {
215        self.inner.clear();
216    }
217}
218
219// ============================ Index/IndexMut 实现 ============================
220impl<K: Id, V> Index<K> for IdMap<K, V> {
221    type Output = V;
222    
223    fn index(&self, id: K) -> &Self::Output {
224        self.get(id).expect("invalid IdMap id")
225    }
226}
227
228impl<K: Id, V> IndexMut<K> for IdMap<K, V> {
229    fn index_mut(&mut self, id: K) -> &mut Self::Output {
230        self.get_mut(id).expect("invalid IdMap id")
231    }
232}
233
234// ============================ 测试用例 ============================
235#[cfg(test)]
236mod tests {
237    use super::*;
238    use serde_json;
239    
240    // 测试默认 Id + 自动递增生成
241    #[test]
242    fn test_default_id_auto_generate() {
243        let mut map = IdMap::new();
244        
245        // 插入值,自动返回递增 Id
246        let id1 = map.insert("hello");
247        let id2 = map.insert("world");
248        let id3 = map.insert("rust");
249        
250        // 验证 Id 递增(从 1 开始)
251        assert_eq!(id1, DefaultId(1));
252        assert_eq!(id2, DefaultId(2));
253        assert_eq!(id3, DefaultId(3));
254        
255        // 验证值查询
256        assert_eq!(map.get(id1), Some(&"hello"));
257        assert_eq!(map[id2], "world");
258        assert_eq!(map.max_id(), DefaultId(3));
259        
260        // 删除值后,max_id 不回退
261        map.remove(id2);
262        assert_eq!(map.max_id(), DefaultId(3));
263        let id4 = map.insert("new value");
264        assert_eq!(id4, DefaultId(4)); // 继续递增
265        
266        // 数量/空判断
267        assert_eq!(map.len(), 3);
268        map.clear();
269        assert!(map.is_empty());
270    }
271    
272    // 测试自定义 Id
273    new_id_type! {
274        struct MyId;
275    }
276    
277    #[test]
278    fn test_custom_id() {
279        let mut map = IdMap::<MyId, u32>::with_id();
280        
281        let id1 = map.insert(42);
282        let id2 = map.insert(100);
283        
284        assert_eq!(id1, MyId(1));
285        assert_eq!(id2, MyId(2));
286        assert_eq!(map.get(id1), Some(&42));
287        
288        // 删除测试
289        map.remove(id1);
290        assert!(!map.contains_id(id1));
291    }
292    
293    // 测试 Id 透明序列化
294    #[test]
295    fn test_id_serde() {
296        // 测试默认 Id
297        let id = DefaultId(123456789);
298        let json = serde_json::to_string(&id).unwrap();
299        assert_eq!(json, "123456789"); // 直接输出 u64 字符串
300        let id2: DefaultId = serde_json::from_str(&json).unwrap();
301        assert_eq!(id2, id);
302        
303        // 测试自定义 Id
304        let my_id = MyId(987654321);
305        let json = serde_json::to_string(&my_id).unwrap();
306        let my_id2: MyId = serde_json::from_str(&json).unwrap();
307        assert_eq!(my_id2, my_id);
308    }
309}