1use alloc::{boxed::Box, rc::Rc, sync::Arc, vec::Vec};
2use core::{
3 cell::{Cell, RefCell},
4 mem::MaybeUninit,
5};
6
7use blink_alloc::Blink;
8use midenc_session::Session;
9use traits::BranchOpInterface;
10
11use super::{traits::BuildableTypeConstraint, *};
12use crate::{
13 constants::{ConstantData, ConstantId, ConstantPool},
14 FxHashMap,
15};
16
17pub struct Context {
30 session: Rc<Session>,
31 allocator: Rc<Blink>,
32 registered_dialects: RefCell<FxHashMap<interner::Symbol, Rc<dyn Dialect>>>,
33 dialect_hooks: RefCell<FxHashMap<interner::Symbol, Vec<DialectRegistrationHook>>>,
34 constants: RefCell<ConstantPool>,
35 type_cache: RefCell<FxHashMap<core::any::TypeId, Arc<Type>>>,
36 next_block_id: Cell<u32>,
37 next_value_id: Cell<u32>,
38}
39
40impl Default for Context {
41 fn default() -> Self {
42 use alloc::sync::Arc;
43
44 use midenc_session::diagnostics::DefaultSourceManager;
45
46 let target_dir = std::env::current_dir().unwrap();
47 let options = midenc_session::Options::default();
48 let source_manager = Arc::new(DefaultSourceManager::default());
49 let session =
50 Rc::new(Session::new([], None, None, target_dir, options, None, source_manager));
51 Self::new(session)
52 }
53}
54
55impl Context {
56 pub fn new(session: Rc<Session>) -> Self {
58 let allocator = Rc::new(Blink::new());
59 Self {
60 session,
61 allocator,
62 registered_dialects: Default::default(),
63 dialect_hooks: Default::default(),
64 constants: Default::default(),
65 type_cache: Default::default(),
66 next_block_id: Cell::new(0),
67 next_value_id: Cell::new(0),
68 }
69 }
70
71 #[inline]
72 pub fn session(&self) -> &Session {
73 &self.session
74 }
75
76 #[inline]
77 pub fn session_rc(&self) -> Rc<Session> {
78 self.session.clone()
79 }
80
81 #[inline]
82 pub fn diagnostics(&self) -> &::midenc_session::DiagnosticsHandler {
83 &self.session.diagnostics
84 }
85
86 pub fn registered_dialects(
87 &self,
88 ) -> core::cell::Ref<'_, FxHashMap<interner::Symbol, Rc<dyn Dialect>>> {
89 self.registered_dialects.borrow()
90 }
91
92 pub fn get_registered_dialect(&self, dialect: impl Into<interner::Symbol>) -> Rc<dyn Dialect> {
93 let dialect = dialect.into();
94 self.registered_dialects.borrow()[&dialect].clone()
95 }
96
97 pub fn get_or_register_dialect<T>(&self) -> Rc<dyn Dialect>
98 where
99 T: DialectRegistration,
100 {
101 let dialect_name = <T as DialectRegistration>::NAMESPACE.into();
102 if let Some(dialect) = self.registered_dialects.borrow().get(&dialect_name).cloned() {
103 return dialect;
104 }
105
106 let mut info = DialectInfo::new::<T>();
107
108 let dialect_hooks = self.dialect_hooks.borrow();
109 if let Some(hooks) = dialect_hooks.get(&dialect_name) {
110 for hook in hooks {
111 hook(&mut info, self);
112 }
113 }
114
115 <T as DialectRegistration>::register_operations(&mut info);
116
117 let dialect = Rc::new(T::init(info)) as Rc<dyn Dialect>;
118 self.registered_dialects.borrow_mut().insert(dialect_name, Rc::clone(&dialect));
119 dialect
120 }
121
122 pub fn register_dialect_hook<T, F>(&self, hook: F)
123 where
124 T: DialectRegistration,
125 F: Fn(&mut DialectInfo, &Context) + 'static,
126 {
127 let dialect_name = <T as DialectRegistration>::NAMESPACE.into();
128 let mut dialect_hooks = self.dialect_hooks.borrow_mut();
129 let registered_hooks =
130 dialect_hooks.entry(dialect_name).or_insert_with(|| Vec::with_capacity(1));
131 registered_hooks.push(Box::new(hook));
132 }
133
134 pub fn create_constant(&self, data: impl Into<ConstantData>) -> ConstantId {
135 let mut constants = self.constants.borrow_mut();
136 constants.insert(data.into())
137 }
138
139 pub fn get_constant(&self, id: ConstantId) -> Arc<ConstantData> {
140 self.constants.borrow().get(id)
141 }
142
143 pub fn get_constant_size_in_bytes(&self, id: ConstantId) -> usize {
144 self.constants.borrow().get_by_ref(id).len()
145 }
146
147 pub fn get_cached_type<T: BuildableTypeConstraint>(&self) -> Option<Arc<Type>> {
148 self.type_cache.borrow().get(&core::any::TypeId::of::<T>()).cloned()
149 }
150
151 pub fn get_or_insert_type<T: BuildableTypeConstraint>(&self) -> Arc<Type> {
152 match self.get_cached_type::<T>() {
153 Some(ty) => ty,
154 None => {
155 let ty = Arc::new(<T as BuildableTypeConstraint>::build(self));
156 let mut types = self.type_cache.borrow_mut();
157 types.insert(core::any::TypeId::of::<T>(), Arc::clone(&ty));
158 ty
159 }
160 }
161 }
162
163 pub fn builder(self: Rc<Self>) -> OpBuilder {
165 OpBuilder::new(Rc::clone(&self))
166 }
167
168 pub fn create_block(&self) -> BlockRef {
170 let block = Block::new(self.alloc_block_id());
171 self.alloc_tracked(block)
172 }
173
174 pub fn create_block_with_params<I>(&self, tys: I) -> BlockRef
176 where
177 I: IntoIterator<Item = Type>,
178 {
179 let block = Block::new(self.alloc_block_id());
180 let mut block = self.alloc_tracked(block);
181 let owner = block;
182 let args = tys.into_iter().enumerate().map(|(index, ty)| {
183 let id = self.alloc_value_id();
184 let arg = BlockArgument::new(
185 SourceSpan::default(),
186 id,
187 ty,
188 owner,
189 index.try_into().expect("too many block arguments"),
190 );
191 self.alloc(arg)
192 });
193 block.borrow_mut().arguments_mut().extend(args);
194 block
195 }
196
197 pub fn append_block_argument(
201 &self,
202 mut block: BlockRef,
203 ty: Type,
204 span: SourceSpan,
205 ) -> ValueRef {
206 let next_index = block.borrow().num_arguments();
207 let id = self.alloc_value_id();
208 let arg = BlockArgument::new(
209 span,
210 id,
211 ty,
212 block,
213 next_index.try_into().expect("too many block arguments"),
214 );
215 let arg = self.alloc(arg);
216 block.borrow_mut().arguments_mut().push(arg);
217 arg.upcast()
218 }
219
220 pub fn make_operand(&self, mut value: ValueRef, owner: OperationRef, index: u8) -> OpOperand {
226 let op_operand = self.alloc_tracked(OpOperandImpl::new(value, owner, index));
227 let mut value = value.borrow_mut();
228 value.insert_use(op_operand);
229 op_operand
230 }
231
232 pub fn make_block_operand(
238 &self,
239 mut block: BlockRef,
240 owner: OperationRef,
241 index: u8,
242 ) -> BlockOperandRef {
243 let block_operand = self.alloc_tracked(BlockOperand::new(owner, index));
244 let mut block = block.borrow_mut();
245 block.insert_use(block_operand);
246 block_operand
247 }
248
249 pub fn make_result(
254 &self,
255 span: SourceSpan,
256 ty: Type,
257 owner: OperationRef,
258 index: u8,
259 ) -> OpResultRef {
260 let id = self.alloc_value_id();
261 self.alloc(OpResult::new(span, id, ty, owner, index))
262 }
263
264 pub fn append_branch_destination_argument(
269 &self,
270 mut branch_inst: OperationRef,
271 dest: BlockRef,
272 value: ValueRef,
273 ) {
274 let mut borrow = branch_inst.borrow_mut();
275 let op = borrow.as_mut().as_operation_mut();
276 assert!(
277 op.as_trait::<dyn BranchOpInterface>().is_some(),
278 "expected branch instruction, got {branch_inst:?}"
279 );
280 let dest_operand_groups: Vec<usize> = op
281 .successors()
282 .iter()
283 .filter(|succ| succ.block.borrow().successor() == dest)
284 .map(|succ| succ.operand_group as usize)
285 .collect();
286 for dest_group in dest_operand_groups {
287 let current_dest_operands_len = op.operands.group(dest_group).len();
288 let operand = self.make_operand(
289 value,
290 op.as_operation_ref(),
291 (current_dest_operands_len + 1) as u8,
292 );
293 op.operands_mut().extend_group(dest_group, [operand]);
294 }
295 }
296
297 pub fn alloc_uninit<T: 'static>(&self) -> UnsafeEntityRef<MaybeUninit<T>> {
303 UnsafeEntityRef::new_uninit(&self.allocator)
304 }
305
306 pub fn alloc_uninit_tracked<T: 'static>(&self) -> UnsafeIntrusiveEntityRef<MaybeUninit<T>> {
313 UnsafeIntrusiveEntityRef::<T>::new_uninit_with_metadata(Default::default(), &self.allocator)
314 }
315
316 pub fn alloc<T: 'static>(&self, value: T) -> UnsafeEntityRef<T> {
322 UnsafeEntityRef::new(value, &self.allocator)
323 }
324
325 pub fn alloc_tracked<T: 'static>(&self, value: T) -> UnsafeIntrusiveEntityRef<T> {
332 UnsafeIntrusiveEntityRef::new_with_metadata(value, Default::default(), &self.allocator)
333 }
334
335 fn alloc_block_id(&self) -> BlockId {
336 let id = self.next_block_id.get();
337 self.next_block_id.set(id + 1);
338 BlockId::from_u32(id)
339 }
340
341 fn alloc_value_id(&self) -> ValueId {
342 let id = self.next_value_id.get();
343 self.next_value_id.set(id + 1);
344 ValueId::from_u32(id)
345 }
346}