lust/jit/specialization/
mod.rs

1use crate::ast::{Type, TypeKind};
2use crate::number::{LustFloat, LustInt};
3use alloc::boxed::Box;
4use alloc::vec::Vec;
5use core::mem::{align_of, size_of};
6use hashbrown::HashMap;
7
8/// Describes how a type is specialized in JIT traces
9#[derive(Debug, Clone, PartialEq)]
10pub enum SpecializedLayout {
11    /// Raw scalar value on stack
12    Scalar { size: usize, align: usize },
13
14    /// Specialized vector (ptr, len, cap on stack)
15    Vec {
16        element_layout: Box<SpecializedLayout>,
17        element_size: usize,
18    },
19
20    /// Specialized hash map
21    HashMap {
22        key_layout: Box<SpecializedLayout>,
23        value_layout: Box<SpecializedLayout>,
24    },
25
26    /// Unboxed struct with specialized fields
27    Struct {
28        field_layouts: Vec<SpecializedLayout>,
29        field_offsets: Vec<usize>,
30        total_size: usize,
31    },
32
33    /// Unboxed tuple
34    Tuple {
35        element_layouts: Vec<SpecializedLayout>,
36        element_offsets: Vec<usize>,
37        total_size: usize,
38    },
39}
40
41impl SpecializedLayout {
42    /// Get the stack space needed for this specialized value
43    pub fn stack_size(&self) -> usize {
44        match self {
45            SpecializedLayout::Scalar { size, .. } => *size,
46            SpecializedLayout::Vec { .. } => {
47                // Vec<T> is represented as (ptr, len, cap)
48                size_of::<usize>() * 3
49            }
50            SpecializedLayout::HashMap { .. } => {
51                // HashMap metadata - we'll need to figure out exact size
52                // For now, use a conservative estimate
53                32
54            }
55            SpecializedLayout::Struct { total_size, .. } => *total_size,
56            SpecializedLayout::Tuple { total_size, .. } => *total_size,
57        }
58    }
59
60    /// Get the alignment requirement for this specialized value
61    pub fn alignment(&self) -> usize {
62        match self {
63            SpecializedLayout::Scalar { align, .. } => *align,
64            SpecializedLayout::Vec { .. } => align_of::<usize>(),
65            SpecializedLayout::HashMap { .. } => align_of::<usize>(),
66            SpecializedLayout::Struct { field_layouts, .. } => {
67                // Use maximum alignment of all fields
68                field_layouts
69                    .iter()
70                    .map(|l| l.alignment())
71                    .max()
72                    .unwrap_or(1)
73            }
74            SpecializedLayout::Tuple {
75                element_layouts, ..
76            } => {
77                // Use maximum alignment of all elements
78                element_layouts
79                    .iter()
80                    .map(|l| l.alignment())
81                    .max()
82                    .unwrap_or(1)
83            }
84        }
85    }
86}
87
88/// Maps TypeKind to its specialized representation
89pub struct SpecializationRegistry {
90    cache: HashMap<TypeKind, Option<SpecializedLayout>>,
91}
92
93impl SpecializationRegistry {
94    pub fn new() -> Self {
95        Self {
96            cache: HashMap::new(),
97        }
98    }
99
100    /// Try to get a specialized layout for a type
101    pub fn get_specialization(&mut self, type_kind: &TypeKind) -> Option<SpecializedLayout> {
102        if let Some(cached) = self.cache.get(type_kind) {
103            return cached.clone();
104        }
105
106        let layout = self.compute_specialization(type_kind);
107        self.cache.insert(type_kind.clone(), layout.clone());
108        layout
109    }
110
111    fn compute_specialization(&self, type_kind: &TypeKind) -> Option<SpecializedLayout> {
112        match type_kind {
113            // Primitives can be unboxed to raw values
114            TypeKind::Int => Some(SpecializedLayout::Scalar {
115                size: size_of::<LustInt>(),
116                align: align_of::<LustInt>(),
117            }),
118
119            TypeKind::Float => Some(SpecializedLayout::Scalar {
120                size: size_of::<LustFloat>(),
121                align: align_of::<LustFloat>(),
122            }),
123
124            TypeKind::Bool => Some(SpecializedLayout::Scalar { size: 1, align: 1 }),
125
126            // Array<T> where T is specializable
127            TypeKind::Array(element_type) => {
128                self.get_specialization_for_type(element_type)
129                    .map(|elem_layout| SpecializedLayout::Vec {
130                        element_size: elem_layout.stack_size(),
131                        element_layout: Box::new(elem_layout),
132                    })
133            }
134
135            // Map<K, V> where K and V are specializable
136            TypeKind::Map(key_type, value_type) => {
137                let key_layout = self.get_specialization_for_type(key_type)?;
138                let value_layout = self.get_specialization_for_type(value_type)?;
139                Some(SpecializedLayout::HashMap {
140                    key_layout: Box::new(key_layout),
141                    value_layout: Box::new(value_layout),
142                })
143            }
144
145            // Tuple with specializable elements
146            TypeKind::Tuple(elements) => {
147                let mut element_layouts = Vec::new();
148                let mut element_offsets = Vec::new();
149                let mut offset = 0;
150
151                for elem_type in elements {
152                    let layout = self.get_specialization_for_type(elem_type)?;
153                    let align = layout.alignment();
154                    // Align offset to element's alignment requirement
155                    offset = (offset + align - 1) & !(align - 1);
156                    element_offsets.push(offset);
157                    offset += layout.stack_size();
158                    element_layouts.push(layout);
159                }
160
161                Some(SpecializedLayout::Tuple {
162                    element_layouts,
163                    element_offsets,
164                    total_size: offset,
165                })
166            }
167
168            // GenericInstance (e.g., Array<int>, Map<string, int>)
169            TypeKind::GenericInstance { name, type_args } => {
170                // Reconstruct the proper TypeKind and recurse
171                match name.as_str() {
172                    "Array" if type_args.len() == 1 => self
173                        .compute_specialization(&TypeKind::Array(Box::new(type_args[0].clone()))),
174                    "Map" if type_args.len() == 2 => self.compute_specialization(&TypeKind::Map(
175                        Box::new(type_args[0].clone()),
176                        Box::new(type_args[1].clone()),
177                    )),
178                    _ => None,
179                }
180            }
181
182            // Not specializable
183            _ => None,
184        }
185    }
186
187    /// Helper to get specialization for a Type (not just TypeKind)
188    fn get_specialization_for_type(&self, ty: &Type) -> Option<SpecializedLayout> {
189        self.compute_specialization(&ty.kind)
190    }
191}
192
193impl Default for SpecializationRegistry {
194    fn default() -> Self {
195        Self::new()
196    }
197}
198
199#[cfg(test)]
200mod tests {
201    use super::*;
202    use crate::ast::Span;
203
204    #[test]
205    fn test_primitive_specialization() {
206        let mut registry = SpecializationRegistry::new();
207
208        let int_layout = registry.get_specialization(&TypeKind::Int);
209        assert!(int_layout.is_some());
210        assert_eq!(int_layout.unwrap().stack_size(), size_of::<LustInt>());
211
212        let float_layout = registry.get_specialization(&TypeKind::Float);
213        assert!(float_layout.is_some());
214        assert_eq!(float_layout.unwrap().stack_size(), size_of::<LustFloat>());
215
216        let bool_layout = registry.get_specialization(&TypeKind::Bool);
217        assert!(bool_layout.is_some());
218        assert_eq!(bool_layout.unwrap().stack_size(), 1);
219    }
220
221    #[test]
222    fn test_array_int_specialization() {
223        let mut registry = SpecializationRegistry::new();
224
225        let array_int_type = TypeKind::Array(Box::new(Type::new(TypeKind::Int, Span::dummy())));
226        let layout = registry.get_specialization(&array_int_type);
227
228        assert!(layout.is_some());
229        if let Some(SpecializedLayout::Vec {
230            element_size,
231            element_layout,
232        }) = layout
233        {
234            assert_eq!(element_size, size_of::<LustInt>());
235            assert!(matches!(*element_layout, SpecializedLayout::Scalar { .. }));
236        } else {
237            panic!("Expected Vec layout");
238        }
239    }
240
241    #[test]
242    fn test_cache_works() {
243        let mut registry = SpecializationRegistry::new();
244
245        // First call computes
246        let layout1 = registry.get_specialization(&TypeKind::Int);
247        // Second call uses cache
248        let layout2 = registry.get_specialization(&TypeKind::Int);
249
250        assert_eq!(layout1, layout2);
251    }
252}