1use std::rc::Rc;
8
9use self::extension_op::{ExtensionOpFn, ExtensionOpMap};
10use hugr_core::{
11 HugrView, Node,
12 extension::{ExtensionId, simple_op::MakeOpDef},
13 ops::{ExtensionOp, OpName, constant::CustomConst},
14};
15
16use strum::IntoEnumIterator;
17use types::CustomTypeKey;
18
19use self::load_constant::{LoadConstantFn, LoadConstantsMap};
20use self::types::LLVMCustomTypeFn;
21use anyhow::Result;
22
23use crate::{
24 emit::{EmitOpArgs, func::EmitFuncContext},
25 types::TypeConverter,
26};
27
28pub mod extension_op;
29pub mod load_constant;
30pub mod types;
31
32pub trait CodegenExtension {
39 fn add_extension<'a, H: HugrView<Node = Node> + 'a>(
42 self,
43 builder: CodegenExtsBuilder<'a, H>,
44 ) -> CodegenExtsBuilder<'a, H>
45 where
46 Self: 'a;
47}
48
49#[derive(Default)]
64pub struct CodegenExtsBuilder<'a, H> {
65 load_constant_handlers: LoadConstantsMap<'a, H>,
66 extension_op_handlers: ExtensionOpMap<'a, H>,
67 type_converter: TypeConverter<'a>,
68}
69
70impl<'a, H: HugrView<Node = Node> + 'a> CodegenExtsBuilder<'a, H> {
71 pub fn add_extension(self, ext: impl CodegenExtension + 'a) -> Self {
79 ext.add_extension(self)
80 }
81
82 pub fn custom_type(
87 mut self,
88 custom_type: CustomTypeKey,
89 handler: impl LLVMCustomTypeFn<'a>,
90 ) -> Self {
91 self.type_converter.custom_type(custom_type, handler);
92 self
93 }
94
95 pub fn extension_op(
98 mut self,
99 extension: ExtensionId,
100 op: OpName,
101 handler: impl ExtensionOpFn<'a, H>,
102 ) -> Self {
103 self.extension_op_handlers
104 .extension_op(extension, op, handler);
105 self
106 }
107
108 pub fn simple_extension_op<Op: MakeOpDef + IntoEnumIterator>(
111 mut self,
112 handler: impl 'a
113 + for<'c> Fn(
114 &mut EmitFuncContext<'c, 'a, H>,
115 EmitOpArgs<'c, '_, ExtensionOp, H>,
116 Op,
117 ) -> Result<()>,
118 ) -> Self {
119 self.extension_op_handlers
120 .simple_extension_op::<Op>(handler);
121 self
122 }
123
124 pub fn custom_const<CC: CustomConst>(
126 mut self,
127 handler: impl LoadConstantFn<'a, H, CC>,
128 ) -> Self {
129 self.load_constant_handlers.custom_const(handler);
130 self
131 }
132
133 #[must_use]
136 pub fn finish(self) -> CodegenExtsMap<'a, H> {
137 CodegenExtsMap {
138 load_constant_handlers: Rc::new(self.load_constant_handlers),
139 extension_op_handlers: Rc::new(self.extension_op_handlers),
140 type_converter: Rc::new(self.type_converter),
141 }
142 }
143}
144
145#[derive(Default)]
150#[non_exhaustive]
151pub struct CodegenExtsMap<'a, H> {
152 pub load_constant_handlers: Rc<LoadConstantsMap<'a, H>>,
153 pub extension_op_handlers: Rc<ExtensionOpMap<'a, H>>,
154 pub type_converter: Rc<TypeConverter<'a>>,
155}
156
157#[cfg(test)]
158mod test {
159 use hugr_core::{
160 Hugr,
161 extension::prelude::{ConstString, PRELUDE_ID, PRINT_OP_ID, STRING_TYPE_NAME, string_type},
162 };
163 use inkwell::{
164 context::Context,
165 types::BasicType,
166 values::{BasicMetadataValueEnum, BasicValue},
167 };
168 use itertools::Itertools as _;
169
170 use crate::{CodegenExtsBuilder, emit::libc::emit_libc_printf};
171
172 #[test]
173 fn types_with_lifetimes() {
174 let n = "name_with_lifetime".to_string();
175
176 let cem = CodegenExtsBuilder::<Hugr>::default()
177 .custom_type((PRELUDE_ID, STRING_TYPE_NAME), |session, _| {
178 let ctx = session.iw_context();
179 Ok(ctx
180 .get_struct_type(n.as_ref())
181 .unwrap_or_else(|| ctx.opaque_struct_type(n.as_ref()))
182 .as_basic_type_enum())
183 })
184 .finish();
185
186 let ctx = Context::create();
187
188 let ty = cem
189 .type_converter
190 .session(&ctx)
191 .llvm_type(&string_type())
192 .unwrap()
193 .into_struct_type();
194 let ty_n = ty.get_name().unwrap().to_str().unwrap();
195 assert_eq!(ty_n, n);
196 }
197
198 #[test]
199 fn custom_const_lifetime_of_context() {
200 let ctx = Context::create();
201
202 let _ = CodegenExtsBuilder::<Hugr>::default()
203 .custom_const::<ConstString>(|_, konst| {
204 Ok(ctx
205 .const_string(konst.value().as_bytes(), true)
206 .as_basic_value_enum())
207 })
208 .finish();
209 }
210
211 #[test]
212 fn extension_op_lifetime() {
213 let ctx = Context::create();
214
215 let _ = CodegenExtsBuilder::<Hugr>::default()
216 .extension_op(PRELUDE_ID, PRINT_OP_ID, |context, args| {
217 let mut print_args: Vec<BasicMetadataValueEnum> =
218 vec![ctx.const_string("%s".as_bytes(), true).into()];
219 print_args.extend(args.inputs.into_iter().map_into::<BasicMetadataValueEnum>());
220 emit_libc_printf(context, &print_args)?;
221 args.outputs.finish(context.builder(), [])
222 })
223 .finish();
224 }
225}