1use crate::StdError;
2use crate::codegen::building_block::{ApplyBoundaryConditionConfig, InitialConditionConfig};
3use crate::codegen::input_schema::quantity::{NO_REF_QUANTITY_PATTERN, QuantityEnum};
4use crate::codegen::input_schema::unit::format_unit;
5use derive_more::{Deref, DerefMut, From, FromStr, IntoIterator};
6use indexmap::IndexMap as BaseIndexMap;
7use itertools::Itertools;
8use lazy_static::lazy_static;
9use log::debug;
10use regex::Regex;
11use schemars::{JsonSchema, json_schema};
12use serde::de::Error;
13use serde::de::Visitor;
14use serde::ser::SerializeMap;
15use std::collections::{HashMap, HashSet};
16use std::fmt::Debug;
17use std::fs::{self, File};
18use std::hash::Hash;
19use std::io::Write;
20use std::path::Path;
21use std::str::FromStr;
22use std::sync::LazyLock;
23use symrs::ops::ParseExprError;
24use symrs::system::SystemError;
25pub trait RawRepr {
26 fn raw(&self) -> &str;
27}
28
29pub mod mesh;
30pub mod quantity;
31pub mod range;
32mod reference;
33mod unit;
34
35use quantity::{Length, RANGE_PATTERN, Time};
36
37use mesh::MeshEnum;
38use range::Range;
39use serde::{Deserialize, Serialize};
40use tera::Tera;
41use thiserror::Error;
42
43use crate::codegen::building_block::deal_ii_factory;
44use symrs::{Equation, Expr, Func, Symbol, System, symbol};
45
46#[derive(Deref, DerefMut, Deserialize, Serialize, Clone, Debug, IntoIterator, From)]
47#[from(forward)]
48pub struct IndexMap<K, V>(#[into_iterator(owned, ref, ref_mut)] BaseIndexMap<K, V>)
49where
50 K: Eq + Hash;
51
52impl<K, V> IndexMap<K, V>
53where
54 K: Eq + Hash,
55{
56 pub fn new() -> Self {
57 IndexMap(BaseIndexMap::new())
58 }
59 pub fn with_capacity(capacity: usize) -> Self {
60 IndexMap(BaseIndexMap::with_capacity(capacity))
61 }
62}
63
64impl<K, V> JsonSchema for IndexMap<K, V>
65where
66 K: Eq + Hash + JsonSchema,
67 V: JsonSchema,
68{
69 fn schema_name() -> std::borrow::Cow<'static, str> {
70 format!("Map<{}, {}>", K::schema_name(), V::schema_name()).into()
71 }
72
73 fn json_schema(generator: &mut schemars::SchemaGenerator) -> schemars::Schema {
74 HashMap::<K, V>::json_schema(generator)
75 }
76}
77
78use super::building_block::{
79 Block, BlockRes, BuildingBlockFactory, EquationSetupConfig, ShapeMatrix, ShapeMatrixConfig,
80 SolveUnknownConfig,
81};
82use super::{
83 BuildingBlock,
84 building_block::{
85 BuildingBlockError, DofHandlerConfig, MatrixConfig, SparsityPatternConfig, VectorConfig,
86 },
87};
88
89#[derive(Clone, Debug, Serialize, Deserialize, JsonSchema, Default)]
92pub enum FiniteElement {
93 Q1,
94 #[default]
95 Q2,
96 Q3,
97}
98
99#[derive(Clone, Debug, Serialize, Deserialize, JsonSchema)]
102pub struct Solve {
103 pub equations: Vec<String>,
106
107 pub mesh: String,
110
111 pub element: FiniteElement,
112
113 #[serde(default = "default_dimension")]
117 pub dimension: usize,
118
119 #[serde(default = "default_solving_range")]
122 pub time: Range<Time>,
123
124 pub time_step: Time,
127}
128
129fn default_dimension() -> usize {
130 2
131}
132
133fn default_solving_range() -> Range<Time> {
134 "0 .. 5s".parse().unwrap()
135}
136
137#[derive(Clone, Debug, Serialize, Deserialize, JsonSchema, Default)]
140pub struct GenConfig {
141 #[serde(default)]
144 pub mpi: bool,
145
146 #[serde(default)]
149 pub matrix_free: bool,
150
151 #[serde(default)]
154 pub debug: bool,
155}
156
157#[derive(Clone, Debug, JsonSchema, Serialize, Deserialize)]
158#[serde(rename_all = "snake_case")]
159pub enum QuantityKind {
160 Length,
161 Time,
162 Speed,
163 DiffusionCoefficient,
164 Custom,
165}
166
167impl QuantityKind {
168 fn default_unit(&self) -> Option<&'static str> {
169 match self {
170 QuantityKind::Length => Some("m"),
171 QuantityKind::Time => Some("s"),
172 QuantityKind::Speed => Some("m/s"),
173 QuantityKind::DiffusionCoefficient => Some("m²/s"),
174 QuantityKind::Custom => None,
175 }
176 }
177}
178
179#[derive(Clone, Debug)]
180pub struct Parameter {
181 r#type: QuantityKind,
182 value: f64,
183 unit: Option<String>,
184}
185
186impl Parameter {
187 pub fn value_string(&self) -> String {
188 if let Some(unit) = &self.unit {
189 format!("{} {unit}", self.value)
190 } else {
191 self.value.to_string()
192 }
193 }
194
195 }
209
210impl Serialize for Parameter {
211 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
212 where
213 S: serde::Serializer,
214 {
215 let mut map = serializer.serialize_map(Some(2))?;
216 map.serialize_entry("type", &self.r#type)?;
217 map.serialize_entry(
218 "value",
219 &(if let Some(unit) = &self.unit {
220 format!("{} {unit}", self.value)
221 } else {
222 self.value.to_string()
223 }),
224 )?;
225 map.end()
226 }
227}
228
229impl JsonSchema for Parameter {
230 fn schema_name() -> std::borrow::Cow<'static, str> {
231 "Parameter".into()
232 }
233
234 fn json_schema(generator: &mut schemars::SchemaGenerator) -> schemars::Schema {
235 json_schema!({
236 "type": "object",
237 "title": "Parameter",
238 "required": ["type", "value"],
239 "description": "A parameter is a quantity with a value and a unit. If no unit is specified, the default unit will be used.",
240 "properties": {
241 "type": QuantityKind::json_schema(generator),
242 "value": {
243 "oneOf": [
244 {
245 "type": "string",
246 "pattern": NO_REF_QUANTITY_PATTERN,
247 },
248 {
249 "type": "number"
250 }
251 ]
252 }
253 }
254 })
255 }
256}
257
258struct DeserializeParameterVisitor;
259
260static NO_REF_QUANTITY_RE: LazyLock<Regex> =
261 LazyLock::new(|| Regex::new(NO_REF_QUANTITY_PATTERN).unwrap());
262
263impl<'de> Visitor<'de> for DeserializeParameterVisitor {
264 type Value = Parameter;
265
266 fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
267 formatter.write_str("a parameter")
268 }
269
270 fn visit_map<A>(self, mut map: A) -> Result<Self::Value, A::Error>
271 where
272 A: serde::de::MapAccess<'de>,
273 {
274 let mut r#type: Option<QuantityKind> = None;
275 let mut value: Option<&str> = None;
276 while let Some(key) = map.next_key()? {
277 match key {
278 "type" => r#type = Some(map.next_value()?),
279 "value" => value = Some(map.next_value()?),
280 field => Err(A::Error::unknown_field(field, &["type", "value"]))?,
281 }
282 }
283 let r#type = r#type.ok_or_else(|| A::Error::missing_field("type"))?;
284 let value = value.ok_or_else(|| A::Error::missing_field("value"))?;
285 let captures = NO_REF_QUANTITY_RE
286 .captures(value)
287 .ok_or_else(|| A::Error::invalid_value(serde::de::Unexpected::Str(value), &self))?;
288 let value = captures[1].trim().parse().map_err(|_| {
289 A::Error::invalid_value(
290 serde::de::Unexpected::Str(&captures[1]),
291 &"a string that can be parsed as a float",
292 )
293 })?;
294 let unit = captures[2].trim();
295 let unit = if unit.is_empty() {
296 r#type.default_unit().map(|u| u.to_owned())
297 } else {
298 Some(format_unit(&captures[2]).map_err(|e| A::Error::custom(e.to_string()))?)
299 };
300
301 Ok(Parameter {
302 r#type,
303 value,
304 unit,
305 })
306 }
307}
308
309impl<'de> Deserialize<'de> for Parameter {
310 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
311 where
312 D: serde::Deserializer<'de>,
313 {
314 deserializer.deserialize_map(DeserializeParameterVisitor)
315 }
316}
317
318#[derive(Clone, Debug, Serialize, Deserialize, JsonSchema)]
321pub struct InputSchema {
322 #[serde(rename = "generation")]
323 #[serde(default)]
324 pub gen_conf: GenConfig,
325
326 pub meshes: IndexMap<String, MeshEnum>,
329
330 pub equations: IndexMap<String, Equation>,
333
334 pub parameters: IndexMap<String, QuantityEnum>,
337
338 pub unknowns: IndexMap<String, Unknown>,
341
342 pub functions: IndexMap<String, FunctionDef>,
346
347 pub solve: Solve,
348}
349
350impl InputSchema {
355 pub fn from_yaml(yaml: &str) -> Result<Self, serde_yaml::Error> {
356 serde_yaml::from_str(yaml)
357 }
358}
359
360#[derive(Clone, Debug, Serialize, JsonSchema)]
361#[serde(untagged)]
362pub enum Condition<T> {
363 Value(T),
364 Range(Range<T>),
365}
366
367struct ConditionVisitor<T> {
368 _marker: std::marker::PhantomData<T>,
369}
370
371impl<'de, T> Visitor<'de> for ConditionVisitor<T>
372where
373 T: FromStr,
374 T::Err: StdError + 'static,
375{
376 type Value = Condition<T>;
377
378 fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
379 formatter.write_str("a value or a range")
380 }
381
382 fn visit_str<E>(self, v: &str) -> Result<Self::Value, E>
383 where
384 E: serde::de::Error,
385 {
386 if v.contains("..") {
387 Ok(Condition::Range(v.parse().map_err(|e| E::custom(e))?))
388 } else {
389 Ok(Condition::Value(v.parse().map_err(|e| E::custom(e))?))
390 }
391 }
392}
393
394impl<'de, T> Deserialize<'de> for Condition<T>
395where
396 T: FromStr,
397 T::Err: StdError + 'static,
398{
399 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
400 where
401 D: serde::Deserializer<'de>,
402 {
403 deserializer.deserialize_str(ConditionVisitor {
404 _marker: std::marker::PhantomData,
405 })
406 }
407}
408
409#[derive(Clone, Debug, Serialize, Deserialize, JsonSchema, FromStr, DerefMut, Deref)]
415pub struct FunctionExpression(Box<dyn Expr>);
416
417#[derive(Clone, Debug, Serialize, Deserialize, JsonSchema)]
418pub struct ConditionedFunctionExpression {
419 pub expr: FunctionExpression,
420 pub t: Option<Condition<Time>>,
424
425 pub x: Option<Condition<Length>>,
429
430 pub y: Option<Condition<Length>>,
434
435 pub z: Option<Condition<Length>>,
439}
440
441#[derive(Clone, Debug, Serialize, JsonSchema)]
445#[serde(untagged)]
446pub enum FunctionDef {
447 Expr(FunctionExpression),
448 Conditioned(Vec<ConditionedFunctionExpression>),
453}
454
455
456struct FunctionDefVisitor;
457
458impl<'de> Visitor<'de> for FunctionDefVisitor {
459 type Value = FunctionDef;
460
461 fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
462 formatter.write_str("a function expression or a conditioned function")
463 }
464
465 fn visit_str<E>(self, v: &str) -> Result<Self::Value, E>
466 where
467 E: serde::de::Error,
468 {
469 Ok(FunctionDef::Expr(v.parse().map_err(|e| E::custom(e))?))
470 }
471
472 fn visit_u64<E>(self, v: u64) -> Result<Self::Value, E>
473 where
474 E: serde::de::Error,
475 {
476 Ok(FunctionDef::Expr(FunctionExpression(
477 symrs::Integer::new_box(v as isize),
478 )))
479 }
480
481 fn visit_i64<E>(self, v: i64) -> Result<Self::Value, E>
482 where
483 E: serde::de::Error,
484 {
485 Ok(FunctionDef::Expr(FunctionExpression(
486 symrs::Integer::new_box(v as isize),
487 )))
488 }
489
490 fn visit_seq<A>(self, mut seq: A) -> Result<Self::Value, A::Error>
491 where
492 A: serde::de::SeqAccess<'de>,
493 {
494 let mut res = Vec::new();
495
496 while let Some(item) = seq.next_element::<ConditionedFunctionExpression>()? {
497 res.push(item)
498 }
499
500 Ok(FunctionDef::Conditioned(res))
501 }
502}
503
504impl<'de> Deserialize<'de> for FunctionDef {
505 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
506 where
507 D: serde::Deserializer<'de>,
508 {
509 Ok(deserializer.deserialize_any(FunctionDefVisitor)?)
510 }
511}
512
513#[derive(Clone, Debug, Serialize, Deserialize, JsonSchema)]
514#[serde(untagged)]
515pub enum UnknownProperty {
516 Constant(i64),
517 ConstantFloat(f64),
518 FunctionName(String),
520}
521
522impl Default for UnknownProperty {
523 fn default() -> Self {
524 UnknownProperty::Constant(0)
525 }
526}
527impl UnknownProperty {
528 fn to_function_name(&self) -> String {
529 match self {
530 UnknownProperty::Constant(i) => format!("fn_{}", i.to_string().replace("-", "neg")),
531 UnknownProperty::ConstantFloat(f) => {
532 format!(
533 "fn_{}",
534 f.to_string().replace("-", "neg").replace(".", "dot")
535 )
536 }
537 UnknownProperty::FunctionName(s) => format!("fn_{s}"),
538 }
539 }
540}
541
542#[derive(Clone, Debug, Serialize, Deserialize, JsonSchema)]
545pub struct Unknown {
546 #[serde(default)]
549 pub initial: UnknownProperty,
550
551 #[serde(default)]
554 pub boundary: UnknownProperty,
555
556 pub derivative: Option<Box<Unknown>>,
561}
562
563pub struct ConstantFunction {
564 pub name: String,
565 pub value: String,
566}
567
568impl Unknown {
569 pub fn visit_symbols<E, F: Fn(&str) -> Result<(), E>>(&self, f: F) -> Result<(), E> {
570 if let UnknownProperty::FunctionName(initial) = &self.initial {
571 f(initial)?;
572 }
573 if let UnknownProperty::FunctionName(boundary) = &self.boundary {
574 f(boundary)?
575 }
576
577 if let Some(derivative) = &self.derivative {
578 derivative.visit_symbols(f)?;
579 }
580 Ok(())
581 }
582
583 pub fn has_symbol(&self, s: &str) -> bool {
584 if let UnknownProperty::FunctionName(initial) = &self.initial
585 && initial == s
586 {
587 true
588 } else if let UnknownProperty::FunctionName(boundary) = &self.boundary
589 && boundary == s
590 {
591 true
592 } else if let Some(derivative) = &self.derivative {
593 derivative.has_symbol(s)
594 } else {
595 false
596 }
597 }
598
599 pub fn visit_props<E: StdError, F: FnMut(&UnknownProperty) -> Result<(), E>>(
600 &self,
601 mut f: F,
602 ) -> Result<(), E> {
603 f(&self.initial)?;
604 f(&self.boundary)?;
606 if let Some(derivative) = &self.derivative {
608 derivative.visit_props(f)?;
609 }
610 Ok(())
611 }
612
613 pub fn visit_constants<E: StdError, F: FnMut(ConstantFunction) -> Result<(), E>>(
614 &self,
615 mut f: F,
616 ) -> Result<(), E> {
617 self.visit_props(|p| -> Result<(), E> {
618 match p {
619 UnknownProperty::Constant(i) => f(ConstantFunction {
620 value: i.to_string(),
621 name: p.to_function_name(),
622 }),
623 UnknownProperty::ConstantFloat(float) => f(ConstantFunction {
624 value: float.to_string(),
625 name: p.to_function_name(),
626 }),
627 _ => Ok(()),
628 }
629 })
630 }
631}
632
633#[derive(Error, Debug)]
634pub enum CodeGenError {
635 #[error("failed to construct a block")]
636 BlockError(#[from] BuildingBlockError),
637 #[error("templating error: {0}")]
638 TemplatingError(#[from] tera::Error),
639 #[error("missing solver for unknown: {0}")]
640 MissingSolver(String),
641 #[error("invalid yaml input schema")]
642 InvalidYaml(#[from] serde_yaml::Error),
643 #[error("schema validation failed")]
644 InvalidSchema(#[from] SchemaValidationError),
645 #[error("unknown missing from the schema '{0}'")]
646 UnknownUnknown(String),
647 #[error("missing boundary condition for unknown: {0}")]
648 MissingBoundary(String),
649 #[error("missing derivative for unknown: {0}")]
650 MissingDerivative(String),
651 #[error("invalid expression for function expression: {0}")]
652 InvalidFunctionExpression(ParseExprError),
653 #[error("failed to simplify system before code generation")]
654 SystemSimplificationFailed(#[from] SystemError),
655}
656
657lazy_static! {
658 pub static ref TEMPLATES: Tera = {
659 let mut tera = Tera::default();
660 match tera.add_raw_template(
661 "cpp_source",
662 &include_str!("./input_schema/cpp_source_template.cpp")
663 .replace("/*", "")
664 .replace("*/", ""),
665 ) {
666 Ok(t) => t,
667 Err(e) => {
668 println!("Parsing error(s): {e}");
669 std::process::exit(1);
670 }
671 };
672 match tera.add_raw_template(
673 "cmakelists",
674 &include_str!("./input_schema/deal.ii/CMakeLists.txt"),
675 ) {
676 Ok(t) => t,
677 Err(e) => {
678 println!("Parsing error(s): {e}");
679 std::process::exit(1);
680 }
681 }
682 tera
683 };
684}
685
686#[derive(Error, Debug)]
687pub enum SchemaValidationError {
688 #[error("mesh {0} not found")]
689 MeshNotFound(String),
690 #[error("equation(s) {0} not found")]
691 EquationNotFound(String),
692 #[error("function name not found: {0}")]
693 FunctionNotFound(String),
694}
695
696pub struct CodeGenRes {
697 pub code: String,
698 pub schema: String,
699 pub file_name: String,
700 pub cmakelists: Option<String>,
701}
702
703impl CodeGenRes {
704 pub fn write_to_dir<P: AsRef<Path>>(&self, dir: P) -> Result<(), std::io::Error> {
705 let dir = dir.as_ref();
706
707 fs::create_dir_all(dir)?;
708 let mut code_file = File::create(dir.join(&self.file_name))?;
709 code_file.write_all(self.code.as_bytes())?;
710 if let Some(cmakelists) = &self.cmakelists {
711 let mut cmakelists_file = File::create(dir.join("CMakeLists.txt"))?;
712 cmakelists_file.write_all(cmakelists.as_bytes())?;
713 }
714 let mut schema_file = File::create(dir.join("schema.hecate.yaml"))?;
715 schema_file.write_all(self.schema.as_bytes())?;
716
717 eprintln!("Wrote project files to {}.", dir.display());
718 Ok(())
719 }
720}
721
722impl InputSchema {
723 pub fn validate(&self) -> Result<(), SchemaValidationError> {
724 let Solve {
725 equations,
726 mesh,
727 element: _,
728 time: _,
729 time_step: _,
730 dimension: _,
731 } = &self.solve;
732
733 self.meshes
734 .contains_key(mesh)
735 .then_some(())
736 .ok_or(SchemaValidationError::MeshNotFound(mesh.to_string()))?;
737 let missing_eqs = equations
738 .iter()
739 .filter_map(|e| (!self.equations.contains_key(e)).then_some(e))
740 .collect_vec();
741
742 for unknown in self.unknowns.values() {
743 unknown.visit_symbols(|f_name| -> Result<(), SchemaValidationError> {
744 if !self.functions.contains_key(f_name) {
745 return Err(SchemaValidationError::FunctionNotFound(f_name.to_string()));
746 }
747 Ok(())
748 })?;
749 }
750
751 if missing_eqs.len() > 0 {
752 return Err(SchemaValidationError::EquationNotFound(
753 missing_eqs.into_iter().join(", "),
754 ));
755 }
756
757 Ok(())
761 }
762
763 pub fn generate_sources(&self) -> Result<CodeGenRes, CodeGenError> {
764 let code = self.generate_cpp_sources()?;
765 let mut context = tera::Context::new();
766 context.insert("debug", &self.gen_conf.debug);
767 let cmakelists = Some(Tera::one_off(
768 include_str!("./input_schema/deal.ii/CMakeLists.txt"),
769 &context,
770 false,
771 )?);
772 let schema = serde_yaml::to_string(self)?;
773
774 Ok(CodeGenRes {
775 schema,
776 code,
777 file_name: "main.cpp".to_string(),
778 cmakelists,
779 })
780 }
781
782 pub fn generate_cpp_sources(&self) -> Result<String, CodeGenError> {
783 self.validate()?;
784 let gen_conf = &self.gen_conf;
785 let factory = deal_ii_factory();
786 let mut blocks = BuildingBlockCollector::new(&factory, gen_conf);
787 blocks.newline();
793
794 let Solve {
795 equations,
796 mesh,
797 element,
798 time,
799 time_step,
800 dimension,
801 } = &self.solve;
802
803 let mesh = &self.meshes[mesh];
804
805 let equations = equations
806 .iter()
807 .map(|e| self.equations[e].simplify_with_dimension(*dimension))
808 .collect_vec();
809
810 let unknowns = self.unknowns.keys().map(|s| &s[..]).collect_vec();
811 let mut knowns = Vec::with_capacity(self.functions.len());
812 let mut substitutions = Vec::with_capacity(self.functions.len() + self.unknowns.len());
813
814 let mut used_functions: IndexMap<&str, &FunctionDef> =
816 IndexMap::with_capacity(self.functions.len());
817
818 for (name, f) in &self.functions {
819 if equations.iter().any(|eq| eq.has(&Symbol::new(name)))
820 || self.unknowns.iter().any(|(_, u)| u.has_symbol(name))
821 {
822 used_functions.insert(name, f);
823 }
824 }
825
826 let mut functions: HashMap<&str, String> = HashMap::with_capacity(used_functions.len());
827 for (name, f) in used_functions.into_iter() {
828 let function = format!("fn_{name}");
829 blocks.create(&function, Block::Function(f))?;
830 functions.insert(name, function);
831 }
832
833 for (u, _) in &self.unknowns {
836 let symbol = symbol!(u);
837
838 if equations.iter().any(|e| e.get_ref().has(symbol)) {
839 substitutions.push([symbol.clone_box(), Box::new(Func::new(u, []))]);
840 }
841 }
842
843 for (f, _) in &self.functions {
844 let symbol = symbol!(f);
845
846 if equations.iter().any(|e| e.get_ref().has(symbol)) {
847 knowns.push(&f[..]);
848 substitutions.push([symbol.clone_box(), Box::new(Func::new(f, []))]);
849 }
850 }
851
852 let equations = equations
853 .iter()
854 .map(|e| e.subs(&substitutions).as_eq().unwrap())
855 .collect_vec();
856
857 let system = System::new(unknowns, knowns, equations.iter())
858 .to_first_order_in_time()
859 .time_discretized()
860 .simplified()?
861 .matrixify()
862 .to_crank_nikolson()
863 .to_constant_mesh()
864 .simplify();
865
866 debug!("System:\n{system}");
868
869 let mesh = blocks.insert("mesh", factory.mesh("mesh", mesh.get_ref(), gen_conf)?)?;
870
871 let element = blocks.insert(
872 "element",
873 factory.finite_element("element", element, gen_conf)?,
874 )?;
875
876 let dof_handler = blocks.insert(
877 "dof_handler",
878 factory.dof_handler("dof_handler", &DofHandlerConfig { mesh, element }, gen_conf)?,
879 )?;
880
881 let vector_config = VectorConfig {
882 dof_handler,
883 is_unknown: false,
884 };
885 let unknown_config = VectorConfig {
886 dof_handler,
887 is_unknown: true,
888 };
889
890 let mut vectors: HashMap<&dyn Expr, String> = HashMap::with_capacity(system.num_vectors());
891
892 for (vector, is_unknown) in system.vectors() {
893 let vector_cpp = vector.to_cpp();
894 blocks.insert(
895 &vector_cpp,
896 factory.vector(
897 &vector_cpp,
898 if is_unknown {
899 &unknown_config
900 } else {
901 &vector_config
902 },
903 gen_conf,
904 )?,
905 )?;
906 vectors.insert(vector, vector_cpp.clone());
907 }
908
909 let sparsity_pattern = blocks.insert(
910 "sparsity_pattern",
911 factory.sparsity_pattern(
912 "sparsity_pattern",
913 &SparsityPatternConfig { dof_handler },
914 gen_conf,
915 )?,
916 )?;
917
918 let mut unknowns_matrices: HashMap<&dyn Expr, String> =
919 HashMap::with_capacity(system.unknowns.len());
920
921 let matrix_config = MatrixConfig {
922 sparsity_pattern,
923 dof_handler,
924 };
925 let rhs = blocks.create("rhs", Block::Vector(&vector_config))?;
926 let mut unknown_solvers: HashMap<&dyn Expr, String> = HashMap::new();
927 for unknown in &system.unknowns {
928 let unknown = unknown.get_ref();
929 let unknown_cpp = unknown.to_cpp();
930 let mat_name = format!("matrix_{unknown_cpp}");
931 let unknown_vec = &vectors[&unknown];
932
933 unknowns_matrices.insert(
934 unknown,
935 blocks
936 .create(&mat_name, Block::Matrix(&matrix_config))?
937 .to_string(),
938 );
939
940 unknown_solvers.insert(
941 unknown,
942 blocks
943 .create(
944 &format!("solve_{unknown_cpp}"),
945 Block::SolveUnknown(&SolveUnknownConfig {
946 dof_handler,
947 rhs,
948 unknown_vec,
949 unknown_mat: &mat_name,
950 }),
951 )?
952 .to_string(),
953 );
954 }
955
956 let _laplace_mat = blocks.insert(
957 "laplace_mat",
958 factory.shape_matrix(
959 "laplace_mat",
960 &ShapeMatrixConfig {
961 kind: ShapeMatrix::Laplace,
962 dof_handler,
963 element,
964 matrix_config: &matrix_config,
965 },
966 gen_conf,
967 )?,
968 )?;
969 let _mass_mat = blocks.insert(
970 "mass_mat",
971 factory.shape_matrix(
972 "mass_mat",
973 &ShapeMatrixConfig {
974 kind: ShapeMatrix::Mass,
975 dof_handler,
976 element,
977 matrix_config: &matrix_config,
978 },
979 gen_conf,
980 )?,
981 )?;
982
983 let mut solved_unknowns: HashSet<&dyn Expr> = HashSet::new();
984 let vectors: &Vec<_> = &system.vectors().map(|(v, _is_unknown)| v).collect();
985 let matrixes: &Vec<_> = &system.matrixes().collect();
986
987 let mut constant_functions: HashSet<String> = HashSet::new();
989
990 for (i, equation) in system.eqs_in_solving_order().enumerate() {
992 for unknown in system.equation_lhs_unknowns(equation) {
994 if solved_unknowns.contains(&unknown) {
996 continue;
997 }
998
999 blocks.create(
1001 &format!("equation_{i}"),
1002 Block::EquationSetup(&EquationSetupConfig {
1003 equation,
1004 unknown,
1005 vectors,
1006 matrixes,
1007 }),
1008 )?;
1009
1010 let unknown_cpp = unknown.to_cpp();
1013 let mat_name = format!("matrix_{unknown_cpp}");
1014 let unknown_config = if let Some(captures) = UNKNOWN_DT_RE.captures(&unknown_cpp) {
1015 self.unknowns
1016 .get(&captures[1])
1017 .ok_or_else(|| CodeGenError::UnknownUnknown(captures[1].to_string()))?
1018 .derivative
1019 .as_ref()
1020 .ok_or_else(|| CodeGenError::MissingDerivative(captures[1].to_string()))?
1021 } else {
1022 self.unknowns
1023 .get(&unknown_cpp)
1024 .or_else(|| self.unknowns.get(&unknown_cpp.to_uppercase()))
1025 .ok_or_else(|| CodeGenError::UnknownUnknown(unknown_cpp.to_string()))?
1026 };
1027
1028 let initial = &unknown_config.initial;
1030 blocks.create(
1031 &format!("initial_condition_{unknown_cpp}"),
1032 Block::InitialCondition(&InitialConditionConfig {
1033 dof_handler,
1034 element,
1035 function: &initial.to_function_name(),
1036 target: &format!("{unknown_cpp}_prev"),
1037 }),
1038 )?;
1039
1040 unknown_config.visit_constants(
1042 |ConstantFunction { ref name, value }| -> Result<(), CodeGenError> {
1043 if constant_functions.contains(name) {
1044 return Ok(());
1045 }
1046 constant_functions.insert(name.to_string());
1047 blocks.create(
1048 name,
1049 Block::Function(&FunctionDef::Expr(
1050 value
1051 .parse()
1052 .map_err(|e| CodeGenError::InvalidFunctionExpression(e))?,
1053 )),
1054 )?;
1055 Ok(())
1056 },
1057 )?;
1058
1059 let function = &unknown_config
1061 .boundary
1062 .to_function_name();
1066
1067 blocks.newline();
1069 blocks.create(
1070 &format!("apply_boundary_condition_{unknown_cpp}"),
1071 Block::AppyBoundaryCondition(&ApplyBoundaryConditionConfig {
1072 dof_handler,
1073 function,
1074 matrix: &mat_name,
1075 solution: &unknown_cpp,
1076 rhs,
1077 }),
1078 )?;
1079
1080 blocks.newline();
1082 blocks.call(
1083 unknown_solvers
1084 .get(&unknown)
1085 .ok_or_else(|| CodeGenError::MissingSolver(unknown.str()))?,
1086 &[],
1087 )?;
1088 blocks.newline();
1089 solved_unknowns.insert(unknown);
1090 }
1092 }
1093 blocks.call("output_results", &[])?;
1097 for unknown in &system.unknowns {
1098 let unknown = unknown.to_cpp();
1099 blocks.add_vector_output(&unknown)?;
1100 }
1101
1102 blocks.newline();
1104 blocks.comment("Swap new values with previous values for the next step");
1105 for unknown in &system.unknowns {
1106 let unknown = unknown.to_cpp();
1107 let unknown_prev = format!("{unknown}_prev");
1108 blocks.call("swap", &[&unknown, &unknown_prev])?;
1109 }
1110
1111 let mut context: tera::Context = blocks.collect(dof_handler, sparsity_pattern)?.into();
1113 context.insert("time_start", &time.start.seconds());
1114 context.insert("time_end", &time.end.seconds());
1115 context.insert("time_step", &time_step.seconds());
1116 context.insert("dimension", &dimension);
1117 context.insert("mpi", &self.gen_conf.mpi);
1118
1119 let parameters: indexmap::IndexMap<&String, String> = self
1120 .parameters
1121 .iter()
1122 .map(|(name, value)| (name, format!("{:e}", value.si_value())))
1123 .collect();
1124 context.insert("parameters", ¶meters);
1125
1126 Ok(TEMPLATES.render("cpp_source", &context)?)
1128 }
1129}
1130
1131static UNKNOWN_DT_RE: LazyLock<Regex> =
1132 LazyLock::new(|| Regex::new("^dt_(.+)").expect("valid regex"));
1133
1134#[derive(Clone)]
1135struct BuildingBlockCollector<'fa> {
1136 blocks: IndexMap<String, BuildingBlock>,
1137 additional_names: HashSet<String>,
1138 factory: &'fa BuildingBlockFactory<'fa>,
1139 gen_conf: &'fa GenConfig,
1140}
1141
1142impl From<BuildingBlock> for tera::Context {
1143 fn from(block: BuildingBlock) -> Self {
1144 let mut context = tera::Context::new();
1145 let BuildingBlock {
1146 includes,
1147 data,
1148 setup,
1149 additional_names: _,
1150 constructor,
1151 methods_defs,
1152 methods_impls,
1153 additional_vectors: _,
1154 additional_matrixes: _,
1155 main,
1156 main_setup,
1157 global,
1158 output,
1159 } = block;
1160 context.insert(
1161 "includes",
1162 &includes
1163 .into_iter()
1164 .map(|s| format!("#include <{s}>"))
1165 .join("\n"),
1166 );
1167
1168 macro_rules! insert_lines {
1169 ($key:expr, $source:ident, $indent:literal) => {
1170 context.insert(
1171 $key,
1172 &$source
1173 .into_iter()
1174 .map(|s| format!("{}{s}", " ".repeat($indent)))
1175 .join("\n"),
1176 );
1177 };
1178 }
1179
1180 context.insert("data", &to_cpp_lines(" ", data));
1181 context.insert("setup", &to_cpp_lines(" ", setup));
1182 context.insert("constructors", &constructor);
1183 context.insert("methods_defs", &to_cpp_lines(" ", methods_defs));
1184 context.insert("methods_impls", &methods_impls.into_iter().join("\n"));
1185 insert_lines!("main_setup", main_setup, 2);
1186 insert_lines!("main", main, 4);
1187 insert_lines!("output", output, 2);
1188
1189 context.insert("global", &global.join("\n\n\n"));
1190 context
1191 }
1192}
1193
1194pub fn to_cpp_lines<Lines: IntoIterator<Item = String>>(prefix: &str, lines: Lines) -> String {
1195 lines
1196 .into_iter()
1197 .map(|s| {
1198 if s.trim().starts_with("//") {
1199 format!("{prefix}{s}")
1200 } else {
1201 format!("{prefix}{s};")
1202 }
1203 })
1204 .join("\n")
1205}
1206
1207impl<'fa> BuildingBlockCollector<'fa> {
1208 fn new(factory: &'fa BuildingBlockFactory<'fa>, gen_conf: &'fa GenConfig) -> Self {
1209 BuildingBlockCollector {
1210 blocks: IndexMap::new(),
1211 additional_names: HashSet::new(),
1212 factory,
1213 gen_conf,
1214 }
1215 }
1216
1217 fn insert<'a>(
1218 &mut self,
1219 name: &'a str,
1220 block: BuildingBlock,
1221 ) -> Result<&'a str, BuildingBlockError> {
1222 if self.blocks.contains_key(name) {
1223 dbg!(&block);
1224 Err(BuildingBlockError::BlockAlreadyExists(name.to_string()))?
1225 }
1226 if self.additional_names.contains(name) {
1227 Err(BuildingBlockError::NameAlreadyExists(name.to_string()))?
1228 }
1229 self.additional_names
1230 .extend(block.additional_names.iter().cloned());
1231 self.blocks.insert(name.to_string(), block);
1232 return Ok(name);
1233 }
1234 fn collect(mut self, dof_handler: &str, sparsity_pattern: &str) -> BlockRes {
1235 let mut res = BuildingBlock::new();
1236
1237 let mut additional_blocks: IndexMap<String, BuildingBlock> = IndexMap::new();
1238
1239 let tmp_vector_config = VectorConfig {
1240 dof_handler,
1241 is_unknown: true,
1242 };
1243 let tmp_matrix_config = MatrixConfig {
1244 sparsity_pattern,
1245 dof_handler,
1246 };
1247 let gen_conf = self.gen_conf;
1248 for (
1249 _,
1250 BuildingBlock {
1251 additional_vectors,
1252 additional_matrixes,
1253 ..
1254 },
1255 ) in &self.blocks
1256 {
1257 for vector in additional_vectors {
1258 if additional_blocks.contains_key(vector) {
1259 continue;
1260 }
1261 let block = self.factory.vector(vector, &tmp_vector_config, gen_conf)?;
1262 additional_blocks.insert(vector.to_string(), block);
1263 }
1264 for matrix in additional_matrixes {
1265 if additional_blocks.contains_key(matrix) {
1266 continue;
1267 }
1268 let block = self.factory.matrix(matrix, &tmp_matrix_config, gen_conf)?;
1269 additional_blocks.insert(matrix.to_string(), block);
1270 }
1271 }
1272
1273 self.blocks.extend(additional_blocks);
1274
1275 for (
1276 _,
1277 BuildingBlock {
1278 includes,
1279 data,
1280 setup,
1281 additional_names: _,
1282 constructor,
1283 methods_defs,
1284 methods_impls,
1285 main,
1286 main_setup,
1287 additional_vectors: _,
1288 additional_matrixes: _,
1289 global,
1290 output,
1291 },
1292 ) in self.blocks
1293 {
1294 res.includes.extend(includes);
1295 res.setup.extend(setup);
1296 res.data.extend(data);
1297 res.constructor.extend(constructor);
1298 res.methods_defs.extend(methods_defs);
1299 res.methods_impls.extend(methods_impls);
1300 res.main_setup.extend(main_setup);
1301 res.main.extend(main);
1302 res.global.extend(global);
1303 res.output.extend(output);
1304 }
1305
1306 Ok(res)
1307 }
1308
1309 fn create<'na>(
1310 &mut self,
1311 name: &'na str,
1312 block: Block<'_>,
1313 ) -> Result<&'na str, BuildingBlockError> {
1314 let gen_conf = self.gen_conf;
1315 self.insert(
1316 name,
1317 match block {
1318 Block::Matrix(config) => self.factory.matrix(name, config, gen_conf)?,
1319 Block::Vector(vector_config) => {
1320 self.factory.vector(name, vector_config, gen_conf)?
1321 }
1322 Block::SolveUnknown(solve_unknown_config) => {
1323 self.factory
1324 .solve_unknown(name, solve_unknown_config, gen_conf)?
1325 }
1326 Block::EquationSetup(equation_setup_config) => {
1327 self.factory
1328 .equation_setup(name, equation_setup_config, gen_conf)?
1329 }
1330 Block::Parameter(value) => self.factory.parameter(name, &value, gen_conf)?,
1331 Block::Function(function) => self.factory.function(name, function, gen_conf)?,
1332 Block::AppyBoundaryCondition(config) => self
1333 .factory
1334 .apply_boundary_condition(name, config, gen_conf)?,
1335 Block::InitialCondition(config) => {
1336 self.factory.initial_condition(name, config, gen_conf)?
1337 }
1338 },
1339 )?;
1340 Ok(name)
1341 }
1342
1343 fn call(&mut self, name: &str, args: &[&str]) -> Result<(), BuildingBlockError> {
1344 let block = self.factory.call(name, args)?;
1345 let name = format!("call#{}", self.blocks.len());
1346 self.blocks.insert(name, block);
1347
1348 Ok(())
1349 }
1350
1351 fn newline(&mut self) {
1352 let name = format!("newline#{}", self.blocks.len());
1353 self.blocks.insert(name, self.factory.newline());
1354 }
1355
1356 fn comment(&mut self, content: &str) {
1357 let name = format!("comment#{}", self.blocks.len());
1358 self.blocks.insert(name, self.factory.comment(content));
1359 }
1360
1361 fn add_vector_output(&mut self, unknown: &str) -> Result<(), BuildingBlockError> {
1362 let name = format!("add_vector_output_{unknown}");
1363 self.insert(
1364 &name,
1365 self.factory
1366 .add_vector_output(&name, unknown, self.gen_conf)?,
1367 )?;
1368 Ok(())
1369 }
1370}