gaia_assembler/backends/
sass.rs1use crate::{
2 backends::{Backend, GeneratedFiles},
3 config::GaiaConfig,
4 instruction::{DomainInstruction, GaiaInstruction},
5 program::{GaiaModule, GaiaTerminator},
6};
7use gaia_types::{
8 helpers::{AbiCompatible, ApiCompatible, Architecture, CompilationTarget},
9 neural::NeuralNode,
10 Result,
11};
12use sass_assembler::{
13 instructions::{SassInstruction, SassReg},
14 program::{SassKernel, SassProgram},
15 SassWriter,
16};
17use std::collections::HashMap;
18
19pub struct SassBackend {
20 writer: SassWriter,
21}
22
23impl SassBackend {
24 pub fn new() -> Self {
25 Self { writer: SassWriter::new() }
26 }
27}
28
29impl Backend for SassBackend {
30 fn name(&self) -> &'static str {
31 "NVIDIA SASS"
32 }
33
34 fn primary_target(&self) -> CompilationTarget {
35 CompilationTarget { build: Architecture::NvSass, host: AbiCompatible::PTX, target: ApiCompatible::Unknown }
36 }
37
38 fn match_score(&self, target: &CompilationTarget) -> f32 {
39 if target.build == Architecture::NvSass {
40 return 100.0;
41 }
42 0.0
43 }
44
45 fn generate(&self, program: &GaiaModule, _config: &GaiaConfig) -> Result<GeneratedFiles> {
46 let mut files = HashMap::new();
47
48 let sass_program = convert_gaia_to_sass(program)?;
50
51 let binary = self.writer.write(&sass_program)?;
53 files.insert(format!("{}.cubin", program.name), binary);
54
55 Ok(GeneratedFiles { files, diagnostics: vec![] })
56 }
57}
58
59fn convert_gaia_to_sass(module: &GaiaModule) -> Result<SassProgram> {
60 let mut sass_program = SassProgram::new(module.name.clone());
61
62 for function in &module.functions {
63 let mut instructions = Vec::new();
64 for block in &function.blocks {
65 for inst in &block.instructions {
66 match inst {
67 GaiaInstruction::Domain(DomainInstruction::Neural(node)) => {
68 if let NeuralNode::Convolution { .. } = node {
70 instructions.push(SassInstruction::Imma {
71 dst: SassReg::R(0),
72 src0: SassReg::R(1),
73 src1: SassReg::R(2),
74 src2: SassReg::R(0),
75 });
76 }
77 }
78 _ => {
79 instructions.push(SassInstruction::Nop);
80 }
81 }
82 }
83
84 match &block.terminator {
85 GaiaTerminator::Return => {
86 instructions.push(SassInstruction::Exit);
87 }
88 _ => {}
89 }
90 }
91
92 sass_program.kernels.push(SassKernel { name: function.name.clone(), instructions });
93 }
94
95 Ok(sass_program)
96}