hotg_rune_syntax/
analysis.rs

1use std::collections::{HashSet, VecDeque};
2
3use codespan::Span;
4use codespan_reporting::{
5    diagnostic::{Diagnostic, Label},
6};
7use indexmap::IndexMap;
8use crate::{
9    Diagnostics,
10    hir::{self, HirId, Node, Primitive, Resource, ResourceSource, Rune, Slot},
11    utils::{Builtins, HirIds, range_span},
12    yaml::*,
13};
14
15pub fn analyse(doc: &Document, diags: &mut Diagnostics) -> Rune {
16    let mut ctx = Context::new(diags);
17
18    match doc {
19        Document::V1 {
20            image,
21            pipeline,
22            resources,
23        } => {
24            ctx.rune.base_image = Some(image.clone().into());
25
26            ctx.register_resources(resources);
27            ctx.register_stages(pipeline);
28            ctx.register_output_slots(pipeline);
29            ctx.construct_pipeline(pipeline);
30            ctx.check_for_loops();
31        },
32    }
33
34    ctx.rune
35}
36
37#[derive(Debug)]
38struct Context<'diag> {
39    diags: &'diag mut Diagnostics,
40    rune: Rune,
41    ids: HirIds,
42    builtins: Builtins,
43}
44
45impl<'diag> Context<'diag> {
46    fn new(diags: &'diag mut Diagnostics) -> Self {
47        let mut rune = Rune::default();
48        let mut ids = HirIds::new();
49        let builtins = Builtins::new(&mut ids);
50        builtins.copy_into(&mut rune);
51
52        Context {
53            ids,
54            builtins,
55            rune,
56            diags,
57        }
58    }
59
60    fn register_name(&mut self, name: &str, id: HirId, definition: Span) {
61        if let Err(original_definition_id) = self.rune.names.register(name, id)
62        {
63            let duplicate = Label::primary((), range_span(definition))
64                .with_message("Original definition here");
65            let mut labels = vec![duplicate];
66
67            if let Some(original_definition) =
68                self.rune.spans.get(&original_definition_id)
69            {
70                let original =
71                    Label::secondary((), range_span(*original_definition))
72                        .with_message("Original definition here");
73                labels.push(original);
74            }
75
76            let diag = Diagnostic::error()
77                .with_message(format!("\"{}\" is already defined", name))
78                .with_labels(labels);
79            self.diags.push(diag);
80        }
81    }
82
83    fn register_resources(
84        &mut self,
85        resources: &IndexMap<String, ResourceDeclaration>,
86    ) {
87        for (name, declaration) in resources {
88            let source = match declaration {
89                ResourceDeclaration {
90                    inline: Some(inline),
91                    path: None,
92                    ..
93                } => Some(ResourceSource::Inline(inline.clone())),
94                ResourceDeclaration {
95                    inline: None,
96                    path: Some(path),
97                    ..
98                } => Some(ResourceSource::FromDisk(path.into())),
99                ResourceDeclaration {
100                    inline: None,
101                    path: None,
102                    ..
103                } => None,
104                ResourceDeclaration {
105                    inline: Some(_),
106                    path: Some(_),
107                    ..
108                } => {
109                    let diag = Diagnostic::error().with_message(format!("The resource \"{}\" can't specify both a \"path\" and \"inline\" value", name));
110                    self.diags.push(diag);
111                    continue;
112                },
113            };
114            let id = self.ids.next();
115            let resource = Resource {
116                source,
117                ty: declaration.ty,
118            };
119            self.register_name(name, id, resource.span());
120            self.rune.resources.insert(id, resource);
121        }
122    }
123
124    fn register_stages(&mut self, pipeline: &IndexMap<String, Stage>) {
125        for (name, stage) in pipeline {
126            let span = stage.span();
127
128            match hir::Stage::from_yaml(
129                stage.clone(),
130                &self.rune.resources,
131                &self.rune.names,
132            ) {
133                Ok(stage) => {
134                    let id = self.ids.next();
135                    self.rune.stages.insert(
136                        id,
137                        Node {
138                            stage,
139                            input_slots: Vec::new(),
140                            output_slots: Vec::new(),
141                        },
142                    );
143                    self.register_name(name, id, span);
144                },
145                Err(e) => {
146                    let diag = Diagnostic::error()
147                        .with_message(e.to_string())
148                        .with_labels(vec![Label::primary(
149                            (),
150                            range_span(span),
151                        )]);
152                    self.diags.push(diag);
153                },
154            }
155        }
156    }
157
158    fn register_output_slots(&mut self, pipeline: &IndexMap<String, Stage>) {
159        for (name, stage) in pipeline {
160            let node_id = match self.rune.names.get_id(name) {
161                Some(id) => id,
162                None => continue,
163            };
164
165            let mut output_slots = Vec::new();
166
167            for ty in stage.output_types() {
168                let element_type = self.intern_type(ty);
169                let id = self.ids.next();
170                self.rune.slots.insert(
171                    id,
172                    Slot {
173                        element_type,
174                        input_node: node_id,
175                        output_node: HirId::ERROR,
176                    },
177                );
178                output_slots.push(id);
179            }
180
181            let node = self.rune.stages.get_mut(&node_id).unwrap();
182            node.output_slots = output_slots;
183        }
184    }
185
186    fn construct_pipeline(&mut self, pipeline: &IndexMap<String, Stage>) {
187        for (name, stage) in pipeline {
188            let node_id = match self.rune.names.get_id(name) {
189                Some(id) => id,
190                None => continue,
191            };
192
193            let mut input_slots = Vec::new();
194
195            for input in stage.inputs() {
196                let incoming_node_id = match self.rune.names.get_id(&input.name)
197                {
198                    Some(id) => id,
199                    None => {
200                        let diag = Diagnostic::error().with_message(format!(
201                            "No node associated with \"{}\"",
202                            input
203                        ));
204                        self.diags.push(diag);
205                        input_slots.push(HirId::ERROR);
206                        continue;
207                    },
208                };
209
210                let incoming_node = &self.rune.stages[&incoming_node_id];
211
212                if incoming_node.output_slots.is_empty() {
213                    let diag = Diagnostic::error().with_message(format!(
214                            "The \"{}\" stage tried to connect to \"{}\" but that stage doesn't have any outputs",
215                            name,
216                            input
217                        ));
218                    self.diags.push(diag);
219                    input_slots.push(HirId::ERROR);
220                    continue;
221                }
222
223                let input_index = input.index.unwrap_or(0);
224                match incoming_node.output_slots.get(input_index) {
225                    Some(slot_id) => {
226                        input_slots.push(*slot_id);
227                        let slot = self.rune.slots.get_mut(slot_id).unwrap();
228                        slot.output_node = node_id;
229                    },
230                    None => {
231                        let diag = Diagnostic::error().with_message(format!(
232                            "The \"{}\" stage tried to connect to \"{}\" but that stage only has {} outputs",
233                            name,
234                            input,
235                            incoming_node.output_slots.len(),
236                        ));
237                        self.diags.push(diag);
238                        input_slots.push(HirId::ERROR);
239                        continue;
240                    },
241                }
242            }
243
244            let node = self.rune.stages.get_mut(&node_id).unwrap();
245            node.input_slots = input_slots;
246        }
247    }
248
249    fn intern_type(&mut self, ty: &Type) -> HirId {
250        let underlying_type = match self.primitive_type(&ty.name) {
251            Some(p) => p,
252            None => {
253                let msg = format!("Unknown type: {}", ty.name);
254                let diag = Diagnostic::warning().with_message(msg);
255                self.diags.push(diag);
256                return self.builtins.unknown_type;
257            },
258        };
259
260        let ty = if ty.dimensions.is_empty() {
261            hir::Type::Primitive(underlying_type)
262        } else {
263            hir::Type::Buffer {
264                underlying_type: self.builtins.get_id(underlying_type),
265                dimensions: ty.dimensions.clone(),
266            }
267        };
268
269        match self.rune.types.iter().find(|(_, t)| **t == ty) {
270            Some((id, _)) => *id,
271            None => {
272                // new buffer type
273                let id = self.ids.next();
274                self.rune.types.insert(id, ty);
275                id
276            },
277        }
278    }
279
280    fn primitive_type(&mut self, name: &str) -> Option<Primitive> {
281        match name {
282            "u8" | "U8" => Some(Primitive::U8),
283            "i8" | "I8" => Some(Primitive::I8),
284            "u16" | "U16" => Some(Primitive::U16),
285            "i16" | "I16" => Some(Primitive::I16),
286            "u32" | "U32" => Some(Primitive::U32),
287            "i32" | "I32" => Some(Primitive::I32),
288            "u64" | "U64" => Some(Primitive::U64),
289            "i64" | "I64" => Some(Primitive::I64),
290            "f32" | "F32" => Some(Primitive::F32),
291            "f64" | "F64" => Some(Primitive::F64),
292            "utf8" | "UTF8" => Some(Primitive::String),
293            _ => None,
294        }
295    }
296
297    fn check_for_loops(&mut self) {
298        if let Some(cycle) = self.next_cycle() {
299            let (first, middle) = match cycle.as_slice() {
300                [first, middle @ ..] => (first, middle),
301                _ => unreachable!("A cycle must have at least 2 items"),
302            };
303
304            let mut diag = Diagnostic::error().with_message(format!(
305                "Cycle detected when checking \"{}\"",
306                self.rune.names.get_name(*first).unwrap()
307            ));
308
309            if let Some(span) = self.rune.spans.get(first) {
310                diag = diag.with_labels(vec![Label::primary((), *span)]);
311            }
312
313            let mut notes = Vec::new();
314
315            for middle_id in middle {
316                let msg = format!(
317                    "... which receives input from \"{}\"...",
318                    self.rune.names.get_name(*middle_id).unwrap()
319                );
320                notes.push(msg);
321            }
322
323            let closing_message = format!(
324                "... which receives input from \"{}\", completing the cycle.",
325                self.rune.names.get_name(*first).unwrap()
326            );
327            notes.push(closing_message);
328
329            self.diags.push(diag.with_notes(notes));
330        }
331    }
332
333    fn next_cycle(&self) -> Option<Vec<HirId>> {
334        // https://www.geeksforgeeks.org/detect-cycle-in-a-graph/
335        let mut stack = VecDeque::new();
336        let mut visited = HashSet::new();
337
338        for id in self.rune.stages.keys().copied() {
339            if detect_cycles(id, &self.rune, &mut visited, &mut stack) {
340                return Some(stack.into());
341            }
342        }
343
344        None
345    }
346}
347
348fn detect_cycles(
349    id: HirId,
350    rune: &Rune,
351    visited: &mut HashSet<HirId>,
352    stack: &mut VecDeque<HirId>,
353) -> bool {
354    if stack.contains(&id) {
355        // We've detected a cycle, remove everything before our id so the stack
356        // is left just containing the cycle
357        while stack.front() != Some(&id) {
358            stack.pop_front();
359        }
360
361        return true;
362    } else if visited.contains(&id) {
363        return false;
364    }
365
366    visited.insert(id);
367    stack.push_back(id);
368
369    let incoming_nodes = rune.stages[&id]
370        .input_slots
371        .iter()
372        .map(|slot_id| rune.slots[slot_id].input_node);
373
374    for incoming_node in incoming_nodes {
375        if detect_cycles(incoming_node, rune, visited, stack) {
376            return true;
377        }
378    }
379
380    let got = stack.pop_back();
381    debug_assert_eq!(got, Some(id));
382
383    false
384}
385
386#[cfg(test)]
387mod tests {
388    use super::*;
389
390    macro_rules! map {
391        // map-like
392        ($($k:ident : $v:expr),* $(,)?) => {
393            std::iter::Iterator::collect(std::array::IntoIter::new([
394                $(
395                    (String::from(stringify!($k)), $v)
396                ),*
397            ]))
398        };
399        // set-like
400        ($($v:expr),* $(,)?) => {
401            std::iter::Iterator::collect(std::array::IntoIter::new([$($v,)*]))
402        };
403    }
404
405    macro_rules! ty {
406        ($type:ident [$($dim:expr),*]) => {
407            Type {
408                name: String::from(stringify!($type)),
409                dimensions: vec![ $($dim),*],
410            }
411        };
412        ($type:ident) => {
413            Type {
414                name: String::from(stringify!($type)),
415                dimensions: vec![],
416            }
417        }
418    }
419
420    #[test]
421    fn parse_yaml_pipeline() {
422        let src = r#"
423version: 1
424image: "runicos/base"
425
426pipeline:
427  audio:
428    capability: SOUND
429    outputs:
430    - type: i16
431      dimensions: [16000]
432    args:
433      hz: 16000
434
435  fft:
436    proc-block: "hotg-ai/rune#proc_blocks/fft"
437    inputs:
438    - audio
439    outputs:
440    - type: i8
441      dimensions: [1960]
442
443  model:
444    model: "./model.tflite"
445    inputs:
446    - fft
447    outputs:
448    - type: i8
449      dimensions: [6]
450
451  label:
452    proc-block: "hotg-ai/rune#proc_blocks/ohv_label"
453    inputs:
454    - model
455    outputs:
456    - type: utf8
457    args:
458      labels: ["silence", "unknown", "up", "down", "left", "right"]
459
460  output:
461    out: SERIAL
462    inputs:
463    - label
464        "#;
465        let should_be = Document::V1 {
466            image: Path::new("runicos/base", None, None),
467            pipeline: map! {
468                audio: Stage::Capability {
469                    capability: String::from("SOUND"),
470                    outputs: vec![ty!(i16[16000])],
471                    args: map! { hz: Value::Int(16000) },
472                },
473                output: Stage::Out {
474                    out: String::from("SERIAL"),
475                    args: IndexMap::new(),
476                    inputs: vec!["label".parse().unwrap()],
477                },
478                label: Stage::ProcBlock {
479                    proc_block: "hotg-ai/rune#proc_blocks/ohv_label".parse().unwrap(),
480                    inputs: vec!["model".parse().unwrap()],
481                    outputs: vec![Type { name: String::from("utf8"), dimensions: Vec::new() }],
482                    args: map! {
483                        labels: Value::from(vec![
484                            Value::from("silence"),
485                            Value::from("unknown"),
486                            Value::from("up"),
487                            Value::from("down"),
488                            Value::from("left"),
489                            Value::from("right"),
490                        ]),
491                    },
492                },
493                fft: Stage::ProcBlock {
494                    proc_block: "hotg-ai/rune#proc_blocks/fft".parse().unwrap(),
495                    inputs: vec!["audio".parse().unwrap()],
496                    outputs: vec![ty!(i8[1960])],
497                    args: IndexMap::new(),
498                },
499                model: Stage::Model {
500                    model: "./model.tflite".into(),
501                    inputs: vec!["fft".parse().unwrap()],
502                    outputs: vec![ty!(i8[6])],
503                },
504            },
505            resources: map![],
506        };
507
508        let got = Document::parse(src).unwrap();
509
510        assert_eq!(got, should_be);
511    }
512
513    #[test]
514    fn parse_audio_block() {
515        let src = r#"
516              capability: SOUND
517              outputs:
518              - type: i16
519                dimensions: [16000]
520              args:
521                hz: 16000
522        "#;
523        let should_be = Stage::Capability {
524            capability: String::from("SOUND"),
525            outputs: vec![Type {
526                name: String::from("i16"),
527                dimensions: vec![16000],
528            }],
529            args: map! { hz: Value::Int(16000) },
530        };
531
532        let got: Stage = serde_yaml::from_str(src).unwrap();
533
534        assert_eq!(got, should_be);
535    }
536
537    #[test]
538    fn parse_values() {
539        let inputs = vec![
540            ("42", Value::Int(42)),
541            ("3.14", Value::Float(3.14)),
542            ("\"42\"", Value::String("42".into())),
543            (
544                "[1, 2.0, \"asdf\"]",
545                Value::List(vec![
546                    Value::Int(1),
547                    Value::Float(2.0),
548                    Value::String("asdf".into()),
549                ]),
550            ),
551        ];
552
553        for (src, should_be) in inputs {
554            let got: Value = serde_yaml::from_str(src).unwrap();
555            assert_eq!(got, should_be);
556        }
557    }
558
559    #[test]
560    fn parse_paths() {
561        let inputs = vec![
562            ("asdf", Path::new("asdf", None, None)),
563            ("runicos/base", Path::new("runicos/base", None, None)),
564            (
565                "runicos/base@0.1.2",
566                Path::new("runicos/base", None, "0.1.2".to_string()),
567            ),
568            (
569                "runicos/base@latest",
570                Path::new("runicos/base", None, "latest".to_string()),
571            ),
572            (
573                "https://github.com/hotg-ai/rune",
574                Path::new("https://github.com/hotg-ai/rune", None, None),
575            ),
576            (
577                "https://github.com/hotg-ai/rune@2",
578                Path::new(
579                    "https://github.com/hotg-ai/rune",
580                    None,
581                    "2".to_string(),
582                ),
583            ),
584            (
585                "hotg-ai/rune@v1.2#proc_blocks/normalize",
586                Path::new(
587                    "hotg-ai/rune",
588                    "proc_blocks/normalize".to_string(),
589                    "v1.2".to_string(),
590                ),
591            ),
592        ];
593
594        for (src, should_be) in inputs {
595            let got: Path = src.parse().unwrap();
596            assert_eq!(got, should_be);
597        }
598    }
599
600    fn dummy_document() -> Document {
601        Document::V1 {
602            image: Path::new("runicos/base".to_string(), None, None),
603            pipeline: map! {
604                audio: Stage::Capability {
605                    capability: String::from("SOUND"),
606                    outputs: vec![
607                        ty!(i16[16000]),
608                    ],
609                    args: map! {
610                        hz: Value::from(16000),
611                    },
612                },
613                fft: Stage::ProcBlock {
614                    proc_block: "hotg-ai/rune#proc_blocks/fft".parse().unwrap(),
615                    inputs: vec!["audio".parse().unwrap()],
616                    outputs: vec![
617                        ty!(i8[1960]),
618                    ],
619                    args: IndexMap::new(),
620                },
621                model: Stage::Model {
622                    model: "./model.tflite".into(),
623                    inputs: vec!["fft".parse().unwrap()],
624                    outputs: vec![
625                        ty!(i8[6]),
626                    ],
627                },
628                label: Stage::ProcBlock {
629                    proc_block: "hotg-ai/rune#proc_blocks/ohv_label".parse().unwrap(),
630                    inputs: vec!["model".parse().unwrap()],
631                    outputs: vec![
632                        ty!(utf8),
633                    ],
634                    args: map! {
635                        labels: Value::List(vec![
636                            Value::from("silence"),
637                            Value::from("unknown"),
638                            Value::from("up"),
639                        ]),
640                    },
641                },
642                output: Stage::Out {
643                    out: String::from("SERIAL"),
644                    inputs: vec!["label".parse().unwrap()],
645                    args: IndexMap::default(),
646                }
647            },
648            resources: map![],
649        }
650    }
651
652    #[test]
653    fn register_all_stages() {
654        let pipeline = match dummy_document() {
655            Document::V1 { pipeline, .. } => pipeline,
656        };
657        let mut diags = Diagnostics::new();
658        let mut ctx = Context::new(&mut diags);
659        let stages = vec!["audio", "fft", "model", "label", "output"];
660
661        ctx.register_stages(&pipeline);
662
663        for stage_name in stages {
664            let id = ctx.rune.names.get_id(stage_name).unwrap();
665            assert!(ctx.rune.stages.contains_key(&id));
666        }
667
668        assert!(diags.is_empty());
669    }
670
671    #[test]
672    fn construct_the_pipeline() {
673        let pipeline = match dummy_document() {
674            Document::V1 { pipeline, .. } => pipeline,
675        };
676        let mut diags = Diagnostics::new();
677        let mut ctx = Context::new(&mut diags);
678        ctx.register_stages(&pipeline);
679        ctx.register_output_slots(&pipeline);
680        let edges = vec![
681            ("audio", "fft"),
682            ("fft", "model"),
683            ("model", "label"),
684            ("label", "output"),
685        ];
686
687        ctx.construct_pipeline(&pipeline);
688
689        assert!(ctx.diags.is_empty(), "{:?}", ctx.diags);
690        for (from, to) in edges {
691            println!("{:?} => {:?}", from, to);
692            let from_id = ctx.rune.names.get_id(from).unwrap();
693            let to_id = ctx.rune.names.get_id(to).unwrap();
694
695            assert!(ctx.rune.has_connection(from_id, to_id));
696        }
697    }
698
699    #[test]
700    fn construct_pipeline_graph_with_multiple_inputs_and_outputs() {
701        let doc = Document::V1 {
702            image: "runicos/base@latest".parse().unwrap(),
703            pipeline: map! {
704                audio: Stage::Capability {
705                    capability: String::from("SOUND"),
706                    outputs: vec![
707                        ty!(i16[16000]),
708                    ],
709                    args: map! {
710                        hz: Value::from(16000),
711                    },
712                },
713                fft: Stage::ProcBlock {
714                    proc_block: "hotg-ai/rune#proc_blocks/fft".parse().unwrap(),
715                    inputs: vec![
716                        "audio".parse().unwrap(),
717                        "audio".parse().unwrap(),
718                        "audio".parse().unwrap(),
719                        ],
720                    outputs: vec![
721                        ty!(i8[1960]),
722                        ty!(i8[1960]),
723                        ty!(i8[1960]),
724                    ],
725                    args: IndexMap::new(),
726                },
727                serial: Stage::Out {
728                    out: String::from("SERIAL"),
729                    inputs: vec![
730                        "fft.0".parse().unwrap(),
731                        "fft.1".parse().unwrap(),
732                        "fft.2".parse().unwrap(),
733                    ],
734                    args: IndexMap::new(),
735                },
736            },
737            resources: map![],
738        };
739        let mut diags = Diagnostics::new();
740
741        let rune = analyse(&doc, &mut diags);
742
743        assert!(!diags.has_errors() && !diags.has_warnings(), "{:#?}", diags);
744
745        let audio_id = rune.names["audio"];
746        let audio_node = &rune.stages[&audio_id];
747        assert!(audio_node.input_slots.is_empty());
748        assert_eq!(audio_node.output_slots.len(), 1);
749        let audio_output = audio_node.output_slots[0];
750
751        let fft_id = rune.names["fft"];
752        let fft_node = &rune.stages[&fft_id];
753        assert_eq!(
754            fft_node.input_slots,
755            &[audio_output, audio_output, audio_output]
756        );
757
758        let output_id = rune.names["serial"];
759        let output_node = &rune.stages[&output_id];
760        assert_eq!(fft_node.output_slots, output_node.input_slots);
761    }
762
763    #[test]
764    fn topological_sorting() {
765        let doc = dummy_document();
766        let mut diags = Diagnostics::new();
767        let rune = analyse(&doc, &mut diags);
768        let should_be = ["audio", "fft", "model", "label", "output"];
769
770        let got: Vec<_> = rune.sorted_pipeline().collect();
771
772        let should_be: Vec<_> = should_be
773            .iter()
774            .copied()
775            .map(|name| rune.names.get_id(name).unwrap())
776            .map(|id| (id, &rune.stages[&id]))
777            .collect();
778        assert_eq!(got, should_be);
779    }
780
781    #[test]
782    fn detect_pipeline_cycle() {
783        let src = r#"
784image: runicos/base
785version: 1
786
787pipeline:
788  audio:
789    proc-block: "hotg-ai/rune#proc_blocks/fft"
790    inputs:
791    - model
792    outputs:
793    - type: i16
794      dimensions: [16000]
795
796  fft:
797    proc-block: "hotg-ai/rune#proc_blocks/fft"
798    inputs:
799    - audio
800    outputs:
801    - type: i8
802      dimensions: [1960]
803
804  model:
805    model: "./model.tflite"
806    inputs:
807    - fft
808    outputs:
809    - type: i8
810      dimensions: [6]
811            "#;
812        let doc = Document::parse(src).unwrap();
813        let mut diags = Diagnostics::new();
814
815        let _ = analyse(&doc, &mut diags);
816
817        assert!(diags.has_errors());
818        let errors: Vec<_> = diags
819            .iter_severity(codespan_reporting::diagnostic::Severity::Error)
820            .collect();
821        assert_eq!(errors.len(), 1);
822        let diag = errors[0];
823        assert_eq!(diag.message, "Cycle detected when checking \"audio\"");
824        assert!(diag.notes[0].contains("model"));
825        assert!(diag.notes[1].contains("fft"));
826        assert_eq!(
827            diag.notes[2],
828            "... which receives input from \"audio\", completing the cycle."
829        );
830    }
831}