1use crate::prelude::*;
20use std::{
21 hash::{Hash, Hasher},
22 sync::{atomic::Ordering, Arc, LazyLock, Mutex, MutexGuard},
23};
24
25#[derive_group(Serializers)]
27#[derive(Default, Clone, Debug, JsonSchema, Hash, PartialEq, Eq, PartialOrd, Ord)]
28#[serde(transparent)]
29pub struct Id {
30 id: u32,
31}
32
33#[derive(Default, Debug)]
35pub struct Session {
36 next_id: Id,
37 table: Table,
38}
39
40impl Session {
41 pub fn table(&self) -> &Table {
42 &self.table
43 }
44}
45
46#[derive(Debug, Clone, Deserialize, Serialize)]
48pub enum Value {
49 Ty(Arc<TyKind>),
50 DefId(Arc<DefIdContents>),
51}
52
53impl SupportedType<Value> for TyKind {
54 fn to_types(value: Arc<Self>) -> Value {
55 Value::Ty(value)
56 }
57 fn from_types(t: &Value) -> Option<Arc<Self>> {
58 match t {
59 Value::Ty(value) => Some(value.clone()),
60 _ => None,
61 }
62 }
63}
64
65impl SupportedType<Value> for DefIdContents {
66 fn to_types(value: Arc<Self>) -> Value {
67 Value::DefId(value)
68 }
69 fn from_types(t: &Value) -> Option<Arc<Self>> {
70 match t {
71 Value::DefId(value) => Some(value.clone()),
72 _ => None,
73 }
74 }
75}
76
77#[derive(Deserialize, Serialize, Debug, JsonSchema, PartialEq, Eq, PartialOrd, Ord)]
79#[serde(into = "serde_repr::NodeRepr<T>")]
80#[serde(try_from = "serde_repr::NodeRepr<T>")]
81pub struct Node<T: 'static + SupportedType<Value>> {
82 id: Id,
83 value: Arc<T>,
84}
85
86impl<T: SupportedType<Value>> std::ops::Deref for Node<T> {
87 type Target = T;
88 fn deref(&self) -> &Self::Target {
89 self.value.as_ref()
90 }
91}
92
93impl<T: SupportedType<Value> + Hash> Hash for Node<T> {
97 fn hash<H: Hasher>(&self, state: &mut H) {
98 self.value.as_ref().hash(state);
99 }
100}
101
102impl<T: SupportedType<Value>> Clone for Node<T> {
105 fn clone(&self) -> Self {
106 Self {
107 id: self.id.clone(),
108 value: self.value.clone(),
109 }
110 }
111}
112
113#[derive(Default, Debug, Clone, Deserialize, Serialize)]
118#[serde(into = "serde_repr::SortedIdValuePairs")]
119#[serde(from = "serde_repr::SortedIdValuePairs")]
120pub struct Table(HeterogeneousMap<Id, Value>);
121
122mod heterogeneous_map {
123 use std::collections::HashMap;
127 use std::hash::Hash;
128 use std::sync::Arc;
129 #[derive(Clone, Debug)]
130 pub struct HeterogeneousMap<Key, Value>(HashMap<Key, Value>);
134
135 impl<Id, Value> Default for HeterogeneousMap<Id, Value> {
136 fn default() -> Self {
137 Self(HashMap::default())
138 }
139 }
140
141 impl<Key: Hash + Eq + PartialEq, Value> HeterogeneousMap<Key, Value> {
142 pub(super) fn insert<T>(&mut self, key: Key, value: Arc<T>)
143 where
144 T: SupportedType<Value>,
145 {
146 self.insert_raw_value(key, T::to_types(value));
147 }
148 pub(super) fn insert_raw_value(&mut self, key: Key, value: Value) {
149 self.0.insert(key, value);
150 }
151 pub(super) fn from_iter(it: impl Iterator<Item = (Key, Value)>) -> Self {
152 Self(HashMap::from_iter(it))
153 }
154 pub(super) fn into_iter(self) -> impl Iterator<Item = (Key, Value)> {
155 self.0.into_iter()
156 }
157 pub(super) fn get<T>(&self, key: &Key) -> Option<Option<Arc<T>>>
158 where
159 T: SupportedType<Value>,
160 {
161 self.0.get(key).map(T::from_types)
162 }
163 }
164
165 pub trait SupportedType<Value>: std::fmt::Debug {
168 fn to_types(value: Arc<Self>) -> Value;
169 fn from_types(t: &Value) -> Option<Arc<Self>>;
170 }
171}
172use heterogeneous_map::*;
173
174impl Session {
175 fn fresh_id(&mut self) -> Id {
176 let id = self.next_id.id;
177 self.next_id.id += 1;
178 Id { id }
179 }
180}
181
182impl<T: Sync + Send + 'static + SupportedType<Value>> Node<T> {
183 pub fn new(value: T, session: &mut Session) -> Self {
184 let id = session.fresh_id();
185 let value = Arc::new(value);
186 session.table.0.insert(id.clone(), value.clone());
187 Self { id, value }
188 }
189
190 pub fn inner(&self) -> &Arc<T> {
191 &self.value
192 }
193}
194
195#[derive(Debug)]
201pub struct WithTable<T> {
202 table: Table,
203 value: T,
204}
205
206static DESERIALIZATION_STATE: LazyLock<Mutex<Table>> =
208 LazyLock::new(|| Mutex::new(Table::default()));
209static DESERIALIZATION_STATE_LOCK: LazyLock<Mutex<()>> = LazyLock::new(|| Mutex::new(()));
210
211static SERIALIZATION_MODE_USE_IDS: std::sync::atomic::AtomicBool =
213 std::sync::atomic::AtomicBool::new(false);
214
215fn serialize_use_id() -> bool {
216 SERIALIZATION_MODE_USE_IDS.load(Ordering::Relaxed)
217}
218
219impl<T> WithTable<T> {
220 pub fn run<R>(map: Table, value: T, f: impl FnOnce(&Self) -> R) -> R {
224 if serialize_use_id() {
225 panic!("CACHE_MAP_LOCK: only one WithTable serialization can occur at a time (nesting is forbidden)")
226 }
227 SERIALIZATION_MODE_USE_IDS.store(true, Ordering::Relaxed);
228 let result = f(&Self { table: map, value });
229 SERIALIZATION_MODE_USE_IDS.store(false, Ordering::Relaxed);
230 result
231 }
232 pub fn destruct(self) -> (T, Table) {
233 let Self { value, table: map } = self;
234 (value, map)
235 }
236}
237
238impl<T: Serialize> Serialize for WithTable<T> {
239 fn serialize<S: serde::Serializer>(&self, serializer: S) -> Result<S::Ok, S::Error> {
240 let mut ts = serializer.serialize_tuple_struct("WithTable", 2)?;
241 use serde::ser::SerializeTupleStruct;
242 ts.serialize_field(&self.table)?;
243 ts.serialize_field(&self.value)?;
244 ts.end()
245 }
246}
247
248impl<'de, T: Deserialize<'de>> serde::Deserialize<'de> for WithTable<T> {
254 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
255 where
256 D: serde::Deserializer<'de>,
257 {
258 let _lock: MutexGuard<_> = DESERIALIZATION_STATE_LOCK.try_lock().expect("CACHE_MAP_LOCK: only one WithTable deserialization can occur at a time (nesting is forbidden)");
259 use serde_repr::WithTableRepr;
260 let previous = std::mem::take(&mut *DESERIALIZATION_STATE.lock().unwrap());
261 let with_table_repr = WithTableRepr::deserialize(deserializer);
262 *DESERIALIZATION_STATE.lock().unwrap() = previous;
263 let WithTableRepr(table, value) = with_table_repr?;
264 Ok(Self { table, value })
265 }
266}
267
268mod serde_repr {
271 use super::*;
272
273 #[derive(Serialize, Deserialize, JsonSchema, Debug)]
274 pub(super) struct NodeRepr<T> {
275 id: Id,
276 value: Option<Arc<T>>,
277 }
278
279 #[derive(Serialize)]
280 pub(super) struct Pair(Id, Value);
281 pub(super) type SortedIdValuePairs = Vec<Pair>;
282
283 #[derive(Serialize, Deserialize)]
284 pub(super) struct WithTableRepr<T>(pub(super) Table, pub(super) T);
285
286 impl<T: SupportedType<Value>> Into<NodeRepr<T>> for Node<T> {
287 fn into(self) -> NodeRepr<T> {
288 let value = if serialize_use_id() {
289 None
290 } else {
291 Some(self.value.clone())
292 };
293 let id = self.id;
294 NodeRepr { value, id }
295 }
296 }
297
298 impl<T: 'static + SupportedType<Value>> TryFrom<NodeRepr<T>> for Node<T> {
299 type Error = serde::de::value::Error;
300
301 fn try_from(cached: NodeRepr<T>) -> Result<Self, Self::Error> {
302 use serde::de::Error;
303 let table = DESERIALIZATION_STATE.lock().unwrap();
304 let id = cached.id;
305 let kind = if let Some(kind) = cached.value {
306 kind
307 } else {
308 table
309 .0
310 .get(&id)
311 .ok_or_else(|| {
312 Self::Error::custom(&format!(
313 "Stateful deserialization failed for id {:?}: not found in cache",
314 id
315 ))
316 })?
317 .ok_or_else(|| {
318 Self::Error::custom(&format!(
319 "Stateful deserialization failed for id {:?}: wrong type",
320 id
321 ))
322 })?
323 };
324 Ok(Self { value: kind, id })
325 }
326 }
327
328 impl<'de> serde::Deserialize<'de> for Pair {
329 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
330 where
331 D: serde::Deserializer<'de>,
332 {
333 let (id, v) = <(Id, Value)>::deserialize(deserializer)?;
334 DESERIALIZATION_STATE
335 .lock()
336 .unwrap()
337 .0
338 .insert_raw_value(id.clone(), v.clone());
339 Ok(Pair(id, v))
340 }
341 }
342
343 impl Into<SortedIdValuePairs> for Table {
344 fn into(self) -> SortedIdValuePairs {
345 let mut vec: Vec<_> = self.0.into_iter().map(|(x, y)| Pair(x, y)).collect();
346 vec.sort_by_key(|o| o.0.clone());
347 vec
348 }
349 }
350
351 impl From<SortedIdValuePairs> for Table {
352 fn from(t: SortedIdValuePairs) -> Self {
353 Self(HeterogeneousMap::from_iter(
354 t.into_iter().map(|Pair(x, y)| (x, y)),
355 ))
356 }
357 }
358}