protospec_build/coder/
encode.rs

1use indexmap::IndexMap;
2
3use super::*;
4use crate::asg::*;
5use std::{sync::Arc, collections::HashMap};
6
7#[derive(Debug)]
8pub enum Instruction {
9    Eval(usize, Expression),
10    GetField(usize, usize, Vec<FieldRef>), // dest, source, op
11    AllocBuf(usize, usize),                // buf handle, len handle
12    AllocDynBuf(usize),                    // buf handle
13    WrapStream(Target, usize, Arc<Transform>, Vec<usize>), // stream, new stream, transformer, arguments
14    ConditionalWrapStream(
15        usize,
16        Vec<Instruction>,
17        Target,
18        usize,
19        usize,
20        Arc<Transform>,
21        Vec<usize>,
22    ), // condition, prelude, stream, new stream, owned_new_stream, transformer, arguments
23    ProxyStream(Target, usize),                            // stream, new stream
24    EndStream(usize),
25
26    EmitBuf(Target, usize),
27
28    EncodeForeign(Target, usize, Arc<ForeignType>, Vec<usize>),
29    EncodeRef(Target, usize, Vec<usize>),
30    EncodeEnum(PrimitiveType, Target, usize),
31    EncodeBitfield(Target, usize),
32    EncodePrimitive(Target, usize, PrimitiveType),
33    EncodePrimitiveArray(Target, usize, PrimitiveType, Option<usize>),
34    // target, register of length
35    Pad(Target, usize),
36
37    // register representing iterator from -> term, term, inner
38    Loop(usize, usize, Vec<Instruction>),
39    // len target <- buffer, cast_type
40    GetLen(usize, usize, Option<ScalarType>),
41    Drop(usize),
42    // original, checked, message
43    NullCheck(usize, usize, String),
44    Conditional(usize, Vec<Instruction>, Vec<Instruction>), // condition, if_true, if_false
45    /// enum name, discriminant, original, checked, message
46    UnwrapEnum(String, String, usize, usize, String),
47    /// enum name, discriminant, original, checked: (enumstruct field name, checked), message
48    UnwrapEnumStruct(String, String, usize, Vec<(String, usize)>, String),
49    BreakBlock(Vec<Instruction>),
50    Break,
51}
52
53type Resolver = Box<dyn Fn(&mut Context, &str) -> usize>;
54
55#[derive(Debug)]
56pub struct Context {
57    pub register_count: usize,
58    pub instructions: Vec<Instruction>,
59    pub resolved_autos: IndexMap<String, usize>,
60}
61
62impl Context {
63    fn alloc_register(&mut self) -> usize {
64        let x = self.register_count;
65        self.register_count += 1;
66        x
67    }
68}
69
70impl Context {
71    pub fn new() -> Context {
72        Context {
73            instructions: vec![],
74            register_count: 0,
75            resolved_autos: IndexMap::new(),
76        }
77    }
78
79    pub fn encode_field_top(&mut self, field: &Arc<Field>) {
80        let top = self.alloc_register(); // implicitly set to self/equivalent
81        match &*field.type_.borrow() {
82            Type::Foreign(_) => return,
83            Type::Container(_) => (),
84            Type::Enum(_) => (),
85            Type::Bitfield(_) => (),
86            _ => {
87                self.instructions
88                    .push(Instruction::GetField(0, 0, vec![FieldRef::TupleAccess(0)]))
89            }
90        }
91        let resolver: Resolver = Box::new(move |context: &mut Context, name: &str| {
92            let value = context.alloc_register();
93            context.instructions.push(Instruction::GetField(
94                value,
95                top,
96                vec![FieldRef::Name(name.to_string())],
97            ));
98            value
99        });
100        self.encode_field(Target::Direct, &resolver, top, field);
101    }
102
103    fn encode_field_condition(&mut self, field: &Arc<Field>) -> Option<usize> {
104        if let Some(condition) = field.condition.borrow().as_ref() {
105            let value = self.alloc_register();
106            self.instructions
107                .push(Instruction::Eval(value, condition.clone()));
108            Some(value)
109        } else {
110            None
111        }
112    }
113
114    pub fn encode_field(
115        &mut self,
116        target: Target,
117        resolver: &Resolver,
118        source: usize,
119        field: &Arc<Field>,
120    ) {
121        let field_condition = self.encode_field_condition(field);
122        let start = self.instructions.len();
123        
124        self.encode_field_unconditional(target, resolver, source, field, field_condition.is_some());
125
126        if let Some(field_condition) = field_condition {
127            let drained = self.instructions.drain(start..).collect();
128            self.instructions
129                .push(Instruction::Conditional(field_condition, drained, vec![]));
130        }
131    }
132
133    fn encode_container_items(&mut self, container: &ContainerType, buf_target: Target, resolver: &Resolver, source: usize) {
134        let mut auto_targets = vec![];
135        for (name, child) in container.items.iter() {
136            if child.is_auto.get() {
137                let new_target = self.alloc_register();
138                self.instructions.push(Instruction::AllocDynBuf(new_target));
139                auto_targets.push((new_target, child));
140                continue;
141            }
142            let (real_target, _) = auto_targets
143                .last()
144                .map(|x| (Target::Buf(x.0), Some(&x.1)))
145                .unwrap_or_else(|| (buf_target, None));
146            if matches!(&*child.type_.borrow(), Type::Container(_)) || child.is_pad.get() {
147                self.encode_field(real_target, resolver, source, child);
148            } else {
149                let resolved = resolver(self, &**name);
150                self.encode_field(real_target, resolver, resolved, child);
151            }
152
153            for (i, (auto_target, auto_field)) in auto_targets.clone().into_iter().enumerate().rev() {
154                if let Some(resolved) = self.resolved_autos.get(&auto_field.name).copied() {
155                    auto_targets.remove(i);
156                    let target = auto_targets.get(i).map(|(target, _)| Target::Buf(*target)).unwrap_or(buf_target);
157                    self.encode_field(target, resolver, resolved, auto_field);
158                    self.instructions
159                        .push(Instruction::EmitBuf(target, auto_target));
160                }
161            }
162        }
163        for (_, auto_field) in auto_targets {
164            panic!("unused auto field: {}", auto_field.name);
165        }
166    }
167
168    fn encode_field_unconditional(
169        &mut self,
170        mut target: Target,
171        resolver: &Resolver,
172        source: usize,
173        field: &Arc<Field>,
174        was_conditional: bool,
175    ) {
176        let mut new_streams = vec![];
177
178        for transform in field.transforms.borrow().iter() {
179            let condition = if let Some(condition) = &transform.condition {
180                let value = self.alloc_register();
181                self.instructions
182                    .push(Instruction::Eval(value, condition.clone()));
183                Some(value)
184            } else {
185                None
186            };
187
188            let argument_start = self.instructions.len();
189            let mut args = vec![];
190            for arg in transform.arguments.iter() {
191                let r = self.alloc_register();
192                self.instructions.push(Instruction::Eval(r, arg.clone()));
193                args.push(r);
194            }
195            let new_stream = self.alloc_register();
196            let new_owned_stream = condition.map(|_| self.alloc_register());
197            new_streams.push((new_stream, new_owned_stream));
198
199            if let Some(condition) = condition {
200                let drained = self.instructions.drain(argument_start..).collect();
201                self.instructions.push(Instruction::ConditionalWrapStream(
202                    condition,
203                    drained,
204                    target,
205                    new_stream,
206                    new_owned_stream.unwrap(),
207                    transform.transform.clone(),
208                    args,
209                ));
210            } else {
211                self.instructions.push(Instruction::WrapStream(
212                    target,
213                    new_stream,
214                    transform.transform.clone(),
215                    args,
216                ));
217            }
218            target = Target::Stream(new_stream);
219        }
220
221        let source = if was_conditional {
222            let real_source = self.alloc_register();
223            self.instructions.push(Instruction::NullCheck(
224                source,
225                real_source,
226                "failed null check for conditional field".to_string(),
227            ));
228            real_source
229        } else {
230            source
231        };
232
233        match &*field.type_.borrow() {
234            _ if field.is_pad.get() => {
235                let array_type = field.type_.borrow();
236                let array_type = match &*array_type {
237                    Type::Array(a) => &**a,
238                    _ => panic!("invalid type for pad"),
239                };
240                let len = array_type.length.value.as_ref().cloned().unwrap();
241                let length_register = self.alloc_register();
242                self.instructions.push(Instruction::Eval(length_register, len));
243                self.instructions.push(Instruction::Pad(target, length_register));
244            },
245            Type::Container(c) => {
246                let buf_target = if let Some(length) = &c.length {
247                    //todo: use limited stream
248                    let len_register = self.alloc_register();
249                    let buf = self.alloc_register();
250                    self.instructions
251                        .push(Instruction::Eval(len_register, length.clone()));
252                    self.instructions
253                        .push(Instruction::AllocBuf(buf, len_register));
254                    Target::Buf(buf)
255                } else {
256                    target
257                };
258                if c.is_enum.get() {
259                    let break_start = self.instructions.len();
260                    for (name, child) in c.items.iter() {
261                        let condition = self.encode_field_condition(child);
262                        let start = self.instructions.len();
263                        let unwrapped = self.alloc_register();
264
265                        let subtype = child.type_.borrow();
266                        match &*subtype {
267                            Type::Container(c) => {
268
269                                let mut unwrapped = vec![];
270                                for (subname, subchild) in c.flatten_view() {
271                                    if subchild.is_pad.get() || matches!(&*subchild.type_.borrow(), Type::Container(_)) {
272                                        continue;
273                                    }
274                                    let alloced = self.alloc_register();
275                                    unwrapped.push((
276                                        subname.clone(),
277                                        alloced,
278                                    ));
279                                }
280
281                                self.instructions.push(Instruction::UnwrapEnumStruct(
282                                    field.name.clone(),
283                                    name.clone(),
284                                    source,
285                                    unwrapped.clone(),
286                                    "mismatch betweeen condition and enum discriminant".to_string(),
287                                ));
288
289                                let map = unwrapped.into_iter().collect::<HashMap<_, _>>();
290
291                                let resolver: Resolver = Box::new(move |_context, name| *map.get(name).expect("illegal field ref"));
292                                self.encode_container_items(c, buf_target, &resolver, source);
293                                self.instructions.push(Instruction::Break);
294                            },
295                            _ => {
296                                self.instructions.push(Instruction::UnwrapEnum(
297                                    field.name.clone(),
298                                    name.clone(),
299                                    source,
300                                    unwrapped,
301                                    "mismatch betweeen condition and enum discriminant".to_string(),
302                                ));
303                                
304                                let resolver: Resolver = Box::new(|_, _| panic!("fields refs illegal in raw enum value"));
305                                self.encode_field_unconditional(buf_target, &resolver, unwrapped, child, false);
306                                self.instructions.push(Instruction::Break);
307                            },
308                        }
309
310                        if let Some(condition) = condition {
311                            let drained = self.instructions.drain(start..).collect();
312                            self.instructions
313                                .push(Instruction::Conditional(condition, drained, vec![]));
314                        }
315                    }
316                    let drained = self.instructions.drain(break_start..).collect();
317                    self.instructions
318                        .push(Instruction::BreakBlock(drained));
319
320                } else {
321                    self.encode_container_items(c, buf_target, resolver, source);
322                }
323
324                if let Some(length) = &c.length {
325                    self.check_auto(length, buf_target.unwrap_buf());
326                    self.instructions
327                        .push(Instruction::EmitBuf(target, buf_target.unwrap_buf()));
328                }
329            }
330            t => self.encode_type(target, resolver, source, t),
331        }
332
333        for (stream, owned_stream) in new_streams.iter().rev() {
334            self.instructions.push(Instruction::EndStream(*stream));
335            if let Some(owned_stream) = owned_stream {
336                self.instructions.push(Instruction::Drop(*owned_stream));
337            }
338        }
339    }
340
341    fn resolve_auto(&mut self, field: &Arc<Field>, source: usize) -> Option<usize> {
342        let type_ = field.type_.borrow();
343        let cast_type = match type_.resolved().as_ref() {
344            Type::Scalar(s) => *s,
345            Type::Foreign(f) => match f.obj.can_receive_auto() {
346                Some(s) => s,
347                None => unimplemented!("bad ffi type for auto field"),
348            },
349            _ => unimplemented!("bad type for auto field"),
350        };
351
352        let target = self.alloc_register();
353        self.instructions.push(Instruction::GetLen(
354            target,
355            source,
356            Some(cast_type),
357        ));
358        self.resolved_autos.insert(field.name.clone(), target);
359        Some(target)
360    }
361
362    fn check_auto(&mut self, base: &Expression, source: usize) -> Option<usize> {
363        match base {
364            Expression::FieldRef(f) if f.is_auto.get() => {
365                self.resolve_auto(f, source)
366            },
367            Expression::FieldRef(_) => None,
368            Expression::Binary(_) => None,
369            Expression::Member(_) => None,
370            Expression::Unary(_) => None,
371            Expression::Cast(expr) => self.check_auto(&*expr.inner, source),
372            Expression::ArrayIndex(_) => None,
373            Expression::EnumAccess(_) => None,
374            Expression::Int(_) => None,
375            Expression::ConstRef(_) => None,
376            Expression::InputRef(_) => None,
377            Expression::Str(_) => None,
378            Expression::Ternary(_) => None,
379            Expression::Bool(_) => None,
380            Expression::Call(_) => None,
381        }
382    }
383
384    pub fn encode_type(&mut self, target: Target, resolver: &Resolver, source: usize, type_: &Type) {
385        match type_ {
386            Type::Container(_) => unimplemented!(),
387            Type::Array(c) => {
388                let terminator = if c.length.expandable && c.length.value.is_some() {
389                    let len = c.length.value.as_ref().cloned().unwrap();
390                    let r = self.alloc_register();
391                    self.instructions.push(Instruction::Eval(r, len));
392                    Some(r)
393                } else {
394                    None
395                };
396
397                let mut len = if terminator.is_none() {
398                    if let Some(expr) = &c.length.value {
399                        self.check_auto(expr, source)
400                    } else {
401                        None
402                    }
403                } else {
404                    None
405                };
406
407                if len.is_none() && !c.length.expandable {
408                    len = {
409                        let len = c.length.value.as_ref().cloned().unwrap();
410                        let r = self.alloc_register();
411                        self.instructions.push(Instruction::Eval(r, len));
412                        Some(r)
413                    };
414                }
415
416                if c.element.condition.borrow().is_none()
417                    && c.element.transforms.borrow().len() == 0
418                    && terminator.is_none()
419                {
420                    let type_ = c.element.type_.borrow();
421                    let type_ = type_.resolved();
422                    match &*type_ {
423                        // todo: const-length type optimizations for container/array/foreign
424                        Type::Container(_) | Type::Array(_) | Type::Foreign(_) | Type::Ref(_) => (),
425                        Type::Enum(e) => {
426                            self.instructions.push(Instruction::EncodePrimitiveArray(
427                                target,
428                                source,
429                                PrimitiveType::Scalar(e.rep),
430                                len,
431                            ));
432                            return;
433                        },
434                        Type::Bitfield(e) => {
435                            self.instructions.push(Instruction::EncodePrimitiveArray(
436                                target,
437                                source,
438                                PrimitiveType::Scalar(e.rep),
439                                len,
440                            ));
441                            return;
442                        },
443                        Type::Scalar(x) => {
444                            self.instructions.push(Instruction::EncodePrimitiveArray(
445                                target,
446                                source,
447                                PrimitiveType::Scalar(*x),
448                                len,
449                            ));
450                            return;
451                        }
452                        Type::F32 => {
453                            self.instructions.push(Instruction::EncodePrimitiveArray(
454                                target,
455                                source,
456                                PrimitiveType::F32,
457                                len,
458                            ));
459                            return;
460                        }
461                        Type::F64 => {
462                            self.instructions.push(Instruction::EncodePrimitiveArray(
463                                target,
464                                source,
465                                PrimitiveType::F64,
466                                len,
467                            ));
468                            return;
469                        }
470                        Type::Bool => {
471                            self.instructions.push(Instruction::EncodePrimitiveArray(
472                                target,
473                                source,
474                                PrimitiveType::Bool,
475                                len,
476                            ));
477                            return;
478                        }
479                    }
480                }
481
482                let current_pos = self.instructions.len();
483                let iter_index = self.alloc_register();
484                let new_source = self.alloc_register();
485                self.instructions.push(Instruction::GetField(
486                    new_source,
487                    source,
488                    vec![FieldRef::ArrayAccess(iter_index)],
489                ));
490                self.encode_field(target, resolver, new_source, &c.element);
491                let drained = self.instructions.drain(current_pos..).collect();
492                let len = if let Some(len) = len {
493                    len
494                } else {
495                    let len = self.alloc_register();
496                    self.instructions
497                        .push(Instruction::GetLen(len, source, None));
498                    len
499                };
500                self.instructions
501                    .push(Instruction::Loop(iter_index, len, drained));
502                if let Some(terminator) = terminator {
503                    self.instructions.push(Instruction::EncodePrimitiveArray(
504                        target,
505                        terminator,
506                        PrimitiveType::Scalar(ScalarType::U8),
507                        None,
508                    ));
509                }
510            }
511            Type::Enum(e) => {
512                self.instructions.push(Instruction::EncodeEnum(
513                    PrimitiveType::Scalar(e.rep.clone()),
514                    target,
515                    source,
516                ));
517            }
518            Type::Bitfield(_) => {
519                self.instructions.push(Instruction::EncodeBitfield(
520                    target,
521                    source,
522                ));
523            }
524            Type::Scalar(s) => {
525                self.instructions.push(Instruction::EncodePrimitive(
526                    target,
527                    source,
528                    PrimitiveType::Scalar(*s),
529                ));
530            }
531            Type::F32 => {
532                self.instructions.push(Instruction::EncodePrimitive(
533                    target,
534                    source,
535                    PrimitiveType::F32,
536                ));
537            }
538            Type::F64 => {
539                self.instructions.push(Instruction::EncodePrimitive(
540                    target,
541                    source,
542                    PrimitiveType::F64,
543                ));
544            }
545            Type::Bool => {
546                self.instructions.push(Instruction::EncodePrimitive(
547                    target,
548                    source,
549                    PrimitiveType::Bool,
550                ));
551            }
552            Type::Foreign(f) => {
553                self.instructions.push(Instruction::EncodeForeign(
554                    target,
555                    source,
556                    f.clone(),
557                    vec![],
558                ));
559            }
560            Type::Ref(r) => {
561                let mut args = vec![];
562                for arg in r.arguments.iter() {
563                    let r = self.alloc_register();
564                    self.instructions.push(Instruction::Eval(r, arg.clone()));
565                    args.push(r);
566                }
567                if let Type::Foreign(f) = &*r.target.type_.borrow() {
568                    let arguments = f.obj.arguments();
569                    for (expr, arg) in r.arguments.iter().zip(arguments.iter()) {
570                        if arg.can_resolve_auto {
571                            self.check_auto(expr, source);
572                        }
573                    }
574                    self.instructions.push(Instruction::EncodeForeign(
575                        target,
576                        source,
577                        f.clone(),
578                        args,
579                    ));
580                } else {
581                    self.instructions
582                        .push(Instruction::EncodeRef(target, source, args));
583                }
584            }
585        }
586    }
587}