1extern crate proc_macro;
2use proc_macro::TokenStream;
3use proc_macro_error::{proc_macro_error,abort};
4use quote::quote;
5use std::collections::HashMap;
6use syn::{braced, Attribute, Ident, Stmt, Token};
7
8mod index;
9use index::IndexStruct;
10
11mod tensor;
12use tensor::*;
13
14mod scalar;
15use scalar::ScalarStruct;
16
17mod cometexpr;
18use cometexpr::*;
19
20use std::env;
21use std::path::PathBuf;
22use std::process::Command;
23
24enum CometResult{
25 Success(String),
26 Failure(String),
27}
28
29fn create_lib(func_name: &str, mlir_str: &str, comet_opts: Vec<CometOption>, mlir_opts: Vec<MlirOption>) -> CometResult{
30 let comet_base = if let Ok(base) = env::var("COMET_DIR") {
36 PathBuf::from(base)
37 } else {
38 PathBuf::from(env::var("CARGO_MANIFEST_DIR").unwrap()).join("../../..")
39 };
40
41 let comet_bin = if let Ok(dir) = env::var("COMET_BIN_DIR"){
42 PathBuf::from(dir).join("comet-opt")
43 }
44 else{
45 comet_base.clone().join("build/bin/comet-opt")
46 };
47 if !comet_bin.exists(){
48 panic!("Cannot find comet-opt at {}\n Please set the COMET_BIN_DIR envrionement variable to the directory containing the comet-opt binary", comet_bin.display())
49 }
50
51 let comet_lib = if let Ok(dir) = env::var("COMET_LIB_DIR"){
52 PathBuf::from(dir)
53 }
54 else{
55 comet_base.clone().join("build/lib")
56 };
57 if !comet_lib.exists(){
58 panic!("Cannot find comet lib directory at {}\n Please set the COMET_LIB_DIR envrionement variable to the directory containing the comet shared libraries", comet_lib.display())
59 }
60
61 let mlir_bin = if let Ok(dir) = env::var("MLIR_BIN_DIR"){
62 PathBuf::from(dir)
63 }
64 else{
65 comet_base.clone().join("llvm/build/bin")
66 };
67 if !mlir_bin.exists(){
68 panic!("Cannot find mlir-opt at {}\n Please set the MLIR_OPT_DIR envrionement variable to the directory containing the mlir-opt binary", mlir_bin.display())
69 }
70
71 let mlir_lib = if let Ok(dir) = env::var("MLIR_LIB_DIR"){
72 PathBuf::from(dir)
73 }
74 else{
75 comet_base.clone().join("llvm/build/lib")
76 };
77 if !mlir_lib.exists(){
78 panic!("Cannot find mlir lib directory at {}\n Please set the MLIR_LIB_DIR envrionement variable to the directory containing the mlir shared libraries", mlir_lib.display())
79 }
80
81
82 let base_libpath = PathBuf::from(env::var("CARGO_MANIFEST_DIR").unwrap())
83 .join("comet_libs")
84 .join(env::var("CARGO_CRATE_NAME").unwrap());
85 std::fs::create_dir_all(base_libpath.clone()).expect("Could not create directory");
86 let base_filepath = base_libpath.clone().join(func_name);
87
88 let comet_mlir_file = base_filepath.clone().with_extension("comet.mlir");
89 std::fs::write(&comet_mlir_file, mlir_str).expect("Unable to write file");
90 let mut cmd_str = String::new();
93 let mut cmd = Command::new(comet_bin);
94 for opt in comet_opts {
95 cmd_str += opt.as_str();
96 cmd.arg(opt.as_str());
97 }
98 let output = cmd
99 .arg(comet_mlir_file.to_str().unwrap())
100 .output()
101 .expect("Could not run comet-opt");
102
103
104
105 if !output.status.success() {
106 if cfg!(feature = "comet_errors_as_warnings") {
107 println!("ERROR: failed to convert from rs to mlir {:?} option str {}", String::from_utf8_lossy(&output.stderr),cmd_str);
108 return CometResult::Failure(format!("ERROR: failed to convert from rs to mlir {:?} option str {}", String::from_utf8_lossy(&output.stderr),cmd_str));
109 }
110 else {
111 panic!("ERROR: failed to convert from rs to mlir {:?} option str {}", String::from_utf8_lossy(&output.stderr),cmd_str);
112 }
113 }
114 let mlir_file = base_filepath.clone().with_extension("mlir");
115 std::fs::write(&mlir_file, &output.stderr).expect("Unable to write file");
116
117 let mut cmd = Command::new(mlir_bin.clone().join("mlir-opt"));
118 for opt in mlir_opts {
119 cmd.arg(opt.as_str());
120 }
121 let output = cmd
122 .arg(mlir_file.to_str().unwrap())
123 .output()
124 .expect("Could not run mlir-opt");
125
126
127 if !output.status.success() {
128 if cfg!(feature = "comet_errors_as_warnings") {
129 println!("ERROR: failed to convert from mlir to llvm {:?}", String::from_utf8_lossy(&output.stderr));
130 return CometResult::Failure(String::from_utf8_lossy(&output.stderr).to_string());
131 }
132 else {
133 println!("here");
134 panic!("ERROR: failed to convert from mlir to llvm {:?}", String::from_utf8_lossy(&output.stderr));
135 }
136 }
137 let llvm_file = base_filepath.clone().with_extension("llvm");
143 std::fs::write(&llvm_file, &output.stdout).expect("Unable to write file");
144
145 let bc_file = base_filepath.clone().with_extension("bc");
146 let _status = Command::new(mlir_bin.clone().join("mlir-translate"))
147 .arg("--mlir-to-llvmir")
148 .arg(llvm_file.to_str().unwrap())
149 .arg("-o")
150 .arg(bc_file.to_str().unwrap())
151 .status()
152 .expect("Could not run mlir-translate");
153
154 let obj_file = base_filepath.clone().with_extension("comet.o");
155 let _status = Command::new(mlir_bin.clone().join("llc"))
156 .arg(bc_file.to_str().unwrap())
157 .arg("-o")
158 .arg(obj_file.to_str().unwrap())
159 .arg("-filetype=obj")
160 .status()
161 .expect("Could not run llc");
162
163 let so_file = base_filepath.clone().with_extension("so");
164 let _status = Command::new("gcc")
167 .arg(format!{"-Wl,-rpath,{}", mlir_lib.display()})
168 .arg(format!{"-L{}", mlir_lib.display()})
169 .arg("-lmlir_runner_utils")
170 .arg(format!{"-Wl,-rpath,{}",comet_lib.display()})
171 .arg(format!{"-L{}",comet_lib.display()})
172 .arg("-lcomet_runner_utils")
173 .arg("-shared")
174 .arg("-o")
175 .arg(so_file.to_str().unwrap())
176 .arg(obj_file.to_str().unwrap())
177 .status()
178 .expect("Could not run gcc");
179
180
181 CometResult::Success(so_file.to_str().unwrap().to_string())
182}
183
184use syn::parse_macro_input;
185use syn::parse::discouraged::Speculative;
186use syn::parse::{Parse, ParseStream};
187
188#[derive(Debug)]
189enum Comet {
190 Index(Vec<IndexStruct>),
191 Tensor(TensorStruct),
192 Scalar(ScalarStruct),
193 Expr(CometExpr),
194}
195
196#[derive(Debug)]
197pub(crate) struct CometVars {
198 indices: HashMap<Ident, IndexStruct>,
199 tensors: HashMap<Ident, TensorStruct>,
200 scalars: HashMap<Ident, ScalarStruct>,
201 custom_ops: HashMap<String, String>,
202}
203
204impl Comet {
205 fn my_parse(
206 input: ParseStream,
207 vars: &mut CometVars,
208 object_id: &mut usize,
209 ) -> syn::Result<Self> {
210 let mut errs = vec![];
211 match IndexStruct::my_parse(input, vars, object_id) {
212 Ok(indices) => {
213 return Ok(Comet::Index(indices));
214 }
215 Err(e) => errs.push(e),
216 }
217 match TensorStruct::my_parse(input, vars, object_id) {
218 Ok(tensor) => {
219 return Ok(Comet::Tensor(tensor));
220 }
221 Err(e) => errs.push(e),
222 }
223
224 match ScalarStruct::my_parse(input, vars, object_id) {
225 Ok(scalar) => {
226 return Ok(Comet::Scalar(scalar));
227 }
228 Err(e) => errs.push(e),
229 }
230 match CometExpr::my_parse(input, vars, object_id) {
231 Ok(expr) => {
232 return Ok(Comet::Expr(expr));
233 }
234 Err(e) => errs.push(e),
235 }
236 if errs.len() >= 1 {
237 return Err(errs.remove(0));
238 } else {
239 return Err(syn::Error::new(
240 input.span(),
241 "Could not parse comet expression",
242 )); }
244 }
245}
246
247#[derive(Debug, Clone, PartialEq, Eq)]
248enum CometOption{
249 TcToTtgt,
250 ToLoops,
251 TaToIt,
252 BestPermTtgt,
253 MatMulTiling,
254 MatMulkernel,
255 DenseTranspose,
256 CompWorkspace
257}
258
259impl CometOption{
260 fn as_str(&self) -> &'static str{
261 match self{
262 CometOption::TcToTtgt =>"--convert-tc-to-ttgt",
263 CometOption::ToLoops =>"--convert-to-loops",
264 CometOption::TaToIt =>"--convert-ta-to-it",
265 CometOption::BestPermTtgt =>"-opt-bestperm-ttgt",
266 CometOption::MatMulTiling =>"-opt-matmul-tiling",
267 CometOption::MatMulkernel =>"-opt-matmul-mkernel",
268 CometOption::DenseTranspose =>"-opt-dense-transpose",
269 CometOption::CompWorkspace =>"-opt-comp-workspace",
270 }
271 }
272}
273
274impl CometOption {
275 fn my_parse(input: ParseStream) -> syn::Result<Vec<Self>> {
276 let mut options = vec![];
277 let fork = input.fork();
278 if let Ok(comet_opt) = fork.parse::<Ident>(){
279 if comet_opt.to_string().as_str() == "CometOption"{
280 fork.parse::<Token![::]>()?;
281 let content;
282 syn::bracketed!(content in fork);
283 while !content.is_empty() {
284 if let Ok(opt) = content.parse::<Ident>(){
285 match opt.to_string().as_str() {
286 "TcToTtgt" => {
287 if !options.contains(&CometOption::TcToTtgt){
288 options.push(CometOption::TcToTtgt);
289 }
290 },
291 "ToLoops" => {
292 if !options.contains(&CometOption::ToLoops){
293 options.push(CometOption::ToLoops);
294 }
295 },
296 "TaToIt" => {
297 if !options.contains(&CometOption::TaToIt){
298 options.push(CometOption::TaToIt);
299 }
300 },
301 "BestPermTtgt" => {
302 if !options.contains(&CometOption::BestPermTtgt){
303 options.push(CometOption::BestPermTtgt);
304 }
305 },
306 "MatMulTiling" => {
307 if !options.contains(&CometOption::MatMulTiling){
308 options.push(CometOption::MatMulTiling);
309 }
310 },
311 "MatMulkernel" => {
312 if !options.contains(&CometOption::MatMulkernel){
313 options.push(CometOption::MatMulkernel);
314 }
315 },
316 "DenseTranspose" => {
317 if !options.contains(&CometOption::DenseTranspose){
318 options.push(CometOption::DenseTranspose);
319 }
320 },
321 "CompWorkspace" => {
322 if !options.contains(&CometOption::CompWorkspace){
323 options.push(CometOption::CompWorkspace);
324 }
325 },
326 _ => abort!(opt.span(),"Unknown CometOption"),
327 }
328 if content.peek(Token![,]){
329 content.parse::<Token![,]>()?;
330 }
331 }
332 else{
333 abort!(content.span(),"Unknown CometOption");
334 }
335 }
336
337 }
338 else{
339
340 return Err(syn::Error::new(
341 comet_opt.span(),
342 "expected CometOption",
343 ))
344 }
345 }
346 else{
347 return Err(syn::Error::new(
348 input.span(),
349 "expected CometOption",
350 ))
351 }
352 input.advance_to(&fork);
353 Ok(options)
354 }
355}
356
357#[derive(Debug, Clone, PartialEq, Eq)]
358enum MlirOption{
359 ConvertLinalgToLoops,
360 ConvertLinalgToStd,
361 ConvertLinalgToLlvm,
362 ToLoops,
363 ConvertScfToStd,
364 ConvertStdToLlvm,
365 LowerAffine,
366}
367
368impl MlirOption{
369 fn as_str(&self) -> &'static str{
370 match self{
371 MlirOption::ConvertLinalgToLoops =>"--convert-linalg-to-loops",
372 MlirOption::ConvertLinalgToStd =>"--convert-linalg-to-std",
373 MlirOption::ConvertLinalgToLlvm =>"--convert-linalg-to-llvm",
374 MlirOption::ToLoops =>"--convert-to-loops",
375 MlirOption::ConvertScfToStd =>"--convert-scf-to-std",
376 MlirOption::ConvertStdToLlvm =>"--convert-std-to-llvm",
377 MlirOption::LowerAffine =>"--lower-affine",
378 }
379 }
380}
381
382impl MlirOption {
383 fn my_parse(input: ParseStream) -> syn::Result<Vec<Self>> {
384 let mut options = vec![];
385 let fork = input.fork();
386 if let Ok(mlir_opt) = fork.parse::<Ident>(){
387 if mlir_opt.to_string().as_str() == "MlirOption"{
388 fork.parse::<Token![::]>()?;
389 let content;
390 syn::bracketed!(content in fork);
391 while !content.is_empty() {
392 if let Ok(opt) = content.parse::<Ident>(){
393 match opt.to_string().as_str() {
394 "ConvertLinalgToLoops" => {
395 if !options.contains(&MlirOption::ConvertLinalgToLoops){
396 options.push(MlirOption::ConvertLinalgToLoops);
397 }
398 },
399 "ConvertLinalgToStd" => {
400 if !options.contains(&MlirOption::ConvertLinalgToStd){
401 options.push(MlirOption::ConvertLinalgToStd);
402 }
403 },
404 "ConvertLinalgToLlvm" => {
405 if !options.contains(&MlirOption::ConvertLinalgToLlvm){
406 options.push(MlirOption::ConvertLinalgToLlvm);
407 }
408 },
409 "ToLoops" => {
410 if !options.contains(&MlirOption::ToLoops){
411 options.push(MlirOption::ToLoops);
412 }
413 },
414 "ConvertScfToStd" => {
415 if !options.contains(&MlirOption::ConvertScfToStd){
416 options.push(MlirOption::ConvertScfToStd);
417 }
418 },
419 "ConvertStdToLlvm" => {
420 if !options.contains(&MlirOption::ConvertStdToLlvm){
421 options.push(MlirOption::ConvertStdToLlvm);
422 }
423 },
424 "LowerAffine" => {
425 if !options.contains(&MlirOption::LowerAffine){
426 options.push(MlirOption::LowerAffine);
427 }
428 },
429 _ => abort!(opt.span(),"Unknown MlirOption"),
430 }
431 if content.peek(Token![,]){
432 content.parse::<Token![,]>()?;
433 }
434 }
435 else{
436 abort!(content.span(),"Unknown MlirOption");
437 }
438 }
439 }
440 else{
441 return Err(syn::Error::new(
442 mlir_opt.span(),
443 "expected MlirOption",
444 ))
445 }
446 }
447 else{
448 return Err(syn::Error::new(
449 input.span(),
450 "expected MlirOption",
451 ))
452 }
453 input.advance_to(&fork);
454 Ok(options)
455 }
456}
457
458pub(crate) struct CometFn {
459 pub name: Ident,
460 lib_path: CometResult,
461 sparse_env: proc_macro2::TokenStream,
462}
463
464impl Parse for CometFn {
465 fn parse(input: ParseStream) -> syn::Result<Self> {
466 let name = input.parse::<Ident>()?;
467 input.parse::<Token![,]>()?;
468 let content;
469 let _brace_token = braced!(content in input);
470 let _inner_attrs = content.call(Attribute::parse_inner)?;
471 let mut vars = CometVars {
472 indices: HashMap::new(), tensors: HashMap::new(), scalars: HashMap::new(), custom_ops: HashMap::new(), };
477 let mut object_id = 0;
478
479 let mut mlir_str = format!("module {{\nfunc @{}() {{\n", name);
480 while !content.is_empty() {
481 let fork = content.fork();
482 match Comet::my_parse(&fork, &mut vars, &mut object_id) {
483 Ok(comet) => {
484 content.advance_to(&fork);
485 match comet {
486 Comet::Index(indices) => {
487 for index in indices {
489 mlir_str += &index.emit_mlir();
490 }
491 }
492 Comet::Tensor(tensor) => {
493 mlir_str += &tensor.emit_mlir();
495 }
496 Comet::Scalar(mut scalar) => {
497 mlir_str += &scalar.emit_mlir();
499 }
500 Comet::Expr(mut expr) => {
501 mlir_str += &expr.emit_mlir()?;
503 }
504 }
505 }
506 Err(_) => {
507 let stmt = content.parse::<Stmt>()?;
508 println!("not a comet stmt{:?}", stmt);
509 }
510 }
511 }
512
513 mlir_str += "\"ta.return\"(): () -> ()\n}\n}\n\n";
514 let mut opt_comp_workspace = true;
516 let mut sparse_env = quote!{};
518 for tensor in vars.tensors.values() {
519 if tensor.format != TensorFormat::Csr {
521 opt_comp_workspace = false;
522 }
523
524 if let TensorFill::FillFromFile(val,env) = &tensor.fill {
525 let p = val.value();
527 sparse_env.extend(quote!{std::env::set_var( #env, #p );});
528 }
531 }
532 if input.peek(Token![,]){
533 input.parse::<Token![,]>()?;
534 }
535 let mut comet_opts = match CometOption::my_parse(&input){
536 Ok(opts) => {
537 if input.peek(Token![,]){
538 input.parse::<Token![,]>()?;
539 }
540 opts
541 }
542 Err(_) => {
543 vec![CometOption::TaToIt,CometOption::ToLoops]
544 }
545 };
546 let mlir_opts = match MlirOption::my_parse(&input){
547 Ok(opts) => {
548 if input.peek(Token![,]){
549 input.parse::<Token![,]>()?;
550 }
551 opts
552 }
553 Err(_) => {
554 vec![MlirOption::ConvertScfToStd,MlirOption::ConvertStdToLlvm]
555 }
556 };
557
558 if opt_comp_workspace{
559 if !comet_opts.contains(&CometOption::CompWorkspace){
560 comet_opts.insert(0,CometOption::CompWorkspace);
561 }
562 }
563 let lib_path = create_lib(&name.clone().to_string(), &mlir_str,comet_opts,mlir_opts);
565
566 Ok(CometFn {
568 name,
569 lib_path,
570 sparse_env,
571 })
572 }
573}
574
575
576#[proc_macro_error]
577#[proc_macro]
578pub fn comet_fn(input: TokenStream) -> TokenStream {
579 let func = parse_macro_input!(input as CometFn);
580
581 let name = func.name;
582 let sparse_env = func.sparse_env;
583 match func.lib_path {
586 CometResult::Success(lib_name) => {
587 quote! {
589 comet_rs::inventory::submit!{
590 comet_rs::CometFunc{
591 name: stringify!(#name)
592 }
593
594 }
595 comet_rs::inventory::submit!{
596 comet_rs::CometLib{
597 name: #lib_name
598 }
599
600 }
601 fn #name(){
602 #sparse_env
603 unsafe {comet_rs::COMET_FUNCS.get(stringify!(#name)).unwrap().0();}
604 }
605 }
606 .into()
607 }
608 CometResult::Failure(msg) =>{
609 quote! {
610 fn #name(){
611 println!("ERROR: {}",#msg);
612 }
613 }
614 .into()
615 }
616 }
617}