Skip to main content

pharmsol_dsl/
execution.rs

1use std::collections::BTreeMap;
2use std::fmt;
3use std::sync::Arc;
4
5use crate::{
6    AnalyticalKernel, ConstValue, CovariateInterpolation, Diagnostic, DiagnosticPhase,
7    DiagnosticReport, MathIntrinsic, ModelKind, RouteKind, RoutePropertyKind, Span, Symbol,
8    SymbolId, SymbolKind, SymbolType, TypedAssignTargetKind, TypedBinaryOp, TypedCall, TypedExpr,
9    TypedExprKind, TypedModel, TypedModule, TypedRangeExpr, TypedStatePlace, TypedStatementBlock,
10    TypedStmt, TypedStmtKind, TypedUnaryOp, ValueType, DSL_LOWERING_GENERIC,
11};
12
13pub fn lower_typed_module(module: &TypedModule) -> Result<ExecutionModule, LoweringError> {
14    let mut models = Vec::with_capacity(module.models.len());
15    for model in &module.models {
16        models.push(lower_typed_model(model)?);
17    }
18    Ok(ExecutionModule {
19        models,
20        span: module.span,
21    })
22}
23
24pub fn lower_typed_model(model: &TypedModel) -> Result<ExecutionModel, LoweringError> {
25    ExecutionLowerer::new(model)?.lower()
26}
27
28#[derive(Debug, Clone, PartialEq)]
29pub struct ExecutionModule {
30    pub models: Vec<ExecutionModel>,
31    pub span: Span,
32}
33
34#[derive(Debug, Clone, PartialEq)]
35pub struct ExecutionModel {
36    pub name: String,
37    pub kind: ModelKind,
38    pub metadata: ExecutionMetadata,
39    pub abi: ExecutionAbi,
40    pub kernels: Vec<ExecutionKernel>,
41    pub span: Span,
42}
43
44impl ExecutionModel {
45    pub fn kernel(&self, role: KernelRole) -> Option<&ExecutionKernel> {
46        self.kernels.iter().find(|kernel| kernel.role == role)
47    }
48}
49
50#[derive(Debug, Clone, PartialEq)]
51pub struct ExecutionMetadata {
52    pub constants: Vec<ExecutionConstant>,
53    pub parameters: Vec<ExecutionSlot>,
54    pub covariates: Vec<ExecutionCovariate>,
55    pub states: Vec<ExecutionState>,
56    pub routes: Vec<ExecutionRoute>,
57    pub derived: Vec<ExecutionSlot>,
58    pub outputs: Vec<ExecutionSlot>,
59    pub particles: Option<usize>,
60    pub analytical: Option<AnalyticalKernel>,
61}
62
63#[derive(Debug, Clone, PartialEq)]
64pub struct ExecutionConstant {
65    pub symbol: SymbolId,
66    pub name: String,
67    pub value: ConstValue,
68    pub span: Span,
69}
70
71#[derive(Debug, Clone, PartialEq)]
72pub struct ExecutionSlot {
73    pub symbol: SymbolId,
74    pub name: String,
75    pub index: usize,
76    pub span: Span,
77}
78
79#[derive(Debug, Clone, PartialEq)]
80pub struct ExecutionCovariate {
81    pub symbol: SymbolId,
82    pub name: String,
83    pub index: usize,
84    pub interpolation: Option<CovariateInterpolation>,
85    pub span: Span,
86}
87
88#[derive(Debug, Clone, PartialEq)]
89pub struct ExecutionState {
90    pub symbol: SymbolId,
91    pub name: String,
92    pub offset: usize,
93    pub len: usize,
94    pub span: Span,
95}
96
97#[derive(Debug, Clone, PartialEq)]
98pub struct ExecutionRoute {
99    pub symbol: SymbolId,
100    pub name: String,
101    pub declaration_index: usize,
102    pub index: usize,
103    pub kind: Option<RouteKind>,
104    pub destination: RouteDestination,
105    pub has_lag: bool,
106    pub has_bioavailability: bool,
107    pub span: Span,
108}
109
110#[derive(Debug, Clone, PartialEq)]
111pub struct RouteDestination {
112    pub state: SymbolId,
113    pub state_name: String,
114    pub state_offset: usize,
115    pub span: Span,
116}
117
118#[derive(Debug, Clone, PartialEq, Eq)]
119pub struct ExecutionAbi {
120    pub scalar: ScalarAbi,
121    pub calling_convention: CallingConvention,
122    pub parameter_buffer: DenseBufferLayout,
123    pub covariate_buffer: DenseBufferLayout,
124    pub state_buffer: DenseBufferLayout,
125    pub derived_buffer: DenseBufferLayout,
126    pub output_buffer: DenseBufferLayout,
127    pub route_buffer: DenseBufferLayout,
128}
129
130#[derive(Debug, Clone, Copy, PartialEq, Eq)]
131pub enum ScalarAbi {
132    F64,
133}
134
135#[derive(Debug, Clone, Copy, PartialEq, Eq)]
136pub enum CallingConvention {
137    DenseF64Buffers,
138}
139
140#[derive(Debug, Clone, PartialEq, Eq)]
141pub struct DenseBufferLayout {
142    pub kind: BufferKind,
143    pub len: usize,
144    pub slots: Vec<BufferSlot>,
145}
146
147#[derive(Debug, Clone, Copy, PartialEq, Eq)]
148pub enum BufferKind {
149    Parameters,
150    Covariates,
151    States,
152    Derived,
153    Outputs,
154    Routes,
155}
156
157#[derive(Debug, Clone, PartialEq, Eq)]
158pub struct BufferSlot {
159    pub name: String,
160    pub offset: usize,
161    pub len: usize,
162}
163
164#[derive(Debug, Clone, PartialEq)]
165pub struct ExecutionKernel {
166    pub role: KernelRole,
167    pub signature: KernelSignature,
168    pub implementation: KernelImplementation,
169    pub span: Span,
170}
171
172#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)]
173pub enum KernelRole {
174    Derive,
175    Dynamics,
176    Outputs,
177    Init,
178    Drift,
179    Diffusion,
180    RouteLag,
181    RouteBioavailability,
182    Analytical,
183}
184
185#[derive(Debug, Clone, PartialEq, Eq)]
186pub struct KernelSignature {
187    pub args: Vec<KernelArgument>,
188}
189
190#[derive(Debug, Clone, Copy, PartialEq, Eq)]
191pub struct KernelArgument {
192    pub kind: KernelArgumentKind,
193    pub access: KernelAccess,
194}
195
196#[derive(Debug, Clone, Copy, PartialEq, Eq)]
197pub enum KernelArgumentKind {
198    Time,
199    Parameters,
200    Covariates,
201    States,
202    RouteInputs,
203    Derived,
204    Outputs,
205    StateDerivatives,
206    InitialState,
207    StateNoise,
208    RouteLag,
209    RouteBioavailability,
210    AnalyticalState,
211}
212
213#[derive(Debug, Clone, Copy, PartialEq, Eq)]
214pub enum KernelAccess {
215    Input,
216    Output,
217}
218
219#[derive(Debug, Clone, PartialEq)]
220pub enum KernelImplementation {
221    Statements(ExecutionProgram),
222    AnalyticalBuiltin(AnalyticalKernel),
223}
224
225#[derive(Debug, Clone, PartialEq)]
226pub struct ExecutionProgram {
227    pub locals: Vec<ExecutionLocal>,
228    pub body: ExecutionBlock,
229}
230
231#[derive(Debug, Clone, PartialEq)]
232pub struct ExecutionLocal {
233    pub symbol: SymbolId,
234    pub name: String,
235    pub index: usize,
236    pub ty: ValueType,
237    pub kind: SymbolKind,
238    pub span: Span,
239}
240
241#[derive(Debug, Clone, PartialEq)]
242pub struct ExecutionBlock {
243    pub statements: Vec<ExecutionStmt>,
244    pub span: Span,
245}
246
247#[derive(Debug, Clone, PartialEq)]
248pub struct ExecutionStmt {
249    pub kind: ExecutionStmtKind,
250    pub span: Span,
251}
252
253#[derive(Debug, Clone, PartialEq)]
254pub enum ExecutionStmtKind {
255    Let(ExecutionLetStmt),
256    Assign(ExecutionAssignStmt),
257    If(ExecutionIfStmt),
258    For(ExecutionForStmt),
259}
260
261#[derive(Debug, Clone, PartialEq)]
262pub struct ExecutionLetStmt {
263    pub local: usize,
264    pub value: ExecutionExpr,
265}
266
267#[derive(Debug, Clone, PartialEq)]
268pub struct ExecutionAssignStmt {
269    pub target: ExecutionTarget,
270    pub value: ExecutionExpr,
271}
272
273#[derive(Debug, Clone, PartialEq)]
274pub struct ExecutionIfStmt {
275    pub condition: ExecutionExpr,
276    pub then_branch: Vec<ExecutionStmt>,
277    pub else_branch: Option<Vec<ExecutionStmt>>,
278}
279
280#[derive(Debug, Clone, PartialEq)]
281pub struct ExecutionForStmt {
282    pub local: usize,
283    pub range: ExecutionRange,
284    pub body: Vec<ExecutionStmt>,
285}
286
287#[derive(Debug, Clone, PartialEq)]
288pub struct ExecutionRange {
289    pub start: ExecutionExpr,
290    pub end: ExecutionExpr,
291    pub span: Span,
292}
293
294#[derive(Debug, Clone, PartialEq)]
295pub struct ExecutionTarget {
296    pub kind: ExecutionTargetKind,
297    pub span: Span,
298}
299
300#[derive(Debug, Clone, PartialEq)]
301pub enum ExecutionTargetKind {
302    Derived(usize),
303    Output(usize),
304    StateInit(ExecutionStateRef),
305    StateDerivative(ExecutionStateRef),
306    StateNoise(ExecutionStateRef),
307    RouteLag(usize),
308    RouteBioavailability(usize),
309}
310
311#[derive(Debug, Clone, PartialEq)]
312pub struct ExecutionStateRef {
313    pub symbol: SymbolId,
314    pub base_offset: usize,
315    pub len: usize,
316    pub index: Option<Box<ExecutionExpr>>,
317    pub span: Span,
318}
319
320#[derive(Debug, Clone, PartialEq)]
321pub struct ExecutionExpr {
322    pub kind: ExecutionExprKind,
323    pub ty: ValueType,
324    pub constant: Option<ConstValue>,
325    pub span: Span,
326}
327
328#[derive(Debug, Clone, PartialEq)]
329pub enum ExecutionExprKind {
330    Literal(ConstValue),
331    Load(ExecutionLoad),
332    Unary {
333        op: TypedUnaryOp,
334        expr: Box<ExecutionExpr>,
335    },
336    Binary {
337        op: TypedBinaryOp,
338        lhs: Box<ExecutionExpr>,
339        rhs: Box<ExecutionExpr>,
340    },
341    Call {
342        callee: ExecutionCall,
343        args: Vec<ExecutionExpr>,
344    },
345}
346
347#[derive(Debug, Clone, PartialEq)]
348pub enum ExecutionLoad {
349    Parameter(usize),
350    Covariate(usize),
351    State(ExecutionStateRef),
352    Derived(usize),
353    Local(usize),
354    RouteInput { route: SymbolId, index: usize },
355}
356
357#[derive(Debug, Clone, PartialEq)]
358pub enum ExecutionCall {
359    Math(MathIntrinsic),
360}
361
362#[derive(Clone, PartialEq, Eq)]
363pub struct LoweringError {
364    diagnostic: Box<Diagnostic>,
365    source: Option<Arc<str>>,
366}
367
368impl LoweringError {
369    fn new(message: impl Into<String>, span: Span) -> Self {
370        Self {
371            diagnostic: Box::new(Diagnostic::error(
372                DSL_LOWERING_GENERIC,
373                DiagnosticPhase::Lowering,
374                message,
375                span,
376            )),
377            source: None,
378        }
379    }
380
381    fn with_note(mut self, note: impl Into<String>) -> Self {
382        self.diagnostic.notes.push(note.into());
383        self
384    }
385
386    pub fn diagnostic(&self) -> &Diagnostic {
387        self.diagnostic.as_ref()
388    }
389
390    pub fn into_diagnostic(self) -> Diagnostic {
391        *self.diagnostic
392    }
393
394    pub fn render(&self, src: &str) -> String {
395        self.diagnostic.render(src)
396    }
397
398    pub fn diagnostic_report(&self, source_name: impl Into<String>) -> DiagnosticReport {
399        DiagnosticReport::from_diagnostics(
400            source_name,
401            self.source(),
402            std::slice::from_ref(self.diagnostic.as_ref()),
403        )
404    }
405
406    pub fn with_source(mut self, source: impl Into<Arc<str>>) -> Self {
407        self.source = Some(source.into());
408        self
409    }
410
411    pub fn source(&self) -> Option<&str> {
412        self.source.as_deref()
413    }
414}
415
416impl fmt::Debug for LoweringError {
417    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
418        fmt::Display::fmt(self, f)
419    }
420}
421
422impl fmt::Display for LoweringError {
423    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
424        if let Some(source) = self.source() {
425            return f.write_str(&self.render(source));
426        }
427        let span = self.diagnostic.primary_span();
428        write!(
429            f,
430            "{} at bytes {}..{}",
431            self.diagnostic.message, span.start, span.end
432        )
433    }
434}
435
436impl std::error::Error for LoweringError {}
437
438struct ExecutionLowerer<'a> {
439    model: &'a TypedModel,
440    metadata: ExecutionMetadata,
441    symbol_map: BTreeMap<SymbolId, &'a Symbol>,
442    parameter_slots: BTreeMap<SymbolId, usize>,
443    covariate_slots: BTreeMap<SymbolId, usize>,
444    state_slots: BTreeMap<SymbolId, StateLayout>,
445    route_slots: BTreeMap<SymbolId, usize>,
446    derived_slots: BTreeMap<SymbolId, usize>,
447    output_slots: BTreeMap<SymbolId, usize>,
448}
449
450#[derive(Debug, Clone, Copy)]
451struct StateLayout {
452    offset: usize,
453    len: usize,
454}
455
456impl<'a> ExecutionLowerer<'a> {
457    fn new(model: &'a TypedModel) -> Result<Self, LoweringError> {
458        let symbol_map: BTreeMap<SymbolId, &Symbol> = model
459            .symbols
460            .iter()
461            .map(|symbol| (symbol.id, symbol))
462            .collect();
463
464        let constants = model
465            .constants
466            .iter()
467            .map(|constant| {
468                let symbol = lookup_symbol(&symbol_map, constant.symbol, constant.span)?;
469                Ok(ExecutionConstant {
470                    symbol: constant.symbol,
471                    name: symbol.name.clone(),
472                    value: constant.value.clone(),
473                    span: constant.span,
474                })
475            })
476            .collect::<Result<Vec<_>, LoweringError>>()?;
477
478        let mut parameter_slots = BTreeMap::new();
479        let parameters = model
480            .parameters
481            .iter()
482            .enumerate()
483            .map(|(index, symbol_id)| {
484                let symbol = lookup_symbol(&symbol_map, *symbol_id, model.span)?;
485                parameter_slots.insert(*symbol_id, index);
486                Ok(ExecutionSlot {
487                    symbol: *symbol_id,
488                    name: symbol.name.clone(),
489                    index,
490                    span: symbol.span,
491                })
492            })
493            .collect::<Result<Vec<_>, LoweringError>>()?;
494
495        let mut covariate_slots = BTreeMap::new();
496        let covariates = model
497            .covariates
498            .iter()
499            .enumerate()
500            .map(|(index, covariate)| {
501                let symbol = lookup_symbol(&symbol_map, covariate.symbol, covariate.span)?;
502                covariate_slots.insert(covariate.symbol, index);
503                Ok(ExecutionCovariate {
504                    symbol: covariate.symbol,
505                    name: symbol.name.clone(),
506                    index,
507                    interpolation: covariate.interpolation,
508                    span: covariate.span,
509                })
510            })
511            .collect::<Result<Vec<_>, LoweringError>>()?;
512
513        let mut state_slots = BTreeMap::new();
514        let mut states = Vec::with_capacity(model.states.len());
515        let mut next_state_offset = 0usize;
516        for state in &model.states {
517            let symbol = lookup_symbol(&symbol_map, state.symbol, state.span)?;
518            let len = state.size.unwrap_or(1);
519            state_slots.insert(
520                state.symbol,
521                StateLayout {
522                    offset: next_state_offset,
523                    len,
524                },
525            );
526            states.push(ExecutionState {
527                symbol: state.symbol,
528                name: symbol.name.clone(),
529                offset: next_state_offset,
530                len,
531                span: state.span,
532            });
533            next_state_offset += len;
534        }
535
536        let uses_authoring_route_kinds =
537            !model.routes.is_empty() && model.routes.iter().all(|route| route.kind.is_some());
538        let mut route_slots = BTreeMap::new();
539        let mut routes = Vec::with_capacity(model.routes.len());
540        let mut next_bolus_index = 0usize;
541        let mut next_infusion_index = 0usize;
542        for (declaration_index, route) in model.routes.iter().enumerate() {
543            let symbol = lookup_symbol(&symbol_map, route.symbol, route.span)?;
544            if route.kind == Some(RouteKind::Infusion) {
545                if let Some(property) = route.properties.first() {
546                    let label = match property.kind {
547                        RoutePropertyKind::Lag => "lag",
548                        RoutePropertyKind::Bioavailability => "bioavailability",
549                    };
550                    return Err(LoweringError::new(
551                        format!(
552                            "DSL authoring does not allow `{label}` on infusion route `{}`",
553                            symbol.name
554                        ),
555                        property.span,
556                    )
557                    .with_note("lag and bioavailability are bolus-only route properties"));
558                }
559            }
560            let index = if uses_authoring_route_kinds {
561                match route.kind.expect("authoring routes must preserve kind") {
562                    RouteKind::Bolus => {
563                        let index = next_bolus_index;
564                        next_bolus_index += 1;
565                        index
566                    }
567                    RouteKind::Infusion => {
568                        let index = next_infusion_index;
569                        next_infusion_index += 1;
570                        index
571                    }
572                }
573            } else {
574                declaration_index
575            };
576            route_slots.insert(route.symbol, index);
577            let destination =
578                lower_route_destination(&symbol_map, &state_slots, &route.destination)?;
579            routes.push(ExecutionRoute {
580                symbol: route.symbol,
581                name: symbol.name.clone(),
582                declaration_index,
583                index,
584                kind: route.kind,
585                destination,
586                has_lag: route
587                    .properties
588                    .iter()
589                    .any(|property| property.kind == RoutePropertyKind::Lag),
590                has_bioavailability: route
591                    .properties
592                    .iter()
593                    .any(|property| property.kind == RoutePropertyKind::Bioavailability),
594                span: route.span,
595            });
596        }
597
598        let mut derived_slots = BTreeMap::new();
599        let derived = model
600            .derived
601            .iter()
602            .enumerate()
603            .map(|(index, symbol_id)| {
604                let symbol = lookup_symbol(&symbol_map, *symbol_id, model.span)?;
605                derived_slots.insert(*symbol_id, index);
606                Ok(ExecutionSlot {
607                    symbol: *symbol_id,
608                    name: symbol.name.clone(),
609                    index,
610                    span: symbol.span,
611                })
612            })
613            .collect::<Result<Vec<_>, LoweringError>>()?;
614
615        let mut output_slots = BTreeMap::new();
616        let outputs = model
617            .outputs
618            .iter()
619            .enumerate()
620            .map(|(index, symbol_id)| {
621                let symbol = lookup_symbol(&symbol_map, *symbol_id, model.span)?;
622                output_slots.insert(*symbol_id, index);
623                Ok(ExecutionSlot {
624                    symbol: *symbol_id,
625                    name: symbol.name.clone(),
626                    index,
627                    span: symbol.span,
628                })
629            })
630            .collect::<Result<Vec<_>, LoweringError>>()?;
631
632        Ok(Self {
633            model,
634            metadata: ExecutionMetadata {
635                constants,
636                parameters,
637                covariates,
638                states,
639                routes,
640                derived,
641                outputs,
642                particles: model.particles,
643                analytical: model
644                    .analytical
645                    .as_ref()
646                    .map(|analytical| analytical.structure),
647            },
648            symbol_map,
649            parameter_slots,
650            covariate_slots,
651            state_slots,
652            route_slots,
653            derived_slots,
654            output_slots,
655        })
656    }
657
658    fn lower(self) -> Result<ExecutionModel, LoweringError> {
659        let abi = self.build_abi();
660        let mut kernels = Vec::new();
661
662        if let Some(block) = &self.model.derive {
663            kernels.push(self.lower_statement_kernel(KernelRole::Derive, block)?);
664        }
665        if let Some(block) = &self.model.init {
666            kernels.push(self.lower_init_kernel(block)?);
667        }
668        if let Some(block) = &self.model.dynamics {
669            kernels.push(self.lower_statement_kernel(KernelRole::Dynamics, block)?);
670        }
671        if let Some(block) = &self.model.drift {
672            kernels.push(self.lower_statement_kernel(KernelRole::Drift, block)?);
673        }
674        if let Some(block) = &self.model.diffusion {
675            kernels.push(self.lower_statement_kernel(KernelRole::Diffusion, block)?);
676        }
677        if let Some(kernel) =
678            self.lower_route_property_kernel(RoutePropertyKind::Lag, KernelRole::RouteLag)?
679        {
680            kernels.push(kernel);
681        }
682        if let Some(kernel) = self.lower_route_property_kernel(
683            RoutePropertyKind::Bioavailability,
684            KernelRole::RouteBioavailability,
685        )? {
686            kernels.push(kernel);
687        }
688        if let Some(analytical) = &self.model.analytical {
689            kernels.push(ExecutionKernel {
690                role: KernelRole::Analytical,
691                signature: signature_for(KernelRole::Analytical),
692                implementation: KernelImplementation::AnalyticalBuiltin(analytical.structure),
693                span: analytical.span,
694            });
695        }
696        kernels.push(self.lower_statement_kernel(KernelRole::Outputs, &self.model.outputs_block)?);
697
698        Ok(ExecutionModel {
699            name: self.model.name.clone(),
700            kind: self.model.kind,
701            metadata: self.metadata,
702            abi,
703            kernels,
704            span: self.model.span,
705        })
706    }
707
708    fn build_abi(&self) -> ExecutionAbi {
709        ExecutionAbi {
710            scalar: ScalarAbi::F64,
711            calling_convention: CallingConvention::DenseF64Buffers,
712            parameter_buffer: DenseBufferLayout {
713                kind: BufferKind::Parameters,
714                len: self.metadata.parameters.len(),
715                slots: self
716                    .metadata
717                    .parameters
718                    .iter()
719                    .map(|slot| BufferSlot {
720                        name: slot.name.clone(),
721                        offset: slot.index,
722                        len: 1,
723                    })
724                    .collect(),
725            },
726            covariate_buffer: DenseBufferLayout {
727                kind: BufferKind::Covariates,
728                len: self.metadata.covariates.len(),
729                slots: self
730                    .metadata
731                    .covariates
732                    .iter()
733                    .map(|slot| BufferSlot {
734                        name: slot.name.clone(),
735                        offset: slot.index,
736                        len: 1,
737                    })
738                    .collect(),
739            },
740            state_buffer: DenseBufferLayout {
741                kind: BufferKind::States,
742                len: self.metadata.states.iter().map(|state| state.len).sum(),
743                slots: self
744                    .metadata
745                    .states
746                    .iter()
747                    .map(|state| BufferSlot {
748                        name: state.name.clone(),
749                        offset: state.offset,
750                        len: state.len,
751                    })
752                    .collect(),
753            },
754            derived_buffer: DenseBufferLayout {
755                kind: BufferKind::Derived,
756                len: self.metadata.derived.len(),
757                slots: self
758                    .metadata
759                    .derived
760                    .iter()
761                    .map(|slot| BufferSlot {
762                        name: slot.name.clone(),
763                        offset: slot.index,
764                        len: 1,
765                    })
766                    .collect(),
767            },
768            output_buffer: DenseBufferLayout {
769                kind: BufferKind::Outputs,
770                len: self.metadata.outputs.len(),
771                slots: self
772                    .metadata
773                    .outputs
774                    .iter()
775                    .map(|slot| BufferSlot {
776                        name: slot.name.clone(),
777                        offset: slot.index,
778                        len: 1,
779                    })
780                    .collect(),
781            },
782            route_buffer: DenseBufferLayout {
783                kind: BufferKind::Routes,
784                len: self
785                    .metadata
786                    .routes
787                    .iter()
788                    .map(|route| route.index + 1)
789                    .max()
790                    .unwrap_or(0),
791                slots: self
792                    .metadata
793                    .routes
794                    .iter()
795                    .map(|route| BufferSlot {
796                        name: route.name.clone(),
797                        offset: route.index,
798                        len: 1,
799                    })
800                    .collect(),
801            },
802        }
803    }
804
805    fn lower_statement_kernel(
806        &self,
807        role: KernelRole,
808        block: &TypedStatementBlock,
809    ) -> Result<ExecutionKernel, LoweringError> {
810        let mut locals = LocalLowering::default();
811        let statements = block
812            .statements
813            .iter()
814            .map(|stmt| self.lower_stmt(stmt, &mut locals))
815            .collect::<Result<Vec<_>, LoweringError>>()?;
816
817        Ok(ExecutionKernel {
818            role,
819            signature: signature_for(role),
820            implementation: KernelImplementation::Statements(ExecutionProgram {
821                locals: locals.locals,
822                body: ExecutionBlock {
823                    statements,
824                    span: block.span,
825                },
826            }),
827            span: block.span,
828        })
829    }
830
831    fn lower_init_kernel(
832        &self,
833        block: &TypedStatementBlock,
834    ) -> Result<ExecutionKernel, LoweringError> {
835        let mut locals = LocalLowering::default();
836        let mut statements = self
837            .metadata
838            .states
839            .iter()
840            .flat_map(|state| {
841                let base = (0..state.len).map(|component| ExecutionStmt {
842                    kind: ExecutionStmtKind::Assign(ExecutionAssignStmt {
843                        target: ExecutionTarget {
844                            kind: ExecutionTargetKind::StateInit(ExecutionStateRef {
845                                symbol: state.symbol,
846                                base_offset: state.offset,
847                                len: state.len,
848                                index: if state.len == 1 {
849                                    None
850                                } else {
851                                    Some(Box::new(literal_int(component as i64, state.span)))
852                                },
853                                span: state.span,
854                            }),
855                            span: state.span,
856                        },
857                        value: literal_real(0.0, state.span),
858                    }),
859                    span: state.span,
860                });
861                base.collect::<Vec<_>>()
862            })
863            .collect::<Vec<_>>();
864
865        statements.extend(
866            block
867                .statements
868                .iter()
869                .map(|stmt| self.lower_stmt(stmt, &mut locals))
870                .collect::<Result<Vec<_>, LoweringError>>()?,
871        );
872
873        Ok(ExecutionKernel {
874            role: KernelRole::Init,
875            signature: signature_for(KernelRole::Init),
876            implementation: KernelImplementation::Statements(ExecutionProgram {
877                locals: locals.locals,
878                body: ExecutionBlock {
879                    statements,
880                    span: block.span,
881                },
882            }),
883            span: block.span,
884        })
885    }
886
887    fn lower_route_property_kernel(
888        &self,
889        property_kind: RoutePropertyKind,
890        role: KernelRole,
891    ) -> Result<Option<ExecutionKernel>, LoweringError> {
892        if !self.model.routes.iter().any(|route| {
893            route
894                .properties
895                .iter()
896                .any(|property| property.kind == property_kind)
897        }) {
898            return Ok(None);
899        }
900
901        let mut statements = Vec::with_capacity(self.model.routes.len());
902        let mut locals = LocalLowering::default();
903        let default_value = match property_kind {
904            RoutePropertyKind::Lag => literal_real(0.0, self.model.span),
905            RoutePropertyKind::Bioavailability => literal_real(1.0, self.model.span),
906        };
907        let route_len = self
908            .metadata
909            .routes
910            .iter()
911            .map(|route| route.index + 1)
912            .max()
913            .unwrap_or(0);
914        for route_index in 0..route_len {
915            let target_kind = match property_kind {
916                RoutePropertyKind::Lag => ExecutionTargetKind::RouteLag(route_index),
917                RoutePropertyKind::Bioavailability => {
918                    ExecutionTargetKind::RouteBioavailability(route_index)
919                }
920            };
921            statements.push(ExecutionStmt {
922                kind: ExecutionStmtKind::Assign(ExecutionAssignStmt {
923                    target: ExecutionTarget {
924                        kind: target_kind,
925                        span: self.model.span,
926                    },
927                    value: default_value.clone(),
928                }),
929                span: self.model.span,
930            });
931        }
932        for route in &self.model.routes {
933            if route.kind == Some(RouteKind::Infusion) {
934                continue;
935            }
936            let route_name = self.symbol_name(route.symbol)?.to_string();
937            let route_index = *self.route_slots.get(&route.symbol).ok_or_else(|| {
938                LoweringError::new(
939                    format!("route `{}` has no execution slot", route_name),
940                    route.span,
941                )
942            })?;
943            let expression = match route
944                .properties
945                .iter()
946                .find(|property| property.kind == property_kind)
947            {
948                Some(property) => self.lower_expr(&property.value, &mut locals)?,
949                None => continue,
950            };
951            let target_kind = match property_kind {
952                RoutePropertyKind::Lag => ExecutionTargetKind::RouteLag(route_index),
953                RoutePropertyKind::Bioavailability => {
954                    ExecutionTargetKind::RouteBioavailability(route_index)
955                }
956            };
957            statements.push(ExecutionStmt {
958                kind: ExecutionStmtKind::Assign(ExecutionAssignStmt {
959                    target: ExecutionTarget {
960                        kind: target_kind,
961                        span: route.span,
962                    },
963                    value: expression,
964                }),
965                span: route.span,
966            });
967        }
968
969        Ok(Some(ExecutionKernel {
970            role,
971            signature: signature_for(role),
972            implementation: KernelImplementation::Statements(ExecutionProgram {
973                locals: locals.locals,
974                body: ExecutionBlock {
975                    statements,
976                    span: self.model.span,
977                },
978            }),
979            span: self.model.span,
980        }))
981    }
982
983    fn lower_stmt(
984        &self,
985        stmt: &TypedStmt,
986        locals: &mut LocalLowering,
987    ) -> Result<ExecutionStmt, LoweringError> {
988        let kind = match &stmt.kind {
989            TypedStmtKind::Let(let_stmt) => {
990                let local = locals.local_slot(let_stmt.symbol, self)?;
991                ExecutionStmtKind::Let(ExecutionLetStmt {
992                    local,
993                    value: self.lower_expr(&let_stmt.value, locals)?,
994                })
995            }
996            TypedStmtKind::Assign(assign) => ExecutionStmtKind::Assign(ExecutionAssignStmt {
997                target: self.lower_target(&assign.target.kind, assign.target.span, locals)?,
998                value: self.lower_expr(&assign.value, locals)?,
999            }),
1000            TypedStmtKind::If(if_stmt) => ExecutionStmtKind::If(ExecutionIfStmt {
1001                condition: self.lower_expr(&if_stmt.condition, locals)?,
1002                then_branch: if_stmt
1003                    .then_branch
1004                    .iter()
1005                    .map(|stmt| self.lower_stmt(stmt, locals))
1006                    .collect::<Result<Vec<_>, _>>()?,
1007                else_branch: if_stmt
1008                    .else_branch
1009                    .as_ref()
1010                    .map(|branch| {
1011                        branch
1012                            .iter()
1013                            .map(|stmt| self.lower_stmt(stmt, locals))
1014                            .collect::<Result<Vec<_>, LoweringError>>()
1015                    })
1016                    .transpose()?,
1017            }),
1018            TypedStmtKind::For(for_stmt) => {
1019                let local = locals.local_slot(for_stmt.binding, self)?;
1020                ExecutionStmtKind::For(ExecutionForStmt {
1021                    local,
1022                    range: self.lower_range(&for_stmt.range, locals)?,
1023                    body: for_stmt
1024                        .body
1025                        .iter()
1026                        .map(|stmt| self.lower_stmt(stmt, locals))
1027                        .collect::<Result<Vec<_>, _>>()?,
1028                })
1029            }
1030        };
1031
1032        Ok(ExecutionStmt {
1033            kind,
1034            span: stmt.span,
1035        })
1036    }
1037
1038    fn lower_range(
1039        &self,
1040        range: &TypedRangeExpr,
1041        locals: &mut LocalLowering,
1042    ) -> Result<ExecutionRange, LoweringError> {
1043        Ok(ExecutionRange {
1044            start: self.lower_expr(&range.start, locals)?,
1045            end: self.lower_expr(&range.end, locals)?,
1046            span: range.span,
1047        })
1048    }
1049
1050    fn lower_target(
1051        &self,
1052        target: &TypedAssignTargetKind,
1053        span: Span,
1054        locals: &mut LocalLowering,
1055    ) -> Result<ExecutionTarget, LoweringError> {
1056        let kind = match target {
1057            TypedAssignTargetKind::Derived(symbol) => {
1058                ExecutionTargetKind::Derived(self.slot_for_derived(*symbol, span)?)
1059            }
1060            TypedAssignTargetKind::Output(symbol) => {
1061                ExecutionTargetKind::Output(self.slot_for_output(*symbol, span)?)
1062            }
1063            TypedAssignTargetKind::StateInit(place) => {
1064                ExecutionTargetKind::StateInit(self.lower_state_ref(place, locals)?)
1065            }
1066            TypedAssignTargetKind::Derivative(place) => {
1067                ExecutionTargetKind::StateDerivative(self.lower_state_ref(place, locals)?)
1068            }
1069            TypedAssignTargetKind::Noise(place) => {
1070                ExecutionTargetKind::StateNoise(self.lower_state_ref(place, locals)?)
1071            }
1072        };
1073        Ok(ExecutionTarget { kind, span })
1074    }
1075
1076    fn lower_expr(
1077        &self,
1078        expr: &TypedExpr,
1079        locals: &mut LocalLowering,
1080    ) -> Result<ExecutionExpr, LoweringError> {
1081        if let Some(constant) = &expr.constant {
1082            return Ok(ExecutionExpr {
1083                kind: ExecutionExprKind::Literal(constant.clone()),
1084                ty: expr.ty,
1085                constant: Some(constant.clone()),
1086                span: expr.span,
1087            });
1088        }
1089
1090        let kind = match &expr.kind {
1091            TypedExprKind::Literal(constant) => ExecutionExprKind::Literal(constant.clone()),
1092            TypedExprKind::Symbol(symbol) => {
1093                let symbol_info = lookup_symbol(&self.symbol_map, *symbol, expr.span)?;
1094                match symbol_info.kind {
1095                    SymbolKind::Parameter => ExecutionExprKind::Load(ExecutionLoad::Parameter(
1096                        self.slot_for_parameter(*symbol, expr.span)?,
1097                    )),
1098                    SymbolKind::Covariate => ExecutionExprKind::Load(ExecutionLoad::Covariate(
1099                        self.slot_for_covariate(*symbol, expr.span)?,
1100                    )),
1101                    SymbolKind::Derived => ExecutionExprKind::Load(ExecutionLoad::Derived(
1102                        self.slot_for_derived(*symbol, expr.span)?,
1103                    )),
1104                    SymbolKind::Local | SymbolKind::LoopBinding => ExecutionExprKind::Load(
1105                        ExecutionLoad::Local(locals.local_slot(*symbol, self)?),
1106                    ),
1107                    SymbolKind::Constant => {
1108                        return Err(LoweringError::new(
1109                            format!(
1110                                "constant `{}` should have been folded before execution lowering",
1111                                symbol_info.name
1112                            ),
1113                            expr.span,
1114                        ));
1115                    }
1116                    SymbolKind::State => {
1117                        return Err(LoweringError::new(
1118                            format!(
1119                                "state `{}` should lower through a state reference",
1120                                symbol_info.name
1121                            ),
1122                            expr.span,
1123                        ));
1124                    }
1125                    SymbolKind::Route => {
1126                        return Err(LoweringError::new(
1127                            format!(
1128                                "route `{}` is not a scalar execution input",
1129                                symbol_info.name
1130                            ),
1131                            expr.span,
1132                        )
1133                        .with_note("routes must lower through `rate(route)` or route metadata"));
1134                    }
1135                    SymbolKind::Output => {
1136                        return Err(LoweringError::new(
1137                            format!(
1138                                "output `{}` cannot be read inside execution kernels",
1139                                symbol_info.name
1140                            ),
1141                            expr.span,
1142                        ));
1143                    }
1144                }
1145            }
1146            TypedExprKind::StateValue(place) => {
1147                ExecutionExprKind::Load(ExecutionLoad::State(self.lower_state_ref(place, locals)?))
1148            }
1149            TypedExprKind::Unary { op, expr } => ExecutionExprKind::Unary {
1150                op: *op,
1151                expr: Box::new(self.lower_expr(expr, locals)?),
1152            },
1153            TypedExprKind::Binary { op, lhs, rhs } => ExecutionExprKind::Binary {
1154                op: *op,
1155                lhs: Box::new(self.lower_expr(lhs, locals)?),
1156                rhs: Box::new(self.lower_expr(rhs, locals)?),
1157            },
1158            TypedExprKind::Call { callee, args } => match callee {
1159                TypedCall::Math(intrinsic) => ExecutionExprKind::Call {
1160                    callee: ExecutionCall::Math(*intrinsic),
1161                    args: args
1162                        .iter()
1163                        .map(|arg| self.lower_expr(arg, locals))
1164                        .collect::<Result<Vec<_>, _>>()?,
1165                },
1166                TypedCall::Rate(route) => {
1167                    let route_name = self.symbol_name(*route)?.to_string();
1168                    let route_index = *self.route_slots.get(route).ok_or_else(|| {
1169                        LoweringError::new(
1170                            format!("route `{}` has no execution slot", route_name),
1171                            expr.span,
1172                        )
1173                    })?;
1174                    ExecutionExprKind::Load(ExecutionLoad::RouteInput {
1175                        route: *route,
1176                        index: route_index,
1177                    })
1178                }
1179            },
1180        };
1181
1182        Ok(ExecutionExpr {
1183            kind,
1184            ty: expr.ty,
1185            constant: None,
1186            span: expr.span,
1187        })
1188    }
1189
1190    fn lower_state_ref(
1191        &self,
1192        place: &TypedStatePlace,
1193        locals: &mut LocalLowering,
1194    ) -> Result<ExecutionStateRef, LoweringError> {
1195        let state_name = self.symbol_name(place.state)?.to_string();
1196        let layout = self.state_slots.get(&place.state).copied().ok_or_else(|| {
1197            LoweringError::new(
1198                format!("state `{}` has no execution layout", state_name),
1199                place.span,
1200            )
1201        })?;
1202        let index = place
1203            .index
1204            .as_ref()
1205            .map(|index| self.lower_expr(index, locals))
1206            .transpose()?
1207            .map(Box::new);
1208        Ok(ExecutionStateRef {
1209            symbol: place.state,
1210            base_offset: layout.offset,
1211            len: layout.len,
1212            index,
1213            span: place.span,
1214        })
1215    }
1216
1217    fn slot_for_parameter(&self, symbol: SymbolId, span: Span) -> Result<usize, LoweringError> {
1218        self.parameter_slots.get(&symbol).copied().ok_or_else(|| {
1219            LoweringError::new(
1220                format!(
1221                    "parameter `{}` has no ABI slot",
1222                    self.symbol_name(symbol).unwrap_or("<unknown>")
1223                ),
1224                span,
1225            )
1226        })
1227    }
1228
1229    fn slot_for_covariate(&self, symbol: SymbolId, span: Span) -> Result<usize, LoweringError> {
1230        self.covariate_slots.get(&symbol).copied().ok_or_else(|| {
1231            LoweringError::new(
1232                format!(
1233                    "covariate `{}` has no ABI slot",
1234                    self.symbol_name(symbol).unwrap_or("<unknown>")
1235                ),
1236                span,
1237            )
1238        })
1239    }
1240
1241    fn slot_for_derived(&self, symbol: SymbolId, span: Span) -> Result<usize, LoweringError> {
1242        self.derived_slots.get(&symbol).copied().ok_or_else(|| {
1243            LoweringError::new(
1244                format!(
1245                    "derived value `{}` has no ABI slot",
1246                    self.symbol_name(symbol).unwrap_or("<unknown>")
1247                ),
1248                span,
1249            )
1250        })
1251    }
1252
1253    fn slot_for_output(&self, symbol: SymbolId, span: Span) -> Result<usize, LoweringError> {
1254        self.output_slots.get(&symbol).copied().ok_or_else(|| {
1255            LoweringError::new(
1256                format!(
1257                    "output `{}` has no ABI slot",
1258                    self.symbol_name(symbol).unwrap_or("<unknown>")
1259                ),
1260                span,
1261            )
1262        })
1263    }
1264
1265    fn symbol_name(&self, symbol: SymbolId) -> Result<&str, LoweringError> {
1266        Ok(&lookup_symbol(&self.symbol_map, symbol, self.model.span)?.name)
1267    }
1268}
1269
1270#[derive(Default)]
1271struct LocalLowering {
1272    locals: Vec<ExecutionLocal>,
1273    slots: BTreeMap<SymbolId, usize>,
1274}
1275
1276impl LocalLowering {
1277    fn local_slot(
1278        &mut self,
1279        symbol: SymbolId,
1280        lowerer: &ExecutionLowerer<'_>,
1281    ) -> Result<usize, LoweringError> {
1282        if let Some(slot) = self.slots.get(&symbol).copied() {
1283            return Ok(slot);
1284        }
1285        let symbol_info = lookup_symbol(&lowerer.symbol_map, symbol, lowerer.model.span)?;
1286        let ty = match symbol_info.ty {
1287            SymbolType::Scalar(ty) => ty,
1288            SymbolType::Array { .. } => {
1289                return Err(LoweringError::new(
1290                    format!("local `{}` must be scalar", symbol_info.name),
1291                    symbol_info.span,
1292                ));
1293            }
1294            SymbolType::Route => {
1295                return Err(LoweringError::new(
1296                    format!("local `{}` cannot be a route handle", symbol_info.name),
1297                    symbol_info.span,
1298                ));
1299            }
1300        };
1301        let slot = self.locals.len();
1302        self.locals.push(ExecutionLocal {
1303            symbol,
1304            name: symbol_info.name.clone(),
1305            index: slot,
1306            ty,
1307            kind: symbol_info.kind,
1308            span: symbol_info.span,
1309        });
1310        self.slots.insert(symbol, slot);
1311        Ok(slot)
1312    }
1313}
1314
1315fn signature_for(role: KernelRole) -> KernelSignature {
1316    let args = match role {
1317        KernelRole::Derive => vec![
1318            arg(KernelArgumentKind::Time, KernelAccess::Input),
1319            arg(KernelArgumentKind::Parameters, KernelAccess::Input),
1320            arg(KernelArgumentKind::Covariates, KernelAccess::Input),
1321            arg(KernelArgumentKind::RouteInputs, KernelAccess::Input),
1322            arg(KernelArgumentKind::States, KernelAccess::Input),
1323            arg(KernelArgumentKind::Derived, KernelAccess::Output),
1324        ],
1325        KernelRole::Dynamics => vec![
1326            arg(KernelArgumentKind::Time, KernelAccess::Input),
1327            arg(KernelArgumentKind::States, KernelAccess::Input),
1328            arg(KernelArgumentKind::Parameters, KernelAccess::Input),
1329            arg(KernelArgumentKind::Covariates, KernelAccess::Input),
1330            arg(KernelArgumentKind::RouteInputs, KernelAccess::Input),
1331            arg(KernelArgumentKind::Derived, KernelAccess::Input),
1332            arg(KernelArgumentKind::StateDerivatives, KernelAccess::Output),
1333        ],
1334        KernelRole::Outputs => vec![
1335            arg(KernelArgumentKind::Time, KernelAccess::Input),
1336            arg(KernelArgumentKind::States, KernelAccess::Input),
1337            arg(KernelArgumentKind::Parameters, KernelAccess::Input),
1338            arg(KernelArgumentKind::Covariates, KernelAccess::Input),
1339            arg(KernelArgumentKind::RouteInputs, KernelAccess::Input),
1340            arg(KernelArgumentKind::Derived, KernelAccess::Input),
1341            arg(KernelArgumentKind::Outputs, KernelAccess::Output),
1342        ],
1343        KernelRole::Init => vec![
1344            arg(KernelArgumentKind::Time, KernelAccess::Input),
1345            arg(KernelArgumentKind::Parameters, KernelAccess::Input),
1346            arg(KernelArgumentKind::Covariates, KernelAccess::Input),
1347            arg(KernelArgumentKind::RouteInputs, KernelAccess::Input),
1348            arg(KernelArgumentKind::Derived, KernelAccess::Input),
1349            arg(KernelArgumentKind::InitialState, KernelAccess::Output),
1350        ],
1351        KernelRole::Drift => vec![
1352            arg(KernelArgumentKind::Time, KernelAccess::Input),
1353            arg(KernelArgumentKind::States, KernelAccess::Input),
1354            arg(KernelArgumentKind::Parameters, KernelAccess::Input),
1355            arg(KernelArgumentKind::Covariates, KernelAccess::Input),
1356            arg(KernelArgumentKind::RouteInputs, KernelAccess::Input),
1357            arg(KernelArgumentKind::Derived, KernelAccess::Input),
1358            arg(KernelArgumentKind::StateDerivatives, KernelAccess::Output),
1359        ],
1360        KernelRole::Diffusion => vec![
1361            arg(KernelArgumentKind::Time, KernelAccess::Input),
1362            arg(KernelArgumentKind::States, KernelAccess::Input),
1363            arg(KernelArgumentKind::Parameters, KernelAccess::Input),
1364            arg(KernelArgumentKind::Covariates, KernelAccess::Input),
1365            arg(KernelArgumentKind::RouteInputs, KernelAccess::Input),
1366            arg(KernelArgumentKind::Derived, KernelAccess::Input),
1367            arg(KernelArgumentKind::StateNoise, KernelAccess::Output),
1368        ],
1369        KernelRole::RouteLag => vec![
1370            arg(KernelArgumentKind::Time, KernelAccess::Input),
1371            arg(KernelArgumentKind::Parameters, KernelAccess::Input),
1372            arg(KernelArgumentKind::Covariates, KernelAccess::Input),
1373            arg(KernelArgumentKind::RouteInputs, KernelAccess::Input),
1374            arg(KernelArgumentKind::Derived, KernelAccess::Input),
1375            arg(KernelArgumentKind::RouteLag, KernelAccess::Output),
1376        ],
1377        KernelRole::RouteBioavailability => vec![
1378            arg(KernelArgumentKind::Time, KernelAccess::Input),
1379            arg(KernelArgumentKind::Parameters, KernelAccess::Input),
1380            arg(KernelArgumentKind::Covariates, KernelAccess::Input),
1381            arg(KernelArgumentKind::RouteInputs, KernelAccess::Input),
1382            arg(KernelArgumentKind::Derived, KernelAccess::Input),
1383            arg(
1384                KernelArgumentKind::RouteBioavailability,
1385                KernelAccess::Output,
1386            ),
1387        ],
1388        KernelRole::Analytical => vec![
1389            arg(KernelArgumentKind::Time, KernelAccess::Input),
1390            arg(KernelArgumentKind::States, KernelAccess::Input),
1391            arg(KernelArgumentKind::Parameters, KernelAccess::Input),
1392            arg(KernelArgumentKind::Covariates, KernelAccess::Input),
1393            arg(KernelArgumentKind::RouteInputs, KernelAccess::Input),
1394            arg(KernelArgumentKind::Derived, KernelAccess::Input),
1395            arg(KernelArgumentKind::AnalyticalState, KernelAccess::Output),
1396        ],
1397    };
1398    KernelSignature { args }
1399}
1400
1401fn arg(kind: KernelArgumentKind, access: KernelAccess) -> KernelArgument {
1402    KernelArgument { kind, access }
1403}
1404
1405fn lookup_symbol<'a>(
1406    symbols: &'a BTreeMap<SymbolId, &'a Symbol>,
1407    symbol: SymbolId,
1408    span: Span,
1409) -> Result<&'a Symbol, LoweringError> {
1410    symbols.get(&symbol).copied().ok_or_else(|| {
1411        LoweringError::new(
1412            format!("symbol id {symbol} is missing from the typed model symbol table"),
1413            span,
1414        )
1415    })
1416}
1417
1418fn lower_route_destination(
1419    symbols: &BTreeMap<SymbolId, &Symbol>,
1420    state_slots: &BTreeMap<SymbolId, StateLayout>,
1421    destination: &TypedStatePlace,
1422) -> Result<RouteDestination, LoweringError> {
1423    let symbol = lookup_symbol(symbols, destination.state, destination.span)?;
1424    let layout = state_slots
1425        .get(&destination.state)
1426        .copied()
1427        .ok_or_else(|| {
1428            LoweringError::new(
1429                format!("state `{}` has no execution layout", symbol.name),
1430                destination.span,
1431            )
1432        })?;
1433    let element = match &destination.index {
1434        None => 0,
1435        Some(index) => constant_index(index, destination.span)?,
1436    };
1437    if element >= layout.len {
1438        return Err(LoweringError::new(
1439            format!(
1440                "route destination for `{}` indexes element {}, but state length is {}",
1441                symbol.name, element, layout.len
1442            ),
1443            destination.span,
1444        ));
1445    }
1446    Ok(RouteDestination {
1447        state: destination.state,
1448        state_name: symbol.name.clone(),
1449        state_offset: layout.offset + element,
1450        span: destination.span,
1451    })
1452}
1453
1454fn constant_index(expr: &TypedExpr, span: Span) -> Result<usize, LoweringError> {
1455    let value = expr
1456        .constant
1457        .as_ref()
1458        .and_then(ConstValue::as_i64)
1459        .ok_or_else(|| LoweringError::new("expected a compile-time integer index", span))?;
1460    if value < 0 {
1461        return Err(LoweringError::new(
1462            "expected a non-negative compile-time index",
1463            span,
1464        ));
1465    }
1466    Ok(value as usize)
1467}
1468
1469fn literal_real(value: f64, span: Span) -> ExecutionExpr {
1470    ExecutionExpr {
1471        kind: ExecutionExprKind::Literal(ConstValue::Real(value)),
1472        ty: ValueType::Real,
1473        constant: Some(ConstValue::Real(value)),
1474        span,
1475    }
1476}
1477
1478fn literal_int(value: i64, span: Span) -> ExecutionExpr {
1479    ExecutionExpr {
1480        kind: ExecutionExprKind::Literal(ConstValue::Int(value)),
1481        ty: ValueType::Int,
1482        constant: Some(ConstValue::Int(value)),
1483        span,
1484    }
1485}
1486
1487#[cfg(test)]
1488mod tests {
1489    use super::*;
1490    use crate::test_fixtures::STRUCTURED_BLOCK_CORPUS;
1491    use crate::{analyze_module, parse_module};
1492
1493    #[test]
1494    fn lowers_structured_block_corpus_into_execution_models() {
1495        let execution = structured_block_execution();
1496        assert_eq!(execution.models.len(), 4);
1497
1498        let ode = find_model(&execution, "one_cmt_oral_iv");
1499        assert_eq!(ode.abi.parameter_buffer.len, 5);
1500        assert_eq!(ode.abi.covariate_buffer.len, 1);
1501        assert_eq!(ode.abi.state_buffer.len, 2);
1502        assert_eq!(ode.abi.derived_buffer.len, 3);
1503        assert_eq!(ode.abi.output_buffer.len, 1);
1504        assert_eq!(ode.abi.route_buffer.len, 2);
1505        assert_eq!(ode.metadata.routes[0].destination.state_offset, 0);
1506        assert_eq!(
1507            kernel_roles(ode),
1508            vec![
1509                KernelRole::Derive,
1510                KernelRole::Dynamics,
1511                KernelRole::RouteLag,
1512                KernelRole::RouteBioavailability,
1513                KernelRole::Outputs,
1514            ]
1515        );
1516    }
1517
1518    #[test]
1519    fn authoring_routes_share_input_indices_by_kind_local_ordinal() {
1520        let src = r#"name = shared_authoring
1521kind = ode
1522
1523params = ka, ke, v, tlag, f_oral
1524states = depot, central
1525outputs = cp
1526
1527bolus(oral) -> depot
1528infusion(iv) -> central
1529lag(oral) = tlag
1530fa(oral) = f_oral
1531
1532dx(depot) = -ka * depot
1533dx(central) = ka * depot - ke * central
1534
1535out(cp) = central / v ~ continuous()
1536"#;
1537
1538        let model = crate::parse_model(src).expect("authoring model parses");
1539        let typed = crate::analyze_model(&model).expect("authoring model analyzes");
1540        let lowered = crate::lower_typed_model(&typed).expect("authoring model lowers");
1541
1542        assert_eq!(lowered.abi.route_buffer.len, 1);
1543        assert_eq!(lowered.metadata.routes.len(), 2);
1544        assert_eq!(lowered.metadata.routes[0].kind, Some(RouteKind::Bolus));
1545        assert_eq!(lowered.metadata.routes[1].kind, Some(RouteKind::Infusion));
1546        assert_eq!(lowered.metadata.routes[0].declaration_index, 0);
1547        assert_eq!(lowered.metadata.routes[1].declaration_index, 1);
1548        assert_eq!(lowered.metadata.routes[0].index, 0);
1549        assert_eq!(lowered.metadata.routes[1].index, 0);
1550        assert!(lowered.metadata.routes[0].has_lag);
1551        assert!(lowered.metadata.routes[0].has_bioavailability);
1552        assert!(!lowered.metadata.routes[1].has_lag);
1553        assert!(!lowered.metadata.routes[1].has_bioavailability);
1554    }
1555
1556    #[test]
1557    fn canonical_numeric_channel_names_flow_into_execution_metadata_and_abi() {
1558        let src = r#"name = canonical_numeric_channels
1559kind = ode
1560
1561params = ke, v
1562states = depot, central
1563outputs = cp, outeq_2
1564
1565bolus(input_10) -> depot
1566infusion(iv) -> central
1567
1568dx(depot) = -ke * depot
1569dx(central) = rate(input_10) - ke * central
1570
1571out(cp) = central / v
1572out(outeq_2) = depot / v
1573"#;
1574
1575        let model = crate::parse_model(src).expect("authoring model parses");
1576        let typed = crate::analyze_model(&model).expect("authoring model analyzes");
1577        let lowered = crate::lower_typed_model(&typed).expect("authoring model lowers");
1578
1579        assert_eq!(
1580            lowered
1581                .metadata
1582                .routes
1583                .iter()
1584                .map(|route| route.name.as_str())
1585                .collect::<Vec<_>>(),
1586            vec!["input_10", "iv"]
1587        );
1588        assert_eq!(
1589            lowered
1590                .metadata
1591                .outputs
1592                .iter()
1593                .map(|output| output.name.as_str())
1594                .collect::<Vec<_>>(),
1595            vec!["cp", "outeq_2"]
1596        );
1597        assert_eq!(
1598            lowered
1599                .abi
1600                .route_buffer
1601                .slots
1602                .iter()
1603                .map(|slot| slot.name.as_str())
1604                .collect::<Vec<_>>(),
1605            vec!["input_10", "iv"]
1606        );
1607        assert_eq!(
1608            lowered
1609                .abi
1610                .output_buffer
1611                .slots
1612                .iter()
1613                .map(|slot| slot.name.as_str())
1614                .collect::<Vec<_>>(),
1615            vec!["cp", "outeq_2"]
1616        );
1617    }
1618
1619    #[test]
1620    fn authoring_routes_reject_infusion_lag_properties() {
1621        let src = r#"name = invalid_infusion_lag
1622kind = ode
1623
1624params = ke, v, tlag
1625states = central
1626outputs = cp
1627
1628infusion(iv) -> central
1629lag(iv) = tlag
1630
1631dx(central) = -ke * central
1632
1633out(cp) = central / v ~ continuous()
1634"#;
1635
1636        let model = crate::parse_model(src).expect("authoring model parses");
1637        let typed = crate::analyze_model(&model).expect("authoring model analyzes");
1638        let error = crate::lower_typed_model(&typed)
1639            .err()
1640            .expect("infusion lag should fail during lowering");
1641
1642        assert!(error
1643            .to_string()
1644            .contains("DSL authoring does not allow `lag` on infusion route `iv`"));
1645    }
1646
1647    #[test]
1648    fn authoring_routes_reject_infusion_bioavailability_properties() {
1649        let src = r#"name = invalid_infusion_fa
1650kind = ode
1651
1652params = ke, v, f_iv
1653states = central
1654outputs = cp
1655
1656infusion(iv) -> central
1657fa(iv) = f_iv
1658
1659dx(central) = -ke * central
1660
1661out(cp) = central / v ~ continuous()
1662"#;
1663
1664        let model = crate::parse_model(src).expect("authoring model parses");
1665        let typed = crate::analyze_model(&model).expect("authoring model analyzes");
1666        let error = crate::lower_typed_model(&typed)
1667            .err()
1668            .expect("infusion bioavailability should fail during lowering");
1669
1670        assert!(error
1671            .to_string()
1672            .contains("DSL authoring does not allow `bioavailability` on infusion route `iv`"));
1673    }
1674
1675    #[test]
1676    fn flattens_array_states_and_preserves_loop_structure() {
1677        let execution = structured_block_execution();
1678        let transit = find_model(&execution, "transit_absorption");
1679        assert_eq!(transit.abi.state_buffer.len, 5);
1680        assert_eq!(transit.metadata.states[0].name, "transit");
1681        assert_eq!(transit.metadata.states[0].offset, 0);
1682        assert_eq!(transit.metadata.states[0].len, 4);
1683        assert_eq!(transit.metadata.states[1].name, "central");
1684        assert_eq!(transit.metadata.states[1].offset, 4);
1685        assert!(transit.kernel(KernelRole::RouteLag).is_none());
1686        assert!(transit.kernel(KernelRole::RouteBioavailability).is_none());
1687
1688        let dynamics = transit
1689            .kernel(KernelRole::Dynamics)
1690            .expect("dynamics kernel");
1691        let KernelImplementation::Statements(program) = &dynamics.implementation else {
1692            panic!("expected statement-based dynamics kernel");
1693        };
1694        assert!(program
1695            .body
1696            .statements
1697            .iter()
1698            .any(|stmt| matches!(stmt.kind, ExecutionStmtKind::For(_))));
1699    }
1700
1701    #[test]
1702    fn analytical_models_lower_to_builtin_execution_kernels() {
1703        let execution = structured_block_execution();
1704        let analytical = find_model(&execution, "one_cmt_abs");
1705        let kernel = analytical
1706            .kernel(KernelRole::Analytical)
1707            .expect("analytical kernel");
1708        assert_eq!(
1709            kernel.signature.args,
1710            vec![
1711                arg(KernelArgumentKind::Time, KernelAccess::Input),
1712                arg(KernelArgumentKind::States, KernelAccess::Input),
1713                arg(KernelArgumentKind::Parameters, KernelAccess::Input),
1714                arg(KernelArgumentKind::Covariates, KernelAccess::Input),
1715                arg(KernelArgumentKind::RouteInputs, KernelAccess::Input),
1716                arg(KernelArgumentKind::Derived, KernelAccess::Input),
1717                arg(KernelArgumentKind::AnalyticalState, KernelAccess::Output),
1718            ]
1719        );
1720        assert!(matches!(
1721            kernel.implementation,
1722            KernelImplementation::AnalyticalBuiltin(AnalyticalKernel::OneCompartmentWithAbsorption)
1723        ));
1724    }
1725
1726    #[test]
1727    fn sde_models_emit_runtime_kernels_and_zero_filled_init() {
1728        let execution = structured_block_execution();
1729        let sde = find_model(&execution, "vanco_sde");
1730        assert_eq!(sde.metadata.particles, Some(1000));
1731        assert_eq!(
1732            kernel_roles(sde),
1733            vec![
1734                KernelRole::Init,
1735                KernelRole::Drift,
1736                KernelRole::Diffusion,
1737                KernelRole::Outputs,
1738            ]
1739        );
1740
1741        let init = sde.kernel(KernelRole::Init).expect("init kernel");
1742        let KernelImplementation::Statements(program) = &init.implementation else {
1743            panic!("expected statement init kernel");
1744        };
1745        assert!(program.body.statements.len() > sde.metadata.states.len());
1746        assert!(matches!(
1747            program.body.statements[0].kind,
1748            ExecutionStmtKind::Assign(ExecutionAssignStmt {
1749                target: ExecutionTarget {
1750                    kind: ExecutionTargetKind::StateInit(_),
1751                    ..
1752                },
1753                ..
1754            })
1755        ));
1756    }
1757
1758    #[test]
1759    fn route_property_kernels_fill_defaults_for_unconfigured_routes() {
1760        let execution = structured_block_execution();
1761        let ode = find_model(&execution, "one_cmt_oral_iv");
1762        let lag = ode.kernel(KernelRole::RouteLag).expect("lag kernel");
1763        let bio = ode
1764            .kernel(KernelRole::RouteBioavailability)
1765            .expect("bioavailability kernel");
1766
1767        let KernelImplementation::Statements(lag_program) = &lag.implementation else {
1768            panic!("expected statement lag kernel");
1769        };
1770        let KernelImplementation::Statements(bio_program) = &bio.implementation else {
1771            panic!("expected statement bioavailability kernel");
1772        };
1773
1774        assert_eq!(lag_program.body.statements.len(), 3);
1775        assert_eq!(bio_program.body.statements.len(), 3);
1776        assert!(matches!(
1777            lag_program.body.statements[1].kind,
1778            ExecutionStmtKind::Assign(ExecutionAssignStmt {
1779                value: ExecutionExpr {
1780                    kind: ExecutionExprKind::Literal(ConstValue::Real(value)),
1781                    ..
1782                },
1783                ..
1784            }) if value == 0.0
1785        ));
1786        assert!(matches!(
1787            bio_program.body.statements[1].kind,
1788            ExecutionStmtKind::Assign(ExecutionAssignStmt {
1789                value: ExecutionExpr {
1790                    kind: ExecutionExprKind::Literal(ConstValue::Real(value)),
1791                    ..
1792                },
1793                ..
1794            }) if value == 1.0
1795        ));
1796    }
1797
1798    fn structured_block_execution() -> ExecutionModule {
1799        let src = STRUCTURED_BLOCK_CORPUS;
1800        let module = parse_module(src).expect("structured-block fixture parses");
1801        let typed = analyze_module(&module).expect("structured-block fixture analyzes");
1802        lower_typed_module(&typed).expect("execution lowering succeeds")
1803    }
1804
1805    fn find_model<'a>(module: &'a ExecutionModule, name: &str) -> &'a ExecutionModel {
1806        module
1807            .models
1808            .iter()
1809            .find(|model| model.name == name)
1810            .unwrap_or_else(|| panic!("missing model {name}"))
1811    }
1812
1813    fn kernel_roles(model: &ExecutionModel) -> Vec<KernelRole> {
1814        model.kernels.iter().map(|kernel| kernel.role).collect()
1815    }
1816}