comet_rs_impl/
lib.rs

1extern crate proc_macro;
2use proc_macro::TokenStream;
3use proc_macro_error::{proc_macro_error,abort};
4use quote::quote;
5use std::collections::HashMap;
6use syn::{braced, Attribute, Ident, Stmt, Token};
7
8mod index;
9use index::IndexStruct;
10
11mod tensor;
12use tensor::*;
13
14mod scalar;
15use scalar::ScalarStruct;
16
17mod cometexpr;
18use cometexpr::*;
19
20use std::env;
21use std::path::PathBuf;
22use std::process::Command;
23
24enum CometResult{
25    Success(String),
26    Failure(String),    
27}
28
29fn create_lib(func_name: &str, mlir_str: &str, comet_opts: Vec<CometOption>, mlir_opts: Vec<MlirOption>) -> CometResult{
30    // println!("{}",env::var("CARGO_CRATE_NAME").unwrap());
31    // for (key,value) in env::vars() {
32    //     println!("{}={}",key,value);
33    // }
34
35    let comet_base = if let Ok(base) = env::var("COMET_DIR") {
36        PathBuf::from(base)
37    } else {
38        PathBuf::from(env::var("CARGO_MANIFEST_DIR").unwrap()).join("../../..")
39    };
40
41    let comet_bin = if let Ok(dir) = env::var("COMET_BIN_DIR"){
42        PathBuf::from(dir).join("comet-opt")
43    }
44    else{
45        comet_base.clone().join("build/bin/comet-opt")
46    };
47    if !comet_bin.exists(){
48        panic!("Cannot find comet-opt at {}\n Please set the COMET_BIN_DIR envrionement variable to the directory containing the comet-opt binary", comet_bin.display())
49    }
50
51    let comet_lib = if let Ok(dir) = env::var("COMET_LIB_DIR"){
52        PathBuf::from(dir)
53    }
54    else{
55        comet_base.clone().join("build/lib")
56    };
57    if !comet_lib.exists(){
58        panic!("Cannot find comet lib directory at {}\n Please set the COMET_LIB_DIR envrionement variable to the directory containing the comet shared libraries", comet_lib.display())
59    }
60
61    let mlir_bin = if let Ok(dir) = env::var("MLIR_BIN_DIR"){
62        PathBuf::from(dir)
63    }
64    else{
65        comet_base.clone().join("llvm/build/bin")
66    };
67    if !mlir_bin.exists(){
68        panic!("Cannot find mlir-opt at {}\n Please set the MLIR_OPT_DIR envrionement variable to the directory containing the mlir-opt binary", mlir_bin.display())
69    }
70
71    let mlir_lib = if let Ok(dir) = env::var("MLIR_LIB_DIR"){
72        PathBuf::from(dir)
73    }
74    else{
75        comet_base.clone().join("llvm/build/lib")
76    };
77    if !mlir_lib.exists(){
78        panic!("Cannot find mlir lib directory at {}\n Please set the MLIR_LIB_DIR envrionement variable to the directory containing the mlir shared libraries", mlir_lib.display())
79    }
80
81
82    let base_libpath = PathBuf::from(env::var("CARGO_MANIFEST_DIR").unwrap())
83        .join("comet_libs")
84        .join(env::var("CARGO_CRATE_NAME").unwrap());
85    std::fs::create_dir_all(base_libpath.clone()).expect("Could not create directory");
86    let base_filepath = base_libpath.clone().join(func_name);
87
88    let comet_mlir_file = base_filepath.clone().with_extension("comet.mlir");
89    std::fs::write(&comet_mlir_file, mlir_str).expect("Unable to write file");
90    // println!("{}", comet_mlir_file.to_str().unwrap());
91
92    let mut cmd_str = String::new();
93    let mut cmd = Command::new(comet_bin);
94    for opt in comet_opts {
95        cmd_str += opt.as_str();
96        cmd.arg(opt.as_str());
97    }
98    let output = cmd
99        .arg(comet_mlir_file.to_str().unwrap())
100        .output()
101        .expect("Could not run comet-opt");
102    
103
104
105    if !output.status.success() {
106        if cfg!(feature = "comet_errors_as_warnings") {
107            println!("ERROR: failed to convert from rs to mlir {:?} option str {}", String::from_utf8_lossy(&output.stderr),cmd_str);
108            return CometResult::Failure(format!("ERROR: failed to convert from rs to mlir {:?} option str {}", String::from_utf8_lossy(&output.stderr),cmd_str));
109        }
110        else  {
111            panic!("ERROR: failed to convert from rs to mlir {:?} option str {}", String::from_utf8_lossy(&output.stderr),cmd_str);
112        }
113    }
114    let mlir_file = base_filepath.clone().with_extension("mlir");
115    std::fs::write(&mlir_file, &output.stderr).expect("Unable to write file");
116
117    let mut cmd = Command::new(mlir_bin.clone().join("mlir-opt"));
118    for opt in mlir_opts {
119        cmd.arg(opt.as_str());
120    }
121    let output = cmd
122        .arg(mlir_file.to_str().unwrap())
123        .output()
124        .expect("Could not run mlir-opt");
125
126   
127    if !output.status.success() {
128        if cfg!(feature = "comet_errors_as_warnings") {
129            println!("ERROR: failed to convert from mlir to llvm {:?}", String::from_utf8_lossy(&output.stderr));
130            return CometResult::Failure(String::from_utf8_lossy(&output.stderr).to_string());
131        }
132        else {
133            println!("here");
134            panic!("ERROR: failed to convert from mlir to llvm {:?}", String::from_utf8_lossy(&output.stderr));
135        }
136    }
137    // assert!(
138    //     output.status.success(),
139    //     "{:?}",
140    //     String::from_utf8_lossy(&output.stderr)
141    // );
142    let llvm_file = base_filepath.clone().with_extension("llvm");
143    std::fs::write(&llvm_file, &output.stdout).expect("Unable to write file");
144
145    let bc_file = base_filepath.clone().with_extension("bc");
146    let _status = Command::new(mlir_bin.clone().join("mlir-translate"))
147        .arg("--mlir-to-llvmir")
148        .arg(llvm_file.to_str().unwrap())
149        .arg("-o")
150        .arg(bc_file.to_str().unwrap())
151        .status()
152        .expect("Could not run mlir-translate");
153
154    let obj_file = base_filepath.clone().with_extension("comet.o");
155    let _status = Command::new(mlir_bin.clone().join("llc"))
156        .arg(bc_file.to_str().unwrap())
157        .arg("-o")
158        .arg(obj_file.to_str().unwrap())
159        .arg("-filetype=obj")
160        .status()
161        .expect("Could not run llc");
162
163    let so_file = base_filepath.clone().with_extension("so");
164    // let llvm_lib = comet_base.clone().join("llvm/build/lib").into_os_string().into_string().unwrap();
165    // let comet_lib = comet_base.clone().join("build/lib").into_os_string().into_string().unwrap();
166    let _status = Command::new("gcc")
167        .arg(format!{"-Wl,-rpath,{}", mlir_lib.display()})
168        .arg(format!{"-L{}", mlir_lib.display()})
169        .arg("-lmlir_runner_utils")
170        .arg(format!{"-Wl,-rpath,{}",comet_lib.display()})
171        .arg(format!{"-L{}",comet_lib.display()})
172        .arg("-lcomet_runner_utils")
173        .arg("-shared")
174        .arg("-o")
175        .arg(so_file.to_str().unwrap())
176        .arg(obj_file.to_str().unwrap())
177        .status()
178        .expect("Could not run gcc");
179
180
181    CometResult::Success(so_file.to_str().unwrap().to_string())
182}
183
184use syn::parse_macro_input;
185use syn::parse::discouraged::Speculative;
186use syn::parse::{Parse, ParseStream};
187
188#[derive(Debug)]
189enum Comet {
190    Index(Vec<IndexStruct>),
191    Tensor(TensorStruct),
192    Scalar(ScalarStruct),
193    Expr(CometExpr),
194}
195
196#[derive(Debug)]
197pub(crate) struct CometVars {
198    indices: HashMap<Ident, IndexStruct>,
199    tensors: HashMap<Ident, TensorStruct>,
200    scalars: HashMap<Ident, ScalarStruct>,
201    custom_ops: HashMap<String, String>,
202}
203
204impl Comet {
205    fn my_parse(
206        input: ParseStream,
207        vars: &mut CometVars,
208        object_id: &mut usize,
209    ) -> syn::Result<Self> {
210        let mut errs = vec![];
211        match IndexStruct::my_parse(input, vars, object_id) {
212            Ok(indices) => {
213                return Ok(Comet::Index(indices));
214            }
215            Err(e) => errs.push(e),
216        }
217        match TensorStruct::my_parse(input, vars, object_id) {
218            Ok(tensor) => {
219                return Ok(Comet::Tensor(tensor));
220            }
221            Err(e) => errs.push(e),
222        }
223
224        match ScalarStruct::my_parse(input, vars, object_id) {
225            Ok(scalar) => {
226                return Ok(Comet::Scalar(scalar));
227            }
228            Err(e) => errs.push(e),
229        }
230        match CometExpr::my_parse(input, vars, object_id) {
231            Ok(expr) => {
232                return Ok(Comet::Expr(expr));
233            }
234            Err(e) => errs.push(e),
235        }
236        if errs.len() >= 1 {
237            return Err(errs.remove(0));
238        } else {
239            return Err(syn::Error::new(
240                input.span(),
241                "Could not parse comet expression",
242            )); //this should never happen
243        }
244    }
245}
246
247#[derive(Debug, Clone, PartialEq, Eq)]
248enum CometOption{
249    TcToTtgt,
250    ToLoops,
251    TaToIt,
252    BestPermTtgt,
253    MatMulTiling,
254    MatMulkernel,
255    DenseTranspose,
256    CompWorkspace
257}
258
259impl CometOption{
260    fn as_str(&self) -> &'static str{
261        match self{
262            CometOption::TcToTtgt =>"--convert-tc-to-ttgt",
263            CometOption::ToLoops =>"--convert-to-loops",
264            CometOption::TaToIt =>"--convert-ta-to-it",
265            CometOption::BestPermTtgt =>"-opt-bestperm-ttgt",
266            CometOption::MatMulTiling =>"-opt-matmul-tiling",
267            CometOption::MatMulkernel =>"-opt-matmul-mkernel",
268            CometOption::DenseTranspose =>"-opt-dense-transpose",
269            CometOption::CompWorkspace =>"-opt-comp-workspace",
270        }
271    }
272}
273
274impl CometOption {
275    fn my_parse(input: ParseStream) -> syn::Result<Vec<Self>> {
276        let mut options = vec![];
277        let fork = input.fork();
278        if let Ok(comet_opt) = fork.parse::<Ident>(){
279            if comet_opt.to_string().as_str() == "CometOption"{
280                fork.parse::<Token![::]>()?;
281                let content;
282                syn::bracketed!(content in fork);
283                while !content.is_empty() {
284                    if let Ok(opt) = content.parse::<Ident>(){                
285                        match opt.to_string().as_str() {
286                            "TcToTtgt" => {
287                                if !options.contains(&CometOption::TcToTtgt){
288                                    options.push(CometOption::TcToTtgt);
289                                }
290                            },
291                            "ToLoops" => {
292                                if !options.contains(&CometOption::ToLoops){
293                                    options.push(CometOption::ToLoops);
294                                }
295                            },
296                            "TaToIt" => {
297                                if !options.contains(&CometOption::TaToIt){
298                                    options.push(CometOption::TaToIt);
299                                }
300                            },
301                            "BestPermTtgt" => {
302                                if !options.contains(&CometOption::BestPermTtgt){
303                                    options.push(CometOption::BestPermTtgt);
304                                }
305                            },
306                            "MatMulTiling" => {
307                                if !options.contains(&CometOption::MatMulTiling){
308                                    options.push(CometOption::MatMulTiling);
309                                }
310                            },
311                            "MatMulkernel" => {
312                                if !options.contains(&CometOption::MatMulkernel){
313                                    options.push(CometOption::MatMulkernel);
314                                }
315                            },
316                            "DenseTranspose" => {
317                                if !options.contains(&CometOption::DenseTranspose){
318                                    options.push(CometOption::DenseTranspose);
319                                }
320                            },
321                            "CompWorkspace" => {
322                                if !options.contains(&CometOption::CompWorkspace){
323                                    options.push(CometOption::CompWorkspace);
324                                }
325                            },                        
326                            _ => abort!(opt.span(),"Unknown CometOption"),
327                        }
328                        if content.peek(Token![,]){
329                            content.parse::<Token![,]>()?;
330                        }
331                    }
332                    else{
333                        abort!(content.span(),"Unknown CometOption");
334                    }
335                }
336                
337            }
338            else{
339                
340                return Err(syn::Error::new(
341                    comet_opt.span(),
342                    "expected CometOption",
343                ))
344            }
345        }
346        else{
347            return Err(syn::Error::new(
348                input.span(),
349                "expected CometOption",
350            ))
351        }
352        input.advance_to(&fork);
353        Ok(options)
354    }
355}
356
357#[derive(Debug, Clone, PartialEq, Eq)]
358enum MlirOption{
359    ConvertLinalgToLoops,
360    ConvertLinalgToStd,
361    ConvertLinalgToLlvm,
362    ToLoops,
363    ConvertScfToStd,
364    ConvertStdToLlvm,
365    LowerAffine,    
366}
367
368impl MlirOption{
369    fn as_str(&self) -> &'static str{
370        match self{
371            MlirOption::ConvertLinalgToLoops =>"--convert-linalg-to-loops",
372            MlirOption::ConvertLinalgToStd =>"--convert-linalg-to-std",
373            MlirOption::ConvertLinalgToLlvm =>"--convert-linalg-to-llvm",
374            MlirOption::ToLoops =>"--convert-to-loops",
375            MlirOption::ConvertScfToStd =>"--convert-scf-to-std",
376            MlirOption::ConvertStdToLlvm =>"--convert-std-to-llvm",
377            MlirOption::LowerAffine =>"--lower-affine",
378        }
379    }
380}
381
382impl MlirOption {
383    fn my_parse(input: ParseStream) -> syn::Result<Vec<Self>> {
384        let mut options = vec![];
385        let fork = input.fork();
386        if let Ok(mlir_opt) = fork.parse::<Ident>(){
387            if mlir_opt.to_string().as_str() == "MlirOption"{
388                fork.parse::<Token![::]>()?;
389                let content;
390                syn::bracketed!(content in fork);
391                while !content.is_empty() {
392                    if let Ok(opt) = content.parse::<Ident>(){                
393                        match opt.to_string().as_str() {
394                            "ConvertLinalgToLoops" => {
395                                if !options.contains(&MlirOption::ConvertLinalgToLoops){
396                                    options.push(MlirOption::ConvertLinalgToLoops);
397                                }
398                            },
399                            "ConvertLinalgToStd" => {
400                                if !options.contains(&MlirOption::ConvertLinalgToStd){
401                                    options.push(MlirOption::ConvertLinalgToStd);
402                                }
403                            },
404                            "ConvertLinalgToLlvm" => {
405                                if !options.contains(&MlirOption::ConvertLinalgToLlvm){
406                                    options.push(MlirOption::ConvertLinalgToLlvm);
407                                }
408                            },
409                            "ToLoops" => {
410                                if !options.contains(&MlirOption::ToLoops){
411                                    options.push(MlirOption::ToLoops);
412                                }
413                            },
414                            "ConvertScfToStd" => {
415                                if !options.contains(&MlirOption::ConvertScfToStd){
416                                    options.push(MlirOption::ConvertScfToStd);
417                                }
418                            },
419                            "ConvertStdToLlvm" => {
420                                if !options.contains(&MlirOption::ConvertStdToLlvm){
421                                    options.push(MlirOption::ConvertStdToLlvm);
422                                }
423                            },
424                            "LowerAffine" => {
425                                if !options.contains(&MlirOption::LowerAffine){
426                                    options.push(MlirOption::LowerAffine);
427                                }
428                            },
429                            _ => abort!(opt.span(),"Unknown MlirOption"),
430                        }
431                        if content.peek(Token![,]){
432                            content.parse::<Token![,]>()?;
433                        }
434                    }
435                    else{
436                        abort!(content.span(),"Unknown MlirOption");
437                    }
438                }
439            }
440            else{
441                return Err(syn::Error::new(
442                    mlir_opt.span(),
443                    "expected MlirOption",
444                ))
445            }
446        }
447        else{
448            return Err(syn::Error::new(
449                input.span(),
450                "expected MlirOption",
451            ))
452        }
453        input.advance_to(&fork);
454        Ok(options)
455    }
456}
457
458pub(crate) struct CometFn {
459    pub name: Ident,
460    lib_path: CometResult,
461    sparse_env: proc_macro2::TokenStream,
462}
463
464impl Parse for CometFn {
465    fn parse(input: ParseStream) -> syn::Result<Self> {
466        let name = input.parse::<Ident>()?;
467        input.parse::<Token![,]>()?;
468        let content;
469        let _brace_token = braced!(content in input);
470        let _inner_attrs = content.call(Attribute::parse_inner)?;
471        let mut vars = CometVars {
472            indices: HashMap::new(), //index name, mlir_id
473            tensors: HashMap::new(), //tensor name, mlir_id
474            scalars: HashMap::new(), //scalar name, mlir_id
475            custom_ops: HashMap::new(), //custom op name, op str
476        };
477        let mut object_id = 0;
478
479        let mut mlir_str = format!("module  {{\nfunc @{}() {{\n", name);
480        while !content.is_empty() {
481            let fork = content.fork();
482            match Comet::my_parse(&fork, &mut vars, &mut object_id) {
483                Ok(comet) => {
484                    content.advance_to(&fork);
485                    match comet {
486                        Comet::Index(indices) => {
487                            // println!("index {:?}", index);
488                            for index in indices {
489                                mlir_str += &index.emit_mlir();
490                            }                            
491                        }
492                        Comet::Tensor(tensor) => {
493                            // println!("tensor {:?}", tensor);
494                            mlir_str += &tensor.emit_mlir();
495                        }
496                        Comet::Scalar(mut scalar) => {
497                            // println!("scalar {:?}", scalar);
498                            mlir_str += &scalar.emit_mlir();
499                        }
500                        Comet::Expr(mut expr) => {
501                            // println!("expr {:?}", expr);
502                            mlir_str += &expr.emit_mlir()?;
503                        }
504                    }
505                }
506                Err(_) => {
507                    let stmt = content.parse::<Stmt>()?;
508                    println!("not a comet stmt{:?}", stmt);
509                }
510            }
511        }
512        
513        mlir_str += "\"ta.return\"(): () -> ()\n}\n}\n\n";
514        // println!("{}", mlir_str);
515        let mut opt_comp_workspace = true;
516        // let mut sparse_env = String::new();
517        let mut sparse_env = quote!{};
518        for tensor in vars.tensors.values() {
519            // println!("{} {}",tensor.format,tensor.format != TensorFormat::Csr);
520            if tensor.format != TensorFormat::Csr {
521                opt_comp_workspace = false;
522            }
523
524            if let TensorFill::FillFromFile(val,env) = &tensor.fill {
525                // let stmt = syn::parse_str( &format!("std::env::set_var({},{});\n",env,val.value())).unwrap();
526                let p = val.value();
527                sparse_env.extend(quote!{std::env::set_var( #env, #p );});
528                // sparse_env = env.clone();
529                // sparse_env +=;
530            }
531        }
532        if input.peek(Token![,]){
533            input.parse::<Token![,]>()?;
534        }
535        let mut comet_opts = match CometOption::my_parse(&input){
536            Ok(opts) => {
537                if input.peek(Token![,]){
538                    input.parse::<Token![,]>()?;
539                }
540                opts
541            }
542            Err(_) => {
543                vec![CometOption::TaToIt,CometOption::ToLoops]
544            }
545        };
546        let mlir_opts = match MlirOption::my_parse(&input){
547            Ok(opts) => {
548                if input.peek(Token![,]){
549                    input.parse::<Token![,]>()?;
550                }
551                opts
552            }
553            Err(_) => {
554                vec![MlirOption::ConvertScfToStd,MlirOption::ConvertStdToLlvm]
555            }
556        };
557       
558        if opt_comp_workspace{
559            if !comet_opts.contains(&CometOption::CompWorkspace){
560                comet_opts.insert(0,CometOption::CompWorkspace);
561            }
562        }
563        // println!("{} {} {:?}",name,opt_comp_workspace,comet_opts);
564        let lib_path = create_lib(&name.clone().to_string(), &mlir_str,comet_opts,mlir_opts);
565
566        // println!("sparse_env: {}", sparse_env);
567        Ok(CometFn {
568            name,
569            lib_path,
570            sparse_env,
571        })
572    }
573}
574
575
576#[proc_macro_error]
577#[proc_macro]
578pub fn comet_fn(input: TokenStream) -> TokenStream {
579    let func = parse_macro_input!(input as CometFn);
580
581    let name = func.name;
582    let sparse_env = func.sparse_env;
583    // let sparse_env: Vect<syn::Stmt> = Vec!;
584    
585    match func.lib_path {
586        CometResult::Success(lib_name) => {
587            // println!("lib_name: {}", lib_name);
588            quote! {
589                comet_rs::inventory::submit!{
590                    comet_rs::CometFunc{
591                        name: stringify!(#name)
592                    }
593
594                }
595                comet_rs::inventory::submit!{
596                    comet_rs::CometLib{
597                        name: #lib_name
598                    }
599
600                }
601                fn #name(){
602                    #sparse_env
603                    unsafe {comet_rs::COMET_FUNCS.get(stringify!(#name)).unwrap().0();}
604                }
605            }
606            .into()
607        }
608        CometResult::Failure(msg) =>{
609            quote! {
610                fn #name(){
611                    println!("ERROR: {}",#msg);
612                }
613            }
614            .into()
615        }
616    }
617}