Skip to main content

obj_alloc/
id_map.rs

1//! 极简版 IdMap:自动生成递增 Id + Id 透明序列化 + 无条件编译
2//! 核心特性:插入值自动返回递增 Id、Id 浅包装 u64、无任何条件编译
3
4use core::fmt;
5use std::collections::HashMap;
6use std::marker::PhantomData;
7use std::ops::{Index, IndexMut};
8use serde::{Deserialize, Serialize};
9
10// ============================ 核心 Id 定义 ============================
11/// Id 基础 trait,所有自定义 Id 需实现此 trait
12pub trait Id: Copy + Clone + Eq + PartialEq + fmt::Debug + Into<u64> + From<u64> {
13    /// 快速转换为 u64
14    fn as_u64(&self) -> u64 {
15        (*self).into()
16    }
17    
18    /// 从 u64 构建 Id
19    fn from_u64(val: u64) -> Self {
20        Self::from(val)
21    }
22}
23
24
25// ============================ 自定义 Id 生成宏 ============================
26/// 生成自定义 Id 类型的极简宏
27#[macro_export]
28macro_rules! new_id_type {
29    // 递归终止条件:无剩余参数时结束
30    () => {};
31
32    // 核心匹配模式:单个 ID 结构体定义(带可选 vis + 属性 + 名称)
33    (
34        $(#[$meta:meta])*
35        $vis:vis struct $name:ident;
36        $($rest:tt)*
37    ) => {
38        // 生成单个 ID 结构体的完整定义
39        $(#[$meta])*
40        #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
41        $vis struct $name(pub u64);
42
43        impl From<u64> for $name {
44            #[inline]
45            fn from(val: u64) -> Self {
46                Self(val)
47            }
48        }
49
50        impl From<$name> for u64 {
51            #[inline]
52            fn from(id: $name) -> Self {
53                id.0
54            }
55        }
56
57        impl $crate::Id for $name {}
58
59        impl serde::Serialize for $name {
60            #[inline]
61            fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
62            where
63                S: serde::Serializer,
64            {
65                self.0.serialize(serializer)
66            }
67        }
68
69        impl<'de> serde::Deserialize<'de> for $name {
70            #[inline]
71            fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
72            where
73                D: serde::Deserializer<'de>,
74            {
75                let val = u64::deserialize(deserializer)?;
76                Ok(Self(val))
77            }
78        }
79
80        $crate::new_id_type!($($rest)*);
81    };
82}
83
84new_id_type!{
85    pub struct DefaultId;
86}
87
88
89// ============================ IdMap 核心实现(自动生成递增 Id) ============================
90/// 极简版 IdMap:自动生成递增 Id + HashMap 存储 + 无条件编译
91#[derive(Debug, Clone)]
92#[derive(Serialize, Deserialize)]
93#[serde(transparent)]
94pub struct IdMap<K: Id, V> {
95    pub(crate) inner: HashMap<u64, V>, // 底层存储:u64 -> V
96    #[serde(skip)]
97    max_id: u64,            // 记录最大 Id,用于生成递增 Id
98    #[serde(skip)]
99    _marker: PhantomData<K>,
100}
101
102impl<V> IdMap<DefaultId, V> {
103    /// 创建空的 IdMap(初始 max_id = 0)
104    pub fn new() -> Self { Self::with_id_capacity(0) }
105    
106    /// 创建指定初始容量的 IdMap
107    pub fn with_capacity(capacity: usize) -> Self { Self::with_id_capacity(capacity) }
108}
109
110impl<K: Id ,V> Default for IdMap<K, V> {
111    fn default() -> Self {
112        Self::with_id()
113    }
114}
115
116impl<K: Id, V> IdMap<K, V> {
117    /// 为自定义 Id 类型创建空 IdMap
118    pub fn with_id() -> Self {
119        Self {
120            inner: HashMap::new(),
121            max_id: 0,
122            _marker: PhantomData,
123        }
124    }
125    
126    /// 自定义 Id 类型创建指定初始容量的 IdMap
127    pub fn with_id_capacity(capacity: usize) -> Self {
128        Self {
129            inner: HashMap::with_capacity(capacity),
130            max_id: 0,
131            _marker: PhantomData,
132        }
133    }
134    
135    /// 插入值,自动生成递增 Id 并返回
136    pub fn insert(&mut self, value: V) -> K {
137        self.max_id += 1; // 递增生成新 Id(从 1 开始,避免 0 作为初始值)
138        let id_u64 = self.max_id;
139        self.inner.insert(id_u64, value); // 存储值
140        K::from_u64(id_u64) // 转换为指定 Id 类型并返回
141    }
142    
143    
144    /// 【手动指定 Id】插入键值对,返回旧值(若存在)
145    ///
146    /// 注意:若手动传入的 Id 大于当前 max_id,会更新 max_id 以保证自动生成的 Id 不重复
147    pub fn insert_with_id(&mut self, id: K, value: V) -> Option<V> {
148        let id_u64 = id.as_u64();
149        // 若手动传入的 Id 更大,更新 max_id,避免自动生成 Id 重复
150        if id_u64 > self.max_id {
151            self.max_id = id_u64;
152        }
153        self.inner.insert(id_u64, value)
154    }
155    
156    /// 从 Vec<V> 批量插入值,自动生成递增 Id,返回对应的 Id 列表
157    /// 生成的 Id 从当前 max_id + 1 开始连续递增
158    pub fn from_vec(values: Vec<V>) -> (Self, Vec<K>) {
159        let mut map = Self {
160            inner: HashMap::with_capacity(values.len()),
161            max_id: 0,
162            _marker: PhantomData,
163        };
164        let ids = values
165            .into_iter()
166            .map(|val| {
167                map.max_id += 1;
168                let id_u64 = map.max_id;
169                map.inner.insert(id_u64, val);
170                K::from_u64(id_u64)
171            })
172            .collect();
173        (map, ids)
174    }
175    
176    /// 循环插入:先生成递增 Id,再通过闭包(Id → V)生成值并插入
177    /// 适用于值需要依赖自身 Id 的场景(如循环引用/关联 Id 的场景)
178    pub fn insert_cyclic<F>(&mut self, f: F) -> K
179    where
180        F: FnOnce(K) -> V,
181    {
182        self.max_id += 1;
183        let new_id = K::from_u64(self.max_id);
184        let value = f(new_id);
185        self.inner.insert(self.max_id, value);
186        new_id
187    }
188    
189    /// 根据 Id 查询值
190    pub fn get(&self, id: K) -> Option<&V> {
191        self.inner.get(&id.as_u64())
192    }
193    
194    /// 根据 Id 查询可变值
195    pub fn get_mut(&mut self, id: K) -> Option<&mut V> {
196        self.inner.get_mut(&id.as_u64())
197    }
198    
199    /// 根据 Id 删除值
200    pub fn remove(&mut self, id: K) -> Option<V> {
201        self.inner.remove(&id.as_u64())
202    }
203    
204    /// 判断是否包含指定 Id
205    pub fn contains_id(&self, id: K) -> bool {
206        self.inner.contains_key(&id.as_u64())
207    }
208    
209    /// 获取当前最大 Id(仅用于参考,删除 Id 后不会回退)
210    pub fn max_id(&self) -> K {
211        K::from_u64(self.max_id)
212    }
213    
214    /// 获取元素数量
215    pub fn len(&self) -> usize {
216        self.inner.len()
217    }
218    
219    /// 判断是否为空
220    pub fn is_empty(&self) -> bool {
221        self.inner.is_empty()
222    }
223    
224    /// 清空所有元素(保留 max_id 不变,避免 Id 重复)
225    pub fn clear(&mut self) {
226        self.inner.clear();
227    }
228}
229
230// ============================ Index/IndexMut 实现 ============================
231impl<K: Id, V> Index<K> for IdMap<K, V> {
232    type Output = V;
233    
234    fn index(&self, id: K) -> &Self::Output {
235        self.get(id).expect("invalid IdMap id")
236    }
237}
238
239impl<K: Id, V> IndexMut<K> for IdMap<K, V> {
240    fn index_mut(&mut self, id: K) -> &mut Self::Output {
241        self.get_mut(id).expect("invalid IdMap id")
242    }
243}
244
245// ============================ 测试用例 ============================
246#[cfg(test)]
247mod tests {
248    use super::*;
249    use serde_json;
250    
251    // 测试默认 Id + 自动递增生成
252    #[test]
253    fn test_default_id_auto_generate() {
254        let mut map = IdMap::new();
255        
256        // 插入值,自动返回递增 Id
257        let id1 = map.insert("hello");
258        let id2 = map.insert("world");
259        let id3 = map.insert("rust");
260        
261        // 验证 Id 递增(从 1 开始)
262        assert_eq!(id1, DefaultId(1));
263        assert_eq!(id2, DefaultId(2));
264        assert_eq!(id3, DefaultId(3));
265        
266        // 验证值查询
267        assert_eq!(map.get(id1), Some(&"hello"));
268        assert_eq!(map[id2], "world");
269        assert_eq!(map.max_id(), DefaultId(3));
270        
271        // 删除值后,max_id 不回退
272        map.remove(id2);
273        assert_eq!(map.max_id(), DefaultId(3));
274        let id4 = map.insert("new value");
275        assert_eq!(id4, DefaultId(4)); // 继续递增
276        
277        // 数量/空判断
278        assert_eq!(map.len(), 3);
279        map.clear();
280        assert!(map.is_empty());
281    }
282    
283    // 测试自定义 Id
284    new_id_type! {
285        struct MyId;
286    }
287    
288    #[test]
289    fn test_custom_id() {
290        let mut map = IdMap::<MyId, u32>::with_id();
291        
292        let id1 = map.insert(42);
293        let id2 = map.insert(100);
294        
295        assert_eq!(id1, MyId(1));
296        assert_eq!(id2, MyId(2));
297        assert_eq!(map.get(id1), Some(&42));
298        
299        // 删除测试
300        map.remove(id1);
301        assert!(!map.contains_id(id1));
302    }
303    
304    // 测试 Id 透明序列化
305    #[test]
306    fn test_id_serde() {
307        // 测试默认 Id
308        let id = DefaultId(123456789);
309        let json = serde_json::to_string(&id).unwrap();
310        assert_eq!(json, "123456789"); // 直接输出 u64 字符串
311        let id2: DefaultId = serde_json::from_str(&json).unwrap();
312        assert_eq!(id2, id);
313        
314        // 测试自定义 Id
315        let my_id = MyId(987654321);
316        let json = serde_json::to_string(&my_id).unwrap();
317        let my_id2: MyId = serde_json::from_str(&json).unwrap();
318        assert_eq!(my_id2, my_id);
319    }
320}