1#[cfg(test)]
5mod tests;
6
7use crate::compiler::Function;
8use crate::prelude::{CompilationErrorPayload, Handle};
9use smallvec::SmallVec;
10use std::collections::hash_map::DefaultHasher;
11use std::hash::Hasher;
12use std::rc::Rc;
13use thiserror::Error;
14
15use super::function_ir::FunctionIr;
16use super::{Card, ImportsIr};
17
18#[derive(Debug, Clone, Error)]
19pub enum IntoStreamError {
20 #[error("Main function by name {0} was not found")]
21 MainFnNotFound(String),
22 #[error("{0:?} is not a valid name")]
23 BadName(String),
24}
25
26pub type CaoProgram = Module;
27pub type CaoIdentifier = String;
28pub type Imports = Vec<CaoIdentifier>;
29pub type Functions = Vec<(CaoIdentifier, Function)>;
30pub type Submodules = Vec<(CaoIdentifier, Module)>;
31
32#[derive(Debug, Clone, Default)]
33#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
34pub struct Module {
35 pub submodules: Submodules,
36 pub functions: Functions,
37 pub imports: Imports,
41}
42
43#[derive(Default, Debug, Clone, PartialEq, Eq, Hash)]
45#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
46pub struct CardIndex {
47 pub function: usize,
48 pub card_index: FunctionCardIndex,
49}
50
51impl PartialOrd for CardIndex {
52 fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
53 Some(self.cmp(other))
54 }
55}
56
57impl Ord for CardIndex {
58 fn cmp(&self, other: &Self) -> std::cmp::Ordering {
59 match self.function.cmp(&other.function) {
60 std::cmp::Ordering::Equal => {}
61 c @ std::cmp::Ordering::Less | c @ std::cmp::Ordering::Greater => return c,
62 }
63 for (lhs, rhs) in self
64 .card_index
65 .indices
66 .iter()
67 .zip(other.card_index.indices.iter())
68 {
69 match lhs.cmp(&rhs) {
70 std::cmp::Ordering::Equal => {}
71 c @ std::cmp::Ordering::Less | c @ std::cmp::Ordering::Greater => return c,
72 }
73 }
74 self.card_index
75 .indices
76 .len()
77 .cmp(&other.card_index.indices.len())
78 }
79}
80
81impl CardIndex {
82 pub fn function(function: usize) -> Self {
83 Self {
84 function,
85 ..Default::default()
86 }
87 }
88
89 pub fn new(function: usize, card_index: usize) -> Self {
90 Self {
91 function,
92 card_index: FunctionCardIndex::new(card_index),
93 }
94 }
95
96 pub fn from_slice(function: usize, indices: &[u32]) -> Self {
97 let mut card_index = FunctionCardIndex {
98 indices: SmallVec::with_capacity(indices.len()),
99 };
100 card_index.indices.extend_from_slice(indices);
101 Self {
102 function,
103 card_index,
104 }
105 }
106
107 pub fn push_subindex(&mut self, i: u32) {
108 self.card_index.indices.push(i);
109 }
110
111 pub fn pop_subindex(&mut self) {
112 self.card_index.indices.pop();
113 }
114
115 pub fn as_handle(&self) -> crate::prelude::Handle {
116 let function_handle = crate::prelude::Handle::from_u64(self.function as u64);
117 let subindices = self.card_index.indices.as_slice();
118 let sub_handle = unsafe {
119 crate::prelude::Handle::from_bytes(std::slice::from_raw_parts(
120 subindices.as_ptr().cast(),
121 subindices.len() * 4,
122 ))
123 };
124 function_handle + sub_handle
125 }
126
127 #[must_use]
129 pub fn with_sub_index(mut self, card_index: usize) -> Self {
130 self.push_subindex(card_index as u32);
131 self
132 }
133
134 pub fn current_index(&self) -> usize {
135 self.card_index.current_index()
136 }
137
138 pub fn with_current_index(mut self, card_index: usize) -> Self {
140 self.card_index.set_current_index(card_index);
141 self
142 }
143
144 pub fn set_current_index(&mut self, card_index: usize) {
145 self.card_index.set_current_index(card_index);
146 }
147
148 pub fn begin(&self) -> Result<usize, CardFetchError> {
150 self.card_index.begin()
151 }
152
153 pub fn is_top_level_card(&self) -> bool {
156 self.card_index.indices.len() == 1
157 }
158}
159
160impl std::fmt::Display for CardIndex {
161 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
162 write!(f, "{}", self.function)?;
163 for i in self.card_index.indices.iter() {
164 write!(f, ".{}", i)?;
165 }
166 Ok(())
167 }
168}
169
170#[derive(Default, Debug, Clone, PartialEq, Eq, Hash)]
171#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
172pub struct FunctionCardIndex {
173 pub indices: SmallVec<[u32; 4]>,
174}
175
176impl FunctionCardIndex {
177 #[must_use]
178 pub fn new(card_index: usize) -> Self {
179 Self {
180 indices: smallvec::smallvec![card_index as u32],
181 }
182 }
183
184 pub fn depth(&self) -> usize {
185 self.indices.len()
186 }
187
188 #[must_use]
190 pub fn with_sub_index(mut self, card_index: usize) -> Self {
191 self.push_sub_index(card_index);
192 self
193 }
194
195 pub fn push_sub_index(&mut self, card_index: usize) {
196 self.indices.push(card_index as u32);
197 }
198
199 #[must_use]
200 pub fn current_index(&self) -> usize {
201 self.indices.last().copied().unwrap_or(0) as usize
202 }
203
204 #[must_use]
206 pub fn with_current_index(mut self, card_index: usize) -> Self {
207 self.set_current_index(card_index);
208 self
209 }
210
211 pub fn set_current_index(&mut self, card_index: usize) {
212 if let Some(x) = self.indices.last_mut() {
213 *x = card_index as u32;
214 }
215 }
216
217 pub fn begin(&self) -> Result<usize, CardFetchError> {
218 let i = self.indices.first().ok_or(CardFetchError::InvalidIndex)?;
219 Ok(*i as usize)
220 }
221}
222
223#[derive(Debug, Clone, Error)]
224pub enum CardFetchError {
225 #[error("Function not found")]
226 FunctionNotFound,
227 #[error("Card at depth {depth} not found")]
228 CardNotFound { depth: usize },
229 #[error("The card at depth {depth} has no nested functions, but the index tried to fetch one")]
230 NoSubFunction { depth: usize },
231 #[error("The provided index is not valid")]
232 InvalidIndex,
233}
234
235#[derive(Debug, Clone, Error)]
236pub enum SwapError {
237 #[error("Failed to find card {0}: {1}")]
238 FetchError(CardIndex, CardFetchError),
239 #[error("These cards can not be swapped")]
240 InvalidSwap,
241}
242
243impl Module {
244 pub fn get_card_mut<'a>(&'a mut self, idx: &CardIndex) -> Result<&'a mut Card, CardFetchError> {
245 let (_, function) = self
246 .functions
247 .get_mut(idx.function)
248 .ok_or(CardFetchError::FunctionNotFound)?;
249 let mut card = function
250 .cards
251 .get_mut(idx.begin()?)
252 .ok_or(CardFetchError::CardNotFound { depth: 0 })?;
253
254 for (depth, i) in idx.card_index.indices[1..].iter().enumerate() {
255 card = card
256 .get_child_mut(*i as usize)
257 .ok_or(CardFetchError::CardNotFound { depth: depth + 1 })?;
258 }
259
260 Ok(card)
261 }
262
263 pub fn get_card<'a>(&'a self, idx: &CardIndex) -> Result<&'a Card, CardFetchError> {
264 let (_, function) = self
265 .functions
266 .get(idx.function)
267 .ok_or(CardFetchError::FunctionNotFound)?;
268
269 let mut depth = 0;
270 let mut card = function
271 .cards
272 .get(idx.begin()?)
273 .ok_or(CardFetchError::CardNotFound { depth })?;
274
275 for i in &idx.card_index.indices[1..] {
276 depth += 1;
277 card = card
278 .get_child(*i as usize)
279 .ok_or(CardFetchError::CardNotFound { depth })?;
280 }
281
282 Ok(card)
283 }
284
285 pub fn swap_cards<'a>(
287 &mut self,
288 mut lhs: &'a CardIndex,
289 mut rhs: &'a CardIndex,
290 ) -> Result<(), SwapError> {
291 if lhs < rhs {
292 std::mem::swap(&mut lhs, &mut rhs);
293 }
294
295 let rhs_card = self
296 .replace_card(rhs, Card::ScalarNil)
297 .map_err(|err| SwapError::FetchError(rhs.clone(), err))?;
298
299 if let Err(_) = self.get_card(lhs) {
302 self.replace_card(rhs, rhs_card).unwrap();
303 return Err(SwapError::InvalidSwap);
304 }
305
306 let lhs_card = self.replace_card(lhs, rhs_card).unwrap();
308
309 self.replace_card(rhs, lhs_card).unwrap();
311 Ok(())
312 }
313
314 pub fn remove_card(&mut self, idx: &CardIndex) -> Result<Card, CardFetchError> {
315 let (_, function) = self
316 .functions
317 .get_mut(idx.function)
318 .ok_or(CardFetchError::FunctionNotFound)?;
319 if idx.card_index.indices.len() == 1 {
320 if function.cards.len() <= idx.card_index.indices[0] as usize {
321 return Err(CardFetchError::CardNotFound { depth: 0 });
322 }
323 return Ok(function.cards.remove(idx.card_index.indices[0] as usize));
324 }
325 let mut card = function
326 .cards
327 .get_mut(idx.begin()?)
328 .ok_or(CardFetchError::CardNotFound { depth: 0 })?;
329
330 let len = idx.card_index.indices.len();
332 for (depth, i) in idx.card_index.indices[1..(len - 1).max(1)]
333 .iter()
334 .enumerate()
335 {
336 card = card
337 .get_child_mut(*i as usize)
338 .ok_or(CardFetchError::CardNotFound { depth: depth + 1 })?;
339 }
340 let i = *idx.card_index.indices.last().unwrap() as usize;
341 card.remove_child(i)
342 .ok_or(CardFetchError::CardNotFound { depth: len - 1 })
343 }
344
345 pub fn replace_card(&mut self, idx: &CardIndex, child: Card) -> Result<Card, CardFetchError> {
347 self.get_card_mut(idx).map(|c| std::mem::replace(c, child))
348 }
349
350 pub fn insert_card(&mut self, idx: &CardIndex, child: Card) -> Result<(), CardFetchError> {
351 let (_, function) = self
352 .functions
353 .get_mut(idx.function)
354 .ok_or(CardFetchError::FunctionNotFound)?;
355 if idx.card_index.indices.len() == 1 {
356 if function.cards.len() < idx.card_index.indices[0] as usize {
357 return Err(CardFetchError::CardNotFound { depth: 0 });
358 }
359 function
360 .cards
361 .insert(idx.card_index.indices[0] as usize, child);
362 return Ok(());
363 }
364 let mut card = function
365 .cards
366 .get_mut(idx.begin()?)
367 .ok_or(CardFetchError::CardNotFound { depth: 0 })?;
368
369 let len = idx.card_index.indices.len();
371 for (depth, i) in idx.card_index.indices[1..(len - 1).max(1)]
372 .iter()
373 .enumerate()
374 {
375 card = card
376 .get_child_mut(*i as usize)
377 .ok_or(CardFetchError::CardNotFound { depth: depth + 1 })?;
378 }
379 let i = *idx.card_index.indices.last().unwrap() as usize;
380 card.insert_child(i, child)
381 .map_err(|_| CardFetchError::CardNotFound { depth: len - 1 })
382 }
383
384 pub(crate) fn into_ir_stream(
388 mut self,
389 recursion_limit: u32,
390 ) -> Result<Vec<FunctionIr>, CompilationErrorPayload> {
391 self.submodules
393 .push(("std".to_string(), crate::stdlib::standard_library()));
394
395 self.ensure_invariants(&mut Default::default())?;
396 let (main_index, _) = self
399 .functions
400 .iter()
401 .enumerate()
402 .find(|(_, (name, _))| name == "main")
403 .ok_or(CompilationErrorPayload::NoMain)?;
404
405 let mut result = Vec::with_capacity(self.functions.len() * self.submodules.len() * 2); let mut namespace = SmallVec::<[_; 16]>::new();
408
409 flatten_module(&self, recursion_limit, &mut namespace, &mut result)?;
410
411 result.swap(0, main_index);
413 Ok(result)
414 }
415
416 fn ensure_invariants<'a>(
417 &'a self,
418 aux: &mut std::collections::HashSet<&'a str>,
419 ) -> Result<(), CompilationErrorPayload> {
420 for (name, _) in self.submodules.iter() {
422 if aux.contains(name.as_str()) {
423 return Err(CompilationErrorPayload::DuplicateModule(name.to_string()));
424 }
425 aux.insert(name.as_str());
426 }
427 for (_, module) in self.submodules.iter() {
428 aux.clear();
429 module.ensure_invariants(aux)?;
430 }
431 Ok(())
432 }
433
434 fn execute_imports(&self) -> Result<ImportsIr, CompilationErrorPayload> {
435 let mut result = ImportsIr::with_capacity(self.imports.len());
436
437 for import in self.imports.iter() {
438 let import = import.as_str();
439
440 match import.rsplit_once('.') {
441 Some((_, name)) => {
442 if result.contains_key(name) {
443 return Err(CompilationErrorPayload::AmbigousImport(import.to_string()));
444 }
445 result.insert(name.to_string(), import.to_string());
446 }
447 None => {
448 return Err(CompilationErrorPayload::BadImport(import.to_string()));
449 }
450 }
451 }
452
453 Ok(result)
454 }
455
456 pub fn compute_keys_hash(&self) -> u64 {
460 let mut hasher = DefaultHasher::new();
461 hash_module(&mut hasher, self);
462 hasher.finish()
463 }
464
465 pub fn lookup_submodule(&self, target: &str) -> Option<&Module> {
466 let mut current = self;
467 for submodule_name in target.split('.') {
468 current = current
469 .submodules
470 .iter()
471 .find(|(name, _)| name == submodule_name)
472 .map(|(_, m)| m)?;
473 }
474 Some(current)
475 }
476
477 pub fn lookup_submodule_mut(&mut self, target: &str) -> Option<&mut Module> {
478 let mut current = self;
479 for submodule_name in target.split('.') {
480 current = current
481 .submodules
482 .iter_mut()
483 .find(|(name, _)| name == submodule_name)
484 .map(|(_, m)| m)?;
485 }
486 Some(current)
487 }
488
489 pub fn lookup_function(&self, target: &str) -> Option<&Function> {
490 let Some((submodule, function)) = target.rsplit_once('.') else {
491 return self
492 .functions
493 .iter()
494 .find(|(name, _)| name == target)
495 .map(|(_, l)| l);
496 };
497 let module = self.lookup_submodule(submodule)?;
498 module.lookup_function(function)
499 }
500
501 pub fn lookup_function_mut(&mut self, target: &str) -> Option<&mut Function> {
502 let Some((submodule, function)) = target.rsplit_once('.') else {
503 return self
504 .functions
505 .iter_mut()
506 .find(|(name, _)| name == target)
507 .map(|(_, l)| l);
508 };
509 let module = self.lookup_submodule_mut(submodule)?;
510 module.lookup_function_mut(function)
511 }
512
513 pub fn walk_cards_mut(&mut self, mut op: impl FnMut(&CardIndex, &mut Card)) {
594 let mut id = CardIndex::function(0);
595
596 for (i, (_, f)) in self.functions.iter_mut().enumerate() {
597 id.function = i;
598 for (j, c) in f.cards.iter_mut().enumerate() {
599 id.push_subindex(j as u32);
600 op(&id, c);
601 visit_children_mut(c, &mut id, &mut op);
602 id.pop_subindex();
603 }
604 }
605 }
606
607 pub fn walk_cards(&mut self, mut op: impl FnMut(&CardIndex, &Card)) {
608 let mut id = CardIndex::function(0);
609
610 for (i, (_, f)) in self.functions.iter_mut().enumerate() {
611 id.function = i;
612 for (j, c) in f.cards.iter_mut().enumerate() {
613 id.push_subindex(j as u32);
614 op(&id, c);
615 visit_children(c, &mut id, &mut op);
616 id.pop_subindex();
617 }
618 }
619 }
620}
621
622fn visit_children_mut(
623 card: &mut Card,
624 id: &mut CardIndex,
625 op: &mut impl FnMut(&CardIndex, &mut Card),
626) {
627 id.push_subindex(0);
628 for (k, child) in card.iter_children_mut().enumerate() {
629 id.set_current_index(k);
630 op(&id, child);
631 visit_children_mut(child, id, op);
632 }
633 id.pop_subindex();
634}
635
636fn visit_children(card: &Card, id: &mut CardIndex, op: &mut impl FnMut(&CardIndex, &Card)) {
637 id.push_subindex(0);
638 for (k, child) in card.iter_children().enumerate() {
639 id.set_current_index(k);
640 op(&id, child);
641 visit_children(child, id, op);
642 }
643 id.pop_subindex();
644}
645
646fn hash_module(hasher: &mut impl Hasher, module: &Module) {
647 for (name, function) in module.functions.iter() {
648 hasher.write(name.as_str().as_bytes());
649 hash_function(hasher, function);
650 }
651 for (name, submodule) in module.submodules.iter() {
652 hasher.write(name.as_str().as_bytes());
653 hash_module(hasher, submodule);
654 }
655}
656
657fn hash_function(hasher: &mut impl Hasher, function: &Function) {
658 for card in function.cards.iter() {
659 hasher.write(card.name().as_bytes());
660 }
661}
662
663fn flatten_module<'a>(
664 module: &'a Module,
665 recursion_limit: u32,
666 namespace: &mut SmallVec<[&'a str; 16]>,
667 out: &mut Vec<FunctionIr>,
668) -> Result<(), CompilationErrorPayload> {
669 if namespace.len() >= recursion_limit as usize {
670 return Err(CompilationErrorPayload::RecursionLimitReached(
671 recursion_limit,
672 ));
673 }
674 if out.capacity() - out.len() < module.functions.len() {
675 out.reserve(module.functions.len() - (out.capacity() - out.len()));
676 }
677 let imports = Rc::new(module.execute_imports()?);
678 for (function_id, (name, function)) in module.functions.iter().enumerate() {
679 if !is_name_valid(name.as_ref()) {
680 return Err(CompilationErrorPayload::BadFunctionName(name.to_string()));
681 }
682 namespace.push(name.as_ref());
683 out.push(function_to_function_ir(
684 out.len(),
685 function_id,
686 function,
687 namespace,
688 Rc::clone(&imports),
689 ));
690 namespace.pop();
691 }
692 for (name, submod) in module.submodules.iter() {
693 namespace.push(name.as_ref());
694 flatten_module(submod, recursion_limit, namespace, out)?;
695 namespace.pop();
696 }
697 Ok(())
698}
699
700fn function_to_function_ir(
701 i: usize,
702 function_id: usize,
703 function: &Function,
704 namespace: &[&str],
705 imports: Rc<ImportsIr>,
706) -> FunctionIr {
707 assert!(
708 !namespace.is_empty(),
709 "Assume that function name is the last entry in namespace"
710 );
711
712 let mut cl = FunctionIr {
713 function_index: function_id,
714 name: namespace.last().unwrap().to_string().into_boxed_str(),
715 arguments: function.arguments.clone().into_boxed_slice(),
716 cards: function.cards.clone().into_boxed_slice(),
717 imports,
718 namespace: Default::default(),
719 handle: Handle::from_u64(i as u64),
720 };
721 cl.namespace.extend(
722 namespace
723 .iter()
724 .take(namespace.len() - 1)
725 .map(|x| x.to_string().into_boxed_str()),
726 );
727 cl
728}
729
730fn is_name_valid(name: &str) -> bool {
731 !name.contains(|c: char| !c.is_alphanumeric() && c != '_')
732 && !name.is_empty()
733 && name != "super" }