cao_lang/compiler/
module.rs

1//! The public representation of a program
2//!
3
4#[cfg(test)]
5mod tests;
6
7use crate::compiler::Function;
8use crate::prelude::{CompilationErrorPayload, Handle};
9use smallvec::SmallVec;
10use std::collections::hash_map::DefaultHasher;
11use std::hash::Hasher;
12use std::rc::Rc;
13use thiserror::Error;
14
15use super::function_ir::FunctionIr;
16use super::{Card, ImportsIr};
17
18#[derive(Debug, Clone, Error)]
19pub enum IntoStreamError {
20    #[error("Main function by name {0} was not found")]
21    MainFnNotFound(String),
22    #[error("{0:?} is not a valid name")]
23    BadName(String),
24}
25
26pub type CaoProgram = Module;
27pub type CaoIdentifier = String;
28pub type Imports = Vec<CaoIdentifier>;
29pub type Functions = Vec<(CaoIdentifier, Function)>;
30pub type Submodules = Vec<(CaoIdentifier, Module)>;
31
32#[derive(Debug, Clone, Default)]
33#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
34pub struct Module {
35    pub submodules: Submodules,
36    pub functions: Functions,
37    /// _functions_ to import from submodules
38    ///
39    /// e.g. importing `foo.bar` allows you to use a `Jump("bar")` [Card]
40    pub imports: Imports,
41}
42
43/// Uniquely index a card in a module
44#[derive(Default, Debug, Clone, PartialEq, Eq, Hash)]
45#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
46pub struct CardIndex {
47    pub function: usize,
48    pub card_index: FunctionCardIndex,
49}
50
51impl PartialOrd for CardIndex {
52    fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
53        Some(self.cmp(other))
54    }
55}
56
57impl Ord for CardIndex {
58    fn cmp(&self, other: &Self) -> std::cmp::Ordering {
59        match self.function.cmp(&other.function) {
60            std::cmp::Ordering::Equal => {}
61            c @ std::cmp::Ordering::Less | c @ std::cmp::Ordering::Greater => return c,
62        }
63        for (lhs, rhs) in self
64            .card_index
65            .indices
66            .iter()
67            .zip(other.card_index.indices.iter())
68        {
69            match lhs.cmp(&rhs) {
70                std::cmp::Ordering::Equal => {}
71                c @ std::cmp::Ordering::Less | c @ std::cmp::Ordering::Greater => return c,
72            }
73        }
74        self.card_index
75            .indices
76            .len()
77            .cmp(&other.card_index.indices.len())
78    }
79}
80
81impl CardIndex {
82    pub fn function(function: usize) -> Self {
83        Self {
84            function,
85            ..Default::default()
86        }
87    }
88
89    pub fn new(function: usize, card_index: usize) -> Self {
90        Self {
91            function,
92            card_index: FunctionCardIndex::new(card_index),
93        }
94    }
95
96    pub fn from_slice(function: usize, indices: &[u32]) -> Self {
97        let mut card_index = FunctionCardIndex {
98            indices: SmallVec::with_capacity(indices.len()),
99        };
100        card_index.indices.extend_from_slice(indices);
101        Self {
102            function,
103            card_index,
104        }
105    }
106
107    pub fn push_subindex(&mut self, i: u32) {
108        self.card_index.indices.push(i);
109    }
110
111    pub fn pop_subindex(&mut self) {
112        self.card_index.indices.pop();
113    }
114
115    pub fn as_handle(&self) -> crate::prelude::Handle {
116        let function_handle = crate::prelude::Handle::from_u64(self.function as u64);
117        let subindices = self.card_index.indices.as_slice();
118        let sub_handle = unsafe {
119            crate::prelude::Handle::from_bytes(std::slice::from_raw_parts(
120                subindices.as_ptr().cast(),
121                subindices.len() * 4,
122            ))
123        };
124        function_handle + sub_handle
125    }
126
127    /// pushes a new sub-index to the bottom layer
128    #[must_use]
129    pub fn with_sub_index(mut self, card_index: usize) -> Self {
130        self.push_subindex(card_index as u32);
131        self
132    }
133
134    pub fn current_index(&self) -> usize {
135        self.card_index.current_index()
136    }
137
138    /// Replaces the card index of the leaf node
139    pub fn with_current_index(mut self, card_index: usize) -> Self {
140        self.card_index.set_current_index(card_index);
141        self
142    }
143
144    pub fn set_current_index(&mut self, card_index: usize) {
145        self.card_index.set_current_index(card_index);
146    }
147
148    /// first card's index in the function
149    pub fn begin(&self) -> Result<usize, CardFetchError> {
150        self.card_index.begin()
151    }
152
153    /// Return wether this index points to a 'top level' card in the function.
154    /// Instead of a nested card.
155    pub fn is_top_level_card(&self) -> bool {
156        self.card_index.indices.len() == 1
157    }
158}
159
160impl std::fmt::Display for CardIndex {
161    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
162        write!(f, "{}", self.function)?;
163        for i in self.card_index.indices.iter() {
164            write!(f, ".{}", i)?;
165        }
166        Ok(())
167    }
168}
169
170#[derive(Default, Debug, Clone, PartialEq, Eq, Hash)]
171#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
172pub struct FunctionCardIndex {
173    pub indices: SmallVec<[u32; 4]>,
174}
175
176impl FunctionCardIndex {
177    #[must_use]
178    pub fn new(card_index: usize) -> Self {
179        Self {
180            indices: smallvec::smallvec![card_index as u32],
181        }
182    }
183
184    pub fn depth(&self) -> usize {
185        self.indices.len()
186    }
187
188    /// pushes a new sub-index to the bottom layer
189    #[must_use]
190    pub fn with_sub_index(mut self, card_index: usize) -> Self {
191        self.push_sub_index(card_index);
192        self
193    }
194
195    pub fn push_sub_index(&mut self, card_index: usize) {
196        self.indices.push(card_index as u32);
197    }
198
199    #[must_use]
200    pub fn current_index(&self) -> usize {
201        self.indices.last().copied().unwrap_or(0) as usize
202    }
203
204    /// Replaces the card index of the leaf node
205    #[must_use]
206    pub fn with_current_index(mut self, card_index: usize) -> Self {
207        self.set_current_index(card_index);
208        self
209    }
210
211    pub fn set_current_index(&mut self, card_index: usize) {
212        if let Some(x) = self.indices.last_mut() {
213            *x = card_index as u32;
214        }
215    }
216
217    pub fn begin(&self) -> Result<usize, CardFetchError> {
218        let i = self.indices.first().ok_or(CardFetchError::InvalidIndex)?;
219        Ok(*i as usize)
220    }
221}
222
223#[derive(Debug, Clone, Error)]
224pub enum CardFetchError {
225    #[error("Function not found")]
226    FunctionNotFound,
227    #[error("Card at depth {depth} not found")]
228    CardNotFound { depth: usize },
229    #[error("The card at depth {depth} has no nested functions, but the index tried to fetch one")]
230    NoSubFunction { depth: usize },
231    #[error("The provided index is not valid")]
232    InvalidIndex,
233}
234
235#[derive(Debug, Clone, Error)]
236pub enum SwapError {
237    #[error("Failed to find card {0}: {1}")]
238    FetchError(CardIndex, CardFetchError),
239    #[error("These cards can not be swapped")]
240    InvalidSwap,
241}
242
243impl Module {
244    pub fn get_card_mut<'a>(&'a mut self, idx: &CardIndex) -> Result<&'a mut Card, CardFetchError> {
245        let (_, function) = self
246            .functions
247            .get_mut(idx.function)
248            .ok_or(CardFetchError::FunctionNotFound)?;
249        let mut card = function
250            .cards
251            .get_mut(idx.begin()?)
252            .ok_or(CardFetchError::CardNotFound { depth: 0 })?;
253
254        for (depth, i) in idx.card_index.indices[1..].iter().enumerate() {
255            card = card
256                .get_child_mut(*i as usize)
257                .ok_or(CardFetchError::CardNotFound { depth: depth + 1 })?;
258        }
259
260        Ok(card)
261    }
262
263    pub fn get_card<'a>(&'a self, idx: &CardIndex) -> Result<&'a Card, CardFetchError> {
264        let (_, function) = self
265            .functions
266            .get(idx.function)
267            .ok_or(CardFetchError::FunctionNotFound)?;
268
269        let mut depth = 0;
270        let mut card = function
271            .cards
272            .get(idx.begin()?)
273            .ok_or(CardFetchError::CardNotFound { depth })?;
274
275        for i in &idx.card_index.indices[1..] {
276            depth += 1;
277            card = card
278                .get_child(*i as usize)
279                .ok_or(CardFetchError::CardNotFound { depth })?;
280        }
281
282        Ok(card)
283    }
284
285    /// swapping a parent and child is an error
286    pub fn swap_cards<'a>(
287        &mut self,
288        mut lhs: &'a CardIndex,
289        mut rhs: &'a CardIndex,
290    ) -> Result<(), SwapError> {
291        if lhs < rhs {
292            std::mem::swap(&mut lhs, &mut rhs);
293        }
294
295        let rhs_card = self
296            .replace_card(rhs, Card::ScalarNil)
297            .map_err(|err| SwapError::FetchError(rhs.clone(), err))?;
298
299        // check if lhs is reachable
300        // run the check after taking rhs_card, as this can fail if lhs is a child of rhs
301        if let Err(_) = self.get_card(lhs) {
302            self.replace_card(rhs, rhs_card).unwrap();
303            return Err(SwapError::InvalidSwap);
304        }
305
306        // we know that lhs is reachable so this mustn't err
307        let lhs_card = self.replace_card(lhs, rhs_card).unwrap();
308
309        // we know that rhs is reachable so this mustn't err
310        self.replace_card(rhs, lhs_card).unwrap();
311        Ok(())
312    }
313
314    pub fn remove_card(&mut self, idx: &CardIndex) -> Result<Card, CardFetchError> {
315        let (_, function) = self
316            .functions
317            .get_mut(idx.function)
318            .ok_or(CardFetchError::FunctionNotFound)?;
319        if idx.card_index.indices.len() == 1 {
320            if function.cards.len() <= idx.card_index.indices[0] as usize {
321                return Err(CardFetchError::CardNotFound { depth: 0 });
322            }
323            return Ok(function.cards.remove(idx.card_index.indices[0] as usize));
324        }
325        let mut card = function
326            .cards
327            .get_mut(idx.begin()?)
328            .ok_or(CardFetchError::CardNotFound { depth: 0 })?;
329
330        // len is at least 1
331        let len = idx.card_index.indices.len();
332        for (depth, i) in idx.card_index.indices[1..(len - 1).max(1)]
333            .iter()
334            .enumerate()
335        {
336            card = card
337                .get_child_mut(*i as usize)
338                .ok_or(CardFetchError::CardNotFound { depth: depth + 1 })?;
339        }
340        let i = *idx.card_index.indices.last().unwrap() as usize;
341        card.remove_child(i)
342            .ok_or(CardFetchError::CardNotFound { depth: len - 1 })
343    }
344
345    /// Return the old card
346    pub fn replace_card(&mut self, idx: &CardIndex, child: Card) -> Result<Card, CardFetchError> {
347        self.get_card_mut(idx).map(|c| std::mem::replace(c, child))
348    }
349
350    pub fn insert_card(&mut self, idx: &CardIndex, child: Card) -> Result<(), CardFetchError> {
351        let (_, function) = self
352            .functions
353            .get_mut(idx.function)
354            .ok_or(CardFetchError::FunctionNotFound)?;
355        if idx.card_index.indices.len() == 1 {
356            if function.cards.len() < idx.card_index.indices[0] as usize {
357                return Err(CardFetchError::CardNotFound { depth: 0 });
358            }
359            function
360                .cards
361                .insert(idx.card_index.indices[0] as usize, child);
362            return Ok(());
363        }
364        let mut card = function
365            .cards
366            .get_mut(idx.begin()?)
367            .ok_or(CardFetchError::CardNotFound { depth: 0 })?;
368
369        // len is at least 1
370        let len = idx.card_index.indices.len();
371        for (depth, i) in idx.card_index.indices[1..(len - 1).max(1)]
372            .iter()
373            .enumerate()
374        {
375            card = card
376                .get_child_mut(*i as usize)
377                .ok_or(CardFetchError::CardNotFound { depth: depth + 1 })?;
378        }
379        let i = *idx.card_index.indices.last().unwrap() as usize;
380        card.insert_child(i, child)
381            .map_err(|_| CardFetchError::CardNotFound { depth: len - 1 })
382    }
383
384    /// flatten this program into a vec of functions
385    ///
386    /// called on the root module
387    pub(crate) fn into_ir_stream(
388        mut self,
389        recursion_limit: u32,
390    ) -> Result<Vec<FunctionIr>, CompilationErrorPayload> {
391        // inject the standard library
392        self.submodules
393            .push(("std".to_string(), crate::stdlib::standard_library()));
394
395        self.ensure_invariants(&mut Default::default())?;
396        // the first function is special
397        //
398        let (main_index, _) = self
399            .functions
400            .iter()
401            .enumerate()
402            .find(|(_, (name, _))| name == "main")
403            .ok_or(CompilationErrorPayload::NoMain)?;
404
405        let mut result = Vec::with_capacity(self.functions.len() * self.submodules.len() * 2); // just some dumb heuristic
406
407        let mut namespace = SmallVec::<[_; 16]>::new();
408
409        flatten_module(&self, recursion_limit, &mut namespace, &mut result)?;
410
411        // move the main function to the front
412        result.swap(0, main_index);
413        Ok(result)
414    }
415
416    fn ensure_invariants<'a>(
417        &'a self,
418        aux: &mut std::collections::HashSet<&'a str>,
419    ) -> Result<(), CompilationErrorPayload> {
420        // test that submodule names are unique
421        for (name, _) in self.submodules.iter() {
422            if aux.contains(name.as_str()) {
423                return Err(CompilationErrorPayload::DuplicateModule(name.to_string()));
424            }
425            aux.insert(name.as_str());
426        }
427        for (_, module) in self.submodules.iter() {
428            aux.clear();
429            module.ensure_invariants(aux)?;
430        }
431        Ok(())
432    }
433
434    fn execute_imports(&self) -> Result<ImportsIr, CompilationErrorPayload> {
435        let mut result = ImportsIr::with_capacity(self.imports.len());
436
437        for import in self.imports.iter() {
438            let import = import.as_str();
439
440            match import.rsplit_once('.') {
441                Some((_, name)) => {
442                    if result.contains_key(name) {
443                        return Err(CompilationErrorPayload::AmbigousImport(import.to_string()));
444                    }
445                    result.insert(name.to_string(), import.to_string());
446                }
447                None => {
448                    return Err(CompilationErrorPayload::BadImport(import.to_string()));
449                }
450            }
451        }
452
453        Ok(result)
454    }
455
456    /// Hash the keys in the program.
457    ///
458    /// Keys = functions, submodules, card names.
459    pub fn compute_keys_hash(&self) -> u64 {
460        let mut hasher = DefaultHasher::new();
461        hash_module(&mut hasher, self);
462        hasher.finish()
463    }
464
465    pub fn lookup_submodule(&self, target: &str) -> Option<&Module> {
466        let mut current = self;
467        for submodule_name in target.split('.') {
468            current = current
469                .submodules
470                .iter()
471                .find(|(name, _)| name == submodule_name)
472                .map(|(_, m)| m)?;
473        }
474        Some(current)
475    }
476
477    pub fn lookup_submodule_mut(&mut self, target: &str) -> Option<&mut Module> {
478        let mut current = self;
479        for submodule_name in target.split('.') {
480            current = current
481                .submodules
482                .iter_mut()
483                .find(|(name, _)| name == submodule_name)
484                .map(|(_, m)| m)?;
485        }
486        Some(current)
487    }
488
489    pub fn lookup_function(&self, target: &str) -> Option<&Function> {
490        let Some((submodule, function)) = target.rsplit_once('.') else {
491            return self
492                .functions
493                .iter()
494                .find(|(name, _)| name == target)
495                .map(|(_, l)| l);
496        };
497        let module = self.lookup_submodule(submodule)?;
498        module.lookup_function(function)
499    }
500
501    pub fn lookup_function_mut(&mut self, target: &str) -> Option<&mut Function> {
502        let Some((submodule, function)) = target.rsplit_once('.') else {
503            return self
504                .functions
505                .iter_mut()
506                .find(|(name, _)| name == target)
507                .map(|(_, l)| l);
508        };
509        let module = self.lookup_submodule_mut(submodule)?;
510        module.lookup_function_mut(function)
511    }
512
513    /// Visits all cards in the module recursively
514    ///
515    /// ```
516    /// use cao_lang::prelude::*;
517    /// # use std::collections::HashSet;
518    /// # use cao_lang::compiler::FunctionCardIndex;
519    /// # use smallvec::smallvec;
520    ///
521    /// let mut program = CaoProgram {
522    ///     imports: Default::default(),
523    ///     submodules: Default::default(),
524    ///     functions: [
525    ///         (
526    ///             "main".into(),
527    ///             Function::default().with_card(Card::IfTrue(Box::new([
528    ///                 Card::ScalarInt(42),
529    ///                 Card::call_function("pooh", vec![]),
530    ///             ]))),
531    ///         ),
532    ///         (
533    ///             "pooh".into(),
534    ///             Function::default().with_card(Card::set_global_var("result", Card::ScalarInt(69))),
535    ///         ),
536    ///     ]
537    ///     .into(),
538    /// };
539    ///
540    /// # let mut visited = HashSet::new();
541    /// program.walk_cards_mut(|id, card| {
542    ///     // use id, card
543    /// #   visited.insert(id.clone());
544    /// });
545    ///
546    /// # assert_eq!(visited.len(), 5);
547    /// # let expected: HashSet<_> = [ CardIndex {
548    /// #      function: 0,
549    /// #      card_index: FunctionCardIndex {
550    /// #          indices: smallvec![
551    /// #              0,
552    /// #              0,
553    /// #          ],
554    /// #      },
555    /// #  },
556    /// #  CardIndex {
557    /// #      function: 0,
558    /// #      card_index: FunctionCardIndex {
559    /// #          indices: smallvec![
560    /// #              0,
561    /// #              1,
562    /// #          ],
563    /// #      },
564    /// #  },
565    /// #  CardIndex {
566    /// #      function: 1,
567    /// #      card_index: FunctionCardIndex {
568    /// #          indices: smallvec![
569    /// #              0,
570    /// #          ],
571    /// #      },
572    /// #  },
573    /// #  CardIndex {
574    /// #      function: 1,
575    /// #      card_index: FunctionCardIndex {
576    /// #          indices: smallvec![
577    /// #              0,
578    /// #              0,
579    /// #          ],
580    /// #      },
581    /// #  },
582    /// #  CardIndex {
583    /// #      function: 0,
584    /// #      card_index: FunctionCardIndex {
585    /// #          indices: smallvec![
586    /// #              0,
587    /// #          ],
588    /// #      },
589    /// #  },
590    /// # ].into();
591    /// # assert_eq!(visited, expected);
592    /// ```
593    pub fn walk_cards_mut(&mut self, mut op: impl FnMut(&CardIndex, &mut Card)) {
594        let mut id = CardIndex::function(0);
595
596        for (i, (_, f)) in self.functions.iter_mut().enumerate() {
597            id.function = i;
598            for (j, c) in f.cards.iter_mut().enumerate() {
599                id.push_subindex(j as u32);
600                op(&id, c);
601                visit_children_mut(c, &mut id, &mut op);
602                id.pop_subindex();
603            }
604        }
605    }
606
607    pub fn walk_cards(&mut self, mut op: impl FnMut(&CardIndex, &Card)) {
608        let mut id = CardIndex::function(0);
609
610        for (i, (_, f)) in self.functions.iter_mut().enumerate() {
611            id.function = i;
612            for (j, c) in f.cards.iter_mut().enumerate() {
613                id.push_subindex(j as u32);
614                op(&id, c);
615                visit_children(c, &mut id, &mut op);
616                id.pop_subindex();
617            }
618        }
619    }
620}
621
622fn visit_children_mut(
623    card: &mut Card,
624    id: &mut CardIndex,
625    op: &mut impl FnMut(&CardIndex, &mut Card),
626) {
627    id.push_subindex(0);
628    for (k, child) in card.iter_children_mut().enumerate() {
629        id.set_current_index(k);
630        op(&id, child);
631        visit_children_mut(child, id, op);
632    }
633    id.pop_subindex();
634}
635
636fn visit_children(card: &Card, id: &mut CardIndex, op: &mut impl FnMut(&CardIndex, &Card)) {
637    id.push_subindex(0);
638    for (k, child) in card.iter_children().enumerate() {
639        id.set_current_index(k);
640        op(&id, child);
641        visit_children(child, id, op);
642    }
643    id.pop_subindex();
644}
645
646fn hash_module(hasher: &mut impl Hasher, module: &Module) {
647    for (name, function) in module.functions.iter() {
648        hasher.write(name.as_str().as_bytes());
649        hash_function(hasher, function);
650    }
651    for (name, submodule) in module.submodules.iter() {
652        hasher.write(name.as_str().as_bytes());
653        hash_module(hasher, submodule);
654    }
655}
656
657fn hash_function(hasher: &mut impl Hasher, function: &Function) {
658    for card in function.cards.iter() {
659        hasher.write(card.name().as_bytes());
660    }
661}
662
663fn flatten_module<'a>(
664    module: &'a Module,
665    recursion_limit: u32,
666    namespace: &mut SmallVec<[&'a str; 16]>,
667    out: &mut Vec<FunctionIr>,
668) -> Result<(), CompilationErrorPayload> {
669    if namespace.len() >= recursion_limit as usize {
670        return Err(CompilationErrorPayload::RecursionLimitReached(
671            recursion_limit,
672        ));
673    }
674    if out.capacity() - out.len() < module.functions.len() {
675        out.reserve(module.functions.len() - (out.capacity() - out.len()));
676    }
677    let imports = Rc::new(module.execute_imports()?);
678    for (function_id, (name, function)) in module.functions.iter().enumerate() {
679        if !is_name_valid(name.as_ref()) {
680            return Err(CompilationErrorPayload::BadFunctionName(name.to_string()));
681        }
682        namespace.push(name.as_ref());
683        out.push(function_to_function_ir(
684            out.len(),
685            function_id,
686            function,
687            namespace,
688            Rc::clone(&imports),
689        ));
690        namespace.pop();
691    }
692    for (name, submod) in module.submodules.iter() {
693        namespace.push(name.as_ref());
694        flatten_module(submod, recursion_limit, namespace, out)?;
695        namespace.pop();
696    }
697    Ok(())
698}
699
700fn function_to_function_ir(
701    i: usize,
702    function_id: usize,
703    function: &Function,
704    namespace: &[&str],
705    imports: Rc<ImportsIr>,
706) -> FunctionIr {
707    assert!(
708        !namespace.is_empty(),
709        "Assume that function name is the last entry in namespace"
710    );
711
712    let mut cl = FunctionIr {
713        function_index: function_id,
714        name: namespace.last().unwrap().to_string().into_boxed_str(),
715        arguments: function.arguments.clone().into_boxed_slice(),
716        cards: function.cards.clone().into_boxed_slice(),
717        imports,
718        namespace: Default::default(),
719        handle: Handle::from_u64(i as u64),
720    };
721    cl.namespace.extend(
722        namespace
723            .iter()
724            .take(namespace.len() - 1)
725            .map(|x| x.to_string().into_boxed_str()),
726    );
727    cl
728}
729
730fn is_name_valid(name: &str) -> bool {
731    !name.contains(|c: char| !c.is_alphanumeric() && c != '_')
732        && !name.is_empty()
733        && name != "super" // `super` is a reserved identifier
734}