hax_frontend_exporter/
id_table.rs

1/// This module provides a notion of table, identifiers and nodes. A
2/// `Node<T>` is a `Arc<T>` bundled with a unique identifier such that
3/// there exists an entry in a table for that identifier.
4///
5/// The type `WithTable<T>` bundles a table with a value of type
6/// `T`. That value of type `T` may hold an arbitrary number of
7/// `Node<_>`s. In the context of a `WithTable<T>`, the type `Node<_>`
8/// serializes and deserializes using a table as a state. In this
9/// case, serializing a `Node<U>` produces only an identifier, without
10/// any data of type `U`. Deserializing a `Node<U>` under a
11/// `WithTable<T>` will recover `U` data from the table held by
12/// `WithTable`.
13///
14/// Serde is not designed for stateful (de)serialization. There is no
15/// way of deriving `serde::de::DeserializeSeed` systematically. This
16/// module thus makes use of global state to achieve serialization and
17/// deserialization. This modules provides an API that hides this
18/// global state.
19use crate::prelude::*;
20use std::{
21    hash::{Hash, Hasher},
22    sync::{atomic::Ordering, Arc, LazyLock, Mutex, MutexGuard},
23};
24
25/// Unique IDs in a ID table.
26#[derive_group(Serializers)]
27#[derive(Default, Clone, Debug, JsonSchema, Hash, PartialEq, Eq, PartialOrd, Ord)]
28#[serde(transparent)]
29pub struct Id {
30    id: u32,
31}
32
33/// A session providing fresh IDs for ID table.
34#[derive(Default, Debug)]
35pub struct Session {
36    next_id: Id,
37    table: Table,
38}
39
40impl Session {
41    pub fn table(&self) -> &Table {
42        &self.table
43    }
44}
45
46/// The different types of values one can store in an ID table.
47#[derive(Debug, Clone, Deserialize, Serialize)]
48pub enum Value {
49    Ty(Arc<TyKind>),
50    DefId(Arc<DefIdContents>),
51}
52
53impl SupportedType<Value> for TyKind {
54    fn to_types(value: Arc<Self>) -> Value {
55        Value::Ty(value)
56    }
57    fn from_types(t: &Value) -> Option<Arc<Self>> {
58        match t {
59            Value::Ty(value) => Some(value.clone()),
60            _ => None,
61        }
62    }
63}
64
65impl SupportedType<Value> for DefIdContents {
66    fn to_types(value: Arc<Self>) -> Value {
67        Value::DefId(value)
68    }
69    fn from_types(t: &Value) -> Option<Arc<Self>> {
70        match t {
71            Value::DefId(value) => Some(value.clone()),
72            _ => None,
73        }
74    }
75}
76
77/// A node is a bundle of an ID with a value.
78#[derive(Deserialize, Serialize, Debug, JsonSchema, PartialEq, Eq, PartialOrd, Ord)]
79#[serde(into = "serde_repr::NodeRepr<T>")]
80#[serde(try_from = "serde_repr::NodeRepr<T>")]
81pub struct Node<T: 'static + SupportedType<Value>> {
82    id: Id,
83    value: Arc<T>,
84}
85
86impl<T: SupportedType<Value>> std::ops::Deref for Node<T> {
87    type Target = T;
88    fn deref(&self) -> &Self::Target {
89        self.value.as_ref()
90    }
91}
92
93/// Hax relies on hashes being deterministic for predicates
94/// ids. Identifiers are not deterministic: we implement hash for
95/// `Node` manually, discarding the field `id`.
96impl<T: SupportedType<Value> + Hash> Hash for Node<T> {
97    fn hash<H: Hasher>(&self, state: &mut H) {
98        self.value.as_ref().hash(state);
99    }
100}
101
102/// Manual implementation of `Clone` that doesn't require a `Clone`
103/// bound on `T`.
104impl<T: SupportedType<Value>> Clone for Node<T> {
105    fn clone(&self) -> Self {
106        Self {
107            id: self.id.clone(),
108            value: self.value.clone(),
109        }
110    }
111}
112
113/// A table is a map from IDs to `Value`s. When serialized, we
114/// represent a table as a *sorted* vector. Indeed, the values stored
115/// in the table might reference each other, without cycle, so the
116/// order matters.
117#[derive(Default, Debug, Clone, Deserialize, Serialize)]
118#[serde(into = "serde_repr::SortedIdValuePairs")]
119#[serde(from = "serde_repr::SortedIdValuePairs")]
120pub struct Table(HeterogeneousMap<Id, Value>);
121
122mod heterogeneous_map {
123    //! This module provides an heterogenous map that can store types
124    //! that implement the trait `SupportedType`.
125
126    use std::collections::HashMap;
127    use std::hash::Hash;
128    use std::sync::Arc;
129    #[derive(Clone, Debug)]
130    /// An heterogenous map is a map from `Key` to `Value`. It provide
131    /// the methods `insert` and `get` for any type `T` that
132    /// implements `SupportedType<Value>`.
133    pub struct HeterogeneousMap<Key, Value>(HashMap<Key, Value>);
134
135    impl<Id, Value> Default for HeterogeneousMap<Id, Value> {
136        fn default() -> Self {
137            Self(HashMap::default())
138        }
139    }
140
141    impl<Key: Hash + Eq + PartialEq, Value> HeterogeneousMap<Key, Value> {
142        pub(super) fn insert<T>(&mut self, key: Key, value: Arc<T>)
143        where
144            T: SupportedType<Value>,
145        {
146            self.insert_raw_value(key, T::to_types(value));
147        }
148        pub(super) fn insert_raw_value(&mut self, key: Key, value: Value) {
149            self.0.insert(key, value);
150        }
151        pub(super) fn from_iter(it: impl Iterator<Item = (Key, Value)>) -> Self {
152            Self(HashMap::from_iter(it))
153        }
154        pub(super) fn into_iter(self) -> impl Iterator<Item = (Key, Value)> {
155            self.0.into_iter()
156        }
157        pub(super) fn get<T>(&self, key: &Key) -> Option<Option<Arc<T>>>
158        where
159            T: SupportedType<Value>,
160        {
161            self.0.get(key).map(T::from_types)
162        }
163    }
164
165    /// A type that can be mapped to `Value` and optionally
166    /// reconstructed back.
167    pub trait SupportedType<Value>: std::fmt::Debug {
168        fn to_types(value: Arc<Self>) -> Value;
169        fn from_types(t: &Value) -> Option<Arc<Self>>;
170    }
171}
172use heterogeneous_map::*;
173
174impl Session {
175    fn fresh_id(&mut self) -> Id {
176        let id = self.next_id.id;
177        self.next_id.id += 1;
178        Id { id }
179    }
180}
181
182impl<T: Sync + Send + 'static + SupportedType<Value>> Node<T> {
183    pub fn new(value: T, session: &mut Session) -> Self {
184        let id = session.fresh_id();
185        let value = Arc::new(value);
186        session.table.0.insert(id.clone(), value.clone());
187        Self { id, value }
188    }
189
190    pub fn inner(&self) -> &Arc<T> {
191        &self.value
192    }
193}
194
195/// Wrapper for a type `T` that creates a bundle containing both a ID
196/// table and a value `T`. That value may contains `Node` values
197/// inside it. Serializing `WithTable<T>` will serialize IDs only,
198/// skipping values. Deserialization of a `WithTable<T>` will
199/// automatically use the table and IDs to reconstruct skipped values.
200#[derive(Debug)]
201pub struct WithTable<T> {
202    table: Table,
203    value: T,
204}
205
206/// The state used for deserialization: a table.
207static DESERIALIZATION_STATE: LazyLock<Mutex<Table>> =
208    LazyLock::new(|| Mutex::new(Table::default()));
209static DESERIALIZATION_STATE_LOCK: LazyLock<Mutex<()>> = LazyLock::new(|| Mutex::new(()));
210
211/// The mode of serialization: should `Node<T>` ship values of type `T` or not?
212static SERIALIZATION_MODE_USE_IDS: std::sync::atomic::AtomicBool =
213    std::sync::atomic::AtomicBool::new(false);
214
215fn serialize_use_id() -> bool {
216    SERIALIZATION_MODE_USE_IDS.load(Ordering::Relaxed)
217}
218
219impl<T> WithTable<T> {
220    /// Runs `f` with a `WithTable<T>` created out of `map` and
221    /// `value`. Any serialization of values of type `Node<_>` will
222    /// skip the field `value`.
223    pub fn run<R>(map: Table, value: T, f: impl FnOnce(&Self) -> R) -> R {
224        if serialize_use_id() {
225            panic!("CACHE_MAP_LOCK: only one WithTable serialization can occur at a time (nesting is forbidden)")
226        }
227        SERIALIZATION_MODE_USE_IDS.store(true, Ordering::Relaxed);
228        let result = f(&Self { table: map, value });
229        SERIALIZATION_MODE_USE_IDS.store(false, Ordering::Relaxed);
230        result
231    }
232    pub fn destruct(self) -> (T, Table) {
233        let Self { value, table: map } = self;
234        (value, map)
235    }
236}
237
238impl<T: Serialize> Serialize for WithTable<T> {
239    fn serialize<S: serde::Serializer>(&self, serializer: S) -> Result<S::Ok, S::Error> {
240        let mut ts = serializer.serialize_tuple_struct("WithTable", 2)?;
241        use serde::ser::SerializeTupleStruct;
242        ts.serialize_field(&self.table)?;
243        ts.serialize_field(&self.value)?;
244        ts.end()
245    }
246}
247
248/// The deserializer of `WithTable<T>` is special. We first decode the
249/// table in order: each `(Id, Value)` pair of the table populates the
250/// global table state found in `DESERIALIZATION_STATE`. Only then we
251/// can decode the value itself, knowing `DESERIALIZATION_STATE` is
252/// complete.
253impl<'de, T: Deserialize<'de>> serde::Deserialize<'de> for WithTable<T> {
254    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
255    where
256        D: serde::Deserializer<'de>,
257    {
258        let _lock: MutexGuard<_> = DESERIALIZATION_STATE_LOCK.try_lock().expect("CACHE_MAP_LOCK: only one WithTable deserialization can occur at a time (nesting is forbidden)");
259        use serde_repr::WithTableRepr;
260        let previous = std::mem::take(&mut *DESERIALIZATION_STATE.lock().unwrap());
261        let with_table_repr = WithTableRepr::deserialize(deserializer);
262        *DESERIALIZATION_STATE.lock().unwrap() = previous;
263        let WithTableRepr(table, value) = with_table_repr?;
264        Ok(Self { table, value })
265    }
266}
267
268/// Defines representations for various types when serializing or/and
269/// deserializing via serde
270mod serde_repr {
271    use super::*;
272
273    #[derive(Serialize, Deserialize, JsonSchema, Debug)]
274    pub(super) struct NodeRepr<T> {
275        id: Id,
276        value: Option<Arc<T>>,
277    }
278
279    #[derive(Serialize)]
280    pub(super) struct Pair(Id, Value);
281    pub(super) type SortedIdValuePairs = Vec<Pair>;
282
283    #[derive(Serialize, Deserialize)]
284    pub(super) struct WithTableRepr<T>(pub(super) Table, pub(super) T);
285
286    impl<T: SupportedType<Value>> Into<NodeRepr<T>> for Node<T> {
287        fn into(self) -> NodeRepr<T> {
288            let value = if serialize_use_id() {
289                None
290            } else {
291                Some(self.value.clone())
292            };
293            let id = self.id;
294            NodeRepr { value, id }
295        }
296    }
297
298    impl<T: 'static + SupportedType<Value>> TryFrom<NodeRepr<T>> for Node<T> {
299        type Error = serde::de::value::Error;
300
301        fn try_from(cached: NodeRepr<T>) -> Result<Self, Self::Error> {
302            use serde::de::Error;
303            let table = DESERIALIZATION_STATE.lock().unwrap();
304            let id = cached.id;
305            let kind = if let Some(kind) = cached.value {
306                kind
307            } else {
308                table
309                    .0
310                    .get(&id)
311                    .ok_or_else(|| {
312                        Self::Error::custom(&format!(
313                            "Stateful deserialization failed for id {:?}: not found in cache",
314                            id
315                        ))
316                    })?
317                    .ok_or_else(|| {
318                        Self::Error::custom(&format!(
319                            "Stateful deserialization failed for id {:?}: wrong type",
320                            id
321                        ))
322                    })?
323            };
324            Ok(Self { value: kind, id })
325        }
326    }
327
328    impl<'de> serde::Deserialize<'de> for Pair {
329        fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
330        where
331            D: serde::Deserializer<'de>,
332        {
333            let (id, v) = <(Id, Value)>::deserialize(deserializer)?;
334            DESERIALIZATION_STATE
335                .lock()
336                .unwrap()
337                .0
338                .insert_raw_value(id.clone(), v.clone());
339            Ok(Pair(id, v))
340        }
341    }
342
343    impl Into<SortedIdValuePairs> for Table {
344        fn into(self) -> SortedIdValuePairs {
345            let mut vec: Vec<_> = self.0.into_iter().map(|(x, y)| Pair(x, y)).collect();
346            vec.sort_by_key(|o| o.0.clone());
347            vec
348        }
349    }
350
351    impl From<SortedIdValuePairs> for Table {
352        fn from(t: SortedIdValuePairs) -> Self {
353            Self(HeterogeneousMap::from_iter(
354                t.into_iter().map(|Pair(x, y)| (x, y)),
355            ))
356        }
357    }
358}