use crate::{
dialect::{self, Dialect},
string_ref::StringRef,
};
use mlir_sys::{
mlirContextAppendDialectRegistry, mlirContextCreate, mlirContextDestroy,
mlirContextEnableMultithreading, mlirContextEqual, mlirContextGetAllowUnregisteredDialects,
mlirContextGetNumLoadedDialects, mlirContextGetNumRegisteredDialects,
mlirContextGetOrLoadDialect, mlirContextIsRegisteredOperation,
mlirContextLoadAllAvailableDialects, mlirContextSetAllowUnregisteredDialects, MlirContext,
};
use std::{marker::PhantomData, ops::Deref};
#[derive(Debug)]
pub struct Context {
r#ref: ContextRef<'static>,
}
impl Context {
pub fn new() -> Self {
Self {
r#ref: unsafe { ContextRef::from_raw(mlirContextCreate()) },
}
}
}
impl Drop for Context {
fn drop(&mut self) {
unsafe { mlirContextDestroy(self.raw) };
}
}
impl Default for Context {
fn default() -> Self {
Self::new()
}
}
impl Deref for Context {
type Target = ContextRef<'static>;
fn deref(&self) -> &Self::Target {
&self.r#ref
}
}
#[derive(Clone, Copy, Debug)]
pub struct ContextRef<'a> {
raw: MlirContext,
_reference: PhantomData<&'a Context>,
}
impl<'a> ContextRef<'a> {
pub fn registered_dialect_count(&self) -> usize {
unsafe { mlirContextGetNumRegisteredDialects(self.raw) as usize }
}
pub fn loaded_dialect_count(&self) -> usize {
unsafe { mlirContextGetNumLoadedDialects(self.raw) as usize }
}
pub fn get_or_load_dialect(&self, name: &str) -> Dialect {
unsafe {
Dialect::from_raw(mlirContextGetOrLoadDialect(
self.raw,
StringRef::from(name).to_raw(),
))
}
}
pub fn append_dialect_registry(&self, registry: &dialect::Registry) {
unsafe { mlirContextAppendDialectRegistry(self.raw, registry.to_raw()) }
}
pub fn load_all_available_dialects(&self) {
unsafe { mlirContextLoadAllAvailableDialects(self.raw) }
}
pub fn enable_multi_threading(&self, enabled: bool) {
unsafe { mlirContextEnableMultithreading(self.raw, enabled) }
}
pub fn allow_unregistered_dialects(&self) -> bool {
unsafe { mlirContextGetAllowUnregisteredDialects(self.raw) }
}
pub fn set_allow_unregistered_dialects(&self, allowed: bool) {
unsafe { mlirContextSetAllowUnregisteredDialects(self.raw, allowed) }
}
pub fn is_registered_operation(&self, name: &str) -> bool {
unsafe { mlirContextIsRegisteredOperation(self.raw, StringRef::from(name).to_raw()) }
}
pub(crate) unsafe fn to_raw(self) -> MlirContext {
self.raw
}
pub(crate) unsafe fn from_raw(raw: MlirContext) -> Self {
Self {
raw,
_reference: Default::default(),
}
}
}
impl<'a> PartialEq for ContextRef<'a> {
fn eq(&self, other: &Self) -> bool {
unsafe { mlirContextEqual(self.raw, other.raw) }
}
}
impl<'a> Eq for ContextRef<'a> {}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn new() {
Context::new();
}
#[test]
fn registered_dialect_count() {
let context = Context::new();
assert_eq!(context.registered_dialect_count(), 1);
}
#[test]
fn loaded_dialect_count() {
let context = Context::new();
assert_eq!(context.loaded_dialect_count(), 1);
}
#[test]
fn append_dialect_registry() {
let context = Context::new();
context.append_dialect_registry(&dialect::Registry::new());
}
#[test]
fn is_registered_operation() {
let context = Context::new();
assert!(context.is_registered_operation("builtin.module"));
}
#[test]
fn is_not_registered_operation() {
let context = Context::new();
assert!(!context.is_registered_operation("func.func"));
}
#[test]
fn enable_multi_threading() {
let context = Context::new();
context.enable_multi_threading(true);
}
#[test]
fn disable_multi_threading() {
let context = Context::new();
context.enable_multi_threading(false);
}
#[test]
fn allow_unregistered_dialects() {
let context = Context::new();
assert!(!context.allow_unregistered_dialects());
}
#[test]
fn set_allow_unregistered_dialects() {
let context = Context::new();
context.set_allow_unregistered_dialects(true);
assert!(context.allow_unregistered_dialects());
}
}