1use crate::common_traits::Verify;
39use crate::context::{Arena, Context, Ptr, private::ArenaObj};
40use crate::dialect::DialectName;
41use crate::identifier::Identifier;
42use crate::irfmt::parsers::spaced;
43use crate::location::Located;
44use crate::parsable::{Parsable, ParseResult, ParserFn, StateStream};
45use crate::printable::{self, Printable};
46use crate::result::Result;
47use crate::storage_uniquer::TypeValueHash;
48use crate::{arg_err_noloc, impl_printable_for_display, input_err};
49
50use combine::{Parser, parser};
51use downcast_rs::{Downcast, impl_downcast};
52use linkme::distributed_slice;
53use rustc_hash::FxHashMap;
54use std::cell::Ref;
55use std::fmt::Debug;
56use std::fmt::Display;
57use std::hash::{Hash, Hasher};
58use std::marker::PhantomData;
59use std::ops::Deref;
60use std::sync::LazyLock;
61use thiserror::Error;
62
63pub trait Type: Printable + Verify + Downcast + Sync + Send + Debug {
97 fn hash_type(&self) -> TypeValueHash;
100 fn eq_type(&self, other: &dyn Type) -> bool;
102
103 fn get_self_ptr(&self, ctx: &Context) -> Ptr<TypeObj> {
108 let is = |other: &TypeObj| self.eq_type(&**other);
109 let idx = ctx
110 .type_store
111 .get(self.hash_type(), &is)
112 .expect("Unregistered type object in existence");
113 Ptr {
114 idx,
115 _dummy: PhantomData::<TypeObj>,
116 }
117 }
118
119 fn register_instance(t: Self, ctx: &mut Context) -> TypePtr<Self>
123 where
124 Self: Sized,
125 {
126 let hash = t.hash_type();
127 let idx = ctx
128 .type_store
129 .get_or_create_unique(Box::new(t), hash, &TypeObj::eq);
130 let ptr = Ptr {
131 idx,
132 _dummy: PhantomData::<TypeObj>,
133 };
134 TypePtr(ptr, PhantomData::<Self>)
135 }
136
137 fn get_instance(t: Self, ctx: &Context) -> Option<TypePtr<Self>>
140 where
141 Self: Sized,
142 {
143 let is = |other: &TypeObj| t.eq_type(&**other);
144 ctx.type_store.get(t.hash_type(), &is).map(|idx| {
145 let ptr = Ptr {
146 idx,
147 _dummy: PhantomData::<TypeObj>,
148 };
149 TypePtr(ptr, PhantomData::<Self>)
150 })
151 }
152
153 fn get_type_id(&self) -> TypeId;
157
158 fn get_type_id_static() -> TypeId
160 where
161 Self: Sized;
162
163 fn verify_interfaces(&self, ctx: &Context) -> Result<()>;
165
166 fn register_type_in_dialect(ctx: &mut Context, parser: ParserFn<(), TypePtr<Self>>)
168 where
169 Self: Sized,
170 {
171 fn constrain<F>(f: F) -> F
174 where
175 F: for<'a> Fn(
176 &'a (),
177 ) -> Box<
178 dyn Parser<StateStream<'a>, Output = Ptr<TypeObj>, PartialState = ()> + 'a,
179 >,
180 {
181 f
182 }
183 let ptr_parser = constrain(move |_| {
184 combine::parser(move |parsable_state: &mut StateStream<'_>| {
185 parser(&(), ())
186 .parse_stream(parsable_state)
187 .map(|typtr| typtr.to_ptr())
188 .into_result()
189 })
190 .boxed()
191 });
192 let typeid = Self::get_type_id_static();
193 let dialect = ctx
194 .dialects
195 .get_mut(&typeid.dialect)
196 .unwrap_or_else(|| panic!("Unregistered dialect {}", &typeid.dialect));
197 dialect.add_type(typeid, Box::new(ptr_parser));
198 }
199}
200impl_downcast!(Type);
201
202pub(crate) type TypeParserFn = Box<
204 dyn for<'a> Fn(
205 &'a (),
206 ) -> Box<
207 dyn Parser<StateStream<'a>, Output = Ptr<TypeObj>, PartialState = ()> + 'a,
208 >,
209>;
210
211pub trait Typed {
213 fn get_type(&self, ctx: &Context) -> Ptr<TypeObj>;
215}
216
217impl Typed for Ptr<TypeObj> {
218 fn get_type(&self, _ctx: &Context) -> Ptr<TypeObj> {
219 *self
220 }
221}
222
223impl Typed for dyn Type {
224 fn get_type(&self, ctx: &Context) -> Ptr<TypeObj> {
225 self.get_self_ptr(ctx)
226 }
227}
228
229impl<T: Typed + ?Sized> Typed for &T {
230 fn get_type(&self, ctx: &Context) -> Ptr<TypeObj> {
231 (*self).get_type(ctx)
232 }
233}
234
235impl<T: Typed + ?Sized> Typed for &mut T {
236 fn get_type(&self, ctx: &Context) -> Ptr<TypeObj> {
237 (**self).get_type(ctx)
238 }
239}
240
241impl<T: Typed + ?Sized> Typed for Box<T> {
242 fn get_type(&self, ctx: &Context) -> Ptr<TypeObj> {
243 (**self).get_type(ctx)
244 }
245}
246
247#[derive(Clone, Hash, PartialEq, Eq)]
248pub struct TypeName(Identifier);
250
251impl TypeName {
252 pub fn new(name: &str) -> TypeName {
254 TypeName(name.try_into().expect("Invalid Identifier for TypeName"))
255 }
256}
257
258impl Deref for TypeName {
259 type Target = Identifier;
260
261 fn deref(&self) -> &Self::Target {
262 &self.0
263 }
264}
265
266impl_printable_for_display!(TypeName);
267
268impl Display for TypeName {
269 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
270 write!(f, "{}", self.0)
271 }
272}
273
274impl Parsable for TypeName {
275 type Arg = ();
276 type Parsed = TypeName;
277
278 fn parse<'a>(
279 state_stream: &mut crate::parsable::StateStream<'a>,
280 _arg: Self::Arg,
281 ) -> ParseResult<'a, Self::Parsed>
282 where
283 Self: Sized,
284 {
285 Identifier::parser(())
286 .map(|name| TypeName::new(&name))
287 .parse_stream(state_stream)
288 .into()
289 }
290}
291
292#[derive(Clone, Hash, PartialEq, Eq)]
294pub struct TypeId {
295 pub dialect: DialectName,
296 pub name: TypeName,
297}
298
299impl Parsable for TypeId {
300 type Arg = ();
301 type Parsed = TypeId;
302
303 fn parse<'a>(
305 state_stream: &mut StateStream<'a>,
306 _arg: Self::Arg,
307 ) -> ParseResult<'a, Self::Parsed>
308 where
309 Self: Sized,
310 {
311 let mut parser = DialectName::parser(())
312 .skip(parser::char::char('.'))
313 .and(TypeName::parser(()))
314 .map(|(dialect, name)| TypeId { dialect, name });
315 parser.parse_stream(state_stream).into()
316 }
317}
318
319impl_printable_for_display!(TypeId);
320
321impl Display for TypeId {
322 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
323 write!(f, "{}.{}", self.dialect, self.name)
324 }
325}
326
327pub type TypeObj = Box<dyn Type>;
330
331impl PartialEq for TypeObj {
332 fn eq(&self, other: &Self) -> bool {
333 (**self).eq_type(&**other)
334 }
335}
336
337impl Eq for TypeObj {}
338
339impl Hash for TypeObj {
340 fn hash<H: Hasher>(&self, state: &mut H) {
341 state.write(&u64::from(self.hash_type()).to_ne_bytes())
342 }
343}
344
345impl ArenaObj for TypeObj {
346 fn get_arena(ctx: &Context) -> &Arena<Self> {
347 &ctx.type_store.unique_store
348 }
349
350 fn get_arena_mut(ctx: &mut Context) -> &mut Arena<Self> {
351 &mut ctx.type_store.unique_store
352 }
353
354 fn get_self_ptr(&self, ctx: &Context) -> Ptr<Self> {
355 self.as_ref().get_self_ptr(ctx)
356 }
357
358 fn dealloc_sub_objects(_ptr: Ptr<Self>, _ctx: &mut Context) {
359 panic!("Cannot dealloc arena sub-objects of types")
360 }
361}
362
363impl Printable for TypeObj {
364 fn fmt(
365 &self,
366 ctx: &Context,
367 state: &printable::State,
368 f: &mut std::fmt::Formatter<'_>,
369 ) -> std::fmt::Result {
370 write!(f, "{} ", self.get_type_id())?;
371 Printable::fmt(self.deref(), ctx, state, f)
372 }
373}
374
375impl Parsable for Ptr<TypeObj> {
376 type Arg = ();
377 type Parsed = Self;
378
379 fn parse<'a>(
380 state_stream: &mut StateStream<'a>,
381 _arg: Self::Arg,
382 ) -> ParseResult<'a, Self::Parsed> {
383 let loc = state_stream.loc();
384 let type_id_parser = spaced(TypeId::parser(()));
385
386 let mut type_parser = type_id_parser.then(move |type_id: TypeId| {
387 let loc = loc.clone();
389 combine::parser(move |parsable_state: &mut StateStream<'a>| {
390 let state = &parsable_state.state;
391 let dialect = state
392 .ctx
393 .dialects
394 .get(&type_id.dialect)
395 .expect("Dialect name parsed but dialect isn't registered");
396 let Some(type_parser) = dialect.types.get(&type_id) else {
397 input_err!(loc.clone(), "Unregistered type {}", type_id.disp(state.ctx))?
398 };
399 type_parser(&()).parse_stream(parsable_state).into()
400 })
401 });
402
403 type_parser.parse_stream(state_stream).into_result()
404 }
405}
406
407pub fn verify_type(ty: &dyn Type, ctx: &Context) -> Result<()> {
411 ty.verify_interfaces(ctx)?;
413
414 Verify::verify(ty, ctx)
416}
417
418impl Verify for TypeObj {
419 fn verify(&self, ctx: &Context) -> Result<()> {
420 verify_type(self.as_ref(), ctx)
421 }
422}
423
424#[derive(Debug)]
426pub struct TypePtr<T: Type>(Ptr<TypeObj>, PhantomData<T>);
427
428#[derive(Error, Debug)]
429#[error("TypePtr mismatch: Constructing {expected} but provided {provided}")]
430pub struct TypePtrErr {
431 pub expected: String,
432 pub provided: String,
433}
434
435impl<T: Type> TypePtr<T> {
436 pub fn deref<'a>(&self, ctx: &'a Context) -> Ref<'a, T> {
440 Ref::map(self.0.deref(ctx), |t| {
441 t.downcast_ref::<T>()
442 .expect("Type mistmatch, inconsistent TypePtr")
443 })
444 }
445
446 pub fn from_ptr(ptr: Ptr<TypeObj>, ctx: &Context) -> Result<TypePtr<T>> {
448 if ptr.deref(ctx).is::<T>() {
449 Ok(TypePtr(ptr, PhantomData::<T>))
450 } else {
451 arg_err_noloc!(TypePtrErr {
452 expected: T::get_type_id_static().disp(ctx).to_string(),
453 provided: ptr.disp(ctx).to_string()
454 })
455 }
456 }
457
458 pub fn to_ptr(&self) -> Ptr<TypeObj> {
460 self.0
461 }
462}
463
464impl<T: Type> From<TypePtr<T>> for Ptr<TypeObj> {
465 fn from(value: TypePtr<T>) -> Self {
466 value.to_ptr()
467 }
468}
469
470impl<T: Type> Clone for TypePtr<T> {
471 fn clone(&self) -> TypePtr<T> {
472 *self
473 }
474}
475
476impl<T: Type> Copy for TypePtr<T> {}
477
478impl<T: Type> PartialEq for TypePtr<T> {
479 fn eq(&self, other: &Self) -> bool {
480 self.0 == other.0
481 }
482}
483
484impl<T: Type> Eq for TypePtr<T> {}
485
486impl<T: Type> Hash for TypePtr<T> {
487 fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
488 self.0.hash(state);
489 }
490}
491
492impl<T: Type> Printable for TypePtr<T> {
493 fn fmt(
494 &self,
495 ctx: &Context,
496 state: &printable::State,
497 f: &mut core::fmt::Formatter<'_>,
498 ) -> core::fmt::Result {
499 Printable::fmt(&self.0, ctx, state, f)
500 }
501}
502
503impl<T: Type + Parsable<Arg = (), Parsed = TypePtr<T>>> Parsable for TypePtr<T> {
504 type Arg = ();
505 type Parsed = Self;
506
507 fn parse<'a>(
508 state_stream: &mut StateStream<'a>,
509 arg: Self::Arg,
510 ) -> ParseResult<'a, Self::Parsed> {
511 let loc = state_stream.loc();
512 spaced(TypeId::parser(()))
513 .then(move |type_id| {
514 let loc = loc.clone();
515 combine::parser(move |parsable_state: &mut StateStream<'a>| {
516 if type_id != T::get_type_id_static() {
517 input_err!(
518 loc.clone(),
519 "Expected type {}, but found {}",
520 T::get_type_id_static().disp(parsable_state.state.ctx),
521 type_id.disp(parsable_state.state.ctx)
522 )?
523 }
524 T::parser(arg).parse_stream(parsable_state).into()
525 })
526 })
527 .parse_stream(state_stream)
528 .into_result()
529 }
530}
531
532impl<T: Type> Verify for TypePtr<T> {
533 fn verify(&self, ctx: &Context) -> Result<()> {
534 self.0.verify(ctx)
535 }
536}
537
538pub fn type_cast<T: ?Sized + Type>(ty: &dyn Type) -> Option<&T> {
540 crate::utils::trait_cast::any_to_trait::<T>(ty.as_any())
541}
542
543pub fn type_impls<T: ?Sized + Type>(ty: &dyn Type) -> bool {
545 type_cast::<T>(ty).is_some()
546}
547
548pub type TypeInterfaceVerifier = fn(&dyn Type, &Context) -> Result<()>;
550
551#[doc(hidden)]
552#[distributed_slice]
554pub static TYPE_INTERFACE_VERIFIERS: [LazyLock<(
555 TypeId,
556 (std::any::TypeId, TypeInterfaceVerifier),
557)>];
558
559#[doc(hidden)]
560#[distributed_slice]
562pub static TYPE_INTERFACE_DEPS: [LazyLock<(std::any::TypeId, Vec<std::any::TypeId>)>];
563
564#[doc(hidden)]
565pub static TYPE_INTERFACE_VERIFIERS_MAP: LazyLock<
568 FxHashMap<TypeId, Vec<(std::any::TypeId, TypeInterfaceVerifier)>>,
569> = LazyLock::new(|| {
570 use std::any::TypeId;
571 let mut type_intr_verifiers = FxHashMap::default();
573 for lazy in TYPE_INTERFACE_VERIFIERS {
574 let (ty_id, (type_id, verifier)) = (**lazy).clone();
575 type_intr_verifiers
576 .entry(ty_id)
577 .and_modify(|verifiers: &mut Vec<(TypeId, TypeInterfaceVerifier)>| {
578 verifiers.push((type_id, verifier))
579 })
580 .or_insert(vec![(type_id, verifier)]);
581 }
582
583 let interface_deps: FxHashMap<_, _> = TYPE_INTERFACE_DEPS
585 .iter()
586 .map(|lazy| (**lazy).clone())
587 .collect();
588
589 let mut dep_sort_idx = FxHashMap::<TypeId, u32>::default();
592 let mut sort_idx = 0;
593 fn assign_idx_to_intr(
594 interface_deps: &FxHashMap<TypeId, Vec<TypeId>>,
595 dep_sort_idx: &mut FxHashMap<TypeId, u32>,
596 sort_idx: &mut u32,
597 intr: &TypeId,
598 ) {
599 if dep_sort_idx.contains_key(intr) {
600 return;
601 }
602
603 let deps = interface_deps
606 .get(intr)
607 .expect("Expect every interface to have a (possibly empty) list of dependences");
608 for dep in deps {
609 assign_idx_to_intr(interface_deps, dep_sort_idx, sort_idx, dep);
610 }
611
612 dep_sort_idx.insert(*intr, *sort_idx);
614 *sort_idx += 1;
615 }
616
617 for lazy in TYPE_INTERFACE_DEPS.iter() {
619 let (intr, _deps) = &**lazy;
620 assign_idx_to_intr(&interface_deps, &mut dep_sort_idx, &mut sort_idx, intr);
621 }
622
623 for verifiers in type_intr_verifiers.values_mut() {
624 verifiers.sort_by(|(a, _), (b, _)| dep_sort_idx[a].cmp(&dep_sort_idx[b]));
626 }
627
628 type_intr_verifiers
629});
630
631#[cfg(test)]
632mod tests {
633
634 use pliron::result::Result;
635 use rustc_hash::{FxHashMap, FxHashSet};
636 use std::any::TypeId;
637
638 use crate::verify_err_noloc;
639
640 use super::{TYPE_INTERFACE_DEPS, TYPE_INTERFACE_VERIFIERS_MAP};
641
642 #[test]
643 fn check_verifiers_deps() -> Result<()> {
647 let interface_deps: FxHashMap<_, _> = TYPE_INTERFACE_DEPS
649 .iter()
650 .map(|lazy| (**lazy).clone())
651 .collect();
652
653 for (ty, intrs) in TYPE_INTERFACE_VERIFIERS_MAP.iter() {
654 let mut satisfied_deps = FxHashSet::<TypeId>::default();
655 for (intr, _) in intrs {
656 let deps = interface_deps.get(intr).ok_or_else(|| {
657 let err: Result<()> = verify_err_noloc!(
658 "Missing deps list for TypeId {:?} when checking verifier dependences for {}",
659 intr,
660 ty
661 );
662 err.unwrap_err()
663 })?;
664 for dep in deps {
665 if !satisfied_deps.contains(dep) {
666 return verify_err_noloc!(
667 "For {}, depencence {:?} not satisfied for {:?}",
668 ty,
669 dep,
670 intr
671 );
672 }
673 }
674 satisfied_deps.insert(*intr);
675 }
676 }
677
678 Ok(())
679 }
680}