1use std::fs::{self, File};
2use std::io::Write;
3use std::path::{Path, PathBuf};
4use std::{env, fmt};
5
6use anyhow::{Error, Result};
7use xmas_elf::sections::{SectionData, ShType};
8use xmas_elf::symbol_table::{Binding, Visibility};
9use xmas_elf::{symbol_table, ElfFile};
10
11pub const VAR_SYMBOLS_FILE: &str = "EMBUILD_GENERATED_SYMBOLS_FILE";
12
13#[derive(Debug)]
14pub struct Symbol<'a> {
15    name: &'a str,
16    section_name: Option<&'a str>,
17    visible: bool,
18    global: bool,
19}
20
21#[derive(Debug)]
22pub struct Section {
23    pub name: String,
24    pub prefix: Option<String>,
25    pub mutable: bool,
26}
27
28impl Section {
29    pub fn new(name: impl Into<String>, prefix: Option<String>, mutable: bool) -> Self {
30        Self {
31            name: name.into(),
32            prefix,
33            mutable,
34        }
35    }
36
37    pub fn code(name: impl Into<String>) -> Self {
38        Self::new(name, None, false)
39    }
40
41    pub fn data(name: impl Into<String>) -> Self {
42        Self::new(name, None, true)
43    }
44}
45
46impl<'a> Symbol<'a> {
47    pub fn name(&self) -> &'a str {
49        self.name
50    }
51
52    pub fn section_name(&self) -> Option<&'a str> {
54        self.section_name
55    }
56
57    pub fn visible(&self) -> bool {
59        self.visible
60    }
61
62    pub fn global(&self) -> bool {
64        self.global
65    }
66
67    pub fn default_pointer_gen(&self) -> Option<RustPointer> {
68        if self.section_name().is_some() && self.global() && self.visible() {
69            let valid_identifier = self.name().char_indices().all(|(index, ch)| {
70                ch == '_' || index == 0 && ch.is_alphabetic() || index > 0 && ch.is_alphanumeric()
71            });
72
73            if valid_identifier {
74                return Some(RustPointer {
75                    name: self.name().to_owned(),
76                    mutable: true,
77                    r#type: None,
78                });
79            }
80        }
81
82        None
83    }
84
85    pub fn default_sections(&self) -> Option<RustPointer> {
86        self.sections(&[Section::data(".bss"), Section::data(".data")])
87    }
88
89    pub fn sections<'b>(
90        &'b self,
91        sections: impl IntoIterator<Item = &'b Section>,
92    ) -> Option<RustPointer> {
93        self.default_pointer_gen().and_then(move |mut pointer| {
94            sections
95                .into_iter()
96                .find(|section| self.section_name() == Some(§ion.name))
97                .map(|section| {
98                    if let Some(section_prefix) = §ion.prefix {
99                        pointer.name = format!("{}{}", section_prefix, pointer.name);
100                    }
101
102                    pointer
103                })
104        })
105    }
106}
107
108#[derive(Debug, Clone)]
109pub struct RustPointer {
110    pub name: String,
111    pub mutable: bool,
112    pub r#type: Option<String>,
113}
114
115#[allow(clippy::type_complexity)]
116pub struct Symgen {
117    elf: PathBuf,
118    start_addr: u64,
119    rust_pointer_gen: Box<dyn for<'a> Fn(&Symbol<'a>) -> Option<RustPointer>>,
120}
121
122impl Symgen {
123    pub fn new(elf: impl Into<PathBuf>, start_addr: u64) -> Self {
124        Self::new_with_pointer_gen(elf, start_addr, |symbol| symbol.default_sections())
125    }
126
127    pub fn new_with_pointer_gen(
128        elf: impl Into<PathBuf>,
129        start_addr: u64,
130        rust_pointer_gen: impl for<'a> Fn(&Symbol<'a>) -> Option<RustPointer> + 'static,
131    ) -> Self {
132        Self {
133            elf: elf.into(),
134            start_addr,
135            rust_pointer_gen: Box::new(rust_pointer_gen),
136        }
137    }
138
139    pub fn run(&self) -> Result<PathBuf> {
140        let output_file = PathBuf::from(env::var("OUT_DIR")?).join("symbols.rs");
141
142        self.run_for_file(&output_file)?;
143
144        println!(
145            "cargo:rustc-env={}={}",
146            VAR_SYMBOLS_FILE,
147            output_file.display()
148        );
149
150        Ok(output_file)
151    }
152
153    pub fn run_for_file(&self, output_file: impl AsRef<Path>) -> Result<()> {
154        let output_file = output_file.as_ref();
155
156        eprintln!("Output: {output_file:?}");
157
158        self.write(&mut File::create(output_file)?)
159    }
160
161    pub fn write(&self, output: &mut impl Write) -> Result<()> {
162        eprintln!("Input: {:?}", self.elf);
163
164        let elf_data = fs::read(&self.elf)?;
165        let elf = ElfFile::new(&elf_data).map_err(Error::msg)?;
166
167        for symtable in self.get_symtables(&elf) {
168            match symtable.1 {
169                SectionData::SymbolTable32(entries) => {
170                    self.write_symbols(&elf, symtable.0, entries.iter().enumerate(), output)?
171                }
172                SectionData::SymbolTable64(entries) => {
173                    self.write_symbols(&elf, symtable.0, entries.iter().enumerate(), output)?
174                }
175                _ => unimplemented!(),
176            }
177        }
178
179        Ok(())
180    }
181
182    fn write_symbols<'a, W: Write>(
183        &self,
184        elf: &'a ElfFile<'a>,
185        symtable_index: usize,
186        symbols: impl Iterator<Item = (usize, &'a (impl symbol_table::Entry + fmt::Debug + 'a))>,
187        output: &mut W,
188    ) -> Result<()> {
189        for (_index, sym) in symbols {
190            eprintln!("Found symbol: {sym:?}");
191
192            let sym_type = sym.get_type().map_err(Error::msg)?;
193
194            if sym_type == symbol_table::Type::Object || sym_type == symbol_table::Type::NoType {
195                let name = sym.get_name(elf).map_err(Error::msg)?;
196
197                let section_name = sym
198                    .get_section_header(elf, symtable_index)
199                    .and_then(|sh| sh.get_name(elf))
200                    .ok();
201
202                let global = sym.get_binding().map_err(Error::msg)? == Binding::Global;
203                let visible = matches!(sym.get_other(), Visibility::Default);
204
205                let symbol = Symbol {
206                    name,
207                    section_name,
208                    global,
209                    visible,
210                };
211
212                let pointer = (self.rust_pointer_gen)(&symbol);
213
214                if let Some(pointer) = pointer {
215                    eprintln!("Writing symbol: {name} [{symbol:?}] as [{pointer:?}]");
216                    write!(
217                        output,
218                        "#[allow(dead_code, non_upper_case_globals)]\npub const {name}: *{mutable} {typ} = 0x{addr:x} as *{mutable} {typ};\n",
219                        name = pointer.name,
220                        mutable = if pointer.mutable { "mut" } else {"const" },
221                        typ = pointer.r#type.unwrap_or_else(|| "core::ffi::c_void".to_owned()),
222                        addr = self.start_addr + sym.value()
223                    )?;
224                } else {
225                    eprintln!("Skipping symbol: {name} [{sym:?}]");
226                }
227            }
228        }
229
230        Ok(())
231    }
232
233    fn get_symtables<'a, 'b>(
234        &self,
235        elf: &'b ElfFile<'a>,
236    ) -> impl Iterator<Item = (usize, SectionData<'a>)> + 'b {
237        elf.section_iter()
238            .enumerate()
239            .filter(|(_, header)| header.get_type().map_err(Error::msg).unwrap() == ShType::SymTab)
240            .map(move |(index, header)| (index, header.get_data(elf).unwrap()))
241    }
242}