1use std::sync::Arc;
26
27use cranelift_codegen::ir::{
28 condcodes::IntCC, types, AbiParam, InstBuilder, MemFlags, Signature, Value,
29};
30use cranelift_codegen::isa::CallConv;
31use cranelift_codegen::settings::{self, Configurable};
32use cranelift_frontend::{FunctionBuilder, FunctionBuilderContext};
33use cranelift_jit::{JITBuilder, JITModule};
34use cranelift_module::{Linkage, Module};
35
36use super::filter_compiler::JITModuleOwner;
37
38#[derive(Debug, Clone, Copy)]
44pub struct ProjectSpec {
45 pub src_idx: usize,
47}
48
49type ProjectFn = unsafe extern "C" fn(*const f64, usize, *mut f64, usize) -> i8;
57
58pub struct CompiledProject {
64 fn_ptr: ProjectFn,
66 specs: Vec<ProjectSpec>,
68 _module_owner: Arc<JITModuleOwner>,
70}
71
72unsafe impl Send for CompiledProject {}
75unsafe impl Sync for CompiledProject {}
76
77impl std::fmt::Debug for CompiledProject {
78 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
79 f.debug_struct("CompiledProject")
80 .field("output_width", &self.specs.len())
81 .finish_non_exhaustive()
82 }
83}
84
85impl CompiledProject {
86 pub fn output_width(&self) -> usize {
88 self.specs.len()
89 }
90
91 pub fn extract(&self, src: &[f64], dst: &mut Vec<f64>) -> bool {
102 let out_len = self.specs.len();
103 dst.resize(out_len, 0.0);
104 let result = unsafe { (self.fn_ptr)(src.as_ptr(), src.len(), dst.as_mut_ptr(), dst.len()) };
110 result == 1
111 }
112}
113
114#[derive(Debug, thiserror::Error)]
120pub enum ProjectCompilerError {
121 #[error("JIT codegen error: {0}")]
123 CodegenError(String),
124 #[error("JIT ISA init error: {0}")]
126 IsaInitError(String),
127 #[error("JIT linkage error: {0}")]
129 LinkageError(String),
130}
131
132pub struct ProjectCompiler;
142
143impl Default for ProjectCompiler {
144 fn default() -> Self {
145 ProjectCompiler
146 }
147}
148
149impl ProjectCompiler {
150 pub fn new() -> Self {
152 ProjectCompiler
153 }
154
155 pub fn compile(
159 &mut self,
160 specs: &[ProjectSpec],
161 ) -> Result<CompiledProject, ProjectCompilerError> {
162 let module = build_jit_module()?;
163 let (fn_ptr, module) = compile_project_fn(module, specs)?;
164 let owner = Arc::new(JITModuleOwner::new(module));
165 Ok(CompiledProject {
166 fn_ptr,
167 specs: specs.to_vec(),
168 _module_owner: owner,
169 })
170 }
171}
172
173fn build_jit_module() -> Result<JITModule, ProjectCompilerError> {
178 let mut flag_builder = settings::builder();
179 flag_builder
180 .set("use_colocated_libcalls", "false")
181 .map_err(|e| ProjectCompilerError::CodegenError(e.to_string()))?;
182 flag_builder
183 .set("is_pic", "false")
184 .map_err(|e| ProjectCompilerError::CodegenError(e.to_string()))?;
185 flag_builder
186 .set("opt_level", "speed")
187 .map_err(|e| ProjectCompilerError::CodegenError(e.to_string()))?;
188
189 let flags = settings::Flags::new(flag_builder);
190 let isa = cranelift_native::builder()
191 .map_err(|e| ProjectCompilerError::IsaInitError(e.to_string()))?
192 .finish(flags)
193 .map_err(|e| ProjectCompilerError::IsaInitError(e.to_string()))?;
194
195 let builder = JITBuilder::with_isa(isa, cranelift_module::default_libcall_names());
196 Ok(JITModule::new(builder))
197}
198
199fn compile_project_fn(
206 mut module: JITModule,
207 specs: &[ProjectSpec],
208) -> Result<(ProjectFn, JITModule), ProjectCompilerError> {
209 let ptr_type = module.isa().pointer_type();
211
212 let mut sig = Signature::new(CallConv::SystemV);
214 sig.params.push(AbiParam::new(ptr_type)); sig.params.push(AbiParam::new(ptr_type)); sig.params.push(AbiParam::new(ptr_type)); sig.params.push(AbiParam::new(ptr_type)); sig.returns.push(AbiParam::new(types::I8)); let func_id = module
221 .declare_function("project_fn", Linkage::Local, &sig)
222 .map_err(|e| ProjectCompilerError::LinkageError(e.to_string()))?;
223
224 {
225 let mut ctx = module.make_context();
226 ctx.func.signature = sig.clone();
227
228 let mut fn_builder_ctx = FunctionBuilderContext::new();
229 let mut builder = FunctionBuilder::new(&mut ctx.func, &mut fn_builder_ctx);
230
231 if specs.is_empty() {
232 emit_trivial_success(&mut builder, ptr_type);
234 } else {
235 emit_project_body(&mut builder, specs, ptr_type)?;
236 }
237
238 builder.finalize();
239
240 module
241 .define_function(func_id, &mut ctx)
242 .map_err(|e| ProjectCompilerError::CodegenError(format!("{e:?}")))?;
243 }
244
245 module
246 .finalize_definitions()
247 .map_err(|e| ProjectCompilerError::CodegenError(format!("finalize_definitions: {e:?}")))?;
248
249 let raw_ptr = module.get_finalized_function(func_id);
251 let fn_ptr: ProjectFn = unsafe { std::mem::transmute(raw_ptr) };
253
254 Ok((fn_ptr, module))
255}
256
257fn emit_trivial_success(builder: &mut FunctionBuilder<'_>, ptr_type: types::Type) {
261 let entry_block = builder.create_block();
262 builder.append_block_params_for_function_params(entry_block);
263 builder.switch_to_block(entry_block);
264 builder.seal_block(entry_block);
265
266 let _src_ptr: Value = builder.block_params(entry_block)[0];
269 let _src_len: Value = builder.block_params(entry_block)[1];
270 let _dst_ptr: Value = builder.block_params(entry_block)[2];
271 let _dst_len: Value = builder.block_params(entry_block)[3];
272
273 let _ = ptr_type;
275
276 let one = builder.ins().iconst(types::I8, 1);
277 builder.ins().return_(&[one]);
278}
279
280fn emit_project_body(
289 builder: &mut FunctionBuilder<'_>,
290 specs: &[ProjectSpec],
291 ptr_type: types::Type,
292) -> Result<(), ProjectCompilerError> {
293 let max_src_idx = specs.iter().map(|s| s.src_idx).max().unwrap_or(0);
296
297 let entry_block = builder.create_block();
299 let bounds_fail_block = builder.create_block();
300 let body_block = builder.create_block();
301
302 builder.append_block_params_for_function_params(entry_block);
304 builder.switch_to_block(entry_block);
305 builder.seal_block(entry_block);
306
307 let src_ptr = builder.block_params(entry_block)[0];
308 let src_len = builder.block_params(entry_block)[1];
309 let dst_ptr = builder.block_params(entry_block)[2];
310 let dst_len = builder.block_params(entry_block)[3];
311
312 let max_src_val = builder.ins().iconst(ptr_type, max_src_idx as i64);
316 let src_too_short = builder
317 .ins()
318 .icmp(IntCC::UnsignedLessThanOrEqual, src_len, max_src_val);
319
320 let dst_need = builder.ins().iconst(ptr_type, specs.len() as i64);
323 let dst_too_short = builder
324 .ins()
325 .icmp(IntCC::UnsignedLessThan, dst_len, dst_need);
326
327 let any_fail = builder.ins().bor(src_too_short, dst_too_short);
329
330 builder
332 .ins()
333 .brif(any_fail, bounds_fail_block, &[], body_block, &[]);
334
335 builder.switch_to_block(bounds_fail_block);
337 builder.seal_block(bounds_fail_block);
338 let zero_i8 = builder.ins().iconst(types::I8, 0);
339 builder.ins().return_(&[zero_i8]);
340
341 builder.switch_to_block(body_block);
343 builder.seal_block(body_block);
344
345 for (dst_i, spec) in specs.iter().enumerate() {
347 let src_offset = builder
349 .ins()
350 .iconst(ptr_type, (spec.src_idx * std::mem::size_of::<f64>()) as i64);
351 let src_addr = builder.ins().iadd(src_ptr, src_offset);
352 let val = builder.ins().load(types::F64, MemFlags::new(), src_addr, 0);
354
355 let dst_offset = builder
357 .ins()
358 .iconst(ptr_type, (dst_i * std::mem::size_of::<f64>()) as i64);
359 let dst_addr = builder.ins().iadd(dst_ptr, dst_offset);
360 builder.ins().store(MemFlags::new(), val, dst_addr, 0);
362 }
363
364 let one_i8 = builder.ins().iconst(types::I8, 1);
365 builder.ins().return_(&[one_i8]);
366
367 Ok(())
368}
369
370#[cfg(test)]
375mod tests {
376 use super::*;
377
378 fn compiler() -> ProjectCompiler {
379 ProjectCompiler::new()
380 }
381
382 #[test]
383 fn test_empty_specs_compiles_and_succeeds() {
384 let cp = compiler().compile(&[]).expect("compile ok");
385 assert_eq!(cp.output_width(), 0);
386 let src = [1.0f64, 2.0, 3.0];
387 let mut dst = Vec::new();
388 assert!(cp.extract(&src, &mut dst));
389 assert!(dst.is_empty());
390 }
391
392 #[test]
393 fn test_single_col_extract() {
394 let specs = [ProjectSpec { src_idx: 1 }];
395 let cp = compiler().compile(&specs).expect("compile ok");
396 assert_eq!(cp.output_width(), 1);
397 let src = [10.0f64, 20.0, 30.0];
398 let mut dst = Vec::new();
399 assert!(cp.extract(&src, &mut dst));
400 assert_eq!(dst, vec![20.0]);
401 }
402
403 #[test]
404 fn test_reorder_extract() {
405 let specs = [
406 ProjectSpec { src_idx: 2 },
407 ProjectSpec { src_idx: 0 },
408 ProjectSpec { src_idx: 1 },
409 ];
410 let cp = compiler().compile(&specs).expect("compile ok");
411 let src = [1.0f64, 2.0, 3.0];
412 let mut dst = Vec::new();
413 assert!(cp.extract(&src, &mut dst));
414 assert_eq!(dst, vec![3.0, 1.0, 2.0]);
415 }
416
417 #[test]
418 fn test_src_bounds_fail() {
419 let specs = [ProjectSpec { src_idx: 5 }];
420 let cp = compiler().compile(&specs).expect("compile ok");
421 let src = [1.0f64, 2.0]; let mut dst = Vec::new();
423 assert!(!cp.extract(&src, &mut dst));
424 }
425}