device_driver_generation/mir/
lir_transform.rs

1use std::ops::{Add, Not};
2
3use anyhow::ensure;
4
5use crate::{
6    lir,
7    mir::{self, passes::search_object},
8};
9
10use super::{
11    Integer,
12    passes::{find_min_max_addresses, recurse_objects},
13};
14
15pub fn transform(device: mir::Device, driver_name: &str) -> anyhow::Result<lir::Device> {
16    let lenient_pascal_converter = convert_case::Converter::new()
17        .set_boundaries(&convert_case::Boundary::list_from("aA:AAa:_:-: :a1:A1:1A"))
18        .set_pattern(convert_case::Pattern::Capital);
19    let converted_driver_name = lenient_pascal_converter.convert(driver_name);
20
21    ensure!(
22        driver_name == converted_driver_name,
23        "The device name must be given in PascalCase, e.g. \"{}\"",
24        converted_driver_name
25    );
26
27    let mir_enums = collect_enums(&device)?;
28    let lir_enums = mir_enums
29        .iter()
30        .map(|(e, base_type, size_bits)| transform_enum(e, *base_type, *size_bits))
31        .collect::<Result<_, anyhow::Error>>()?;
32
33    let field_sets = transform_field_sets(&device, mir_enums.iter().map(|(e, _, _)| e))?;
34
35    // Create a root block and pass the device objects to it
36    let blocks = collect_into_blocks(
37        BorrowedBlock {
38            cfg_attr: &mir::Cfg::new(None),
39            description: &format!("Root block of the {driver_name} driver"),
40            name: &driver_name.into(),
41            address_offset: &0,
42            repeat: &None,
43            objects: &device.objects,
44        },
45        true,
46        &device.global_config,
47        &device.objects,
48    )?;
49
50    Ok(lir::Device {
51        internal_address_type: find_best_internal_address_type(&device),
52        register_address_type: device
53            .global_config
54            .register_address_type
55            .unwrap_or(mir::Integer::U8),
56        blocks,
57        field_sets,
58        enums: lir_enums,
59        defmt_feature: device.global_config.defmt_feature,
60    })
61}
62
63fn collect_into_blocks(
64    block: BorrowedBlock,
65    is_root: bool,
66    global_config: &mir::GlobalConfig,
67    device_objects: &[mir::Object],
68) -> anyhow::Result<Vec<lir::Block>> {
69    let mut blocks = Vec::new();
70
71    let BorrowedBlock {
72        cfg_attr,
73        description,
74        name,
75        address_offset: _,
76        repeat: _,
77        objects,
78    } = block;
79
80    let mut methods = Vec::new();
81
82    for object in objects {
83        let method = get_method(
84            object,
85            &mut blocks,
86            global_config,
87            device_objects,
88            "new".to_string(),
89        )?;
90
91        methods.push(method);
92    }
93
94    let new_block = lir::Block {
95        cfg_attr: cfg_attr.to_string(),
96        description: description.clone(),
97        root: is_root,
98        name: name.to_string(),
99        methods,
100    };
101
102    blocks.insert(0, new_block);
103
104    Ok(blocks)
105}
106
107fn get_method(
108    object: &mir::Object,
109    blocks: &mut Vec<lir::Block>,
110    global_config: &mir::GlobalConfig,
111    device_objects: &[mir::Object],
112    mut register_reset_value_function: String,
113) -> Result<lir::BlockMethod, anyhow::Error> {
114    use convert_case::Casing;
115
116    Ok(match object {
117        mir::Object::Block(
118            b @ mir::Block {
119                cfg_attr,
120                description,
121                name,
122                address_offset,
123                repeat,
124                ..
125            },
126        ) => {
127            blocks.extend(collect_into_blocks(
128                b.into(),
129                false,
130                global_config,
131                device_objects,
132            )?);
133
134            lir::BlockMethod {
135                cfg_attr: cfg_attr.to_string(),
136                description: description.clone(),
137                name: name.to_case(convert_case::Case::Snake),
138                address: *address_offset,
139                allow_address_overlap: false,
140                kind: repeat_to_method_kind(repeat),
141                method_type: lir::BlockMethodType::Block {
142                    name: name.to_string(),
143                },
144            }
145        }
146        mir::Object::Register(mir::Register {
147            cfg_attr,
148            description,
149            name,
150            allow_address_overlap,
151            address,
152            access,
153            repeat,
154            ..
155        }) => lir::BlockMethod {
156            cfg_attr: cfg_attr.to_string(),
157            description: description.clone(),
158            name: name.to_case(convert_case::Case::Snake),
159            address: *address,
160            allow_address_overlap: *allow_address_overlap,
161            kind: repeat_to_method_kind(repeat),
162            method_type: lir::BlockMethodType::Register {
163                field_set_name: name.to_string(),
164                access: *access,
165                address_type: global_config
166                    .register_address_type
167                    .expect("The presence of the address type is already checked in a mir pass"),
168                reset_value_function: register_reset_value_function.clone(),
169            },
170        },
171        mir::Object::Command(mir::Command {
172            cfg_attr,
173            description,
174            name,
175            allow_address_overlap,
176            address,
177            repeat,
178            in_fields,
179            out_fields,
180            ..
181        }) => lir::BlockMethod {
182            cfg_attr: cfg_attr.to_string(),
183            description: description.clone(),
184            name: name.to_case(convert_case::Case::Snake),
185            address: *address,
186            allow_address_overlap: *allow_address_overlap,
187            kind: repeat_to_method_kind(repeat),
188            method_type: lir::BlockMethodType::Command {
189                field_set_name_in: in_fields
190                    .is_empty()
191                    .not()
192                    .then(|| format!("{name}FieldsIn")),
193                field_set_name_out: out_fields
194                    .is_empty()
195                    .not()
196                    .then(|| format!("{name}FieldsOut")),
197                address_type: global_config
198                    .command_address_type
199                    .expect("The presence of the address type is already checked in a mir pass"),
200            },
201        },
202        mir::Object::Buffer(mir::Buffer {
203            cfg_attr,
204            description,
205            name,
206            access,
207            address,
208        }) => lir::BlockMethod {
209            cfg_attr: cfg_attr.to_string(),
210            description: description.clone(),
211            name: name.to_case(convert_case::Case::Snake),
212            address: *address,
213            allow_address_overlap: false,
214            kind: lir::BlockMethodKind::Normal, // Buffers can't be repeated (for now?)
215            method_type: lir::BlockMethodType::Buffer {
216                access: *access,
217                address_type: global_config
218                    .buffer_address_type
219                    .expect("The presence of the address type is already checked in a mir pass"),
220            },
221        },
222        mir::Object::Ref(mir::RefObject {
223            cfg_attr,
224            description,
225            name,
226            object_override,
227        }) => {
228            let mut reffed_object = search_object(object_override.name(), device_objects)
229                .expect("All refs are validated in a mir pass")
230                .clone();
231
232            match object_override {
233                mir::ObjectOverride::Block(override_values) => {
234                    let reffed_object = reffed_object
235                        .as_block_mut()
236                        .expect("All refs are validated in a mir pass");
237                    reffed_object.cfg_attr = cfg_attr.clone();
238                    reffed_object.description = description.clone();
239
240                    if let Some(address_offset) = override_values.address_offset {
241                        reffed_object.address_offset = address_offset;
242                    }
243                    if let Some(repeat) = override_values.repeat {
244                        reffed_object.repeat = Some(repeat);
245                    }
246                }
247                mir::ObjectOverride::Register(override_values) => {
248                    let reffed_object = reffed_object
249                        .as_register_mut()
250                        .expect("All refs are validated in a mir pass");
251                    reffed_object.cfg_attr = cfg_attr.clone();
252                    reffed_object.description = description.clone();
253
254                    if let Some(access) = override_values.access {
255                        reffed_object.access = access;
256                    }
257                    if let Some(address) = override_values.address {
258                        reffed_object.address = address;
259                    }
260                    if let Some(reset_value) = override_values.reset_value.clone() {
261                        reffed_object.reset_value = Some(reset_value);
262                        register_reset_value_function =
263                            format!("new_as_{}", name.to_case(convert_case::Case::Snake));
264                    }
265                    if let Some(repeat) = override_values.repeat {
266                        reffed_object.repeat = Some(repeat);
267                    }
268                }
269                mir::ObjectOverride::Command(override_values) => {
270                    let reffed_object = reffed_object
271                        .as_command_mut()
272                        .expect("All refs are validated in a mir pass");
273                    reffed_object.cfg_attr = cfg_attr.clone();
274                    reffed_object.description = description.clone();
275
276                    if let Some(address) = override_values.address {
277                        reffed_object.address = address;
278                    }
279                    if let Some(repeat) = override_values.repeat {
280                        reffed_object.repeat = Some(repeat);
281                    }
282                }
283            }
284
285            let mut method = get_method(
286                &reffed_object,
287                blocks,
288                global_config,
289                device_objects,
290                register_reset_value_function,
291            )?;
292
293            // We kept the old name in the reffed object so it generates with the correct field sets.
294            // But we do want to have the name of ref to be the method name.
295            method.name = name.to_case(convert_case::Case::Snake);
296
297            method
298        }
299    })
300}
301
302fn transform_field_sets<'a>(
303    device: &mir::Device,
304    mir_enums: impl Iterator<Item = &'a mir::Enum> + Clone,
305) -> anyhow::Result<Vec<lir::FieldSet>> {
306    let mut field_sets = Vec::new();
307
308    recurse_objects(&device.objects, &mut |object| {
309        match object {
310            mir::Object::Register(r) => {
311                let ref_reset_overrides = find_refs(device, object)?
312                    .iter()
313                    .map(|r| {
314                        (
315                            &r.name,
316                            r.object_override
317                                .as_register()
318                                .expect("Ref must be register override"),
319                        )
320                    })
321                    .filter_map(|(ref_name, ro)| {
322                        ro.reset_value.as_ref().map(|reset_value| {
323                            (ref_name.clone(), reset_value.as_array().unwrap().clone())
324                        })
325                    })
326                    .collect();
327
328                field_sets.push(transform_field_set(
329                    &r.fields,
330                    r.name.clone(),
331                    &r.cfg_attr,
332                    &r.description,
333                    r.byte_order.unwrap(),
334                    r.bit_order,
335                    r.size_bits,
336                    r.reset_value
337                        .as_ref()
338                        .map(|rv| rv.as_array().unwrap().clone()),
339                    ref_reset_overrides,
340                    mir_enums.clone(),
341                )?);
342            }
343            mir::Object::Command(c) => {
344                if c.size_bits_in != 0 {
345                    field_sets.push(transform_field_set(
346                        &c.in_fields,
347                        format!("{}FieldsIn", c.name),
348                        &c.cfg_attr,
349                        &c.description,
350                        c.byte_order.unwrap(),
351                        c.bit_order,
352                        c.size_bits_in,
353                        None,
354                        Vec::new(),
355                        mir_enums.clone(),
356                    )?);
357                }
358                if c.size_bits_out != 0 {
359                    field_sets.push(transform_field_set(
360                        &c.out_fields,
361                        format!("{}FieldsOut", c.name),
362                        &c.cfg_attr,
363                        &c.description,
364                        c.byte_order.unwrap(),
365                        c.bit_order,
366                        c.size_bits_out,
367                        None,
368                        Vec::new(),
369                        mir_enums.clone(),
370                    )?);
371                }
372            }
373            _ => {}
374        }
375
376        Ok(())
377    })?;
378
379    Ok(field_sets)
380}
381
382#[allow(clippy::too_many_arguments)] // Though it is correct... it's too many args
383fn transform_field_set<'a>(
384    field_set: &[mir::Field],
385    field_set_name: String,
386    cfg_attr: &mir::Cfg,
387    description: &str,
388    byte_order: mir::ByteOrder,
389    bit_order: mir::BitOrder,
390    size_bits: u32,
391    reset_value: Option<Vec<u8>>,
392    ref_reset_overrides: Vec<(String, Vec<u8>)>,
393    enum_list: impl Iterator<Item = &'a mir::Enum> + Clone,
394) -> anyhow::Result<lir::FieldSet> {
395    let fields = field_set
396        .iter()
397        .map(|field| {
398            let mir::Field {
399                cfg_attr,
400                description,
401                name,
402                access,
403                base_type,
404                field_conversion,
405                field_address,
406            } = field;
407
408            let (base_type, conversion_method) =
409                match (base_type, field.field_address.len(), field_conversion) {
410                    (mir::BaseType::Unspecified | mir::BaseType::FixedSize(_), _, _) => todo!(),
411                    (mir::BaseType::Bool, 1, None) => {
412                        ("u8".to_string(), lir::FieldConversionMethod::Bool)
413                    }
414                    (mir::BaseType::Bool, _, _) => unreachable!(
415                        "Checked in a MIR pass. Bools can only be 1 bit and have no conversion"
416                    ),
417                    (mir::BaseType::Uint | mir::BaseType::Int, val, None) => (
418                        format!(
419                            "{}{}",
420                            match base_type {
421                                mir::BaseType::Uint => 'u',
422                                mir::BaseType::Int => 'i',
423                                _ => unreachable!(),
424                            },
425                            val.max(8).next_power_of_two()
426                        ),
427                        lir::FieldConversionMethod::None,
428                    ),
429                    (mir::BaseType::Uint | mir::BaseType::Int, val, Some(fc)) => (
430                        format!(
431                            "{}{}",
432                            match base_type {
433                                mir::BaseType::Uint => 'u',
434                                mir::BaseType::Int => 'i',
435                                _ => unreachable!(),
436                            },
437                            val.max(8).next_power_of_two()
438                        ),
439                        {
440                            match enum_list.clone().find(|e| e.name == fc.type_name()) {
441                                // Always use try if that's specified
442                                _ if fc.use_try() => {
443                                    lir::FieldConversionMethod::TryInto(fc.type_name().into())
444                                }
445                                // There is an enum we generate so we can look at its metadata
446                                Some(mir::Enum {
447                                    generation_style:
448                                        Some(mir::EnumGenerationStyle::Infallible { bit_size }),
449                                    ..
450                                }) if field.field_address.clone().count() <= *bit_size as usize => {
451                                    // This field is equal or smaller in bits than the infallible enum. So we can do the unsafe into
452                                    lir::FieldConversionMethod::UnsafeInto(fc.type_name().into())
453                                }
454                                // Fallback is to require the into trait
455                                _ => lir::FieldConversionMethod::Into(fc.type_name().into()),
456                            }
457                        },
458                    ),
459                };
460
461            Ok(lir::Field {
462                cfg_attr: cfg_attr.to_string(),
463                description: description.clone(),
464                name: name.clone(),
465                address: field_address.clone(),
466                base_type,
467                conversion_method,
468                access: *access,
469            })
470        })
471        .collect::<Result<_, anyhow::Error>>()?;
472
473    Ok(lir::FieldSet {
474        cfg_attr: cfg_attr.to_string(),
475        description: description.into(),
476        name: field_set_name.to_string(),
477        byte_order,
478        bit_order,
479        size_bits,
480        reset_value: reset_value.unwrap_or_else(|| vec![0; size_bits.div_ceil(8) as usize]),
481        ref_reset_overrides,
482        fields,
483    })
484}
485
486fn collect_enums(device: &mir::Device) -> anyhow::Result<Vec<(mir::Enum, mir::BaseType, usize)>> {
487    let mut enums = Vec::new();
488
489    recurse_objects(&device.objects, &mut |object| {
490        for field in object.field_sets().flatten() {
491            if let Some(mir::FieldConversion::Enum { enum_value, .. }) = &field.field_conversion {
492                enums.push((
493                    enum_value.clone(),
494                    field.base_type,
495                    field.field_address.clone().count(),
496                ))
497            }
498        }
499
500        Ok(())
501    })?;
502
503    Ok(enums)
504}
505
506fn transform_enum(
507    e: &mir::Enum,
508    base_type: mir::BaseType,
509    size_bits: usize,
510) -> anyhow::Result<lir::Enum> {
511    let mir::Enum {
512        cfg_attr,
513        description,
514        name,
515        variants,
516        generation_style: _,
517    } = e;
518
519    let base_type = match (base_type, size_bits) {
520        (mir::BaseType::Bool, _) => "u8".to_string(),
521        (mir::BaseType::Uint, val) => format!("u{}", val.max(8).next_power_of_two()),
522        (mir::BaseType::Int, val) => format!("i{}", val.max(8).next_power_of_two()),
523        (mir::BaseType::Unspecified, _) => {
524            todo!()
525        }
526        (mir::BaseType::FixedSize(_), _) => {
527            todo!()
528        }
529    };
530
531    let mut next_variant_number = None;
532    let variants = variants
533        .iter()
534        .map(|v| {
535            let mir::EnumVariant {
536                cfg_attr,
537                description,
538                name,
539                value,
540            } = v;
541
542            let number = match value {
543                mir::EnumValue::Unspecified
544                | mir::EnumValue::Default
545                | mir::EnumValue::CatchAll => {
546                    let val = next_variant_number.unwrap_or_default();
547                    next_variant_number = Some(val + 1);
548                    val
549                }
550                mir::EnumValue::Specified(num) => {
551                    next_variant_number = Some(*num + 1);
552                    *num
553                }
554            };
555
556            Ok(lir::EnumVariant {
557                cfg_attr: cfg_attr.to_string(),
558                description: description.clone(),
559                name: name.to_string(),
560                number,
561                default: matches!(value, mir::EnumValue::Default),
562                catch_all: matches!(value, mir::EnumValue::CatchAll),
563            })
564        })
565        .collect::<Result<_, anyhow::Error>>()?;
566
567    Ok(lir::Enum {
568        cfg_attr: cfg_attr.to_string(),
569        description: description.clone(),
570        name: name.to_string(),
571        base_type,
572        variants,
573    })
574}
575
576fn repeat_to_method_kind(repeat: &Option<mir::Repeat>) -> lir::BlockMethodKind {
577    match repeat {
578        Some(mir::Repeat { count, stride }) => lir::BlockMethodKind::Repeated {
579            count: *count,
580            stride: *stride,
581        },
582        None => lir::BlockMethodKind::Normal,
583    }
584}
585
586#[derive(Debug, Clone, PartialEq, Eq)]
587pub struct BorrowedBlock<'o> {
588    pub cfg_attr: &'o mir::Cfg,
589    pub description: &'o String,
590    pub name: &'o String,
591    pub address_offset: &'o i64,
592    pub repeat: &'o Option<mir::Repeat>,
593    pub objects: &'o [mir::Object],
594}
595
596impl<'o> From<&'o mir::Block> for BorrowedBlock<'o> {
597    fn from(value: &'o mir::Block) -> Self {
598        let mir::Block {
599            cfg_attr,
600            description,
601            name,
602            address_offset,
603            repeat,
604            objects,
605        } = value;
606
607        Self {
608            cfg_attr,
609            description,
610            name,
611            address_offset,
612            repeat,
613            objects,
614        }
615    }
616}
617
618fn find_best_internal_address_type(device: &mir::Device) -> Integer {
619    let (min_address_found, max_address_found) = find_min_max_addresses(&device.objects, |_| true);
620
621    let needs_signed = min_address_found < 0;
622    let needs_bits = (min_address_found
623        .unsigned_abs()
624        .max(max_address_found.unsigned_abs())
625        .add(1)
626        .next_power_of_two()
627        .ilog2()
628        + needs_signed as u32)
629        .next_power_of_two()
630        .max(8);
631
632    if needs_signed {
633        match needs_bits {
634            8 => Integer::I8,
635            16 => Integer::I16,
636            32 => Integer::I32,
637            64 => Integer::I64,
638            _ => unreachable!(),
639        }
640    } else {
641        match needs_bits {
642            8 => Integer::U8,
643            16 => Integer::U16,
644            32 => Integer::U32,
645            _ => unreachable!(),
646        }
647    }
648}
649
650fn find_refs<'d>(
651    device: &'d mir::Device,
652    source_object: &mir::Object,
653) -> anyhow::Result<Vec<&'d mir::RefObject>> {
654    let mut found_refs = Vec::new();
655
656    recurse_objects(&device.objects, &mut |object| {
657        if let mir::Object::Ref(ref_object) = object {
658            if ref_object.object_override.name() == source_object.name() {
659                found_refs.push(ref_object);
660            }
661        }
662
663        Ok(())
664    })?;
665
666    Ok(found_refs)
667}