import os
import argparse
import itertools
import subprocess
import multiprocessing
import xml.etree.ElementTree as ET
from fnmatch import fnmatch
CRATE_LIB_PREAMBLE = """\
// Copyright 2018 Adam Greig
// See LICENSE-APACHE and LICENSE-MIT for license details.
//! This project provides a register access layer (RAL) for all
//! STM32 microcontrollers.
//!
//! When built, you must specify a device feature, such as `stm32f405`.
//! This will cause all modules in that device's module to be re-exported
//! from the top level, so that for example `stm32ral::gpio` will resolve to
//! `stm32ral::stm32f4::stm32f405::gpio`.
//!
//! In the generated documentation, all devices are visible inside their family
//! modules, but when built for a specific device, only that devices' constants
//! will be available.
//!
//! See the
//! [README](https://github.com/adamgreig/stm32ral/blob/master/README.md)
//! for example usage.
#![no_std]
#[cfg(not(feature="nosync"))]
extern crate cortex_m as external_cortex_m;
#[cfg(feature="rt")]
extern crate cortex_m_rt;
#[macro_use]
mod register;
#[cfg(feature="rt")]
pub use cortex_m_rt::interrupt;
pub use register::{RORegister, WORegister, RWRegister};
pub use register::{UnsafeRORegister, UnsafeRWRegister, UnsafeWORegister};
"""
CRATE_CARGO_TOML_PREAMBLE = """\
[package]
name = "stm32ral"
version = "0.1.0"
authors = ["Adam Greig <adam@adamgreig.com>"]
description = "Register access layer for all STM32 microcontrollers"
repository = "https://github.com/adamgreig/stm32ral"
documentation = "https://docs.rs/stm32ral"
readme = "README.md"
keywords = ["stm32", "embedded", "no_std"]
categories = ["embedded", "no-std"]
license = "MIT/Apache-2.0"
[package.metadata.docs.rs]
features = ["doc"]
no-default-features = true
[dependencies]
bare-metal = "0.2.4"
cortex-m = "0.5.8"
[dependencies.cortex-m-rt]
optional = true
version = "0.6.5"
[features]
rt = ["cortex-m-rt/device"]
inline-asm = ["cortex-m/inline-asm"]
default = []
nosync = []
doc = []
"""
BUILD_RS_TEMPLATE = """\
use std::env;
use std::fs;
use std::path::PathBuf;
fn main() {{
if env::var_os("CARGO_FEATURE_RT").is_some() {{
let out = &PathBuf::from(env::var_os("OUT_DIR").unwrap());
println!("cargo:rustc-link-search={{}}", out.display());
let device_file = {device_clauses};
fs::copy(device_file, out.join("device.x")).unwrap();
println!("cargo:rerun-if-changed={{}}", device_file);
}}
println!("cargo:rerun-if-changed=build.rs");
}}
"""
UNSAFE_REGISTERS = [
"S?PAR", "S?M?AR", "CPAR?", "CMAR?",
"FGMAR", "BGMAR", "FGCMAR", "BGCMAR", "OMAR",
"L?CFBAR",
"DIEPDMA*", "DOEPDMA*", "HCDMA*",
"DMARDLAR", "DMATDLAR",
"ICIALLU", "?C?MVA?", "DC?SW", "DCCIMVAC", "DCCISW", "BPIALL",
]
class Node:
pass
class EnumeratedValue(Node):
def __init__(self, name, desc, value):
self.name = name
self.desc = desc
self.value = value
def to_dict(self):
return {"name": self.name, "desc": self.desc, "value": self.value}
def to_rust(self, field_width):
return f"""
/// 0b{self.value:0{field_width}b}: {escape_desc(self.desc)}
pub const {self.name}: u32 = 0b{self.value:0{field_width}b};"""
@classmethod
def from_svd(cls, svd, node):
name = get_string(node, 'name')
desc = get_string(node, 'description')
value = get_int(node, 'value')
return cls(name, desc, value)
def __eq__(self, other):
return (
self.name == other.name and
self.value == other.value and
self.desc == other.desc)
def __lt__(self, other):
return self.value < other.value
class EnumeratedValues(Node):
def __init__(self, name):
self.name = name
self.values = []
def to_dict(self):
return {"name": self.name,
"values": [v.to_dict() for v in self.values]}
def to_rust(self, field_width):
values = "\n".join(v.to_rust(field_width) for v in self.values)
if self.name == "R":
desc = "Read-only values"
elif self.name == "W":
desc = "Write-only values"
else:
desc = "Read-write values"
if not values:
desc += " (empty)"
return f"""\
/// {desc}
pub mod {self.name} {{
{values}
}}"""
@classmethod
def from_svd(cls, svd, node):
usage = get_string(node, 'usage')
if usage == "read":
name = "R"
elif usage == "write":
name = "W"
else:
name = "RW"
evs = cls(name)
for ev in node.findall('enumeratedValue'):
evs.values.append(EnumeratedValue.from_svd(svd, ev))
return evs
@classmethod
def empty(cls, name):
return cls(name)
def __eq__(self, other):
return (
self.name == other.name and
len(self.values) == len(other.values) and
all(v1 == v2 for v1, v2
in zip(sorted(self.values), sorted(other.values))))
class EnumeratedValuesLink(Node):
def __init__(self, field, evs):
self.field = field
self.evs = evs
def to_dict(self):
return {"field": self.field.name, "evs": self.evs.name}
def to_rust(self, field_width):
return f"pub use ::super::{self.field.name}::{self.evs.name};"
def __eq__(self, other):
return self.evs.__eq__(other)
@property
def name(self):
return self.evs.name
@property
def values(self):
return self.evs.values
class Field(Node):
def __init__(self, name, desc, width, offset, access, r, w, rw):
self.name = name
self.desc = desc
self.width = width
self.offset = offset
self.access = access
self.r = r
self.w = w
self.rw = rw
def to_dict(self):
return {"name": self.name, "desc": self.desc, "width": self.width,
"offset": self.offset, "access": self.access,
"r": self.r.to_dict(), "w": self.w.to_dict(),
"rw": self.rw.to_dict()}
def to_rust(self):
mask = 2**self.width - 1
if self.width == 1:
mask = "1"
elif self.width < 6:
mask = f"0b{mask:b}"
else:
mask = f"0x{mask:x}"
bits = f"bit{'s' if self.width>1 else ''}"
return f"""
/// {escape_desc(self.desc)}
pub mod {self.name} {{
/// Offset ({self.offset} bits)
pub const offset: u32 = {self.offset};
/// Mask ({self.width} {bits}: {mask} << {self.offset})
pub const mask: u32 = {mask} << offset;
{self.r.to_rust(self.width)}
{self.w.to_rust(self.width)}
{self.rw.to_rust(self.width)}
}}"""
@classmethod
def from_svd(cls, svd, node, ctx):
ctx = ctx.inherit(node)
name = get_string(node, 'name')
desc = get_string(node, 'description')
width = get_int(node, 'bitWidth')
offset = get_int(node, 'bitOffset')
access = ctx.access
r = EnumeratedValues.empty("R")
w = EnumeratedValues.empty("W")
rw = EnumeratedValues.empty("RW")
for evs in node.findall('enumeratedValues'):
if 'derivedFrom' in evs.attrib:
df = evs.attrib['derivedFrom']
evs = svd.find(f".//enumeratedValues[name='{df}']")
if evs is None:
raise ValueError(f"Can't find derivedFrom {df}")
evs = EnumeratedValues.from_svd(svd, evs)
evsname = evs.name
if evsname == "R":
r = evs
elif evsname == "W":
w = evs
else:
rw = evs
field = cls(name, desc, width, offset, access, r, w, rw)
return field
def __eq__(self, other):
return (
self.name == other.name and
self.width == other.width and
self.offset == other.offset and
self.access == other.access and
self.r == other.r and self.w == other.w and self.rw == other.rw)
def __lt__(self, other):
return (self.offset, self.name) < (other.offset, other.name)
class FieldLink(Node):
def __init__(self, parent, path):
self.parent = parent
self.path = path
self.r = parent.r
self.w = parent.w
self.rw = parent.rw
def to_dict(self):
return {"parent": self.parent.name, "path": self.path}
def to_rust(self):
return f"pub use {self.path}::{self.parent.name};"
def __lt__(self, other):
return self.parent.__lt__(other)
def __eq__(self, other):
return self.parent.__eq__(other)
@property
def name(self):
return self.parent.name
@property
def desc(self):
return self.parent.desc
@property
def width(self):
return self.parent.width
@property
def offset(self):
return self.parent.offset
@property
def access(self):
return self.parent.access
class RegisterCtx:
def __init__(self, size, access, reset_value, reset_mask):
self.size = size
self.access = access
self.reset_value = reset_value
self.reset_mask = reset_mask
@classmethod
def empty(cls):
return cls(None, None, None, None)
def copy(self):
return RegisterCtx(self.size, self.access, self.reset_value,
self.reset_mask)
def update_from_node(self, node):
size = get_int(node, 'size')
access = get_string(node, 'access')
reset_value = get_int(node, 'resetValue')
reset_mask = get_int(node, 'resetMask')
if size is not None:
self.size = size
if access is not None:
self.access = access
if reset_value is not None:
self.reset_value = reset_value
if reset_mask is not None:
self.reset_mask = reset_mask
return self
def inherit(self, node):
return self.copy().update_from_node(node)
class Register(Node):
def __init__(self, name, desc, offset, size, access, reset_value,
reset_mask):
self.name = name
self.desc = desc
self.offset = offset
self.size = size
self.access = access
self.reset_value = reset_value
self.reset_mask = reset_mask
self.fields = []
def to_dict(self):
return {"name": self.name, "desc": self.desc, "offset": self.offset,
"size": self.size, "access": self.access,
"reset_value": self.reset_value, "reset_mask": self.reset_mask,
"fields": [x.to_dict() for x in self.fields]}
def to_rust_mod(self):
fields = "\n".join(f.to_rust() for f in self.fields)
return f"""
/// {escape_desc(self.desc)}
pub mod {self.name} {{
{fields}
}}"""
def to_regtype(self):
regtype = {"read-only": "RORegister", "write-only": "WORegister",
"read-write": "RWRegister"}[self.access]
for unsafe in UNSAFE_REGISTERS:
if fnmatch(self.name, unsafe):
regtype = "Unsafe" + regtype
break
return regtype
def to_rust_struct_entry(self):
regtype = self.to_regtype()
return f"""
/// {escape_desc(self.desc)}
pub {self.name}: {regtype}<u{self.size}>,
"""
@classmethod
def from_svd(cls, svd, node, ctx):
ctx = ctx.inherit(node)
name = get_string(node, 'name')
desc = get_string(node, 'description')
offset = get_int(node, 'addressOffset')
register = cls(name, desc, offset, ctx.size, ctx.access,
ctx.reset_value, ctx.reset_mask)
fields = node.find('fields')
if fields is not None:
for field in fields.findall('field'):
register.fields.append(Field.from_svd(svd, field, ctx))
if register.access is None:
field_accesses = [f.access for f in register.fields]
if all(access == "read-only" for access in field_accesses):
register.access = "read-only"
elif all(access == "write-only" for access in field_accesses):
register.access = "write-only"
else:
register.access = "read-write"
return register
def __eq__(self, other):
return (
self.name == other.name and
self.offset == other.offset and
self.size == other.size and
self.access == other.access and
sorted(self.fields) == sorted(other.fields)
)
def __lt__(self, other):
return (self.offset, self.name) < (other.offset, other.name)
def refactor_common_field_values(self):
replace = []
to_replace = set()
fields = enumerate(self.fields)
for (idx1, f1), (idx2, f2) in itertools.combinations(fields, 2):
if f1 is f2 or idx1 in to_replace or idx2 in to_replace:
continue
if f1.r == f2.r and f1.r.values:
replace.append((idx1, idx2, 'r'))
to_replace.add(idx2)
if f1.w == f2.w and f1.w.values:
replace.append((idx1, idx2, 'w'))
to_replace.add(idx2)
if f1.rw == f2.rw and f1.rw.values:
replace.append((idx1, idx2, 'rw'))
to_replace.add(idx2)
for idx1, idx2, name in replace:
f1 = self.fields[idx1]
evs1 = getattr(f1, name)
f2 = EnumeratedValuesLink(f1, evs1)
setattr(self.fields[idx2], name, f2)
def consume(self, other, parent):
my_field_names = set(f.name for f in self.fields)
for field in other.fields:
if field.name not in my_field_names:
self.fields.append(field)
self.desc = "\n/// ".join([
f"{self.name} and {other.name}",
f"{self.name}: {escape_desc(self.desc)}",
f"{other.name}: {escape_desc(other.desc)}",
])
self.size = max(self.size, other.size)
self.access = "read-write"
newname = common_name(self.name, other.name, parent.name)
if newname != self.name[:len(newname)]:
print(f"Warning [{parent.name}]: {self.name}+{other.name} "
f"-> {newname}: suspected name compaction failure")
if newname != self.name:
if newname not in [r.name for r in parent.registers]:
self.name = newname
else:
print(f"Warning [{parent.name}]: {self.name} + {other.name} "
f"-> {newname}: name already exists, using {self.name}")
class PeripheralInstance(Node):
def __init__(self, name, addr, reset_values):
self.name = name
self.addr = addr
self.reset_values = reset_values
def to_dict(self):
return {"name": self.name, "addr": self.addr,
"reset_values": self.reset_values}
def to_rust(self, registers):
registers = {r.offset: r.name for r in registers}
resets = ", ".join(
f"{registers[k]}: 0x{v:08X}" for k, v in self.reset_values.items())
return f"""
/// Access functions for the {self.name} peripheral instance
pub mod {self.name} {{
#[cfg(not(feature="nosync"))]
use external_cortex_m;
use super::ResetValues;
#[cfg(not(feature="nosync"))]
use super::Instance;
#[cfg(not(feature="nosync"))]
const INSTANCE: Instance = Instance {{
addr: 0x{self.addr:08x},
_marker: ::core::marker::PhantomData,
}};
/// Reset values for each field in {self.name}
pub const reset: ResetValues = ResetValues {{
{resets}
}};
#[cfg(not(feature="nosync"))]
#[allow(renamed_and_removed_lints)]
#[allow(private_no_mangle_statics)]
#[no_mangle]
static mut {self.name}_TAKEN: bool = false;
/// Safe access to {self.name}
///
/// This function returns `Some(Instance)` if this instance is not
/// currently taken, and `None` if it is. This ensures that if you
/// do get `Some(Instance)`, you are ensured unique access to
/// the peripheral and there cannot be data races (unless other
/// code uses `unsafe`, of course). You can then pass the
/// `Instance` around to other functions as required. When you're
/// done with it, you can call `release(instance)` to return it.
///
/// `Instance` itself dereferences to a `RegisterBlock`, which
/// provides access to the peripheral's registers.
#[cfg(not(feature="nosync"))]
#[inline]
pub fn take() -> Option<Instance> {{
external_cortex_m::interrupt::free(|_| unsafe {{
if {self.name}_TAKEN {{
None
}} else {{
{self.name}_TAKEN = true;
Some(INSTANCE)
}}
}})
}}
/// Release exclusive access to {self.name}
///
/// This function allows you to return an `Instance` so that it
/// is available to `take()` again. This function will panic if
/// you return a different `Instance` or if this instance is not
/// already taken.
#[cfg(not(feature="nosync"))]
#[inline]
pub fn release(inst: Instance) {{
external_cortex_m::interrupt::free(|_| unsafe {{
if {self.name}_TAKEN && inst.addr == INSTANCE.addr {{
{self.name}_TAKEN = false;
}} else {{
panic!("Released a peripheral which was not taken");
}}
}});
}}
}}
/// Raw pointer to {self.name}
///
/// Dereferencing this is unsafe because you are not ensured unique
/// access to the peripheral, so you may encounter data races with
/// other users of this peripheral. It is up to you to ensure you
/// will not cause data races.
///
/// This constant is provided for ease of use in unsafe code: you can
/// simply call for example `write_reg!(gpio, GPIOA, ODR, 1);`.
pub const {self.name}: *const RegisterBlock =
0x{self.addr:08x} as *const _;"""
def __lt__(self, other):
return self.name < other.name
def __eq__(self, other):
return (self.name == other.name and
self.addr == other.addr and
self.reset_values == other.reset_values)
class PeripheralPrototype(Node):
def __init__(self, name, desc):
self.name = name.lower()
self.desc = desc
self.registers = []
self.instances = []
self.parent_device_names = []
def to_dict(self):
return {"name": self.name, "desc": self.desc,
"registers": [x.to_dict() for x in self.registers],
"instances": [x.to_dict() for x in self.instances]}
def to_rust_register_block(self):
lines = []
address = 0
reservedctr = 1
for register in sorted(self.registers):
if register.offset < address:
raise RuntimeError("Unexpected register aliasing")
if register.offset != address:
gaps = []
u32s = (register.offset - address) // 4
if u32s != 0:
gaps.append(f"[u32; {u32s}]")
address += u32s * 4
u16s = (register.offset - address) // 2
if u16s != 0:
gaps.append(f"[u16; {u16s}]")
address += u16s * 2
u8s = register.offset - address
if u8s != 0:
gaps.append(f"[u8; {u8s}]")
address += u8s
for gaptype in gaps:
lines.append(f"_reserved{reservedctr}: {gaptype},")
reservedctr += 1
lines.append(register.to_rust_struct_entry())
address += register.size // 8
lines = "\n".join(lines)
return f"""
pub struct RegisterBlock {{
{lines}
}}"""
def to_rust_reset_values(self):
lines = []
for register in sorted(self.registers):
lines.append(f"pub {register.name}: u{register.size},")
lines = "\n".join(lines)
return f"""
pub struct ResetValues {{
{lines}
}}"""
def to_rust_instance(self):
return """
#[cfg(not(feature="nosync"))]
pub struct Instance {
pub(crate) addr: u32,
pub(crate) _marker: PhantomData<*const RegisterBlock>,
}
#[cfg(not(feature="nosync"))]
impl ::core::ops::Deref for Instance {
type Target = RegisterBlock;
#[inline(always)]
fn deref(&self) -> &RegisterBlock {
unsafe { &*(self.addr as *const _) }
}
}
"""
def to_rust_file(self, path):
regtypes = set(r.to_regtype() for r in self.registers)
regtypes = ", ".join(regtypes)
if self.desc is None:
print(self.to_dict())
desc = "\n//! ".join(escape_desc(self.desc).split("\n"))
if len(self.parent_device_names) > 1:
desc += "\n//!\n"
desc += "//! Used by: " + ', '.join(
sorted(set(self.parent_device_names)))
preamble = "\n".join([
"#![allow(non_snake_case, non_upper_case_globals)]",
"#![allow(non_camel_case_types)]",
f"//! {desc}",
"",
"#[cfg(not(feature=\"nosync\"))]",
"use core::marker::PhantomData;",
f"use {{{regtypes}}};",
"",
])
modules = "\n".join(r.to_rust_mod() for r in self.registers)
instances = "\n".join(i.to_rust(self.registers)
for i in sorted(self.instances))
fname = os.path.join(path, f"{self.name}.rs")
with open(fname, "w") as f:
f.write(preamble)
f.write(modules)
f.write(self.to_rust_register_block())
f.write(self.to_rust_reset_values())
f.write(self.to_rust_instance())
f.write(instances)
rustfmt(fname)
def to_parent_entry(self):
return f"pub mod {self.name};\n"
@classmethod
def from_svd(cls, svd, node, register_ctx):
name = get_string(node, 'name')
addr = get_int(node, 'baseAddress')
desc = get_string(node, 'description')
registers = node.find('registers')
if 'derivedFrom' in node.attrib:
df = node.attrib['derivedFrom']
df_node = svd.find(f".//peripheral[name='{df}']")
if df_node is None:
raise ValueError("Can't find derivedFrom[{df}]")
desc = get_string(df_node, 'description', default=desc)
addr = get_int(node, 'baseAddress', addr)
registers = df_node.find('registers')
register_ctx = register_ctx.inherit(df_node)
register_ctx = register_ctx.inherit(node)
peripheral = cls(name, desc)
if registers is None:
raise ValueError(f"No registers found for peripheral {name}")
ctx = register_ctx
for register in registers.findall('register'):
peripheral.registers.append(Register.from_svd(svd, register, ctx))
resets = {r.offset: r.reset_value for r in peripheral.registers}
peripheral.instances.append(PeripheralInstance(name, addr, resets))
return peripheral
def consume(self, other, parent):
self.instances += other.instances
newname = common_name(self.name, other.name, parent.name)
if newname != self.name:
if newname not in [p.name for p in parent.peripherals]:
self.name = newname
else:
print(f"Warning [{parent.name}]: {self.name} + {other.name} "
f"-> {newname}: name already exists, using {self.name}")
def refactor_common_register_fields(self):
replace = []
to_replace = set()
registers = enumerate(self.registers)
for (idx1, r1), (idx2, r2) in itertools.combinations(registers, 2):
if r1 is r2 or idx1 in to_replace or idx2 in to_replace:
continue
if r1.fields == r2.fields and r1.fields:
replace.append((idx1, idx2))
to_replace.add(idx2)
for idx1, idx2 in replace:
r1 = self.registers[idx1]
r2 = self.registers[idx2]
path = f"super::{r1.name}"
r2.fields = [FieldLink(f, path) for f in r1.fields]
def refactor_aliased_registers(self):
to_delete = set()
registers = enumerate(self.registers)
for (idx1, r1), (idx2, r2) in itertools.combinations(registers, 2):
if r1 is r2 or idx1 in to_delete or idx2 in to_delete:
continue
if r1.offset == r2.offset:
r1.consume(r2, parent=self)
to_delete.add(idx2)
for idx in sorted(to_delete, reverse=True):
del self.registers[idx]
def __lt__(self, other):
return self.name < other.name
class PeripheralPrototypeLink(Node):
def __init__(self, name, prototype, path):
self.name = name
self.prototype = prototype
self.path = path
self.instances = []
self.parent_device_names = []
def to_dict(self):
return {"prototype": self.prototype.name, "path": self.path,
"instances": [x.to_dict() for x in self.instances]}
def to_rust_file(self, path):
desc = "\n//! ".join(self.prototype.desc.split("\n"))
if len(self.parent_device_names) > 1:
desc += "\n//!\n"
desc += "//! Used by: " + ', '.join(
sorted(set(self.parent_device_names)))
preamble = "\n".join([
"#![allow(non_snake_case, non_upper_case_globals)]",
"#![allow(non_camel_case_types)]",
f"//! {desc}",
"",
f"pub use {self.path}::{{RegisterBlock, ResetValues}};",
"#[cfg(not(feature = \"nosync\"))]",
f"pub use {self.path}::{{Instance}};",
"",
])
registers = ", ".join(m.name for m in self.prototype.registers)
registers = f"pub use {self.path}::{{{registers}}};\n"
instances = "\n".join(i.to_rust(self.registers)
for i in sorted(self.instances))
fname = os.path.join(path, f"{self.name}.rs")
with open(fname, "w") as f:
f.write(preamble)
f.write(registers)
f.write("\n")
f.write(instances)
rustfmt(fname)
def to_parent_entry(self):
return f"pub mod {self.name};\n"
@classmethod
def from_peripherals(cls, p1, p2, path):
plink = cls(p2.name, p1, path)
plink.instances = p2.instances
return plink
@property
def registers(self):
return self.prototype.registers
@property
def desc(self):
return self.prototype.desc
def refactor_common_register_fields(self):
pass
def refactor_common_instances(self):
pass
def refactor_aliased_registers(self):
pass
def __lt__(self, other):
return self.name < other.name
class PeripheralSharedInstanceLink(Node):
def __init__(self, name, usename, prototype):
self.name = name
self.usename = usename
self.prototype = prototype
def to_parent_entry(self):
if self.usename == self.name:
return f"pub use super::instances::{self.name};\n"
else:
return (f"pub use super::instances::{self.usename} "
f"as {self.name};\n")
def to_rust_file(self, path):
pass
@property
def registers(self):
return self.prototype.registers
@property
def desc(self):
return self.prototype.desc
def refactor_common_register_fields(self):
pass
def refactor_common_instances(self):
pass
def refactor_aliased_registers(self):
pass
def __lt__(self, other):
return self.name < other.name
class CPU(Node):
def __init__(self, name, nvic_prio_bits):
self.name = name
self.nvic_prio_bits = nvic_prio_bits
def get_architecture(self):
if self.name == "CM0":
return "ARMv6-M"
elif self.name == "CM0+":
return "ARMv6-M"
elif self.name == "CM3":
return "ARMv7-M"
elif self.name == "CM4":
return "ARMv7E-M"
elif self.name == "CM7":
return "ARMv7E-M"
def to_dict(self):
return {"name": self.name, "nvic_prio_bits": self.nvic_prio_bits}
@classmethod
def from_svd(cls, svd, node):
name = get_string(node, 'name')
nvic_prio_bits = node.find('nvicPrioBits').text
return cls(name, nvic_prio_bits)
class Interrupt(Node):
def __init__(self, name, desc, value):
self.name = name
self.desc = desc
self.value = value
def to_dict(self):
return {"name": self.name, "desc": self.desc, "value": self.value}
@classmethod
def from_svd(cls, svd, node):
name = get_string(node, 'name')
desc = get_string(node, 'description')
value = get_int(node, 'value')
return cls(name, desc, value)
def __lt__(self, other):
return self.value < other.value
class Device(Node):
def __init__(self, name, cpu):
self.name = name.lower().replace("-", "_")
self.cpu = cpu
self.peripherals = []
self.interrupts = []
self.special = False
def to_dict(self):
return {"name": self.name, "cpu": self.cpu.to_dict(),
"peripherals": [x.to_dict() for x in self.peripherals],
"interrupts": [x.to_dict() for x in self.interrupts]}
def to_interrupt_file(self, familypath):
devicepath = os.path.join(familypath, self.name)
iname = os.path.join(devicepath, "interrupts.rs")
with open(iname, "w") as f:
f.write("extern crate bare_metal;\n")
f.write('#[cfg(feature="rt")]\nextern "C" {\n')
for interrupt in self.interrupts:
f.write(f' fn {interrupt.name}();\n')
f.write('}\n\n')
vectors = []
offset = 0
for interrupt in self.interrupts:
while interrupt.value != offset:
vectors.append("Vector { _reserved: 0 },")
offset += 1
vectors.append(f"Vector {{ _handler: {interrupt.name} }},")
offset += 1
nvectors = len(vectors)
vectors = "\n".join(vectors)
f.write(f"""\
#[doc(hidden)]
pub union Vector {{
_handler: unsafe extern "C" fn(),
_reserved: u32,
}}
#[cfg(feature="rt")]
#[doc(hidden)]
#[link_section=".vector_table.interrupts"]
#[no_mangle]
pub static __INTERRUPTS: [Vector; {nvectors}] = [
{vectors}
];
/// Available interrupts for this device
#[repr(u8)]
#[derive(Clone,Copy)]
#[allow(non_camel_case_types)]
pub enum Interrupt {{""")
for interrupt in self.interrupts:
f.write(f"/// {interrupt.value}: ")
f.write(f"{escape_desc(interrupt.desc)}\n")
f.write(f"{interrupt.name} = {interrupt.value},\n")
f.write("}\n")
f.write("""\
unsafe impl bare_metal::Nr for Interrupt {
#[inline]
fn nr(&self) -> u8 {
*self as u8
}
}\n""")
rustfmt(iname)
def to_files(self, familypath):
devicepath = os.path.join(familypath, self.name)
os.makedirs(devicepath, exist_ok=True)
for peripheral in self.peripherals:
peripheral.to_rust_file(devicepath)
pnames = [p.name for p in self.peripherals]
dupnames = set(name for name in pnames if pnames.count(name) > 1)
if dupnames:
print(f"Warning [{self.name}]: duplicate peripherals: ", end='')
print(dupnames)
if not self.special:
self.to_interrupt_file(familypath)
mname = os.path.join(devicepath, "mod.rs")
with open(mname, "w") as f:
f.write(f"//! stm32ral module for {self.name}\n\n")
prio_bits = self.cpu.nvic_prio_bits
if not self.special:
f.write("/// Number of priority bits implemented by the NVIC")
f.write(f"\npub const NVIC_PRIO_BITS: u8 = {prio_bits};\n\n")
f.write("/// Interrupt-related magic for this device\n")
f.write("pub mod interrupts;\n")
f.write("pub use self::interrupts::Interrupt;\n")
f.write("pub use self::interrupts::Interrupt as interrupt;\n\n")
for peripheral in self.peripherals:
f.write(peripheral.to_parent_entry())
rustfmt(mname)
if not self.special:
dname = os.path.join(devicepath, "device.x")
with open(dname, "w") as f:
for interrupt in self.interrupts:
f.write(f"PROVIDE({interrupt.name} = DefaultHandler);\n")
@classmethod
def from_svd(cls, svd):
name = get_string(svd, 'name')
cpu = CPU.from_svd(svd, svd.find('cpu'))
device = cls(name, cpu)
register_ctx = RegisterCtx.empty()
register_ctx = register_ctx.inherit(svd)
interrupt_nums = set()
for interrupt in svd.findall('.//interrupt'):
interrupt = Interrupt.from_svd(svd, interrupt)
if interrupt.value in interrupt_nums:
continue
device.interrupts.append(interrupt)
interrupt_nums.add(interrupt.value)
device.interrupts.sort()
for peripheral in svd.findall('.//peripheral'):
device.peripherals.append(
PeripheralPrototype.from_svd(svd, peripheral, register_ctx))
for peripheral in device.peripherals:
peripheral.parent_device_names.append(device.name)
return device
@classmethod
def from_svdfile(cls, svdfile):
svd = ET.parse(svdfile)
return cls.from_svd(svd)
def refactor_peripheral_instances(self):
to_delete = set()
to_link = set()
links = []
periphs = enumerate(self.peripherals)
for (idx1, p1), (idx2, p2) in itertools.combinations(periphs, 2):
if p1 is p2 or idx1 in to_delete or idx2 in to_delete:
continue
elif idx1 in to_link or idx2 in to_link:
continue
elif p1.registers == p2.registers:
if p1.name.startswith("tim"):
links.append((idx1, idx2))
to_link.add(idx2)
else:
p1.consume(p2, parent=self)
to_delete.add(idx2)
for idx1, idx2 in links:
p1 = self.peripherals[idx1]
p2 = self.peripherals[idx2]
path = f"super::{p1.name}"
plink = PeripheralPrototypeLink.from_peripherals(p1, p2, path)
self.peripherals[idx2] = plink
for idx in sorted(to_delete, reverse=True):
del self.peripherals[idx]
class Family(Node):
def __init__(self, name):
self.name = name
self.devices = []
self.peripherals = []
self.instances = []
def to_dict(self):
return {"name": self.name,
"devices": [d.to_dict() for d in self.devices],
"peripherals": [p.to_dict() for p in self.peripherals],
"instances": [i.to_dict()
for i in self.instances]}
def to_files(self, path, pool):
familypath = os.path.join(path, self.name)
os.makedirs(familypath, exist_ok=True)
periphpath = os.path.join(familypath, "peripherals")
instpath = os.path.join(familypath, "instances")
os.makedirs(periphpath, exist_ok=True)
os.makedirs(instpath, exist_ok=True)
pool_results = []
with open(os.path.join(familypath, "mod.rs"), "w") as f:
uname = self.name.upper()
f.write(f"//! Parent module for all {uname} devices.\n\n")
f.write("/// Peripherals shared by multiple devices\n")
f.write('pub mod peripherals;\n\n')
f.write("/// Peripheral instances shared by multiple devices\n")
f.write("pub(crate) mod instances;\n\n")
for device in self.devices:
dname = device.name
result = pool.apply_async(device.to_files, (familypath,))
pool_results.append(result)
f.write(f'#[cfg(any(feature="{dname}", feature="doc"))]\n')
f.write(f'pub mod {dname};\n\n')
with open(os.path.join(periphpath, "mod.rs"), "w") as f:
for peripheral in self.peripherals:
r = pool.apply_async(peripheral.to_rust_file, (periphpath,))
pool_results.append(r)
features = ", ".join(
f'feature="{d}"' for d in peripheral.parent_device_names)
f.write(f'#[cfg(any(feature="doc", {features}))]\n')
f.write(f'pub mod {peripheral.name};\n\n')
with open(os.path.join(instpath, "mod.rs"), "w") as f:
for instance in self.instances:
r = pool.apply_async(instance.to_rust_file, (instpath,))
pool_results.append(r)
features = ", ".join(
f'feature="{d}"' for d in instance.parent_device_names)
f.write(f'#[cfg(any(feature="doc", {features}))]\n')
f.write(f'pub mod {instance.name};\n\n')
return pool_results
def _enumerate_peripherals(self):
peripherals = []
for didx, device in enumerate(self.devices):
for pidx, peripheral in enumerate(device.peripherals):
peripherals.append((didx, pidx, peripheral))
return peripherals
def _match_peripherals(self):
to_link = set()
links = dict()
peripherals = self._enumerate_peripherals()
for pt1, pt2 in itertools.combinations(peripherals, 2):
didx1, pidx1, p1 = pt1
didx2, pidx2, p2 = pt2
idx1 = (didx1, pidx1)
idx2 = (didx2, pidx2)
if p1 is p2 or idx1 in to_link or idx2 in to_link:
continue
elif p1.registers == p2.registers:
to_link.add(idx2)
if idx1 not in links:
links[idx1] = []
links[idx1].append(idx2)
return links
def refactor_common_peripherals(self):
links = self._match_peripherals()
pnames = set()
dupnames = set()
for idx in links:
didx, pidx = idx
p = self.devices[didx].peripherals[pidx]
if p.name in pnames:
dupnames.add(p.name)
pnames.add(p.name)
versions = {}
for idx in links:
didx, pidx = idx
device = self.devices[didx]
p = device.peripherals[pidx]
name = p.name
if name in dupnames:
if name not in versions:
versions[name] = 0
versions[name] += 1
name = f'{name}_v{versions[name]}'
familyp = PeripheralPrototype(name, p.desc)
familyp.registers = p.registers
familyp.parent_device_names.append(device.name)
self.peripherals.append(familyp)
path = f"{self.name}::peripherals::{name}"
linkp = PeripheralPrototypeLink(p.name, familyp, path)
linkp.instances = p.instances
self.devices[didx].peripherals[pidx] = linkp
for childidx in links[idx]:
cdidx, cpidx = childidx
childd = self.devices[cdidx]
childp = childd.peripherals[cpidx]
familyp.parent_device_names.append(childd.name)
linkp = PeripheralPrototypeLink(childp.name, familyp, path)
linkp.instances = childp.instances
childd.peripherals[cpidx] = linkp
self.refactor_common_instances(links)
def refactor_common_instances(self, links):
to_group = set()
groups = dict()
for primary, children in links.items():
members = [primary] + list(children)
for l1, l2 in itertools.combinations(members, 2):
didx1, pidx1 = l1
didx2, pidx2 = l2
p1 = self.devices[didx1].peripherals[pidx1]
p2 = self.devices[didx2].peripherals[pidx2]
if p1 is p2 or l1 in to_group or l2 in to_group:
continue
elif p1.instances == p2.instances:
to_group.add(l2)
if l1 not in groups:
groups[l1] = []
groups[l1].append(l2)
pnames = set()
dupnames = set()
for (didx, pidx) in groups:
p = self.devices[didx].peripherals[pidx]
if p.name in pnames:
dupnames.add(p.name)
pnames.add(p.name)
for idx in groups:
didx, pidx = idx
d = self.devices[didx]
p = d.peripherals[pidx]
name = p.name
if name in dupnames:
name += "_" + d.name[5:]
for cidx in groups[idx]:
cdidx, _ = cidx
cd = self.devices[cdidx]
name += "_" + cd.name[5:]
linkp = PeripheralSharedInstanceLink(p.name, name, p)
self.devices[didx].peripherals[pidx] = linkp
groupp = p
groupp.name = name
groupp.parent_device_names.append(d.name)
self.instances.append(groupp)
for cidx in groups[idx]:
cdidx, cpidx = cidx
cd = self.devices[cdidx]
groupp.parent_device_names.append(cd.name)
cd.peripherals[cpidx] = linkp
class Crate:
def __init__(self):
self.families = []
self.peripherals = []
def to_dict(self):
return {"families": [x.to_dict() for x in self.families],
"peripherals": [x.to_dict() for x in self.peripherals]}
def write_build_script(self, path):
devices = []
for family in self.families:
for device in family.devices:
if not device.special:
devices.append((family.name, device.name))
clauses = " else ".join("""\
if env::var_os("CARGO_FEATURE_{}").is_some() {{
"src/{}/{}/device.x"
}}""".format(d.upper(), f, d) for (f, d) in sorted(devices))
clauses += " else { panic!(\"No device features selected\"); }"
fname = os.path.join(path, "build.rs")
with open(fname, "w") as f:
f.write(BUILD_RS_TEMPLATE.format(device_clauses=clauses))
rustfmt(fname)
def to_files(self, path, pool):
srcpath = os.path.join(path, 'src')
if not os.path.isdir(srcpath):
raise ValueError(f"{srcpath} does not exist")
periphpath = os.path.join(srcpath, "peripherals")
os.makedirs(periphpath, exist_ok=True)
lib_f = open(os.path.join(srcpath, "lib.rs"), "w")
lib_f.write(CRATE_LIB_PREAMBLE)
cargo_f = open(os.path.join(path, "Cargo.toml"), "w")
cargo_f.write(CRATE_CARGO_TOML_PREAMBLE)
self.write_build_script(path)
periph_f = open(os.path.join(periphpath, "mod.rs"), "w")
pool_results = []
for family in self.families:
fname = family.name
pool_results += family.to_files(srcpath, pool)
features = [f'feature="{d.name}"' for d in family.devices]
lib_f.write(f'#[cfg(any(feature="doc", {", ".join(features)}))]\n')
lib_f.write(f'pub mod {fname};\n\n')
for device in family.devices:
dname = device.name
arch = device.cpu.get_architecture().lower().replace("-", "_")
if device.special:
cargo_f.write(f'{dname} = []\n')
else:
cargo_f.write(f'{dname} = ["{arch}"]\n')
lib_f.write(f'#[cfg(feature="{dname}")]\n')
lib_f.write(f'pub use {fname}::{dname}::*;\n\n')
if self.peripherals:
lib_f.write("//! Peripherals shared between multiple families\n")
lib_f.write("pub mod peripherals;\n\n")
for peripheral in self.peripherals:
result = pool.apply_async(peripheral.to_rust_file, (periphpath,))
pool_results.append(result)
features = ", ".join(
f'feature="{d}"' for d in peripheral.parent_device_names)
periph_f.write(f'#[cfg(any(feature="doc", {features}))]\n')
periph_f.write(f'pub mod {peripheral.name};\n\n')
return pool_results
def get_int(node, tag, default=None):
text = get_string(node, tag, default=default)
if text == default:
return text
text = text.lower().strip()
if text == "true":
return 1
elif text == "false":
return 0
elif text[:2] == "0x":
return int(text[2:], 16)
elif text[:2] == "0b":
return int(text[2:], 2)
else:
return int(text, 10)
def get_string(node, tag, default=None):
text = node.findtext(tag, default=default)
if text == default:
return text
return " ".join(text.split())
def escape_desc(desc):
return desc.replace("[", "\\[").replace("]", "\\]")
def rustfmt(fname):
subprocess.run(["rustfmt", fname])
def common_name(a, b, ctx=""):
diffpos = [i for i in range(min(len(a), len(b))) if a[i] != b[i]]
for x, y in ((a, b), (b, a)):
if x.startswith("i2s") and x.endswith("ext") and y.startswith("spi"):
return "spi"
if x.startswith("usart") and y.startswith("uart"):
return "usart"
if x == "adc1_2" and y == "adc3_4":
return "adc_common"
if x == "adc12_common" and y == "adc3_common":
return "adc_common"
if x.startswith("delay_block_") and y.startswith("delay_block_"):
return "dlyb"
if x == "dlyb" and y.startswith("delay_block_"):
return "dlyb"
if len(diffpos) == 0:
if a == b:
print(f"Warning [{ctx}]: {a} and {b} are identical")
return a
elif b[:len(a)] == a:
return a
elif a[:len(b)] == b:
return b
elif len(diffpos) == 1:
p = diffpos[0]
an = a[:p] + a[p+1:]
bn = b[:p] + b[p+1:]
if an == bn:
return an
else:
print(f"Warning [{ctx}]: {a}->{an} and {b}->{bn} failed")
return a
else:
p = diffpos[0]
if a == b[:p] + b[p+1:]:
return a
ap = a[:p]
bp = b[:p]
if ap.endswith("_"):
ap = ap[:-1]
bp = bp[:-1]
if len(ap) > 0 and ap == bp:
return ap
else:
print(f"Warning [{ctx}]: {a}->{ap} and {b}->{bp} failed")
return a
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument("cratepath", help="Path to crate root")
parser.add_argument("svdfiles", nargs="+", help="SVD files to parse")
return parser.parse_args()
def main():
args = parse_args()
crate = Crate()
print("Parsing input files...")
with multiprocessing.Pool() as p:
devices = p.map(Device.from_svdfile, args.svdfiles)
print("Collating families...")
cortex_family = Family("cortex_m")
crate.families.append(cortex_family)
for device in devices:
if device.name.startswith("armv"):
device.special = True
cortex_family.devices.append(device)
else:
device_family = device.name[:7].lower()
if device_family not in [f.name for f in crate.families]:
crate.families.append(Family(device_family))
family = [f for f in crate.families if f.name == device_family][0]
family.devices.append(device)
print("Running refactors...")
for device in devices:
device.refactor_peripheral_instances()
for peripheral in device.peripherals:
peripheral.refactor_aliased_registers()
peripheral.refactor_common_register_fields()
for register in peripheral.registers:
register.refactor_common_field_values()
for family in crate.families:
family.refactor_common_peripherals()
print("Outputting crate...")
pool_results = []
with multiprocessing.Pool() as pool:
pool_results += crate.to_files(args.cratepath, pool)
for result in pool_results:
result.get()
if __name__ == "__main__":
main()