import yaml
import sys
from typing import Any
from dataclasses import dataclass
from pathlib import Path
import re
def to_snake(s: str) -> str:
s = re.sub(r'([A-Z]+)([A-Z][a-z])', r'\1_\2', s)
s = re.sub(r'([a-z\d])([A-Z])', r'\1_\2', s)
s = s.lower()
return s
def snake_to_pascal(snake_str: str) -> str:
s = re.sub(r'(\d+)_(\d+)', r'\1p\2', snake_str)
return ''.join(word.capitalize() for word in s.split('_'))
enum_remap : dict[str,str] = {
'BitrateCh1' : 'ZwaveMode',
'BitrateCh2' : 'ZwaveMode',
'BitrateCh3' : 'ZwaveMode',
'BitrateCh4' : 'ZwaveMode',
'LastDetect' : 'ZwaveMode',
'Sd1Sf' : 'Sf',
'Sd2Sf' : 'Sf',
'Sd3Sf' : 'Sf',
'Sd1Ldro' : 'Ldro',
'Sd2Ldro' : 'Ldro',
'Sd3Ldro' : 'Ldro',
'TriggerStop' : 'CaptureTrigger',
}
@dataclass
class BytePosition:
byte_index: int
bit_range: str
def get_bit_range_tuple(self) -> tuple[int, int]:
if ':' in self.bit_range:
msb, lsb = map(int, self.bit_range.split(':'))
return (msb, lsb)
else:
bit = int(self.bit_range)
return (bit, bit)
@dataclass
class Field:
name: str
bit_width: int
signed: bool
byte_positions: list[BytePosition]
description: str
little_endian: bool = False
optional: bool = False
enum: dict[str, int] | None = None
@dataclass
class Command:
name: str
opcode: int
description: str
parameters: list[Field]
status_fields: list[Field]
class ValidationError(Exception):
pass
def parse_byte_positions(positions_data: list[tuple[int,str]]) -> list[BytePosition]:
positions : list[BytePosition] = []
for pos_data in positions_data:
if len(pos_data) != 2:
raise ValidationError(f"Invalid byte position format: {pos_data}")
byte_index, bit_range = pos_data
positions.append(BytePosition(byte_index, str(bit_range)))
return positions
def validate_field(field: Field, context: str) -> None:
if field.bit_width < 0:
raise ValidationError(f"{context}: bit_width cannot be negative for field '{field.name}'")
if field.bit_width == 0 and field.byte_positions:
raise ValidationError(f"{context}: variable length field '{field.name}' should not have byte_positions")
if field.bit_width > 0:
total_bits = 0
for pos in field.byte_positions:
msb, lsb = pos.get_bit_range_tuple()
if msb < lsb:
raise ValidationError(f"{context}: invalid bit range '{pos.bit_range}' for field '{field.name}' (MSB < LSB)")
total_bits += (msb - lsb + 1)
if total_bits != field.bit_width:
raise ValidationError(f"{context}: bit_width ({field.bit_width}) doesn't match byte_positions total ({total_bits}) for field '{field.name}'")
def parse_field(field_data: dict[str, Any], context: str) -> Field:
try:
name : str = field_data['name']
bit_width : int = field_data['bit_width']
byte_positions = parse_byte_positions(field_data.get('byte_positions', []))
signed : bool = field_data.get('signed', False)
description : str = field_data.get('description', '')
optional : bool = field_data.get('optional', False)
enum: dict[str, int] | None = field_data.get('enum', None)
le : bool = field_data.get('le', False)
field = Field(name, bit_width, signed, byte_positions, description, le, optional, enum)
validate_field(field, f"{context}.{name}")
return field
except KeyError as e:
raise ValidationError(f"{context}: missing required field property: {e}")
except Exception as e:
raise ValidationError(f"{context}: error parsing field: {e}")
def parse_command(cmd_name: str, cmd_data: dict[str, Any]) -> Command:
try:
opcode : int = cmd_data['opcode']
description : str = cmd_data.get('description', '')
parameters : list[Field] = []
for param_data in cmd_data.get('parameters', []):
param = parse_field(param_data, f"command '{cmd_name}' parameter")
parameters.append(param)
status_fields : list[Field] = []
for field_data in cmd_data.get('status_fields', []):
field = parse_field(field_data, f"command '{cmd_name}' status_field")
status_fields.append(field)
return Command(cmd_name, opcode, description, parameters, status_fields)
except KeyError as e:
raise ValidationError(f"command '{cmd_name}': missing required property: {e}")
except Exception as e:
raise ValidationError(f"command '{cmd_name}': {e}")
def get_rust_type(field: Field) -> str:
if field.enum:
t = snake_to_pascal(field.name)
if t == 'TriggerStart':
t = 'CaptureTrigger'
return enum_remap.get(t,t)
elif field.bit_width == 0:
return "&[u8]" elif field.bit_width == 1:
return "bool"
elif field.bit_width <= 8:
return "i8" if field.signed else "u8"
elif field.bit_width <= 16:
return "i16" if field.signed else "u16"
elif field.bit_width <= 32:
return "i32" if field.signed else "u32"
else:
return "i64" if field.signed else "u64"
def gen_enum(field: Field) -> str:
if not field.enum:
return ''
enum_name = snake_to_pascal(field.name)
if enum_name == 'TriggerStart':
enum_name = 'CaptureTrigger'
lines = ["",f"/// {field.description}"]
more_derive = ', PartialOrd, Ord' if enum_name in ['Sf'] else ''
lines.append(f"#[derive(Debug, Clone, Copy, PartialEq, Eq{more_derive})]")
lines.append("#[cfg_attr(feature = \"defmt\", derive(defmt::Format))]")
lines.append(f"pub enum {enum_name} {{")
for variant_name, value in field.enum.items():
name = snake_to_pascal(variant_name)
lines.append(f" {name} = {value},")
lines.append("}")
if enum_name == 'ZwaveMode':
lines.append("\nimpl ZwaveMode {")
lines.append(" pub fn new(val: u8) -> Self{")
lines.append(" match val {")
lines.append(" 3 => ZwaveMode::R3,")
lines.append(" 2 => ZwaveMode::R2,")
lines.append(" _ => ZwaveMode::R1,")
lines.append(" }")
lines.append(" }")
lines.append("}")
elif enum_name == 'LoraBw' :
lines.append("\nimpl LoraBw {")
lines.append(" /// Return Bandwidth in Hz")
lines.append(" pub fn to_hz(&self) -> u32 {")
lines.append(" match self {")
lines.append(" LoraBw::Bw1000 => 1_000_000,")
lines.append(" LoraBw::Bw812 => 812_500,")
lines.append(" LoraBw::Bw500 => 500_000,")
lines.append(" LoraBw::Bw406 => 406_250,")
lines.append(" LoraBw::Bw250 => 250_000,")
lines.append(" LoraBw::Bw203 => 203_125,")
lines.append(" LoraBw::Bw125 => 125_000,")
lines.append(" LoraBw::Bw101 => 101_562,")
lines.append(" LoraBw::Bw83 => 83_333,")
lines.append(" LoraBw::Bw62 => 62_500,")
lines.append(" LoraBw::Bw41 => 41_666,")
lines.append(" LoraBw::Bw31 => 31_250,")
lines.append(" LoraBw::Bw20 => 20_833,")
lines.append(" LoraBw::Bw15 => 15_625,")
lines.append(" LoraBw::Bw10 => 10_416,")
lines.append(" LoraBw::Bw7 => 7_812,")
lines.append(" }")
lines.append(" }\n")
lines.append(" /// Flag Fractional bandwidth");
lines.append(" /// Corresponds to band used in SX1280");
lines.append(" pub fn is_fractional(&self) -> bool {");
lines.append(" use LoraBw::*;");
lines.append(" matches!(self, Bw812 | Bw406 | Bw203 | Bw101)");
lines.append(" }");
lines.append("}\n")
lines.append("impl PartialOrd for LoraBw {")
lines.append(" fn partial_cmp(&self, other: &Self) -> Option<core::cmp::Ordering> {")
lines.append(" Some(self.cmp(other))")
lines.append(" }")
lines.append("}\n")
lines.append("impl Ord for LoraBw {")
lines.append(" fn cmp(&self, other: &Self) -> core::cmp::Ordering {")
lines.append(" self.to_hz().cmp(&other.to_hz())")
lines.append(" }")
lines.append("}")
elif enum_name == 'LoraCr' :
lines.append("\nimpl LoraCr {")
lines.append(" /// Return if Code-rate uses long interleaving")
lines.append(" pub fn is_li(&self) -> bool {")
lines.append(" use LoraCr::*;")
lines.append(" matches!(self, Cr5Ham45Li|Cr6Ham23Li|Cr7Ham12Li|Cr8Cc23|Cr9Cc12)")
lines.append(" }")
lines.append(" /// Return denominator for the coding rate, supposing a numerator equal to 4")
lines.append(" pub fn denominator(&self) -> u8 {")
lines.append(" match self {")
lines.append(" LoraCr::NoCoding => 4,")
lines.append(" // Code rate 4/5")
lines.append(" LoraCr::Cr1Ham45Si |")
lines.append(" LoraCr::Cr5Ham45Li => 5,")
lines.append(" // Code rate 2/3 -> 4/6")
lines.append(" LoraCr::Cr2Ham23Si |")
lines.append(" LoraCr::Cr6Ham23Li |")
lines.append(" LoraCr::Cr8Cc23 => 6,")
lines.append(" // Code rate 4/7")
lines.append(" LoraCr::Cr3Ham47Si => 7,")
lines.append(" // Code rate 1/2 -> 4/8")
lines.append(" LoraCr::Cr4Ham12Si |")
lines.append(" LoraCr::Cr7Ham12Li |")
lines.append(" LoraCr::Cr9Cc12 => 8,")
lines.append(" }")
lines.append(" }")
lines.append("}")
elif enum_name == 'WmbusMode' :
lines.append("")
lines.append("/// WM-Bus mode selection")
lines.append("#[derive(Debug, Clone, Copy, PartialEq, Eq)]")
lines.append("#[cfg_attr(feature = \"defmt\", derive(defmt::Format))]")
lines.append("pub enum WmbusSubBand {A,B,C,D}")
lines.append("")
lines.append("impl WmbusMode {")
lines.append(" /// Return center frequency associated with a mode")
lines.append(" /// The channel index is only used in Mode R2 and N")
lines.append(" /// The sub-band is only used in Mode N")
lines.append(" pub fn rf(&self, channel: u8, subband: WmbusSubBand) -> u32 {")
lines.append(" use {WmbusMode::*, WmbusSubBand::*};")
lines.append(" match self {")
lines.append(" ModeS => 868_300_000,")
lines.append(" ModeT1 |")
lines.append(" ModeT2M2o => 868_950_000,")
lines.append(" ModeT2O2m => 868_300_000,")
lines.append(" ModeR2 => 868_030_000 + 60_000 * channel as u32,")
lines.append(" ModeC1 |")
lines.append(" ModeC2M2o => 868_950_000,")
lines.append(" ModeC2O2m => 868_525_000,")
lines.append(" ModeF2 => 433_820_000,")
lines.append(" ModeN4p8 |")
lines.append(" ModeN2p4 |")
lines.append(" ModeN6p4 => {")
lines.append(" match subband {")
lines.append(" A => 169_406_250 + 12_500 * channel as u32,")
lines.append(" B => 169_481_250 ,")
lines.append(" C => 169_493_750 + 12_500 * channel as u32,")
lines.append(" D => 169_593_250 + 12_500 * channel as u32,")
lines.append(" }")
lines.append(" }")
lines.append(" ModeN19p2 => {")
lines.append(" match subband {")
lines.append(" A => 169_437_500,")
lines.append(" B => 169_481_250,")
lines.append(" C => 169_493_750,")
lines.append(" D => 169_625_000 + 50_000 * channel as u32,")
lines.append(" }")
lines.append(" }")
lines.append(" }")
lines.append(" }")
lines.append("}")
return '\n'.join(lines)
def size_of(fields: list[Field]) -> int:
if not fields:
return 2
max_byte = 1 for field in fields:
for pos in field.byte_positions:
max_byte = max(max_byte, pos.byte_index)
return max_byte + 1
def gen_req(cmd: Command, _category: str, advanced: bool = False) -> str:
if cmd.opcode < 0 :
return ''
func_name = to_snake(cmd.name)
if advanced:
func_name += "_adv"
func_suffix = "_req" if cmd.status_fields else "_cmd"
func_name += func_suffix
if advanced:
params = cmd.parameters
else:
params = [p for p in cmd.parameters if not p.optional]
skipped_params = [p for p in params if p.bit_width == 0]
if skipped_params:
return f"// TODO: Implement {func_name} (contains variable length parameters: {[p.name for p in skipped_params]})"
param_list: list[str] = []
for param in params:
if param.name == 'temp_format':
continue
param_type = get_rust_type(param)
param_list.append(f"{param.name}: {param_type}")
buffer_size = size_of(params)
lines : list[str] = []
if len(params) > 7:
lines.append('#[allow(clippy::too_many_arguments)]')
desc = cmd.description.replace('[', '(').replace(']', ')')
lines.append(f"/// {desc}")
if param_list:
lines.append(f"pub fn {func_name}({', '.join(param_list)}) -> [u8; {buffer_size}] {{")
else:
lines.append(f"pub fn {func_name}() -> [u8; {buffer_size}] {{")
opcode_msb = (cmd.opcode >> 8) & 0xFF
opcode_lsb = cmd.opcode & 0xFF
if not params:
lines.append(f" [0x{opcode_msb:02X}, 0x{opcode_lsb:02X}]")
else:
lines.append(" let mut cmd = [0u8; {}];".format(buffer_size))
lines.append(f" cmd[0] = 0x{opcode_msb:02X};")
lines.append(f" cmd[1] = 0x{opcode_lsb:02X};")
lines.append("")
for param in params:
if param.bit_width == 0:
continue
need_cast_u8 = param.signed or param.bit_width > 8
mask = (1 << param.bit_width) - 1 if param.bit_width < 8 else 255
for pos in param.byte_positions:
_, lsb = pos.get_bit_range_tuple()
if param.bit_width == 1 and not param.enum:
lines.append(f" if {param.name} {{ cmd[{pos.byte_index}] |= {1<<lsb}; }}")
elif param.name == 'temp_format':
lines.append(f" cmd[2] |= 8; // Force format to Celsius")
else:
if param.little_endian:
shift_right = sum((p_msb - p_lsb + 1) for p_pos in param.byte_positions[:param.byte_positions.index(pos)] for p_msb, p_lsb in [p_pos.get_bit_range_tuple()])
else :
shift_right = param.bit_width - sum((p_msb - p_lsb + 1) for p_pos in param.byte_positions[:param.byte_positions.index(pos)+1] for p_msb, p_lsb in [p_pos.get_bit_range_tuple()])
l = f" cmd[{pos.byte_index}] |= "
if need_cast_u8: l += '(';
if lsb!= 0: l += '(';
if shift_right!=0: l += '(';
if param.enum:
if shift_right==0 and lsb==0 and param.bit_width==8:
l += f"{param.name} as u8"
else:
l += f"({param.name} as u8)"
else:
l += param.name
if shift_right!=0:
l += f" >> {shift_right})"
if param.bit_width!=8:
l += f" & 0x{mask:X}"
if lsb!= 0:
l += f') << {lsb}'
if need_cast_u8:
l += ') as u8';
l += ';'
lines.append(l)
lines.append(" cmd")
lines.append("}")
return '\n'.join(lines)
def gen_rsp(cmd: Command, _category: str, advanced: bool = False) -> str:
if not cmd.status_fields:
return ""
struct_name = f"{cmd.name}Rsp"
if advanced:
struct_name += "Adv"
if struct_name.startswith("Get"):
struct_name = struct_name[3:]
if advanced:
fields = cmd.status_fields
else:
fields = [f for f in cmd.status_fields if not f.optional]
buffer_size = size_of(fields)
lines = [f"/// Response for {cmd.name} command"]
lines.append("#[derive(Default)]")
lines.append(f"pub struct {struct_name}([u8; {buffer_size}]);")
lines.append("")
lines.append(f"impl {struct_name} {{")
lines.append(" /// Create a new response buffer")
lines.append(" pub fn new() -> Self {")
lines.append(" Self::default()")
lines.append(" }")
lines.append("")
lines.append(" /// Return Status")
lines.append(" pub fn status(&mut self) -> Status {")
lines.append(" Status::from_slice(&self.0[..2])")
lines.append(" }")
for field in fields:
if field.bit_width == 0:
lines.append(f" // TODO: Implement accessor for variable length field '{field.name}'")
continue
return_type = get_rust_type(field).replace('&[u8]', 'u32')
lines.append(f"")
desc = field.description.replace('[', '(').replace(']', ')')
lines.append(f" /// {desc}")
if cmd.name == 'GetStatus' and field.name=='intr':
lines.append(f" pub fn {field.name}(&self) -> Intr {{")
lines.append(' Intr::from_slice(&self.0[2..6])')
lines.append(' }')
continue
if cmd.name == 'GetZwavePacketStatus' and field.name=='last_detect':
lines.append(' pub fn last_detect(&self) -> ZwaveMode {')
lines.append(' ZwaveMode::new(self.0[6] & 0x3)')
lines.append(' }')
continue
lines.append(f" pub fn {field.name}(&self) -> {return_type} {{")
l = ' '
if len(field.byte_positions) == 1 and field.bit_width <= 8:
pos = field.byte_positions[0]
msb, lsb = pos.get_bit_range_tuple()
if field.signed and field.bit_width!=8:
l+='('
if field.bit_width==8:
l += f"self.0[{pos.byte_index}]"
elif lsb == 0:
bit_count = msb - lsb + 1
mask = (1 << bit_count) - 1
l += f"self.0[{pos.byte_index}] & 0x{mask:X}"
else:
bit_count = msb - lsb + 1
mask = (1 << bit_count) - 1
l += f"(self.0[{pos.byte_index}] >> {lsb}) & 0x{mask:X}"
if field.signed and field.bit_width!=8:
l+=')'
if return_type=='bool':
l += ' != 0'
elif field.signed:
mask = 1 << msb
l += f' as i8'
if field.bit_width<8:
l += f' - if (self.0[{pos.byte_index}] & {mask:#0x}) != 0 {{1<<{field.bit_width}}} else {{0}}'
else:
shift = 0
raw_type = return_type.replace('i','u') if not field.enum else return_type
if field.signed:
l += 'let raw = '
for (i,pos) in enumerate(reversed(field.byte_positions)): msb, lsb = pos.get_bit_range_tuple()
bit_count = msb - lsb + 1
has_mask = bit_count!=8 and msb!=7
if shift != 0 : l += f'(';
if has_mask : l += f'(';
if lsb!=0 : l += f'(';
l += f"(self.0[{pos.byte_index}]"
if lsb!=0:
l += f" >> {lsb})"
if has_mask :
mask = (1 << bit_count) - 1
l += f" & 0x{mask:X})"
l+= f' as {raw_type})'
if shift != 0 :
l += f' << {shift})'
shift += bit_count
if i+1 != len(field.byte_positions):
l+= ' |\n '
if field.signed: l+= ' ';
if field.signed:
bi = field.byte_positions[0].byte_index
mask = 1 << field.byte_positions[0].get_bit_range_tuple()[0]
l+= f';\n raw as {return_type}'
if field.bit_width!=16 and field.bit_width!=32 and field.bit_width!=64:
l+= f' - if (self.0[{bi}] & {mask:#0x}) != 0 {{1<<{field.bit_width}}} else {{0}}'
lines.append(l)
lines.append(" }")
if cmd.name == 'GetErrors':
lines.append(" /// 16 bits value")
lines.append(" pub fn value(&self) -> u16 {")
lines.append(" u16::from_be_bytes([self.0[2], self.0[3]])")
lines.append(" }\n")
lines.append(" /// Flag when no error are present")
lines.append(" pub fn none(&self) -> bool {")
lines.append(" self.0[2] == 0 && self.0[3] == 0")
lines.append(" }")
lines.append("}")
lines.append("")
lines.append(f"impl AsMut<[u8]> for {struct_name} {{")
lines.append(" fn as_mut(&mut self) -> &mut [u8] {")
lines.append(" &mut self.0")
lines.append(" }")
lines.append("}")
if cmd.name == 'GetVersion':
lines.append("#[cfg(feature = \"defmt\")]")
lines.append("impl defmt::Format for VersionRsp {")
lines.append(" fn format(&self, fmt: defmt::Formatter) {")
lines.append(" defmt::write!(fmt, \"{:02x}.{:02x}\", self.major(), self.minor());")
lines.append(" }")
lines.append("}")
elif cmd.name == 'GetErrors':
lines.append("#[cfg(feature = \"defmt\")]")
lines.append("impl defmt::Format for ErrorsRsp {")
lines.append(" fn format(&self, f: defmt::Formatter) {")
lines.append(" defmt::write!(f, \"Errors: \");")
lines.append(" if self.none() {")
lines.append(" defmt::write!(f, \"None\");")
lines.append(" return;")
lines.append(" }")
lines.append(" if self.hf_xosc_start() {defmt::write!(f, \"HfXoscStart \")};")
lines.append(" if self.lf_xosc_start() {defmt::write!(f, \"LfXoscStart \")};")
lines.append(" if self.pll_lock() {defmt::write!(f, \"PllLock \")};")
lines.append(" if self.lf_rc_calib() {defmt::write!(f, \"LfRcCalib \")};")
lines.append(" if self.hf_rc_calib() {defmt::write!(f, \"HfRcCalib \")};")
lines.append(" if self.pll_calib() {defmt::write!(f, \"PllCalib \")};")
lines.append(" if self.aaf_calib() {defmt::write!(f, \"AafCalib \")};")
lines.append(" if self.img_calib() {defmt::write!(f, \"ImgCalib \")};")
lines.append(" if self.chip_busy() {defmt::write!(f, \"ChipBusy \")};")
lines.append(" if self.rxfreq_no_fe_cal() {defmt::write!(f, \"RxfreqNoFeCal \")};")
lines.append(" if self.meas_unit_adc_calib() {defmt::write!(f, \"MeasUnitAdcCalib \")};")
lines.append(" if self.pa_offset_calib() {defmt::write!(f, \"PaOffsetCalib \")};")
lines.append(" if self.ppf_calib() {defmt::write!(f, \"PpfCalib \")};")
lines.append(" if self.src_calib() {defmt::write!(f, \"SrcCalib \")};")
lines.append(" }")
lines.append("}")
return '\n'.join(lines)
def gen_file(category: str, commands: list[Command], output_dir: Path) -> None:
file_path = output_dir / f"cmd_{category}.rs"
lines = [f"// {category.title()} commands API\n"]
has_rsp = any(len(cmd.status_fields) > 0 for cmd in commands)
if category=='system':
lines.append("use crate::status::{Status,Intr};")
elif has_rsp:
lines.append("use crate::status::Status;")
if category in ['ble', 'ook', 'zigbee', 'zwave', 'wisun', 'wmbus', 'raw']:
lines.append("use super::RxBw;")
if category in ['flrc', 'bpsk', 'ook']:
lines.append("use super::PulseShape;")
if category in ['lora']:
lines.append("use super::cmd_system::DioNum;")
enum_kind : dict[str,list[str]] = {}
enums : list[str] = []
skipped_cmd : list[str] = []
for cmd in commands:
has_var_length = any(p.bit_width == 0 for p in cmd.parameters)
if has_var_length:
skipped_cmd.append(cmd.name)
continue
for param in cmd.parameters:
if param.enum:
n = snake_to_pascal(param.name)
if n in enum_remap.keys() or (n in ['RxBw', 'PulseShape'] and category!='fsk') or n=='TempFormat' or (n=='DioNum' and category=='lora'):
continue
if n in enum_kind.keys():
if list(param.enum.keys()) != enum_kind[n]:
print(f'Conflicting type definition for {n} in {cmd.name}')
continue
enum_kind[n] = list(param.enum.keys())
enums.append(gen_enum(param))
for status in cmd.status_fields:
if status.enum:
n = snake_to_pascal(status.name)
if n in enum_remap.keys():
continue
if n in enum_kind.keys():
if list(status.enum.keys()) != enum_kind[n]:
print(f'Conflicting type definition for {n} in {cmd.name}')
continue
enum_kind[n] = list(status.enum.keys())
enums.append(gen_enum(status))
if enums:
lines.extend(enums)
lines.append("")
for cmd in commands:
has_var_length = any(p.bit_width == 0 for p in cmd.parameters)
if has_var_length:
continue
has_optional = any(p.optional for p in cmd.parameters)
lines.append(gen_req(cmd, category, advanced=False))
lines.append("")
if has_optional:
lines.append(gen_req(cmd, category, advanced=True))
lines.append("")
if has_rsp:
lines.append("// Response structs")
lines.append("")
for cmd in commands:
if not cmd.status_fields:
continue
has_optional_status = any(f.optional for f in cmd.status_fields)
struct_code = gen_rsp(cmd, category, advanced=False)
if struct_code:
lines.append(struct_code)
lines.append("")
if has_optional_status:
struct_code = gen_rsp(cmd, category, advanced=True)
if struct_code:
lines.append(struct_code)
lines.append("")
if skipped_cmd:
lines.append("// Commands with variable length parameters (not implemented):")
for cmd_name in skipped_cmd:
lines.append(f"// - {cmd_name}")
lines.append("")
with open(file_path, 'w') as f:
_ = f.write('\n'.join(lines))
def main():
yaml_path = Path(sys.argv[1]) if len(sys.argv) > 1 else Path("./commands.yaml")
output_dir = Path(sys.argv[2]) if len(sys.argv) > 2 else Path("../src/cmd")
output_dir.mkdir(parents=True, exist_ok=True)
try:
if yaml_path.is_file():
with open(yaml_path) as f:
data = yaml.safe_load(f)
for category, category_data in data.get('categories', {}).items():
print(f'Category {category}')
commands : list[Command] = []
for cmd_name, cmd_data in category_data.get('commands', {}).items():
try:
cmd = parse_command(cmd_name, cmd_data)
commands.append(cmd)
except ValidationError as e:
print(f"Error in {yaml_path}:{cmd_name}: {e}", file=sys.stderr)
sys.exit(1)
gen_file(category, commands, output_dir)
else:
print(f"Error: {yaml_path} is not a file or directory", file=sys.stderr)
sys.exit(1)
except Exception as e:
print(f"Error: {e}", file=sys.stderr)
sys.exit(1)
print("Code generation completed successfully!")
if __name__ == "__main__":
main()