hecate/codegen/
input_schema.rs

1use crate::StdError;
2use crate::codegen::building_block::{ApplyBoundaryConditionConfig, InitialConditionConfig};
3use crate::codegen::input_schema::quantity::{NO_REF_QUANTITY_PATTERN, QuantityEnum};
4use crate::codegen::input_schema::unit::format_unit;
5use derive_more::{Deref, DerefMut, From, FromStr, IntoIterator};
6use indexmap::IndexMap as BaseIndexMap;
7use itertools::Itertools;
8use lazy_static::lazy_static;
9use log::debug;
10use regex::Regex;
11use schemars::{JsonSchema, json_schema};
12use serde::de::Error;
13use serde::de::Visitor;
14use serde::ser::SerializeMap;
15use std::collections::{HashMap, HashSet};
16use std::fmt::Debug;
17use std::fs::{self, File};
18use std::hash::Hash;
19use std::io::Write;
20use std::path::Path;
21use std::str::FromStr;
22use std::sync::LazyLock;
23use symrs::ops::ParseExprError;
24use symrs::system::SystemError;
25pub trait RawRepr {
26    fn raw(&self) -> &str;
27}
28
29pub mod mesh;
30pub mod quantity;
31pub mod range;
32mod reference;
33mod unit;
34
35use quantity::{Length, RANGE_PATTERN, Time};
36
37use mesh::MeshEnum;
38use range::Range;
39use serde::{Deserialize, Serialize};
40use tera::Tera;
41use thiserror::Error;
42
43use crate::codegen::building_block::deal_ii_factory;
44use symrs::{Equation, Expr, Func, Symbol, System, symbol};
45
46#[derive(Deref, DerefMut, Deserialize, Serialize, Clone, Debug, IntoIterator, From)]
47#[from(forward)]
48pub struct IndexMap<K, V>(#[into_iterator(owned, ref, ref_mut)] BaseIndexMap<K, V>)
49where
50    K: Eq + Hash;
51
52impl<K, V> IndexMap<K, V>
53where
54    K: Eq + Hash,
55{
56    pub fn new() -> Self {
57        IndexMap(BaseIndexMap::new())
58    }
59    pub fn with_capacity(capacity: usize) -> Self {
60        IndexMap(BaseIndexMap::with_capacity(capacity))
61    }
62}
63
64impl<K, V> JsonSchema for IndexMap<K, V>
65where
66    K: Eq + Hash + JsonSchema,
67    V: JsonSchema,
68{
69    fn schema_name() -> std::borrow::Cow<'static, str> {
70        format!("Map<{}, {}>", K::schema_name(), V::schema_name()).into()
71    }
72
73    fn json_schema(generator: &mut schemars::SchemaGenerator) -> schemars::Schema {
74        HashMap::<K, V>::json_schema(generator)
75    }
76}
77
78use super::building_block::{
79    Block, BlockRes, BuildingBlockFactory, EquationSetupConfig, ShapeMatrix, ShapeMatrixConfig,
80    SolveUnknownConfig,
81};
82use super::{
83    BuildingBlock,
84    building_block::{
85        BuildingBlockError, DofHandlerConfig, MatrixConfig, SparsityPatternConfig, VectorConfig,
86    },
87};
88
89/// # Finite Element
90/// The finite element to use for the mesh.
91#[derive(Clone, Debug, Serialize, Deserialize, JsonSchema, Default)]
92pub enum FiniteElement {
93    Q1,
94    #[default]
95    Q2,
96    Q3,
97}
98
99/// # Solve
100/// The equation(s) to solve and the mesh to use.
101#[derive(Clone, Debug, Serialize, Deserialize, JsonSchema)]
102pub struct Solve {
103    /// # Equations
104    /// The equation(s) to solve
105    pub equations: Vec<String>,
106
107    ///# Mesh
108    /// The mesh to use
109    pub mesh: String,
110
111    pub element: FiniteElement,
112
113    /// # Dimension
114    /// The dimension of the problem
115    /// Possible values: 1, 2, 3
116    #[serde(default = "default_dimension")]
117    pub dimension: usize,
118
119    /// # Time
120    /// The time range to solve.
121    #[serde(default = "default_solving_range")]
122    pub time: Range<Time>,
123
124    /// # Time Step
125    /// The time step to use.
126    pub time_step: Time,
127}
128
129fn default_dimension() -> usize {
130    2
131}
132
133fn default_solving_range() -> Range<Time> {
134    "0 .. 5s".parse().unwrap()
135}
136
137/// # Generation Configuration
138/// The configuration for the generation of the code.
139#[derive(Clone, Debug, Serialize, Deserialize, JsonSchema, Default)]
140pub struct GenConfig {
141    /// # MPI
142    /// Whether to generate MPI code.
143    #[serde(default)]
144    pub mpi: bool,
145
146    /// # Matrix Free
147    /// Whether to generate matrix free code.
148    #[serde(default)]
149    pub matrix_free: bool,
150
151    /// # Debug
152    /// Whether to generate debug code.
153    #[serde(default)]
154    pub debug: bool,
155}
156
157#[derive(Clone, Debug, JsonSchema, Serialize, Deserialize)]
158#[serde(rename_all = "snake_case")]
159pub enum QuantityKind {
160    Length,
161    Time,
162    Speed,
163    DiffusionCoefficient,
164    Custom,
165}
166
167impl QuantityKind {
168    fn default_unit(&self) -> Option<&'static str> {
169        match self {
170            QuantityKind::Length => Some("m"),
171            QuantityKind::Time => Some("s"),
172            QuantityKind::Speed => Some("m/s"),
173            QuantityKind::DiffusionCoefficient => Some("m²/s"),
174            QuantityKind::Custom => None,
175        }
176    }
177}
178
179#[derive(Clone, Debug)]
180pub struct Parameter {
181    r#type: QuantityKind,
182    value: f64,
183    unit: Option<String>,
184}
185
186impl Parameter {
187    pub fn value_string(&self) -> String {
188        if let Some(unit) = &self.unit {
189            format!("{} {unit}", self.value)
190        } else {
191            self.value.to_string()
192        }
193    }
194
195    // pub fn si_value(&self) -> f64 {
196    //     let s = self.value_string();
197    //     match self.r#type {
198    //         QuantityKind::Length => s.parse::<Length>(),
199    //         QuantityKind::Time => self.value,
200    //         QuantityKind::Speed => self.value,
201    //         QuantityKind::DiffusionCoefficient => self.value,
202    //         QuantityKind::Custom => self.value,
203    //
204    //     }
205    //
206    //
207    // }
208}
209
210impl Serialize for Parameter {
211    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
212    where
213        S: serde::Serializer,
214    {
215        let mut map = serializer.serialize_map(Some(2))?;
216        map.serialize_entry("type", &self.r#type)?;
217        map.serialize_entry(
218            "value",
219            &(if let Some(unit) = &self.unit {
220                format!("{} {unit}", self.value)
221            } else {
222                self.value.to_string()
223            }),
224        )?;
225        map.end()
226    }
227}
228
229impl JsonSchema for Parameter {
230    fn schema_name() -> std::borrow::Cow<'static, str> {
231        "Parameter".into()
232    }
233
234    fn json_schema(generator: &mut schemars::SchemaGenerator) -> schemars::Schema {
235        json_schema!({
236            "type": "object",
237            "title": "Parameter",
238            "required": ["type", "value"],
239            "description": "A parameter is a quantity with a value and a unit. If no unit is specified, the default unit will be used.",
240            "properties": {
241                "type": QuantityKind::json_schema(generator),
242                "value": {
243                    "oneOf": [
244                        {
245                            "type": "string",
246                            "pattern": NO_REF_QUANTITY_PATTERN,
247                        },
248                        {
249                            "type": "number"
250                        }
251                    ]
252                }
253            }
254        })
255    }
256}
257
258struct DeserializeParameterVisitor;
259
260static NO_REF_QUANTITY_RE: LazyLock<Regex> =
261    LazyLock::new(|| Regex::new(NO_REF_QUANTITY_PATTERN).unwrap());
262
263impl<'de> Visitor<'de> for DeserializeParameterVisitor {
264    type Value = Parameter;
265
266    fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
267        formatter.write_str("a parameter")
268    }
269
270    fn visit_map<A>(self, mut map: A) -> Result<Self::Value, A::Error>
271    where
272        A: serde::de::MapAccess<'de>,
273    {
274        let mut r#type: Option<QuantityKind> = None;
275        let mut value: Option<&str> = None;
276        while let Some(key) = map.next_key()? {
277            match key {
278                "type" => r#type = Some(map.next_value()?),
279                "value" => value = Some(map.next_value()?),
280                field => Err(A::Error::unknown_field(field, &["type", "value"]))?,
281            }
282        }
283        let r#type = r#type.ok_or_else(|| A::Error::missing_field("type"))?;
284        let value = value.ok_or_else(|| A::Error::missing_field("value"))?;
285        let captures = NO_REF_QUANTITY_RE
286            .captures(value)
287            .ok_or_else(|| A::Error::invalid_value(serde::de::Unexpected::Str(value), &self))?;
288        let value = captures[1].trim().parse().map_err(|_| {
289            A::Error::invalid_value(
290                serde::de::Unexpected::Str(&captures[1]),
291                &"a string that can be parsed as a float",
292            )
293        })?;
294        let unit = captures[2].trim();
295        let unit = if unit.is_empty() {
296            r#type.default_unit().map(|u| u.to_owned())
297        } else {
298            Some(format_unit(&captures[2]).map_err(|e| A::Error::custom(e.to_string()))?)
299        };
300
301        Ok(Parameter {
302            r#type,
303            value,
304            unit,
305        })
306    }
307}
308
309impl<'de> Deserialize<'de> for Parameter {
310    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
311    where
312        D: serde::Deserializer<'de>,
313    {
314        deserializer.deserialize_map(DeserializeParameterVisitor)
315    }
316}
317
318/// # Hecate Input Schema
319/// The input schema for Hecate.
320#[derive(Clone, Debug, Serialize, Deserialize, JsonSchema)]
321pub struct InputSchema {
322    #[serde(rename = "generation")]
323    #[serde(default)]
324    pub gen_conf: GenConfig,
325
326    /// # Meshes
327    /// The available meshes.
328    pub meshes: IndexMap<String, MeshEnum>,
329
330    /// # Equations
331    /// The available equations.
332    pub equations: IndexMap<String, Equation>,
333
334    /// # Parameters
335    /// The available parameters.
336    pub parameters: IndexMap<String, QuantityEnum>,
337
338    /// # Unknowns
339    /// The available unknowns.
340    pub unknowns: IndexMap<String, Unknown>,
341
342    /// # Functions
343    /// The available functions.
344    /// They can either be simple function expression, or a list of function expression with conditions.
345    pub functions: IndexMap<String, FunctionDef>,
346
347    pub solve: Solve,
348}
349
350// TODO: ensure this is fine
351// unsafe impl Send for InputSchema {}
352// unsafe impl Sync for InputSchema {}
353
354impl InputSchema {
355    pub fn from_yaml(yaml: &str) -> Result<Self, serde_yaml::Error> {
356        serde_yaml::from_str(yaml)
357    }
358}
359
360#[derive(Clone, Debug, Serialize, JsonSchema)]
361#[serde(untagged)]
362pub enum Condition<T> {
363    Value(T),
364    Range(Range<T>),
365}
366
367struct ConditionVisitor<T> {
368    _marker: std::marker::PhantomData<T>,
369}
370
371impl<'de, T> Visitor<'de> for ConditionVisitor<T>
372where
373    T: FromStr,
374    T::Err: StdError + 'static,
375{
376    type Value = Condition<T>;
377
378    fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
379        formatter.write_str("a value or a range")
380    }
381
382    fn visit_str<E>(self, v: &str) -> Result<Self::Value, E>
383    where
384        E: serde::de::Error,
385    {
386        if v.contains("..") {
387            Ok(Condition::Range(v.parse().map_err(|e| E::custom(e))?))
388        } else {
389            Ok(Condition::Value(v.parse().map_err(|e| E::custom(e))?))
390        }
391    }
392}
393
394impl<'de, T> Deserialize<'de> for Condition<T>
395where
396    T: FromStr,
397    T::Err: StdError + 'static,
398{
399    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
400    where
401        D: serde::Deserializer<'de>,
402    {
403        deserializer.deserialize_str(ConditionVisitor {
404            _marker: std::marker::PhantomData,
405        })
406    }
407}
408
409/// # Function Expression
410/// A function expression.
411/// Available variables are : t, x, y, z.
412/// Math functions such as cosinus or exponentials are available.
413/// They can be called through their cpp names like log for the logarithm.
414#[derive(Clone, Debug, Serialize, Deserialize, JsonSchema, FromStr, DerefMut, Deref)]
415pub struct FunctionExpression(Box<dyn Expr>);
416
417#[derive(Clone, Debug, Serialize, Deserialize, JsonSchema)]
418pub struct ConditionedFunctionExpression {
419    pub expr: FunctionExpression,
420    /// # Time Condition
421    /// The time condition for which the function expression is valid.
422    /// It can be none, a value or a range.
423    pub t: Option<Condition<Time>>,
424
425    /// # X Condition
426    /// The x condition for which the function expression is valid.
427    /// It can be none, a value or a range.
428    pub x: Option<Condition<Length>>,
429
430    /// # Y Condition
431    /// The y condition for which the function expression is valid.
432    /// It can be none, a value or a range.
433    pub y: Option<Condition<Length>>,
434
435    /// # Z Condition
436    /// The z condition for which the function expression is valid.
437    /// It can be none, a value or a range.
438    pub z: Option<Condition<Length>>,
439}
440
441/// # Function Definition
442/// The definition of a function.
443/// This can be an expression or a conditioned function.
444#[derive(Clone, Debug, Serialize, JsonSchema)]
445#[serde(untagged)]
446pub enum FunctionDef {
447    Expr(FunctionExpression),
448    /// # Conditioned Function
449    /// A function defined as list of function expression with conditions (time range, space range, etc...).
450    /// The function expressions are checked in order. Therefore, in case of an overlap, the first one will be used.
451    /// If no function expressions without conditions are specified, a default value of 0 will be assumed.
452    Conditioned(Vec<ConditionedFunctionExpression>),
453}
454
455
456struct FunctionDefVisitor;
457
458impl<'de> Visitor<'de> for FunctionDefVisitor {
459    type Value = FunctionDef;
460
461    fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
462        formatter.write_str("a function expression or a conditioned function")
463    }
464
465    fn visit_str<E>(self, v: &str) -> Result<Self::Value, E>
466    where
467        E: serde::de::Error,
468    {
469        Ok(FunctionDef::Expr(v.parse().map_err(|e| E::custom(e))?))
470    }
471
472    fn visit_u64<E>(self, v: u64) -> Result<Self::Value, E>
473    where
474        E: serde::de::Error,
475    {
476        Ok(FunctionDef::Expr(FunctionExpression(
477            symrs::Integer::new_box(v as isize),
478        )))
479    }
480
481    fn visit_i64<E>(self, v: i64) -> Result<Self::Value, E>
482    where
483        E: serde::de::Error,
484    {
485        Ok(FunctionDef::Expr(FunctionExpression(
486            symrs::Integer::new_box(v as isize),
487        )))
488    }
489
490    fn visit_seq<A>(self, mut seq: A) -> Result<Self::Value, A::Error>
491    where
492        A: serde::de::SeqAccess<'de>,
493    {
494        let mut res = Vec::new();
495
496        while let Some(item) = seq.next_element::<ConditionedFunctionExpression>()? {
497            res.push(item)
498        }
499
500        Ok(FunctionDef::Conditioned(res))
501    }
502}
503
504impl<'de> Deserialize<'de> for FunctionDef {
505    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
506    where
507        D: serde::Deserializer<'de>,
508    {
509        Ok(deserializer.deserialize_any(FunctionDefVisitor)?)
510    }
511}
512
513#[derive(Clone, Debug, Serialize, Deserialize, JsonSchema)]
514#[serde(untagged)]
515pub enum UnknownProperty {
516    Constant(i64),
517    ConstantFloat(f64),
518    /// Function name referencing one of the globally defined functions.
519    FunctionName(String),
520}
521
522impl Default for UnknownProperty {
523    fn default() -> Self {
524        UnknownProperty::Constant(0)
525    }
526}
527impl UnknownProperty {
528    fn to_function_name(&self) -> String {
529        match self {
530            UnknownProperty::Constant(i) => format!("fn_{}", i.to_string().replace("-", "neg")),
531            UnknownProperty::ConstantFloat(f) => {
532                format!(
533                    "fn_{}",
534                    f.to_string().replace("-", "neg").replace(".", "dot")
535                )
536            }
537            UnknownProperty::FunctionName(s) => format!("fn_{s}"),
538        }
539    }
540}
541
542/// # Unknown
543/// Represents an unknown to be solved in the PDE.
544#[derive(Clone, Debug, Serialize, Deserialize, JsonSchema)]
545pub struct Unknown {
546    /// # Initial Condition
547    /// The initial value of the unknown.
548    #[serde(default)]
549    pub initial: UnknownProperty,
550
551    /// # Boundary Condition
552    /// The boundary condition of the unknown.
553    #[serde(default)]
554    pub boundary: UnknownProperty,
555
556    /// # Time Derivative Conditions
557    /// The time derivative's conditions of the unknown.
558    /// The number of derivative specified should match the max time order of the equations - 1.
559    /// (ie. an equation of order 2 in time needs one derivative specified)
560    pub derivative: Option<Box<Unknown>>,
561}
562
563pub struct ConstantFunction {
564    pub name: String,
565    pub value: String,
566}
567
568impl Unknown {
569    pub fn visit_symbols<E, F: Fn(&str) -> Result<(), E>>(&self, f: F) -> Result<(), E> {
570        if let UnknownProperty::FunctionName(initial) = &self.initial {
571            f(initial)?;
572        }
573        if let UnknownProperty::FunctionName(boundary) = &self.boundary {
574            f(boundary)?
575        }
576
577        if let Some(derivative) = &self.derivative {
578            derivative.visit_symbols(f)?;
579        }
580        Ok(())
581    }
582
583    pub fn has_symbol(&self, s: &str) -> bool {
584        if let UnknownProperty::FunctionName(initial) = &self.initial
585            && initial == s
586        {
587            true
588        } else if let UnknownProperty::FunctionName(boundary) = &self.boundary
589            && boundary == s
590        {
591            true
592        } else if let Some(derivative) = &self.derivative {
593            derivative.has_symbol(s)
594        } else {
595            false
596        }
597    }
598
599    pub fn visit_props<E: StdError, F: FnMut(&UnknownProperty) -> Result<(), E>>(
600        &self,
601        mut f: F,
602    ) -> Result<(), E> {
603        f(&self.initial)?;
604        // if let Some(boundary) = &self.boundary {
605        f(&self.boundary)?;
606        // }
607        if let Some(derivative) = &self.derivative {
608            derivative.visit_props(f)?;
609        }
610        Ok(())
611    }
612
613    pub fn visit_constants<E: StdError, F: FnMut(ConstantFunction) -> Result<(), E>>(
614        &self,
615        mut f: F,
616    ) -> Result<(), E> {
617        self.visit_props(|p| -> Result<(), E> {
618            match p {
619                UnknownProperty::Constant(i) => f(ConstantFunction {
620                    value: i.to_string(),
621                    name: p.to_function_name(),
622                }),
623                UnknownProperty::ConstantFloat(float) => f(ConstantFunction {
624                    value: float.to_string(),
625                    name: p.to_function_name(),
626                }),
627                _ => Ok(()),
628            }
629        })
630    }
631}
632
633#[derive(Error, Debug)]
634pub enum CodeGenError {
635    #[error("failed to construct a block")]
636    BlockError(#[from] BuildingBlockError),
637    #[error("templating error: {0}")]
638    TemplatingError(#[from] tera::Error),
639    #[error("missing solver for unknown: {0}")]
640    MissingSolver(String),
641    #[error("invalid yaml input schema")]
642    InvalidYaml(#[from] serde_yaml::Error),
643    #[error("schema validation failed")]
644    InvalidSchema(#[from] SchemaValidationError),
645    #[error("unknown missing from the schema '{0}'")]
646    UnknownUnknown(String),
647    #[error("missing boundary condition for unknown: {0}")]
648    MissingBoundary(String),
649    #[error("missing derivative for unknown: {0}")]
650    MissingDerivative(String),
651    #[error("invalid expression for function expression: {0}")]
652    InvalidFunctionExpression(ParseExprError),
653    #[error("failed to simplify system before code generation")]
654    SystemSimplificationFailed(#[from] SystemError),
655}
656
657lazy_static! {
658    pub static ref TEMPLATES: Tera = {
659        let mut tera = Tera::default();
660        match tera.add_raw_template(
661            "cpp_source",
662            &include_str!("./input_schema/cpp_source_template.cpp")
663                .replace("/*", "")
664                .replace("*/", ""),
665        ) {
666            Ok(t) => t,
667            Err(e) => {
668                println!("Parsing error(s): {e}");
669                std::process::exit(1);
670            }
671        };
672        match tera.add_raw_template(
673            "cmakelists",
674            &include_str!("./input_schema/deal.ii/CMakeLists.txt"),
675        ) {
676            Ok(t) => t,
677            Err(e) => {
678                println!("Parsing error(s): {e}");
679                std::process::exit(1);
680            }
681        }
682        tera
683    };
684}
685
686#[derive(Error, Debug)]
687pub enum SchemaValidationError {
688    #[error("mesh {0} not found")]
689    MeshNotFound(String),
690    #[error("equation(s) {0} not found")]
691    EquationNotFound(String),
692    #[error("function name not found: {0}")]
693    FunctionNotFound(String),
694}
695
696pub struct CodeGenRes {
697    pub code: String,
698    pub schema: String,
699    pub file_name: String,
700    pub cmakelists: Option<String>,
701}
702
703impl CodeGenRes {
704    pub fn write_to_dir<P: AsRef<Path>>(&self, dir: P) -> Result<(), std::io::Error> {
705        let dir = dir.as_ref();
706
707        fs::create_dir_all(dir)?;
708        let mut code_file = File::create(dir.join(&self.file_name))?;
709        code_file.write_all(self.code.as_bytes())?;
710        if let Some(cmakelists) = &self.cmakelists {
711            let mut cmakelists_file = File::create(dir.join("CMakeLists.txt"))?;
712            cmakelists_file.write_all(cmakelists.as_bytes())?;
713        }
714        let mut schema_file = File::create(dir.join("schema.hecate.yaml"))?;
715        schema_file.write_all(self.schema.as_bytes())?;
716
717        eprintln!("Wrote project files to {}.", dir.display());
718        Ok(())
719    }
720}
721
722impl InputSchema {
723    pub fn validate(&self) -> Result<(), SchemaValidationError> {
724        let Solve {
725            equations,
726            mesh,
727            element: _,
728            time: _,
729            time_step: _,
730            dimension: _,
731        } = &self.solve;
732
733        self.meshes
734            .contains_key(mesh)
735            .then_some(())
736            .ok_or(SchemaValidationError::MeshNotFound(mesh.to_string()))?;
737        let missing_eqs = equations
738            .iter()
739            .filter_map(|e| (!self.equations.contains_key(e)).then_some(e))
740            .collect_vec();
741
742        for unknown in self.unknowns.values() {
743            unknown.visit_symbols(|f_name| -> Result<(), SchemaValidationError> {
744                if !self.functions.contains_key(f_name) {
745                    return Err(SchemaValidationError::FunctionNotFound(f_name.to_string()));
746                }
747                Ok(())
748            })?;
749        }
750
751        if missing_eqs.len() > 0 {
752            return Err(SchemaValidationError::EquationNotFound(
753                missing_eqs.into_iter().join(", "),
754            ));
755        }
756
757        // TODO validate all needed symbols by the equations are present
758        // either as parameter of function
759
760        Ok(())
761    }
762
763    pub fn generate_sources(&self) -> Result<CodeGenRes, CodeGenError> {
764        let code = self.generate_cpp_sources()?;
765        let mut context = tera::Context::new();
766        context.insert("debug", &self.gen_conf.debug);
767        let cmakelists = Some(Tera::one_off(
768            include_str!("./input_schema/deal.ii/CMakeLists.txt"),
769            &context,
770            false,
771        )?);
772        let schema = serde_yaml::to_string(self)?;
773
774        Ok(CodeGenRes {
775            schema,
776            code,
777            file_name: "main.cpp".to_string(),
778            cmakelists,
779        })
780    }
781
782    pub fn generate_cpp_sources(&self) -> Result<String, CodeGenError> {
783        self.validate()?;
784        let gen_conf = &self.gen_conf;
785        let factory = deal_ii_factory();
786        let mut blocks = BuildingBlockCollector::new(&factory, gen_conf);
787        // blocks.newline();
788        // blocks.comment("Parameters");
789        // for (name, quantity) in &self.parameters {
790        //     blocks.create(name, Block::Parameter(quantity.si_value()))?;
791        // }
792        blocks.newline();
793
794        let Solve {
795            equations,
796            mesh,
797            element,
798            time,
799            time_step,
800            dimension,
801        } = &self.solve;
802
803        let mesh = &self.meshes[mesh];
804
805        let equations = equations
806            .iter()
807            .map(|e| self.equations[e].simplify_with_dimension(*dimension))
808            .collect_vec();
809
810        let unknowns = self.unknowns.keys().map(|s| &s[..]).collect_vec();
811        let mut knowns = Vec::with_capacity(self.functions.len());
812        let mut substitutions = Vec::with_capacity(self.functions.len() + self.unknowns.len());
813
814        // Identify functions used for the current problem
815        let mut used_functions: IndexMap<&str, &FunctionDef> =
816            IndexMap::with_capacity(self.functions.len());
817
818        for (name, f) in &self.functions {
819            if equations.iter().any(|eq| eq.has(&Symbol::new(name)))
820                || self.unknowns.iter().any(|(_, u)| u.has_symbol(name))
821            {
822                used_functions.insert(name, f);
823            }
824        }
825
826        let mut functions: HashMap<&str, String> = HashMap::with_capacity(used_functions.len());
827        for (name, f) in used_functions.into_iter() {
828            let function = format!("fn_{name}");
829            blocks.create(&function, Block::Function(f))?;
830            functions.insert(name, function);
831        }
832
833        // TODO: evaluate functions in expression setup code
834
835        for (u, _) in &self.unknowns {
836            let symbol = symbol!(u);
837
838            if equations.iter().any(|e| e.get_ref().has(symbol)) {
839                substitutions.push([symbol.clone_box(), Box::new(Func::new(u, []))]);
840            }
841        }
842
843        for (f, _) in &self.functions {
844            let symbol = symbol!(f);
845
846            if equations.iter().any(|e| e.get_ref().has(symbol)) {
847                knowns.push(&f[..]);
848                substitutions.push([symbol.clone_box(), Box::new(Func::new(f, []))]);
849            }
850        }
851
852        let equations = equations
853            .iter()
854            .map(|e| e.subs(&substitutions).as_eq().unwrap())
855            .collect_vec();
856
857        let system = System::new(unknowns, knowns, equations.iter())
858            .to_first_order_in_time()
859            .time_discretized()
860            .simplified()?
861            .matrixify()
862            .to_crank_nikolson()
863            .to_constant_mesh()
864            .simplify();
865
866        // blocks.comment(&format!("/*\n{system}\n*/\n"));
867        debug!("System:\n{system}");
868
869        let mesh = blocks.insert("mesh", factory.mesh("mesh", mesh.get_ref(), gen_conf)?)?;
870
871        let element = blocks.insert(
872            "element",
873            factory.finite_element("element", element, gen_conf)?,
874        )?;
875
876        let dof_handler = blocks.insert(
877            "dof_handler",
878            factory.dof_handler("dof_handler", &DofHandlerConfig { mesh, element }, gen_conf)?,
879        )?;
880
881        let vector_config = VectorConfig {
882            dof_handler,
883            is_unknown: false,
884        };
885        let unknown_config = VectorConfig {
886            dof_handler,
887            is_unknown: true,
888        };
889
890        let mut vectors: HashMap<&dyn Expr, String> = HashMap::with_capacity(system.num_vectors());
891
892        for (vector, is_unknown) in system.vectors() {
893            let vector_cpp = vector.to_cpp();
894            blocks.insert(
895                &vector_cpp,
896                factory.vector(
897                    &vector_cpp,
898                    if is_unknown {
899                        &unknown_config
900                    } else {
901                        &vector_config
902                    },
903                    gen_conf,
904                )?,
905            )?;
906            vectors.insert(vector, vector_cpp.clone());
907        }
908
909        let sparsity_pattern = blocks.insert(
910            "sparsity_pattern",
911            factory.sparsity_pattern(
912                "sparsity_pattern",
913                &SparsityPatternConfig { dof_handler },
914                gen_conf,
915            )?,
916        )?;
917
918        let mut unknowns_matrices: HashMap<&dyn Expr, String> =
919            HashMap::with_capacity(system.unknowns.len());
920
921        let matrix_config = MatrixConfig {
922            sparsity_pattern,
923            dof_handler,
924        };
925        let rhs = blocks.create("rhs", Block::Vector(&vector_config))?;
926        let mut unknown_solvers: HashMap<&dyn Expr, String> = HashMap::new();
927        for unknown in &system.unknowns {
928            let unknown = unknown.get_ref();
929            let unknown_cpp = unknown.to_cpp();
930            let mat_name = format!("matrix_{unknown_cpp}");
931            let unknown_vec = &vectors[&unknown];
932
933            unknowns_matrices.insert(
934                unknown,
935                blocks
936                    .create(&mat_name, Block::Matrix(&matrix_config))?
937                    .to_string(),
938            );
939
940            unknown_solvers.insert(
941                unknown,
942                blocks
943                    .create(
944                        &format!("solve_{unknown_cpp}"),
945                        Block::SolveUnknown(&SolveUnknownConfig {
946                            dof_handler,
947                            rhs,
948                            unknown_vec,
949                            unknown_mat: &mat_name,
950                        }),
951                    )?
952                    .to_string(),
953            );
954        }
955
956        let _laplace_mat = blocks.insert(
957            "laplace_mat",
958            factory.shape_matrix(
959                "laplace_mat",
960                &ShapeMatrixConfig {
961                    kind: ShapeMatrix::Laplace,
962                    dof_handler,
963                    element,
964                    matrix_config: &matrix_config,
965                },
966                gen_conf,
967            )?,
968        )?;
969        let _mass_mat = blocks.insert(
970            "mass_mat",
971            factory.shape_matrix(
972                "mass_mat",
973                &ShapeMatrixConfig {
974                    kind: ShapeMatrix::Mass,
975                    dof_handler,
976                    element,
977                    matrix_config: &matrix_config,
978                },
979                gen_conf,
980            )?,
981        )?;
982
983        let mut solved_unknowns: HashSet<&dyn Expr> = HashSet::new();
984        let vectors: &Vec<_> = &system.vectors().map(|(v, _is_unknown)| v).collect();
985        let matrixes: &Vec<_> = &system.matrixes().collect();
986
987        /* Create a set to create identic constant functions only once */
988        let mut constant_functions: HashSet<String> = HashSet::new();
989
990        // Solve equations
991        for (i, equation) in system.eqs_in_solving_order().enumerate() {
992            // Iterate through unknowns in the equation to find the unknown to solve for (TODO: for now only one unknown per equation is supported)
993            for unknown in system.equation_lhs_unknowns(equation) {
994                // Solve unknowns only once
995                if solved_unknowns.contains(&unknown) {
996                    continue;
997                }
998
999                // Setup equation for solving the unknown
1000                blocks.create(
1001                    &format!("equation_{i}"),
1002                    Block::EquationSetup(&EquationSetupConfig {
1003                        equation,
1004                        unknown,
1005                        vectors,
1006                        matrixes,
1007                    }),
1008                )?;
1009
1010                // Begin boundary condition
1011                // Retrive initial and boundary conditions for the unknown
1012                let unknown_cpp = unknown.to_cpp();
1013                let mat_name = format!("matrix_{unknown_cpp}");
1014                let unknown_config = if let Some(captures) = UNKNOWN_DT_RE.captures(&unknown_cpp) {
1015                    self.unknowns
1016                        .get(&captures[1])
1017                        .ok_or_else(|| CodeGenError::UnknownUnknown(captures[1].to_string()))?
1018                        .derivative
1019                        .as_ref()
1020                        .ok_or_else(|| CodeGenError::MissingDerivative(captures[1].to_string()))?
1021                } else {
1022                    self.unknowns
1023                        .get(&unknown_cpp)
1024                        .or_else(|| self.unknowns.get(&unknown_cpp.to_uppercase()))
1025                        .ok_or_else(|| CodeGenError::UnknownUnknown(unknown_cpp.to_string()))?
1026                };
1027
1028                // Setup initial values
1029                let initial = &unknown_config.initial;
1030                blocks.create(
1031                    &format!("initial_condition_{unknown_cpp}"),
1032                    Block::InitialCondition(&InitialConditionConfig {
1033                        dof_handler,
1034                        element,
1035                        function: &initial.to_function_name(),
1036                        target: &format!("{unknown_cpp}_prev"),
1037                    }),
1038                )?;
1039
1040                // Create constant functions associated needed by the unknown
1041                unknown_config.visit_constants(
1042                    |ConstantFunction { ref name, value }| -> Result<(), CodeGenError> {
1043                        if constant_functions.contains(name) {
1044                            return Ok(());
1045                        }
1046                        constant_functions.insert(name.to_string());
1047                        blocks.create(
1048                            name,
1049                            Block::Function(&FunctionDef::Expr(
1050                                value
1051                                    .parse()
1052                                    .map_err(|e| CodeGenError::InvalidFunctionExpression(e))?,
1053                            )),
1054                        )?;
1055                        Ok(())
1056                    },
1057                )?;
1058
1059                // Retrive boundary condition function
1060                let function = &unknown_config
1061                    .boundary
1062                    // .as_ref()
1063                    // .or_else(|| )
1064                    // .ok_or_else(|| CodeGenError::MissingBoundary(unknown_cpp.to_string()))?
1065                    .to_function_name();
1066
1067                // Apply boundary condition
1068                blocks.newline();
1069                blocks.create(
1070                    &format!("apply_boundary_condition_{unknown_cpp}"),
1071                    Block::AppyBoundaryCondition(&ApplyBoundaryConditionConfig {
1072                        dof_handler,
1073                        function,
1074                        matrix: &mat_name,
1075                        solution: &unknown_cpp,
1076                        rhs,
1077                    }),
1078                )?;
1079
1080                // Solve equation for unknown
1081                blocks.newline();
1082                blocks.call(
1083                    unknown_solvers
1084                        .get(&unknown)
1085                        .ok_or_else(|| CodeGenError::MissingSolver(unknown.str()))?,
1086                    &[],
1087                )?;
1088                blocks.newline();
1089                solved_unknowns.insert(unknown);
1090                // End boundary condition
1091            }
1092        }
1093        // End Solve Equations
1094
1095        // Output results
1096        blocks.call("output_results", &[])?;
1097        for unknown in &system.unknowns {
1098            let unknown = unknown.to_cpp();
1099            blocks.add_vector_output(&unknown)?;
1100        }
1101
1102        // Swap new values with previous values to move on to the next step
1103        blocks.newline();
1104        blocks.comment("Swap new values with previous values for the next step");
1105        for unknown in &system.unknowns {
1106            let unknown = unknown.to_cpp();
1107            let unknown_prev = format!("{unknown}_prev");
1108            blocks.call("swap", &[&unknown, &unknown_prev])?;
1109        }
1110
1111        // Setup context for filling the template
1112        let mut context: tera::Context = blocks.collect(dof_handler, sparsity_pattern)?.into();
1113        context.insert("time_start", &time.start.seconds());
1114        context.insert("time_end", &time.end.seconds());
1115        context.insert("time_step", &time_step.seconds());
1116        context.insert("dimension", &dimension);
1117        context.insert("mpi", &self.gen_conf.mpi);
1118
1119        let parameters: indexmap::IndexMap<&String, String> = self
1120            .parameters
1121            .iter()
1122            .map(|(name, value)| (name, format!("{:e}", value.si_value())))
1123            .collect();
1124        context.insert("parameters", &parameters);
1125
1126        // Fill the template with the context
1127        Ok(TEMPLATES.render("cpp_source", &context)?)
1128    }
1129}
1130
1131static UNKNOWN_DT_RE: LazyLock<Regex> =
1132    LazyLock::new(|| Regex::new("^dt_(.+)").expect("valid regex"));
1133
1134#[derive(Clone)]
1135struct BuildingBlockCollector<'fa> {
1136    blocks: IndexMap<String, BuildingBlock>,
1137    additional_names: HashSet<String>,
1138    factory: &'fa BuildingBlockFactory<'fa>,
1139    gen_conf: &'fa GenConfig,
1140}
1141
1142impl From<BuildingBlock> for tera::Context {
1143    fn from(block: BuildingBlock) -> Self {
1144        let mut context = tera::Context::new();
1145        let BuildingBlock {
1146            includes,
1147            data,
1148            setup,
1149            additional_names: _,
1150            constructor,
1151            methods_defs,
1152            methods_impls,
1153            additional_vectors: _,
1154            additional_matrixes: _,
1155            main,
1156            main_setup,
1157            global,
1158            output,
1159        } = block;
1160        context.insert(
1161            "includes",
1162            &includes
1163                .into_iter()
1164                .map(|s| format!("#include <{s}>"))
1165                .join("\n"),
1166        );
1167
1168        macro_rules! insert_lines {
1169            ($key:expr, $source:ident, $indent:literal) => {
1170                context.insert(
1171                    $key,
1172                    &$source
1173                        .into_iter()
1174                        .map(|s| format!("{}{s}", " ".repeat($indent)))
1175                        .join("\n"),
1176                );
1177            };
1178        }
1179
1180        context.insert("data", &to_cpp_lines("  ", data));
1181        context.insert("setup", &to_cpp_lines("    ", setup));
1182        context.insert("constructors", &constructor);
1183        context.insert("methods_defs", &to_cpp_lines("  ", methods_defs));
1184        context.insert("methods_impls", &methods_impls.into_iter().join("\n"));
1185        insert_lines!("main_setup", main_setup, 2);
1186        insert_lines!("main", main, 4);
1187        insert_lines!("output", output, 2);
1188
1189        context.insert("global", &global.join("\n\n\n"));
1190        context
1191    }
1192}
1193
1194pub fn to_cpp_lines<Lines: IntoIterator<Item = String>>(prefix: &str, lines: Lines) -> String {
1195    lines
1196        .into_iter()
1197        .map(|s| {
1198            if s.trim().starts_with("//") {
1199                format!("{prefix}{s}")
1200            } else {
1201                format!("{prefix}{s};")
1202            }
1203        })
1204        .join("\n")
1205}
1206
1207impl<'fa> BuildingBlockCollector<'fa> {
1208    fn new(factory: &'fa BuildingBlockFactory<'fa>, gen_conf: &'fa GenConfig) -> Self {
1209        BuildingBlockCollector {
1210            blocks: IndexMap::new(),
1211            additional_names: HashSet::new(),
1212            factory,
1213            gen_conf,
1214        }
1215    }
1216
1217    fn insert<'a>(
1218        &mut self,
1219        name: &'a str,
1220        block: BuildingBlock,
1221    ) -> Result<&'a str, BuildingBlockError> {
1222        if self.blocks.contains_key(name) {
1223            dbg!(&block);
1224            Err(BuildingBlockError::BlockAlreadyExists(name.to_string()))?
1225        }
1226        if self.additional_names.contains(name) {
1227            Err(BuildingBlockError::NameAlreadyExists(name.to_string()))?
1228        }
1229        self.additional_names
1230            .extend(block.additional_names.iter().cloned());
1231        self.blocks.insert(name.to_string(), block);
1232        return Ok(name);
1233    }
1234    fn collect(mut self, dof_handler: &str, sparsity_pattern: &str) -> BlockRes {
1235        let mut res = BuildingBlock::new();
1236
1237        let mut additional_blocks: IndexMap<String, BuildingBlock> = IndexMap::new();
1238
1239        let tmp_vector_config = VectorConfig {
1240            dof_handler,
1241            is_unknown: true,
1242        };
1243        let tmp_matrix_config = MatrixConfig {
1244            sparsity_pattern,
1245            dof_handler,
1246        };
1247        let gen_conf = self.gen_conf;
1248        for (
1249            _,
1250            BuildingBlock {
1251                additional_vectors,
1252                additional_matrixes,
1253                ..
1254            },
1255        ) in &self.blocks
1256        {
1257            for vector in additional_vectors {
1258                if additional_blocks.contains_key(vector) {
1259                    continue;
1260                }
1261                let block = self.factory.vector(vector, &tmp_vector_config, gen_conf)?;
1262                additional_blocks.insert(vector.to_string(), block);
1263            }
1264            for matrix in additional_matrixes {
1265                if additional_blocks.contains_key(matrix) {
1266                    continue;
1267                }
1268                let block = self.factory.matrix(matrix, &tmp_matrix_config, gen_conf)?;
1269                additional_blocks.insert(matrix.to_string(), block);
1270            }
1271        }
1272
1273        self.blocks.extend(additional_blocks);
1274
1275        for (
1276            _,
1277            BuildingBlock {
1278                includes,
1279                data,
1280                setup,
1281                additional_names: _,
1282                constructor,
1283                methods_defs,
1284                methods_impls,
1285                main,
1286                main_setup,
1287                additional_vectors: _,
1288                additional_matrixes: _,
1289                global,
1290                output,
1291            },
1292        ) in self.blocks
1293        {
1294            res.includes.extend(includes);
1295            res.setup.extend(setup);
1296            res.data.extend(data);
1297            res.constructor.extend(constructor);
1298            res.methods_defs.extend(methods_defs);
1299            res.methods_impls.extend(methods_impls);
1300            res.main_setup.extend(main_setup);
1301            res.main.extend(main);
1302            res.global.extend(global);
1303            res.output.extend(output);
1304        }
1305
1306        Ok(res)
1307    }
1308
1309    fn create<'na>(
1310        &mut self,
1311        name: &'na str,
1312        block: Block<'_>,
1313    ) -> Result<&'na str, BuildingBlockError> {
1314        let gen_conf = self.gen_conf;
1315        self.insert(
1316            name,
1317            match block {
1318                Block::Matrix(config) => self.factory.matrix(name, config, gen_conf)?,
1319                Block::Vector(vector_config) => {
1320                    self.factory.vector(name, vector_config, gen_conf)?
1321                }
1322                Block::SolveUnknown(solve_unknown_config) => {
1323                    self.factory
1324                        .solve_unknown(name, solve_unknown_config, gen_conf)?
1325                }
1326                Block::EquationSetup(equation_setup_config) => {
1327                    self.factory
1328                        .equation_setup(name, equation_setup_config, gen_conf)?
1329                }
1330                Block::Parameter(value) => self.factory.parameter(name, &value, gen_conf)?,
1331                Block::Function(function) => self.factory.function(name, function, gen_conf)?,
1332                Block::AppyBoundaryCondition(config) => self
1333                    .factory
1334                    .apply_boundary_condition(name, config, gen_conf)?,
1335                Block::InitialCondition(config) => {
1336                    self.factory.initial_condition(name, config, gen_conf)?
1337                }
1338            },
1339        )?;
1340        Ok(name)
1341    }
1342
1343    fn call(&mut self, name: &str, args: &[&str]) -> Result<(), BuildingBlockError> {
1344        let block = self.factory.call(name, args)?;
1345        let name = format!("call#{}", self.blocks.len());
1346        self.blocks.insert(name, block);
1347
1348        Ok(())
1349    }
1350
1351    fn newline(&mut self) {
1352        let name = format!("newline#{}", self.blocks.len());
1353        self.blocks.insert(name, self.factory.newline());
1354    }
1355
1356    fn comment(&mut self, content: &str) {
1357        let name = format!("comment#{}", self.blocks.len());
1358        self.blocks.insert(name, self.factory.comment(content));
1359    }
1360
1361    fn add_vector_output(&mut self, unknown: &str) -> Result<(), BuildingBlockError> {
1362        let name = format!("add_vector_output_{unknown}");
1363        self.insert(
1364            &name,
1365            self.factory
1366                .add_vector_output(&name, unknown, self.gen_conf)?,
1367        )?;
1368        Ok(())
1369    }
1370}