lr2021 0.13.1

Driver for Semtech LR2021
Documentation
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
#!/usr/bin/env python3
# pyright: reportAny=false

import yaml
import sys
from typing import Any
from dataclasses import dataclass
from pathlib import Path

import re

def to_snake(s: str) -> str:
    # Add an underscore before each uppercase letter that is followed by a lowercase letter
    s = re.sub(r'([A-Z]+)([A-Z][a-z])', r'\1_\2', s)
    # Add an underscore before each lowercase letter that is preceded by an uppercase letter
    s = re.sub(r'([a-z\d])([A-Z])', r'\1_\2', s)
    # Convert the entire string to lowercase
    s = s.lower()
    return s

def snake_to_pascal(snake_str: str) -> str:
    """Convert snake_case to PascalCase"""
    # Replace _ between two number by p (like point)
    s = re.sub(r'(\d+)_(\d+)', r'\1p\2', snake_str)
    return ''.join(word.capitalize() for word in s.split('_'))

enum_remap : dict[str,str] = {
    'BitrateCh1' : 'ZwaveMode',
    'BitrateCh2' : 'ZwaveMode',
    'BitrateCh3' : 'ZwaveMode',
    'BitrateCh4' : 'ZwaveMode',
    'LastDetect' : 'ZwaveMode',
    'Sd1Sf' : 'Sf',
    'Sd2Sf' : 'Sf',
    'Sd3Sf' : 'Sf',
    'Sd1Ldro' : 'Ldro',
    'Sd2Ldro' : 'Ldro',
    'Sd3Ldro' : 'Ldro',
    'TriggerStop'  : 'CaptureTrigger',
}

@dataclass
class BytePosition:
    byte_index: int
    bit_range: str
    
    def get_bit_range_tuple(self) -> tuple[int, int]:
        """Parse bit range like '7:0' into (msb, lsb) tuple"""
        if ':' in self.bit_range:
            msb, lsb = map(int, self.bit_range.split(':'))
            return (msb, lsb)
        else:
            # Single bit
            bit = int(self.bit_range)
            return (bit, bit)

@dataclass
class Field:
    name: str
    bit_width: int
    signed: bool
    byte_positions: list[BytePosition]
    description: str
    little_endian: bool = False
    optional: bool = False
    enum: dict[str, int] | None = None

@dataclass
class Command:
    name: str
    opcode: int
    description: str
    parameters: list[Field]
    status_fields: list[Field]

class ValidationError(Exception):
    pass

def parse_byte_positions(positions_data: list[tuple[int,str]]) -> list[BytePosition]:
    """Parse byte positions from YAML format"""
    positions : list[BytePosition] = []
    for pos_data in positions_data:
        if len(pos_data) != 2:
            raise ValidationError(f"Invalid byte position format: {pos_data}")
        byte_index, bit_range = pos_data
        positions.append(BytePosition(byte_index, str(bit_range)))
    return positions

def validate_field(field: Field, context: str) -> None:
    """Validate field consistency"""
    if field.bit_width < 0:
        raise ValidationError(f"{context}: bit_width cannot be negative for field '{field.name}'")
    
    if field.bit_width == 0 and field.byte_positions:
        raise ValidationError(f"{context}: variable length field '{field.name}' should not have byte_positions")
    
    if field.bit_width > 0:
        # Calculate total bits from byte positions
        total_bits = 0
        for pos in field.byte_positions:
            msb, lsb = pos.get_bit_range_tuple()
            if msb < lsb:
                raise ValidationError(f"{context}: invalid bit range '{pos.bit_range}' for field '{field.name}' (MSB < LSB)")
            total_bits += (msb - lsb + 1)
        
        if total_bits != field.bit_width:
            raise ValidationError(f"{context}: bit_width ({field.bit_width}) doesn't match byte_positions total ({total_bits}) for field '{field.name}'")

def parse_field(field_data: dict[str, Any], context: str) -> Field:  # pyright: ignore[reportExplicitAny]
    """Parse a field from YAML data"""
    try:
        name : str = field_data['name']
        bit_width : int = field_data['bit_width']
        byte_positions = parse_byte_positions(field_data.get('byte_positions', []))
        signed : bool = field_data.get('signed', False)
        description : str = field_data.get('description', '')
        optional : bool = field_data.get('optional', False)
        enum: dict[str, int] | None = field_data.get('enum', None)
        le : bool = field_data.get('le', False)
        
        field = Field(name, bit_width, signed, byte_positions, description, le, optional, enum)
        validate_field(field, f"{context}.{name}")
        return field
        
    except KeyError as e:
        raise ValidationError(f"{context}: missing required field property: {e}")
    except Exception as e:
        raise ValidationError(f"{context}: error parsing field: {e}")

def parse_command(cmd_name: str, cmd_data: dict[str, Any]) -> Command:  # pyright: ignore[reportExplicitAny]
    """Parse a command from YAML data"""
    try:
        opcode : int = cmd_data['opcode']
        description : str = cmd_data.get('description', '')
        
        parameters : list[Field] = []
        for param_data in cmd_data.get('parameters', []):
            param = parse_field(param_data, f"command '{cmd_name}' parameter")
            parameters.append(param)
        
        status_fields : list[Field] = []
        for field_data in cmd_data.get('status_fields', []):
            field = parse_field(field_data, f"command '{cmd_name}' status_field")
            status_fields.append(field)
        
        return Command(cmd_name, opcode, description, parameters, status_fields)
        
    except KeyError as e:
        raise ValidationError(f"command '{cmd_name}': missing required property: {e}")
    except Exception as e:
        raise ValidationError(f"command '{cmd_name}': {e}")

def get_rust_type(field: Field) -> str:
    """Get appropriate Rust type for a field"""
    if field.enum:
        t = snake_to_pascal(field.name)
        if t == 'TriggerStart':
            t = 'CaptureTrigger'
        return enum_remap.get(t,t)
    elif field.bit_width == 0:
        return "&[u8]"  # Variable length
    elif field.bit_width == 1:
        return "bool"
    elif field.bit_width <= 8:
        return "i8" if field.signed else "u8"
    elif field.bit_width <= 16:
        return "i16" if field.signed else "u16"
    elif field.bit_width <= 32:
        return "i32" if field.signed else "u32"
    else:
        return "i64" if field.signed else "u64"

def gen_enum(field: Field) -> str:
    """Generate Rust enum for a field"""
    if not field.enum:
        return ''
    enum_name = snake_to_pascal(field.name)
    if enum_name == 'TriggerStart':
        enum_name = 'CaptureTrigger'
    lines = ["",f"/// {field.description}"]
    more_derive = ', PartialOrd, Ord' if enum_name in ['Sf'] else ''
    lines.append(f"#[derive(Debug, Clone, Copy, PartialEq, Eq{more_derive})]")
    lines.append("#[cfg_attr(feature = \"defmt\", derive(defmt::Format))]")
    lines.append(f"pub enum {enum_name} {{")

    for variant_name, value in field.enum.items():
        name = snake_to_pascal(variant_name)
        # print(f'Enum {enum_name} : {variant_name} -> {name}')
        lines.append(f"    {name} = {value},")
    
    lines.append("}")
    # Add methods for some enum
    if enum_name == 'ZwaveMode':
        lines.append("\nimpl ZwaveMode {")
        lines.append("    pub fn new(val: u8) -> Self{")
        lines.append("        match val {")
        lines.append("            3 => ZwaveMode::R3,")
        lines.append("            2 => ZwaveMode::R2,")
        lines.append("            _ => ZwaveMode::R1,")
        lines.append("        }")
        lines.append("    }")
        lines.append("}")
    elif enum_name == 'LoraBw' :
        lines.append("\nimpl LoraBw {")
        lines.append("    /// Return Bandwidth in Hz")
        lines.append("    pub fn to_hz(&self) -> u32 {")
        lines.append("        match self {")
        lines.append("            LoraBw::Bw1000 => 1_000_000,")
        lines.append("            LoraBw::Bw812  =>   812_500,")
        lines.append("            LoraBw::Bw500  =>   500_000,")
        lines.append("            LoraBw::Bw406  =>   406_250,")
        lines.append("            LoraBw::Bw250  =>   250_000,")
        lines.append("            LoraBw::Bw203  =>   203_125,")
        lines.append("            LoraBw::Bw125  =>   125_000,")
        lines.append("            LoraBw::Bw101  =>   101_562,")
        lines.append("            LoraBw::Bw83   =>    83_333,")
        lines.append("            LoraBw::Bw62   =>    62_500,")
        lines.append("            LoraBw::Bw41   =>    41_666,")
        lines.append("            LoraBw::Bw31   =>    31_250,")
        lines.append("            LoraBw::Bw20   =>    20_833,")
        lines.append("            LoraBw::Bw15   =>    15_625,")
        lines.append("            LoraBw::Bw10   =>    10_416,")
        lines.append("            LoraBw::Bw7    =>     7_812,")
        lines.append("        }")
        lines.append("    }\n")
        lines.append("    /// Flag Fractional bandwidth");
        lines.append("    /// Corresponds to band used in SX1280");
        lines.append("    pub fn is_fractional(&self) -> bool {");
        lines.append("        use LoraBw::*;");
        lines.append("        matches!(self, Bw812 | Bw406 | Bw203 | Bw101)");
        lines.append("    }");
        lines.append("}\n")
        lines.append("impl PartialOrd for LoraBw {")
        lines.append("    fn partial_cmp(&self, other: &Self) -> Option<core::cmp::Ordering> {")
        lines.append("        Some(self.cmp(other))")
        lines.append("    }")
        lines.append("}\n")
        lines.append("impl Ord for LoraBw {")
        lines.append("    fn cmp(&self, other: &Self) -> core::cmp::Ordering {")
        lines.append("        self.to_hz().cmp(&other.to_hz())")
        lines.append("    }")
        lines.append("}")
    elif enum_name == 'LoraCr' :
        lines.append("\nimpl LoraCr {")
        lines.append("    /// Return if Code-rate uses long interleaving")
        lines.append("    pub fn is_li(&self) -> bool {")
        lines.append("        use LoraCr::*;")
        lines.append("        matches!(self, Cr5Ham45Li|Cr6Ham23Li|Cr7Ham12Li|Cr8Cc23|Cr9Cc12)")
        lines.append("    }")
        lines.append("    /// Return denominator for the coding rate, supposing a numerator equal to 4")
        lines.append("    pub fn denominator(&self) -> u8 {")
        lines.append("        match self {")
        lines.append("            LoraCr::NoCoding   => 4,")
        lines.append("            // Code rate 4/5")
        lines.append("            LoraCr::Cr1Ham45Si |")
        lines.append("            LoraCr::Cr5Ham45Li => 5,")
        lines.append("            // Code rate 2/3 -> 4/6")
        lines.append("            LoraCr::Cr2Ham23Si |")
        lines.append("            LoraCr::Cr6Ham23Li |")
        lines.append("            LoraCr::Cr8Cc23    => 6,")
        lines.append("            // Code rate 4/7")
        lines.append("            LoraCr::Cr3Ham47Si => 7,")
        lines.append("            // Code rate 1/2 -> 4/8")
        lines.append("            LoraCr::Cr4Ham12Si |")
        lines.append("            LoraCr::Cr7Ham12Li |")
        lines.append("            LoraCr::Cr9Cc12    => 8,")
        lines.append("        }")
        lines.append("    }")
        lines.append("}")
    elif enum_name == 'WmbusMode' :
        lines.append("")
        lines.append("/// WM-Bus mode selection")
        lines.append("#[derive(Debug, Clone, Copy, PartialEq, Eq)]")
        lines.append("#[cfg_attr(feature = \"defmt\", derive(defmt::Format))]")
        lines.append("pub enum WmbusSubBand {A,B,C,D}")
        lines.append("")
        lines.append("impl WmbusMode {")
        lines.append("    /// Return center frequency associated with a mode")
        lines.append("    /// The channel index is only used in Mode R2 and N")
        lines.append("    /// The sub-band is only used in Mode N")
        lines.append("    pub fn rf(&self, channel: u8, subband: WmbusSubBand) -> u32 {")
        lines.append("        use {WmbusMode::*, WmbusSubBand::*};")
        lines.append("        match self {")
        lines.append("            ModeS     => 868_300_000,")
        lines.append("            ModeT1    |")
        lines.append("            ModeT2M2o => 868_950_000,")
        lines.append("            ModeT2O2m => 868_300_000,")
        lines.append("            ModeR2    => 868_030_000 + 60_000 * channel as u32,")
        lines.append("            ModeC1    |")
        lines.append("            ModeC2M2o => 868_950_000,")
        lines.append("            ModeC2O2m => 868_525_000,")
        lines.append("            ModeF2    => 433_820_000,")
        lines.append("            ModeN4p8 |")
        lines.append("            ModeN2p4 |")
        lines.append("            ModeN6p4 => {")
        lines.append("                match subband {")
        lines.append("                    A => 169_406_250 + 12_500 * channel as u32,")
        lines.append("                    B => 169_481_250 ,")
        lines.append("                    C => 169_493_750 + 12_500 * channel as u32,")
        lines.append("                    D => 169_593_250 + 12_500 * channel as u32,")
        lines.append("                }")
        lines.append("            }")
        lines.append("            ModeN19p2 => {")
        lines.append("                match subband {")
        lines.append("                    A => 169_437_500,")
        lines.append("                    B => 169_481_250,")
        lines.append("                    C => 169_493_750,")
        lines.append("                    D => 169_625_000 + 50_000 * channel as u32,")
        lines.append("                }")
        lines.append("            }")
        lines.append("        }")
        lines.append("    }")
        lines.append("}")


    return '\n'.join(lines)

def size_of(fields: list[Field]) -> int:
    if not fields:
        return 2  # Just status bytes
    
    max_byte = 1  # Start with status bytes (0, 1)
    for field in fields:
        for pos in field.byte_positions:
            max_byte = max(max_byte, pos.byte_index)
    
    return max_byte + 1

def gen_req(cmd: Command, _category: str, advanced: bool = False) -> str:
    """Generate Rust request function"""
    if cmd.opcode < 0 :
        return ''
    func_name = to_snake(cmd.name)
    
    if advanced:
        func_name += "_adv"

    func_suffix = "_req" if cmd.status_fields else "_cmd"
    func_name += func_suffix
    
    # Filter parameters based on advanced flag
    if advanced:
        params = cmd.parameters
    else:
        params = [p for p in cmd.parameters if not p.optional]
    
    # Skip variable length parameters for now
    skipped_params = [p for p in params if p.bit_width == 0]
    if skipped_params:
        return f"// TODO: Implement {func_name} (contains variable length parameters: {[p.name for p in skipped_params]})"
    
    # Generate function signature
    param_list: list[str] = []
    for param in params:
        # Skip the temp_format
        if param.name == 'temp_format':
            continue
        param_type = get_rust_type(param)
        param_list.append(f"{param.name}: {param_type}")
    
    # Calculate buffer size
    buffer_size = size_of(params)

    lines : list[str] = []
    if len(params) > 7:
        lines.append('#[allow(clippy::too_many_arguments)]')
    desc = cmd.description.replace('[', '(').replace(']', ')')
    lines.append(f"/// {desc}")
    if param_list:
        lines.append(f"pub fn {func_name}({', '.join(param_list)}) -> [u8; {buffer_size}] {{")
    else:
        lines.append(f"pub fn {func_name}() -> [u8; {buffer_size}] {{")
    
    # Generate opcode bytes
    opcode_msb = (cmd.opcode >> 8) & 0xFF
    opcode_lsb = cmd.opcode & 0xFF
    
    if not params:
        lines.append(f"    [0x{opcode_msb:02X}, 0x{opcode_lsb:02X}]")
    else:
        # Generate parameter packing code
        lines.append("    let mut cmd = [0u8; {}];".format(buffer_size))
        lines.append(f"    cmd[0] = 0x{opcode_msb:02X};")
        lines.append(f"    cmd[1] = 0x{opcode_lsb:02X};")
        lines.append("")
        
        for param in params:
            if param.bit_width == 0:
                continue

            need_cast_u8 = param.signed or param.bit_width > 8
            mask = (1 << param.bit_width) - 1 if param.bit_width < 8 else 255
            # Generate bit packing for each byte position
            for pos in param.byte_positions:
                _, lsb = pos.get_bit_range_tuple()
                if param.bit_width == 1 and not param.enum:
                    lines.append(f"    if {param.name} {{ cmd[{pos.byte_index}] |= {1<<lsb}; }}")
                elif param.name == 'temp_format':
                    lines.append(f"    cmd[2] |= 8; // Force format to Celsius")
                else:
                    if param.little_endian:
                        shift_right = sum((p_msb - p_lsb + 1) for p_pos in param.byte_positions[:param.byte_positions.index(pos)] for p_msb, p_lsb in [p_pos.get_bit_range_tuple()])
                    else :
                        shift_right = param.bit_width - sum((p_msb - p_lsb + 1) for p_pos in param.byte_positions[:param.byte_positions.index(pos)+1] for p_msb, p_lsb in [p_pos.get_bit_range_tuple()])
                    l = f"    cmd[{pos.byte_index}] |= "
                    if need_cast_u8:  l += '(';
                    if lsb!= 0: l += '(';
                    if shift_right!=0: l += '(';
                    if param.enum:
                        # print(f'Enum {param.name} : shift_right={shift_right}, lsb={lsb}, width={param.bit_width}')
                        if shift_right==0 and lsb==0 and param.bit_width==8:
                            l += f"{param.name} as u8"
                        else:
                            l += f"({param.name} as u8)"
                    else:
                        l += param.name
                    if shift_right!=0:
                        l += f" >> {shift_right})"
                    if param.bit_width!=8:
                        l += f" & 0x{mask:X}"
                    if lsb!= 0:
                        l += f') << {lsb}'
                    if need_cast_u8:
                        l += ') as u8';
                    l += ';'
                    lines.append(l)
        
        lines.append("    cmd")
    
    lines.append("}")
    return '\n'.join(lines)

def gen_rsp(cmd: Command, _category: str, advanced: bool = False) -> str:
    """Generate Rust response struct"""
    if not cmd.status_fields:
        return ""
    
    struct_name = f"{cmd.name}Rsp"
    if advanced:
        struct_name += "Adv"
    if struct_name.startswith("Get"):
        struct_name = struct_name[3:]
    
    # Filter fields based on advanced flag
    if advanced:
        fields = cmd.status_fields
    else:
        fields = [f for f in cmd.status_fields if not f.optional]
    
    buffer_size = size_of(fields)
    
    lines = [f"/// Response for {cmd.name} command"]
    lines.append("#[derive(Default)]")
    lines.append(f"pub struct {struct_name}([u8; {buffer_size}]);")
    lines.append("")
    lines.append(f"impl {struct_name} {{")
    lines.append("    /// Create a new response buffer")
    lines.append("    pub fn new() -> Self {")
    lines.append("        Self::default()")
    lines.append("    }")
    lines.append("")
    lines.append("    /// Return Status")
    lines.append("    pub fn status(&mut self) -> Status {")
    lines.append("        Status::from_slice(&self.0[..2])")
    lines.append("    }")

    # Generate accessor methods for each field
    for field in fields:
            
        if field.bit_width == 0:
            lines.append(f"    // TODO: Implement accessor for variable length field '{field.name}'")
            continue
        
        return_type = get_rust_type(field).replace('&[u8]', 'u32')  # Variable length becomes u32 for now
        
        lines.append(f"")
        desc = field.description.replace('[', '(').replace(']', ')')
        lines.append(f"    /// {desc}")

        # Custom implementation
        if cmd.name == 'GetStatus' and field.name=='intr':
            lines.append(f"    pub fn {field.name}(&self) -> Intr {{")
            lines.append('        Intr::from_slice(&self.0[2..6])')
            lines.append('    }')
            continue
        if cmd.name == 'GetZwavePacketStatus' and field.name=='last_detect':
            lines.append('    pub fn last_detect(&self) -> ZwaveMode {')
            lines.append('        ZwaveMode::new(self.0[6] & 0x3)')
            lines.append('    }')
            continue

        # Implementation
        lines.append(f"    pub fn {field.name}(&self) -> {return_type} {{")
        l = '        '
        
        if len(field.byte_positions) == 1 and field.bit_width <= 8:
            # Simple single byte case
            pos = field.byte_positions[0]
            msb, lsb = pos.get_bit_range_tuple()
            if field.signed and field.bit_width!=8:
                l+='('
            if field.bit_width==8:
                l += f"self.0[{pos.byte_index}]"
            elif lsb == 0:
                bit_count = msb - lsb + 1
                mask = (1 << bit_count) - 1
                l += f"self.0[{pos.byte_index}] & 0x{mask:X}"
            else:
                bit_count = msb - lsb + 1
                mask = (1 << bit_count) - 1
                l += f"(self.0[{pos.byte_index}] >> {lsb}) & 0x{mask:X}"
            if field.signed and field.bit_width!=8:
                l+=')'
            if return_type=='bool':
                l += ' != 0'
            elif field.signed:
                mask = 1 << msb
                l += f' as i8'
                if field.bit_width<8:
                    l += f' - if (self.0[{pos.byte_index}] & {mask:#0x}) != 0 {{1<<{field.bit_width}}} else {{0}}'

        else:
            # Multi-byte or complex bit extraction
            shift = 0
            raw_type = return_type.replace('i','u') if not field.enum else return_type

            if field.signed:
                l += 'let raw = '
            for (i,pos) in enumerate(reversed(field.byte_positions)):  # Process LSB first
                msb, lsb = pos.get_bit_range_tuple()

                bit_count = msb - lsb + 1
                has_mask = bit_count!=8 and msb!=7
                if shift != 0 : l += f'(';
                if has_mask   : l += f'(';
                if lsb!=0     : l += f'(';
                l += f"(self.0[{pos.byte_index}]"
                if lsb!=0:
                    l += f" >> {lsb})"
                if has_mask :
                    mask = (1 << bit_count) - 1
                    l += f" & 0x{mask:X})"
                l+= f' as {raw_type})'
                if shift != 0 :
                    l += f' << {shift})'
                shift += bit_count
                if i+1 != len(field.byte_positions):
                    l+= ' |\n        '
                    if field.signed: l+= '    ';
            if field.signed:
                bi = field.byte_positions[0].byte_index
                mask = 1 << field.byte_positions[0].get_bit_range_tuple()[0]
                l+= f';\n        raw as {return_type}'
                # No need to apply any offset when already aligned on a word boundary
                if field.bit_width!=16 and field.bit_width!=32 and field.bit_width!=64:
                    l+= f' - if (self.0[{bi}] & {mask:#0x}) != 0 {{1<<{field.bit_width}}} else {{0}}'
        lines.append(l)
        
        lines.append("    }")
    
    if cmd.name == 'GetErrors':
        lines.append("    /// 16 bits value")
        lines.append("    pub fn value(&self) -> u16 {")
        lines.append("        u16::from_be_bytes([self.0[2], self.0[3]])")
        lines.append("    }\n")
        lines.append("    /// Flag when no error are present")
        lines.append("    pub fn none(&self) -> bool {")
        lines.append("        self.0[2] == 0 && self.0[3] == 0")
        lines.append("    }")

    lines.append("}")
    lines.append("")
    lines.append(f"impl AsMut<[u8]> for {struct_name} {{")
    lines.append("    fn as_mut(&mut self) -> &mut [u8] {")
    lines.append("        &mut self.0")
    lines.append("    }")
    lines.append("}")

    if cmd.name == 'GetVersion':
        lines.append("#[cfg(feature = \"defmt\")]")
        lines.append("impl defmt::Format for VersionRsp {")
        lines.append("    fn format(&self, fmt: defmt::Formatter) {")
        lines.append("        defmt::write!(fmt, \"{:02x}.{:02x}\", self.major(), self.minor());")
        lines.append("    }")
        lines.append("}")
    elif cmd.name == 'GetErrors':
        lines.append("#[cfg(feature = \"defmt\")]")
        lines.append("impl defmt::Format for ErrorsRsp {")
        lines.append("    fn format(&self, f: defmt::Formatter) {")
        lines.append("        defmt::write!(f, \"Errors: \");")
        lines.append("        if self.none() {")
        lines.append("            defmt::write!(f, \"None\");")
        lines.append("            return;")
        lines.append("        }")
        lines.append("        if self.hf_xosc_start()       {defmt::write!(f, \"HfXoscStart \")};")
        lines.append("        if self.lf_xosc_start()       {defmt::write!(f, \"LfXoscStart \")};")
        lines.append("        if self.pll_lock()            {defmt::write!(f, \"PllLock \")};")
        lines.append("        if self.lf_rc_calib()         {defmt::write!(f, \"LfRcCalib \")};")
        lines.append("        if self.hf_rc_calib()         {defmt::write!(f, \"HfRcCalib \")};")
        lines.append("        if self.pll_calib()           {defmt::write!(f, \"PllCalib \")};")
        lines.append("        if self.aaf_calib()           {defmt::write!(f, \"AafCalib \")};")
        lines.append("        if self.img_calib()           {defmt::write!(f, \"ImgCalib \")};")
        lines.append("        if self.chip_busy()           {defmt::write!(f, \"ChipBusy \")};")
        lines.append("        if self.rxfreq_no_fe_cal()    {defmt::write!(f, \"RxfreqNoFeCal \")};")
        lines.append("        if self.meas_unit_adc_calib() {defmt::write!(f, \"MeasUnitAdcCalib \")};")
        lines.append("        if self.pa_offset_calib()     {defmt::write!(f, \"PaOffsetCalib \")};")
        lines.append("        if self.ppf_calib()           {defmt::write!(f, \"PpfCalib \")};")
        lines.append("        if self.src_calib()           {defmt::write!(f, \"SrcCalib \")};")
        lines.append("    }")
        lines.append("}")
    
    return '\n'.join(lines)

def gen_file(category: str, commands: list[Command], output_dir: Path) -> None:
    """Generate complete Rust file for a category"""
    file_path = output_dir / f"cmd_{category}.rs"
    
    lines = [f"// {category.title()} commands API\n"]

    has_rsp = any(len(cmd.status_fields) > 0 for cmd in commands)

    if category=='system':
        lines.append("use crate::status::{Status,Intr};")
    elif has_rsp:
        lines.append("use crate::status::Status;")
    if category in ['ble', 'ook', 'zigbee', 'zwave', 'wisun', 'wmbus', 'raw']:
        lines.append("use super::RxBw;")
    if category in ['flrc', 'bpsk', 'ook']:
        lines.append("use super::PulseShape;")
    if category in ['lora']:
        lines.append("use super::cmd_system::DioNum;")
    
    # Collect all enums first
    enum_kind : dict[str,list[str]] = {}
    enums : list[str] = []
    skipped_cmd : list[str] = []
    
    for cmd in commands:
        # Check for variable length parameters
        has_var_length = any(p.bit_width == 0 for p in cmd.parameters)
        if has_var_length:
            skipped_cmd.append(cmd.name)
            continue
        
        for param in cmd.parameters:
            if param.enum:
                n = snake_to_pascal(param.name)
                if n in enum_remap.keys() or (n in ['RxBw', 'PulseShape'] and category!='fsk') or n=='TempFormat' or (n=='DioNum' and category=='lora'):
                    continue
                if n in enum_kind.keys():
                    if list(param.enum.keys()) != enum_kind[n]:
                        print(f'Conflicting type definition for {n} in {cmd.name}')
                    continue
                enum_kind[n] = list(param.enum.keys())
                enums.append(gen_enum(param))

        for status in cmd.status_fields:
            if status.enum:
                n = snake_to_pascal(status.name)
                if n in enum_remap.keys():
                    continue
                if n in enum_kind.keys():
                    if list(status.enum.keys()) != enum_kind[n]:
                        print(f'Conflicting type definition for {n} in {cmd.name}')
                    continue
                enum_kind[n] = list(status.enum.keys())
                enums.append(gen_enum(status))
    
    # Add enums
    if enums:
        lines.extend(enums)
        lines.append("")
    
    # Generate request functions
    for cmd in commands:
        has_var_length = any(p.bit_width == 0 for p in cmd.parameters)
        if has_var_length:
            continue
            
        has_optional = any(p.optional for p in cmd.parameters)
        
        # Generate base function
        lines.append(gen_req(cmd, category, advanced=False))
        lines.append("")
        
        # Generate advanced function if needed
        if has_optional:
            lines.append(gen_req(cmd, category, advanced=True))
            lines.append("")
    
    # Generate response structs
    if has_rsp:
        lines.append("// Response structs")
        lines.append("")
    
        for cmd in commands:
            if not cmd.status_fields:
                continue

            has_optional_status = any(f.optional for f in cmd.status_fields)

            # Generate base struct
            struct_code = gen_rsp(cmd, category, advanced=False)
            if struct_code:
                lines.append(struct_code)
                lines.append("")

            # Generate advanced struct if needed
            if has_optional_status:
                struct_code = gen_rsp(cmd, category, advanced=True)
                if struct_code:
                    lines.append(struct_code)
                    lines.append("")
    
    # Add summary of unimplemented commands
    if skipped_cmd:
        lines.append("// Commands with variable length parameters (not implemented):")
        for cmd_name in skipped_cmd:
            lines.append(f"// - {cmd_name}")
        lines.append("")
    
    # Write file
    with open(file_path, 'w') as f:
        _ = f.write('\n'.join(lines))

def main():
    
    yaml_path = Path(sys.argv[1]) if len(sys.argv) > 1 else Path("./commands.yaml")
    output_dir = Path(sys.argv[2]) if len(sys.argv) > 2 else Path("../src/cmd")
    
    # Create output directory if it doesn't exist
    output_dir.mkdir(parents=True, exist_ok=True)
    
    try:
        if yaml_path.is_file():
            # Single YAML file
            with open(yaml_path) as f:
                data = yaml.safe_load(f)
            
            # Parse commands
            for category, category_data in data.get('categories', {}).items():
                print(f'Category {category}')
                commands : list[Command] = []
                for cmd_name, cmd_data in category_data.get('commands', {}).items():
                    try:
                        cmd = parse_command(cmd_name, cmd_data)
                        commands.append(cmd)
                    except ValidationError as e:
                        print(f"Error in {yaml_path}:{cmd_name}: {e}", file=sys.stderr)
                        sys.exit(1)
            
                gen_file(category, commands, output_dir)

        else:
            print(f"Error: {yaml_path} is not a file or directory", file=sys.stderr)
            sys.exit(1)
            
    except Exception as e:
        print(f"Error: {e}", file=sys.stderr)
        sys.exit(1)
    
    print("Code generation completed successfully!")

if __name__ == "__main__":
    main()