Skip to main content

gaia_assembler/backends/
sass.rs

1use crate::{
2    backends::{Backend, GeneratedFiles},
3    config::GaiaConfig,
4    instruction::{DomainInstruction, GaiaInstruction},
5    program::{GaiaModule, GaiaTerminator},
6};
7use gaia_types::{
8    helpers::{AbiCompatible, ApiCompatible, Architecture, CompilationTarget},
9    neural::NeuralNode,
10    Result,
11};
12use sass_assembler::{
13    instructions::{SassInstruction, SassReg},
14    program::{SassKernel, SassProgram},
15    SassWriter,
16};
17use std::collections::HashMap;
18
19pub struct SassBackend {
20    writer: SassWriter,
21}
22
23impl SassBackend {
24    pub fn new() -> Self {
25        Self { writer: SassWriter::new() }
26    }
27}
28
29impl Backend for SassBackend {
30    fn name(&self) -> &'static str {
31        "NVIDIA SASS"
32    }
33
34    fn primary_target(&self) -> CompilationTarget {
35        CompilationTarget { build: Architecture::NvSass, host: AbiCompatible::PTX, target: ApiCompatible::Unknown }
36    }
37
38    fn match_score(&self, target: &CompilationTarget) -> f32 {
39        if target.build == Architecture::NvSass {
40            return 100.0;
41        }
42        0.0
43    }
44
45    fn generate(&self, program: &GaiaModule, _config: &GaiaConfig) -> Result<GeneratedFiles> {
46        let mut files = HashMap::new();
47
48        // 转换 GaiaModule -> SassProgram
49        let sass_program = convert_gaia_to_sass(program)?;
50
51        // 写入二进制
52        let binary = self.writer.write(&sass_program)?;
53        files.insert(format!("{}.cubin", program.name), binary);
54
55        Ok(GeneratedFiles { files, diagnostics: vec![] })
56    }
57}
58
59fn convert_gaia_to_sass(module: &GaiaModule) -> Result<SassProgram> {
60    let mut sass_program = SassProgram::new(module.name.clone());
61
62    for function in &module.functions {
63        let mut instructions = Vec::new();
64        for block in &function.blocks {
65            for inst in &block.instructions {
66                match inst {
67                    GaiaInstruction::Domain(DomainInstruction::Neural(node)) => {
68                        // 映射神经网络算子到 SASS 指令
69                        if let NeuralNode::Convolution { .. } = node {
70                            instructions.push(SassInstruction::Imma {
71                                dst: SassReg::R(0),
72                                src0: SassReg::R(1),
73                                src1: SassReg::R(2),
74                                src2: SassReg::R(0),
75                            });
76                        }
77                    }
78                    _ => {
79                        instructions.push(SassInstruction::Nop);
80                    }
81                }
82            }
83
84            match &block.terminator {
85                GaiaTerminator::Return => {
86                    instructions.push(SassInstruction::Exit);
87                }
88                _ => {}
89            }
90        }
91
92        sass_program.kernels.push(SassKernel { name: function.name.clone(), instructions });
93    }
94
95    Ok(sass_program)
96}