1use std::collections::{HashSet, VecDeque};
2
3use codespan::Span;
4use codespan_reporting::{
5 diagnostic::{Diagnostic, Label},
6};
7use indexmap::IndexMap;
8use crate::{
9 Diagnostics,
10 hir::{self, HirId, Node, Primitive, Resource, ResourceSource, Rune, Slot},
11 utils::{Builtins, HirIds, range_span},
12 yaml::*,
13};
14
15pub fn analyse(doc: &Document, diags: &mut Diagnostics) -> Rune {
16 let mut ctx = Context::new(diags);
17
18 match doc {
19 Document::V1 {
20 image,
21 pipeline,
22 resources,
23 } => {
24 ctx.rune.base_image = Some(image.clone().into());
25
26 ctx.register_resources(resources);
27 ctx.register_stages(pipeline);
28 ctx.register_output_slots(pipeline);
29 ctx.construct_pipeline(pipeline);
30 ctx.check_for_loops();
31 },
32 }
33
34 ctx.rune
35}
36
37#[derive(Debug)]
38struct Context<'diag> {
39 diags: &'diag mut Diagnostics,
40 rune: Rune,
41 ids: HirIds,
42 builtins: Builtins,
43}
44
45impl<'diag> Context<'diag> {
46 fn new(diags: &'diag mut Diagnostics) -> Self {
47 let mut rune = Rune::default();
48 let mut ids = HirIds::new();
49 let builtins = Builtins::new(&mut ids);
50 builtins.copy_into(&mut rune);
51
52 Context {
53 ids,
54 builtins,
55 rune,
56 diags,
57 }
58 }
59
60 fn register_name(&mut self, name: &str, id: HirId, definition: Span) {
61 if let Err(original_definition_id) = self.rune.names.register(name, id)
62 {
63 let duplicate = Label::primary((), range_span(definition))
64 .with_message("Original definition here");
65 let mut labels = vec![duplicate];
66
67 if let Some(original_definition) =
68 self.rune.spans.get(&original_definition_id)
69 {
70 let original =
71 Label::secondary((), range_span(*original_definition))
72 .with_message("Original definition here");
73 labels.push(original);
74 }
75
76 let diag = Diagnostic::error()
77 .with_message(format!("\"{}\" is already defined", name))
78 .with_labels(labels);
79 self.diags.push(diag);
80 }
81 }
82
83 fn register_resources(
84 &mut self,
85 resources: &IndexMap<String, ResourceDeclaration>,
86 ) {
87 for (name, declaration) in resources {
88 let source = match declaration {
89 ResourceDeclaration {
90 inline: Some(inline),
91 path: None,
92 ..
93 } => Some(ResourceSource::Inline(inline.clone())),
94 ResourceDeclaration {
95 inline: None,
96 path: Some(path),
97 ..
98 } => Some(ResourceSource::FromDisk(path.into())),
99 ResourceDeclaration {
100 inline: None,
101 path: None,
102 ..
103 } => None,
104 ResourceDeclaration {
105 inline: Some(_),
106 path: Some(_),
107 ..
108 } => {
109 let diag = Diagnostic::error().with_message(format!("The resource \"{}\" can't specify both a \"path\" and \"inline\" value", name));
110 self.diags.push(diag);
111 continue;
112 },
113 };
114 let id = self.ids.next();
115 let resource = Resource {
116 source,
117 ty: declaration.ty,
118 };
119 self.register_name(name, id, resource.span());
120 self.rune.resources.insert(id, resource);
121 }
122 }
123
124 fn register_stages(&mut self, pipeline: &IndexMap<String, Stage>) {
125 for (name, stage) in pipeline {
126 let span = stage.span();
127
128 match hir::Stage::from_yaml(
129 stage.clone(),
130 &self.rune.resources,
131 &self.rune.names,
132 ) {
133 Ok(stage) => {
134 let id = self.ids.next();
135 self.rune.stages.insert(
136 id,
137 Node {
138 stage,
139 input_slots: Vec::new(),
140 output_slots: Vec::new(),
141 },
142 );
143 self.register_name(name, id, span);
144 },
145 Err(e) => {
146 let diag = Diagnostic::error()
147 .with_message(e.to_string())
148 .with_labels(vec![Label::primary(
149 (),
150 range_span(span),
151 )]);
152 self.diags.push(diag);
153 },
154 }
155 }
156 }
157
158 fn register_output_slots(&mut self, pipeline: &IndexMap<String, Stage>) {
159 for (name, stage) in pipeline {
160 let node_id = match self.rune.names.get_id(name) {
161 Some(id) => id,
162 None => continue,
163 };
164
165 let mut output_slots = Vec::new();
166
167 for ty in stage.output_types() {
168 let element_type = self.intern_type(ty);
169 let id = self.ids.next();
170 self.rune.slots.insert(
171 id,
172 Slot {
173 element_type,
174 input_node: node_id,
175 output_node: HirId::ERROR,
176 },
177 );
178 output_slots.push(id);
179 }
180
181 let node = self.rune.stages.get_mut(&node_id).unwrap();
182 node.output_slots = output_slots;
183 }
184 }
185
186 fn construct_pipeline(&mut self, pipeline: &IndexMap<String, Stage>) {
187 for (name, stage) in pipeline {
188 let node_id = match self.rune.names.get_id(name) {
189 Some(id) => id,
190 None => continue,
191 };
192
193 let mut input_slots = Vec::new();
194
195 for input in stage.inputs() {
196 let incoming_node_id = match self.rune.names.get_id(&input.name)
197 {
198 Some(id) => id,
199 None => {
200 let diag = Diagnostic::error().with_message(format!(
201 "No node associated with \"{}\"",
202 input
203 ));
204 self.diags.push(diag);
205 input_slots.push(HirId::ERROR);
206 continue;
207 },
208 };
209
210 let incoming_node = &self.rune.stages[&incoming_node_id];
211
212 if incoming_node.output_slots.is_empty() {
213 let diag = Diagnostic::error().with_message(format!(
214 "The \"{}\" stage tried to connect to \"{}\" but that stage doesn't have any outputs",
215 name,
216 input
217 ));
218 self.diags.push(diag);
219 input_slots.push(HirId::ERROR);
220 continue;
221 }
222
223 let input_index = input.index.unwrap_or(0);
224 match incoming_node.output_slots.get(input_index) {
225 Some(slot_id) => {
226 input_slots.push(*slot_id);
227 let slot = self.rune.slots.get_mut(slot_id).unwrap();
228 slot.output_node = node_id;
229 },
230 None => {
231 let diag = Diagnostic::error().with_message(format!(
232 "The \"{}\" stage tried to connect to \"{}\" but that stage only has {} outputs",
233 name,
234 input,
235 incoming_node.output_slots.len(),
236 ));
237 self.diags.push(diag);
238 input_slots.push(HirId::ERROR);
239 continue;
240 },
241 }
242 }
243
244 let node = self.rune.stages.get_mut(&node_id).unwrap();
245 node.input_slots = input_slots;
246 }
247 }
248
249 fn intern_type(&mut self, ty: &Type) -> HirId {
250 let underlying_type = match self.primitive_type(&ty.name) {
251 Some(p) => p,
252 None => {
253 let msg = format!("Unknown type: {}", ty.name);
254 let diag = Diagnostic::warning().with_message(msg);
255 self.diags.push(diag);
256 return self.builtins.unknown_type;
257 },
258 };
259
260 let ty = if ty.dimensions.is_empty() {
261 hir::Type::Primitive(underlying_type)
262 } else {
263 hir::Type::Buffer {
264 underlying_type: self.builtins.get_id(underlying_type),
265 dimensions: ty.dimensions.clone(),
266 }
267 };
268
269 match self.rune.types.iter().find(|(_, t)| **t == ty) {
270 Some((id, _)) => *id,
271 None => {
272 let id = self.ids.next();
274 self.rune.types.insert(id, ty);
275 id
276 },
277 }
278 }
279
280 fn primitive_type(&mut self, name: &str) -> Option<Primitive> {
281 match name {
282 "u8" | "U8" => Some(Primitive::U8),
283 "i8" | "I8" => Some(Primitive::I8),
284 "u16" | "U16" => Some(Primitive::U16),
285 "i16" | "I16" => Some(Primitive::I16),
286 "u32" | "U32" => Some(Primitive::U32),
287 "i32" | "I32" => Some(Primitive::I32),
288 "u64" | "U64" => Some(Primitive::U64),
289 "i64" | "I64" => Some(Primitive::I64),
290 "f32" | "F32" => Some(Primitive::F32),
291 "f64" | "F64" => Some(Primitive::F64),
292 "utf8" | "UTF8" => Some(Primitive::String),
293 _ => None,
294 }
295 }
296
297 fn check_for_loops(&mut self) {
298 if let Some(cycle) = self.next_cycle() {
299 let (first, middle) = match cycle.as_slice() {
300 [first, middle @ ..] => (first, middle),
301 _ => unreachable!("A cycle must have at least 2 items"),
302 };
303
304 let mut diag = Diagnostic::error().with_message(format!(
305 "Cycle detected when checking \"{}\"",
306 self.rune.names.get_name(*first).unwrap()
307 ));
308
309 if let Some(span) = self.rune.spans.get(first) {
310 diag = diag.with_labels(vec![Label::primary((), *span)]);
311 }
312
313 let mut notes = Vec::new();
314
315 for middle_id in middle {
316 let msg = format!(
317 "... which receives input from \"{}\"...",
318 self.rune.names.get_name(*middle_id).unwrap()
319 );
320 notes.push(msg);
321 }
322
323 let closing_message = format!(
324 "... which receives input from \"{}\", completing the cycle.",
325 self.rune.names.get_name(*first).unwrap()
326 );
327 notes.push(closing_message);
328
329 self.diags.push(diag.with_notes(notes));
330 }
331 }
332
333 fn next_cycle(&self) -> Option<Vec<HirId>> {
334 let mut stack = VecDeque::new();
336 let mut visited = HashSet::new();
337
338 for id in self.rune.stages.keys().copied() {
339 if detect_cycles(id, &self.rune, &mut visited, &mut stack) {
340 return Some(stack.into());
341 }
342 }
343
344 None
345 }
346}
347
348fn detect_cycles(
349 id: HirId,
350 rune: &Rune,
351 visited: &mut HashSet<HirId>,
352 stack: &mut VecDeque<HirId>,
353) -> bool {
354 if stack.contains(&id) {
355 while stack.front() != Some(&id) {
358 stack.pop_front();
359 }
360
361 return true;
362 } else if visited.contains(&id) {
363 return false;
364 }
365
366 visited.insert(id);
367 stack.push_back(id);
368
369 let incoming_nodes = rune.stages[&id]
370 .input_slots
371 .iter()
372 .map(|slot_id| rune.slots[slot_id].input_node);
373
374 for incoming_node in incoming_nodes {
375 if detect_cycles(incoming_node, rune, visited, stack) {
376 return true;
377 }
378 }
379
380 let got = stack.pop_back();
381 debug_assert_eq!(got, Some(id));
382
383 false
384}
385
386#[cfg(test)]
387mod tests {
388 use super::*;
389
390 macro_rules! map {
391 ($($k:ident : $v:expr),* $(,)?) => {
393 std::iter::Iterator::collect(std::array::IntoIter::new([
394 $(
395 (String::from(stringify!($k)), $v)
396 ),*
397 ]))
398 };
399 ($($v:expr),* $(,)?) => {
401 std::iter::Iterator::collect(std::array::IntoIter::new([$($v,)*]))
402 };
403 }
404
405 macro_rules! ty {
406 ($type:ident [$($dim:expr),*]) => {
407 Type {
408 name: String::from(stringify!($type)),
409 dimensions: vec![ $($dim),*],
410 }
411 };
412 ($type:ident) => {
413 Type {
414 name: String::from(stringify!($type)),
415 dimensions: vec![],
416 }
417 }
418 }
419
420 #[test]
421 fn parse_yaml_pipeline() {
422 let src = r#"
423version: 1
424image: "runicos/base"
425
426pipeline:
427 audio:
428 capability: SOUND
429 outputs:
430 - type: i16
431 dimensions: [16000]
432 args:
433 hz: 16000
434
435 fft:
436 proc-block: "hotg-ai/rune#proc_blocks/fft"
437 inputs:
438 - audio
439 outputs:
440 - type: i8
441 dimensions: [1960]
442
443 model:
444 model: "./model.tflite"
445 inputs:
446 - fft
447 outputs:
448 - type: i8
449 dimensions: [6]
450
451 label:
452 proc-block: "hotg-ai/rune#proc_blocks/ohv_label"
453 inputs:
454 - model
455 outputs:
456 - type: utf8
457 args:
458 labels: ["silence", "unknown", "up", "down", "left", "right"]
459
460 output:
461 out: SERIAL
462 inputs:
463 - label
464 "#;
465 let should_be = Document::V1 {
466 image: Path::new("runicos/base", None, None),
467 pipeline: map! {
468 audio: Stage::Capability {
469 capability: String::from("SOUND"),
470 outputs: vec![ty!(i16[16000])],
471 args: map! { hz: Value::Int(16000) },
472 },
473 output: Stage::Out {
474 out: String::from("SERIAL"),
475 args: IndexMap::new(),
476 inputs: vec!["label".parse().unwrap()],
477 },
478 label: Stage::ProcBlock {
479 proc_block: "hotg-ai/rune#proc_blocks/ohv_label".parse().unwrap(),
480 inputs: vec!["model".parse().unwrap()],
481 outputs: vec![Type { name: String::from("utf8"), dimensions: Vec::new() }],
482 args: map! {
483 labels: Value::from(vec![
484 Value::from("silence"),
485 Value::from("unknown"),
486 Value::from("up"),
487 Value::from("down"),
488 Value::from("left"),
489 Value::from("right"),
490 ]),
491 },
492 },
493 fft: Stage::ProcBlock {
494 proc_block: "hotg-ai/rune#proc_blocks/fft".parse().unwrap(),
495 inputs: vec!["audio".parse().unwrap()],
496 outputs: vec![ty!(i8[1960])],
497 args: IndexMap::new(),
498 },
499 model: Stage::Model {
500 model: "./model.tflite".into(),
501 inputs: vec!["fft".parse().unwrap()],
502 outputs: vec![ty!(i8[6])],
503 },
504 },
505 resources: map![],
506 };
507
508 let got = Document::parse(src).unwrap();
509
510 assert_eq!(got, should_be);
511 }
512
513 #[test]
514 fn parse_audio_block() {
515 let src = r#"
516 capability: SOUND
517 outputs:
518 - type: i16
519 dimensions: [16000]
520 args:
521 hz: 16000
522 "#;
523 let should_be = Stage::Capability {
524 capability: String::from("SOUND"),
525 outputs: vec![Type {
526 name: String::from("i16"),
527 dimensions: vec![16000],
528 }],
529 args: map! { hz: Value::Int(16000) },
530 };
531
532 let got: Stage = serde_yaml::from_str(src).unwrap();
533
534 assert_eq!(got, should_be);
535 }
536
537 #[test]
538 fn parse_values() {
539 let inputs = vec![
540 ("42", Value::Int(42)),
541 ("3.14", Value::Float(3.14)),
542 ("\"42\"", Value::String("42".into())),
543 (
544 "[1, 2.0, \"asdf\"]",
545 Value::List(vec![
546 Value::Int(1),
547 Value::Float(2.0),
548 Value::String("asdf".into()),
549 ]),
550 ),
551 ];
552
553 for (src, should_be) in inputs {
554 let got: Value = serde_yaml::from_str(src).unwrap();
555 assert_eq!(got, should_be);
556 }
557 }
558
559 #[test]
560 fn parse_paths() {
561 let inputs = vec![
562 ("asdf", Path::new("asdf", None, None)),
563 ("runicos/base", Path::new("runicos/base", None, None)),
564 (
565 "runicos/base@0.1.2",
566 Path::new("runicos/base", None, "0.1.2".to_string()),
567 ),
568 (
569 "runicos/base@latest",
570 Path::new("runicos/base", None, "latest".to_string()),
571 ),
572 (
573 "https://github.com/hotg-ai/rune",
574 Path::new("https://github.com/hotg-ai/rune", None, None),
575 ),
576 (
577 "https://github.com/hotg-ai/rune@2",
578 Path::new(
579 "https://github.com/hotg-ai/rune",
580 None,
581 "2".to_string(),
582 ),
583 ),
584 (
585 "hotg-ai/rune@v1.2#proc_blocks/normalize",
586 Path::new(
587 "hotg-ai/rune",
588 "proc_blocks/normalize".to_string(),
589 "v1.2".to_string(),
590 ),
591 ),
592 ];
593
594 for (src, should_be) in inputs {
595 let got: Path = src.parse().unwrap();
596 assert_eq!(got, should_be);
597 }
598 }
599
600 fn dummy_document() -> Document {
601 Document::V1 {
602 image: Path::new("runicos/base".to_string(), None, None),
603 pipeline: map! {
604 audio: Stage::Capability {
605 capability: String::from("SOUND"),
606 outputs: vec![
607 ty!(i16[16000]),
608 ],
609 args: map! {
610 hz: Value::from(16000),
611 },
612 },
613 fft: Stage::ProcBlock {
614 proc_block: "hotg-ai/rune#proc_blocks/fft".parse().unwrap(),
615 inputs: vec!["audio".parse().unwrap()],
616 outputs: vec![
617 ty!(i8[1960]),
618 ],
619 args: IndexMap::new(),
620 },
621 model: Stage::Model {
622 model: "./model.tflite".into(),
623 inputs: vec!["fft".parse().unwrap()],
624 outputs: vec![
625 ty!(i8[6]),
626 ],
627 },
628 label: Stage::ProcBlock {
629 proc_block: "hotg-ai/rune#proc_blocks/ohv_label".parse().unwrap(),
630 inputs: vec!["model".parse().unwrap()],
631 outputs: vec![
632 ty!(utf8),
633 ],
634 args: map! {
635 labels: Value::List(vec![
636 Value::from("silence"),
637 Value::from("unknown"),
638 Value::from("up"),
639 ]),
640 },
641 },
642 output: Stage::Out {
643 out: String::from("SERIAL"),
644 inputs: vec!["label".parse().unwrap()],
645 args: IndexMap::default(),
646 }
647 },
648 resources: map![],
649 }
650 }
651
652 #[test]
653 fn register_all_stages() {
654 let pipeline = match dummy_document() {
655 Document::V1 { pipeline, .. } => pipeline,
656 };
657 let mut diags = Diagnostics::new();
658 let mut ctx = Context::new(&mut diags);
659 let stages = vec!["audio", "fft", "model", "label", "output"];
660
661 ctx.register_stages(&pipeline);
662
663 for stage_name in stages {
664 let id = ctx.rune.names.get_id(stage_name).unwrap();
665 assert!(ctx.rune.stages.contains_key(&id));
666 }
667
668 assert!(diags.is_empty());
669 }
670
671 #[test]
672 fn construct_the_pipeline() {
673 let pipeline = match dummy_document() {
674 Document::V1 { pipeline, .. } => pipeline,
675 };
676 let mut diags = Diagnostics::new();
677 let mut ctx = Context::new(&mut diags);
678 ctx.register_stages(&pipeline);
679 ctx.register_output_slots(&pipeline);
680 let edges = vec![
681 ("audio", "fft"),
682 ("fft", "model"),
683 ("model", "label"),
684 ("label", "output"),
685 ];
686
687 ctx.construct_pipeline(&pipeline);
688
689 assert!(ctx.diags.is_empty(), "{:?}", ctx.diags);
690 for (from, to) in edges {
691 println!("{:?} => {:?}", from, to);
692 let from_id = ctx.rune.names.get_id(from).unwrap();
693 let to_id = ctx.rune.names.get_id(to).unwrap();
694
695 assert!(ctx.rune.has_connection(from_id, to_id));
696 }
697 }
698
699 #[test]
700 fn construct_pipeline_graph_with_multiple_inputs_and_outputs() {
701 let doc = Document::V1 {
702 image: "runicos/base@latest".parse().unwrap(),
703 pipeline: map! {
704 audio: Stage::Capability {
705 capability: String::from("SOUND"),
706 outputs: vec![
707 ty!(i16[16000]),
708 ],
709 args: map! {
710 hz: Value::from(16000),
711 },
712 },
713 fft: Stage::ProcBlock {
714 proc_block: "hotg-ai/rune#proc_blocks/fft".parse().unwrap(),
715 inputs: vec![
716 "audio".parse().unwrap(),
717 "audio".parse().unwrap(),
718 "audio".parse().unwrap(),
719 ],
720 outputs: vec![
721 ty!(i8[1960]),
722 ty!(i8[1960]),
723 ty!(i8[1960]),
724 ],
725 args: IndexMap::new(),
726 },
727 serial: Stage::Out {
728 out: String::from("SERIAL"),
729 inputs: vec![
730 "fft.0".parse().unwrap(),
731 "fft.1".parse().unwrap(),
732 "fft.2".parse().unwrap(),
733 ],
734 args: IndexMap::new(),
735 },
736 },
737 resources: map![],
738 };
739 let mut diags = Diagnostics::new();
740
741 let rune = analyse(&doc, &mut diags);
742
743 assert!(!diags.has_errors() && !diags.has_warnings(), "{:#?}", diags);
744
745 let audio_id = rune.names["audio"];
746 let audio_node = &rune.stages[&audio_id];
747 assert!(audio_node.input_slots.is_empty());
748 assert_eq!(audio_node.output_slots.len(), 1);
749 let audio_output = audio_node.output_slots[0];
750
751 let fft_id = rune.names["fft"];
752 let fft_node = &rune.stages[&fft_id];
753 assert_eq!(
754 fft_node.input_slots,
755 &[audio_output, audio_output, audio_output]
756 );
757
758 let output_id = rune.names["serial"];
759 let output_node = &rune.stages[&output_id];
760 assert_eq!(fft_node.output_slots, output_node.input_slots);
761 }
762
763 #[test]
764 fn topological_sorting() {
765 let doc = dummy_document();
766 let mut diags = Diagnostics::new();
767 let rune = analyse(&doc, &mut diags);
768 let should_be = ["audio", "fft", "model", "label", "output"];
769
770 let got: Vec<_> = rune.sorted_pipeline().collect();
771
772 let should_be: Vec<_> = should_be
773 .iter()
774 .copied()
775 .map(|name| rune.names.get_id(name).unwrap())
776 .map(|id| (id, &rune.stages[&id]))
777 .collect();
778 assert_eq!(got, should_be);
779 }
780
781 #[test]
782 fn detect_pipeline_cycle() {
783 let src = r#"
784image: runicos/base
785version: 1
786
787pipeline:
788 audio:
789 proc-block: "hotg-ai/rune#proc_blocks/fft"
790 inputs:
791 - model
792 outputs:
793 - type: i16
794 dimensions: [16000]
795
796 fft:
797 proc-block: "hotg-ai/rune#proc_blocks/fft"
798 inputs:
799 - audio
800 outputs:
801 - type: i8
802 dimensions: [1960]
803
804 model:
805 model: "./model.tflite"
806 inputs:
807 - fft
808 outputs:
809 - type: i8
810 dimensions: [6]
811 "#;
812 let doc = Document::parse(src).unwrap();
813 let mut diags = Diagnostics::new();
814
815 let _ = analyse(&doc, &mut diags);
816
817 assert!(diags.has_errors());
818 let errors: Vec<_> = diags
819 .iter_severity(codespan_reporting::diagnostic::Severity::Error)
820 .collect();
821 assert_eq!(errors.len(), 1);
822 let diag = errors[0];
823 assert_eq!(diag.message, "Cycle detected when checking \"audio\"");
824 assert!(diag.notes[0].contains("model"));
825 assert!(diag.notes[1].contains("fft"));
826 assert_eq!(
827 diag.notes[2],
828 "... which receives input from \"audio\", completing the cycle."
829 );
830 }
831}