1use std::rc::Rc;
8
9use self::extension_op::{ExtensionOpFn, ExtensionOpMap};
10use hugr_core::{
11 HugrView, Node,
12 extension::{ExtensionId, simple_op::MakeOpDef},
13 ops::{ExtensionOp, OpName, constant::CustomConst},
14};
15
16use strum::IntoEnumIterator;
17use types::CustomTypeKey;
18
19use self::load_constant::{LoadConstantFn, LoadConstantsMap};
20use self::types::LLVMCustomTypeFn;
21use anyhow::Result;
22
23use crate::{
24 emit::{EmitOpArgs, func::EmitFuncContext},
25 types::TypeConverter,
26};
27
28pub mod extension_op;
29pub mod load_constant;
30pub mod types;
31
32pub trait CodegenExtension {
39 fn add_extension<'a, H: HugrView<Node = Node> + 'a>(
42 self,
43 builder: CodegenExtsBuilder<'a, H>,
44 ) -> CodegenExtsBuilder<'a, H>
45 where
46 Self: 'a;
47}
48
49#[derive(Default)]
64pub struct CodegenExtsBuilder<'a, H> {
65 load_constant_handlers: LoadConstantsMap<'a, H>,
66 extension_op_handlers: ExtensionOpMap<'a, H>,
67 type_converter: TypeConverter<'a>,
68}
69
70impl<'a, H: HugrView<Node = Node> + 'a> CodegenExtsBuilder<'a, H> {
71 pub fn add_extension(self, ext: impl CodegenExtension + 'a) -> Self {
79 ext.add_extension(self)
80 }
81
82 pub fn custom_type(
87 mut self,
88 custom_type: CustomTypeKey,
89 handler: impl LLVMCustomTypeFn<'a>,
90 ) -> Self {
91 self.type_converter.custom_type(custom_type, handler);
92 self
93 }
94
95 pub fn extension_op(
98 mut self,
99 extension: ExtensionId,
100 op: OpName,
101 handler: impl ExtensionOpFn<'a, H>,
102 ) -> Self {
103 self.extension_op_handlers
104 .extension_op(extension, op, handler);
105 self
106 }
107
108 pub fn simple_extension_op<Op: MakeOpDef + IntoEnumIterator>(
111 mut self,
112 handler: impl 'a
113 + for<'c> Fn(
114 &mut EmitFuncContext<'c, 'a, H>,
115 EmitOpArgs<'c, '_, ExtensionOp, H>,
116 Op,
117 ) -> Result<()>,
118 ) -> Self {
119 self.extension_op_handlers
120 .simple_extension_op::<Op>(handler);
121 self
122 }
123
124 pub fn custom_const<CC: CustomConst>(
126 mut self,
127 handler: impl LoadConstantFn<'a, H, CC>,
128 ) -> Self {
129 self.load_constant_handlers.custom_const(handler);
130 self
131 }
132
133 #[must_use]
136 pub fn finish(self) -> CodegenExtsMap<'a, H> {
137 CodegenExtsMap {
138 load_constant_handlers: Rc::new(self.load_constant_handlers),
139 extension_op_handlers: Rc::new(self.extension_op_handlers),
140 type_converter: Rc::new(self.type_converter),
141 }
142 }
143}
144
145#[derive(Default)]
151#[non_exhaustive]
152pub struct CodegenExtsMap<'a, H> {
153 pub load_constant_handlers: Rc<LoadConstantsMap<'a, H>>,
154 pub extension_op_handlers: Rc<ExtensionOpMap<'a, H>>,
155 pub type_converter: Rc<TypeConverter<'a>>,
156}
157
158#[cfg(test)]
159mod test {
160 use hugr_core::{
161 Hugr,
162 extension::prelude::{ConstString, PRELUDE_ID, PRINT_OP_ID, STRING_TYPE_NAME, string_type},
163 };
164 use inkwell::{
165 context::Context,
166 types::BasicType,
167 values::{BasicMetadataValueEnum, BasicValue},
168 };
169 use itertools::Itertools as _;
170
171 use crate::{CodegenExtsBuilder, emit::libc::emit_libc_printf};
172
173 #[test]
174 fn types_with_lifetimes() {
175 let n = "name_with_lifetime".to_string();
176
177 let cem = CodegenExtsBuilder::<Hugr>::default()
178 .custom_type((PRELUDE_ID, STRING_TYPE_NAME), |session, _| {
179 let ctx = session.iw_context();
180 Ok(ctx
181 .get_struct_type(n.as_ref())
182 .unwrap_or_else(|| ctx.opaque_struct_type(n.as_ref()))
183 .as_basic_type_enum())
184 })
185 .finish();
186
187 let ctx = Context::create();
188
189 let ty = cem
190 .type_converter
191 .session(&ctx)
192 .llvm_type(&string_type())
193 .unwrap()
194 .into_struct_type();
195 let ty_n = ty.get_name().unwrap().to_str().unwrap();
196 assert_eq!(ty_n, n);
197 }
198
199 #[test]
200 fn custom_const_lifetime_of_context() {
201 let ctx = Context::create();
202
203 let _ = CodegenExtsBuilder::<Hugr>::default()
204 .custom_const::<ConstString>(|_, konst| {
205 Ok(ctx
206 .const_string(konst.value().as_bytes(), true)
207 .as_basic_value_enum())
208 })
209 .finish();
210 }
211
212 #[test]
213 fn extension_op_lifetime() {
214 let ctx = Context::create();
215
216 let _ = CodegenExtsBuilder::<Hugr>::default()
217 .extension_op(PRELUDE_ID, PRINT_OP_ID, |context, args| {
218 let mut print_args: Vec<BasicMetadataValueEnum> =
219 vec![ctx.const_string("%s".as_bytes(), true).into()];
220 print_args.extend(args.inputs.into_iter().map_into::<BasicMetadataValueEnum>());
221 emit_libc_printf(context, &print_args)?;
222 args.outputs.finish(context.builder(), [])
223 })
224 .finish();
225 }
226}