1use std::collections::BTreeMap;
2use std::fmt;
3use std::sync::Arc;
4
5use crate::{
6 AnalyticalKernel, ConstValue, CovariateInterpolation, Diagnostic, DiagnosticPhase,
7 DiagnosticReport, MathIntrinsic, ModelKind, RouteKind, RoutePropertyKind, Span, Symbol,
8 SymbolId, SymbolKind, SymbolType, TypedAssignTargetKind, TypedBinaryOp, TypedCall, TypedExpr,
9 TypedExprKind, TypedModel, TypedModule, TypedRangeExpr, TypedStatePlace, TypedStatementBlock,
10 TypedStmt, TypedStmtKind, TypedUnaryOp, ValueType, DSL_LOWERING_GENERIC,
11};
12
13pub fn lower_typed_module(module: &TypedModule) -> Result<ExecutionModule, LoweringError> {
14 let mut models = Vec::with_capacity(module.models.len());
15 for model in &module.models {
16 models.push(lower_typed_model(model)?);
17 }
18 Ok(ExecutionModule {
19 models,
20 span: module.span,
21 })
22}
23
24pub fn lower_typed_model(model: &TypedModel) -> Result<ExecutionModel, LoweringError> {
25 ExecutionLowerer::new(model)?.lower()
26}
27
28#[derive(Debug, Clone, PartialEq)]
29pub struct ExecutionModule {
30 pub models: Vec<ExecutionModel>,
31 pub span: Span,
32}
33
34#[derive(Debug, Clone, PartialEq)]
35pub struct ExecutionModel {
36 pub name: String,
37 pub kind: ModelKind,
38 pub metadata: ExecutionMetadata,
39 pub abi: ExecutionAbi,
40 pub kernels: Vec<ExecutionKernel>,
41 pub span: Span,
42}
43
44impl ExecutionModel {
45 pub fn kernel(&self, role: KernelRole) -> Option<&ExecutionKernel> {
46 self.kernels.iter().find(|kernel| kernel.role == role)
47 }
48}
49
50#[derive(Debug, Clone, PartialEq)]
51pub struct ExecutionMetadata {
52 pub constants: Vec<ExecutionConstant>,
53 pub parameters: Vec<ExecutionSlot>,
54 pub covariates: Vec<ExecutionCovariate>,
55 pub states: Vec<ExecutionState>,
56 pub routes: Vec<ExecutionRoute>,
57 pub derived: Vec<ExecutionSlot>,
58 pub outputs: Vec<ExecutionSlot>,
59 pub particles: Option<usize>,
60 pub analytical: Option<AnalyticalKernel>,
61}
62
63#[derive(Debug, Clone, PartialEq)]
64pub struct ExecutionConstant {
65 pub symbol: SymbolId,
66 pub name: String,
67 pub value: ConstValue,
68 pub span: Span,
69}
70
71#[derive(Debug, Clone, PartialEq)]
72pub struct ExecutionSlot {
73 pub symbol: SymbolId,
74 pub name: String,
75 pub index: usize,
76 pub span: Span,
77}
78
79#[derive(Debug, Clone, PartialEq)]
80pub struct ExecutionCovariate {
81 pub symbol: SymbolId,
82 pub name: String,
83 pub index: usize,
84 pub interpolation: Option<CovariateInterpolation>,
85 pub span: Span,
86}
87
88#[derive(Debug, Clone, PartialEq)]
89pub struct ExecutionState {
90 pub symbol: SymbolId,
91 pub name: String,
92 pub offset: usize,
93 pub len: usize,
94 pub span: Span,
95}
96
97#[derive(Debug, Clone, PartialEq)]
98pub struct ExecutionRoute {
99 pub symbol: SymbolId,
100 pub name: String,
101 pub declaration_index: usize,
102 pub index: usize,
103 pub kind: Option<RouteKind>,
104 pub destination: RouteDestination,
105 pub has_lag: bool,
106 pub has_bioavailability: bool,
107 pub span: Span,
108}
109
110#[derive(Debug, Clone, PartialEq)]
111pub struct RouteDestination {
112 pub state: SymbolId,
113 pub state_name: String,
114 pub state_offset: usize,
115 pub span: Span,
116}
117
118#[derive(Debug, Clone, PartialEq, Eq)]
119pub struct ExecutionAbi {
120 pub scalar: ScalarAbi,
121 pub calling_convention: CallingConvention,
122 pub parameter_buffer: DenseBufferLayout,
123 pub covariate_buffer: DenseBufferLayout,
124 pub state_buffer: DenseBufferLayout,
125 pub derived_buffer: DenseBufferLayout,
126 pub output_buffer: DenseBufferLayout,
127 pub route_buffer: DenseBufferLayout,
128}
129
130#[derive(Debug, Clone, Copy, PartialEq, Eq)]
131pub enum ScalarAbi {
132 F64,
133}
134
135#[derive(Debug, Clone, Copy, PartialEq, Eq)]
136pub enum CallingConvention {
137 DenseF64Buffers,
138}
139
140#[derive(Debug, Clone, PartialEq, Eq)]
141pub struct DenseBufferLayout {
142 pub kind: BufferKind,
143 pub len: usize,
144 pub slots: Vec<BufferSlot>,
145}
146
147#[derive(Debug, Clone, Copy, PartialEq, Eq)]
148pub enum BufferKind {
149 Parameters,
150 Covariates,
151 States,
152 Derived,
153 Outputs,
154 Routes,
155}
156
157#[derive(Debug, Clone, PartialEq, Eq)]
158pub struct BufferSlot {
159 pub name: String,
160 pub offset: usize,
161 pub len: usize,
162}
163
164#[derive(Debug, Clone, PartialEq)]
165pub struct ExecutionKernel {
166 pub role: KernelRole,
167 pub signature: KernelSignature,
168 pub implementation: KernelImplementation,
169 pub span: Span,
170}
171
172#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)]
173pub enum KernelRole {
174 Derive,
175 Dynamics,
176 Outputs,
177 Init,
178 Drift,
179 Diffusion,
180 RouteLag,
181 RouteBioavailability,
182 Analytical,
183}
184
185#[derive(Debug, Clone, PartialEq, Eq)]
186pub struct KernelSignature {
187 pub args: Vec<KernelArgument>,
188}
189
190#[derive(Debug, Clone, Copy, PartialEq, Eq)]
191pub struct KernelArgument {
192 pub kind: KernelArgumentKind,
193 pub access: KernelAccess,
194}
195
196#[derive(Debug, Clone, Copy, PartialEq, Eq)]
197pub enum KernelArgumentKind {
198 Time,
199 Parameters,
200 Covariates,
201 States,
202 RouteInputs,
203 Derived,
204 Outputs,
205 StateDerivatives,
206 InitialState,
207 StateNoise,
208 RouteLag,
209 RouteBioavailability,
210 AnalyticalState,
211}
212
213#[derive(Debug, Clone, Copy, PartialEq, Eq)]
214pub enum KernelAccess {
215 Input,
216 Output,
217}
218
219#[derive(Debug, Clone, PartialEq)]
220pub enum KernelImplementation {
221 Statements(ExecutionProgram),
222 AnalyticalBuiltin(AnalyticalKernel),
223}
224
225#[derive(Debug, Clone, PartialEq)]
226pub struct ExecutionProgram {
227 pub locals: Vec<ExecutionLocal>,
228 pub body: ExecutionBlock,
229}
230
231#[derive(Debug, Clone, PartialEq)]
232pub struct ExecutionLocal {
233 pub symbol: SymbolId,
234 pub name: String,
235 pub index: usize,
236 pub ty: ValueType,
237 pub kind: SymbolKind,
238 pub span: Span,
239}
240
241#[derive(Debug, Clone, PartialEq)]
242pub struct ExecutionBlock {
243 pub statements: Vec<ExecutionStmt>,
244 pub span: Span,
245}
246
247#[derive(Debug, Clone, PartialEq)]
248pub struct ExecutionStmt {
249 pub kind: ExecutionStmtKind,
250 pub span: Span,
251}
252
253#[derive(Debug, Clone, PartialEq)]
254pub enum ExecutionStmtKind {
255 Let(ExecutionLetStmt),
256 Assign(ExecutionAssignStmt),
257 If(ExecutionIfStmt),
258 For(ExecutionForStmt),
259}
260
261#[derive(Debug, Clone, PartialEq)]
262pub struct ExecutionLetStmt {
263 pub local: usize,
264 pub value: ExecutionExpr,
265}
266
267#[derive(Debug, Clone, PartialEq)]
268pub struct ExecutionAssignStmt {
269 pub target: ExecutionTarget,
270 pub value: ExecutionExpr,
271}
272
273#[derive(Debug, Clone, PartialEq)]
274pub struct ExecutionIfStmt {
275 pub condition: ExecutionExpr,
276 pub then_branch: Vec<ExecutionStmt>,
277 pub else_branch: Option<Vec<ExecutionStmt>>,
278}
279
280#[derive(Debug, Clone, PartialEq)]
281pub struct ExecutionForStmt {
282 pub local: usize,
283 pub range: ExecutionRange,
284 pub body: Vec<ExecutionStmt>,
285}
286
287#[derive(Debug, Clone, PartialEq)]
288pub struct ExecutionRange {
289 pub start: ExecutionExpr,
290 pub end: ExecutionExpr,
291 pub span: Span,
292}
293
294#[derive(Debug, Clone, PartialEq)]
295pub struct ExecutionTarget {
296 pub kind: ExecutionTargetKind,
297 pub span: Span,
298}
299
300#[derive(Debug, Clone, PartialEq)]
301pub enum ExecutionTargetKind {
302 Derived(usize),
303 Output(usize),
304 StateInit(ExecutionStateRef),
305 StateDerivative(ExecutionStateRef),
306 StateNoise(ExecutionStateRef),
307 RouteLag(usize),
308 RouteBioavailability(usize),
309}
310
311#[derive(Debug, Clone, PartialEq)]
312pub struct ExecutionStateRef {
313 pub symbol: SymbolId,
314 pub base_offset: usize,
315 pub len: usize,
316 pub index: Option<Box<ExecutionExpr>>,
317 pub span: Span,
318}
319
320#[derive(Debug, Clone, PartialEq)]
321pub struct ExecutionExpr {
322 pub kind: ExecutionExprKind,
323 pub ty: ValueType,
324 pub constant: Option<ConstValue>,
325 pub span: Span,
326}
327
328#[derive(Debug, Clone, PartialEq)]
329pub enum ExecutionExprKind {
330 Literal(ConstValue),
331 Load(ExecutionLoad),
332 Unary {
333 op: TypedUnaryOp,
334 expr: Box<ExecutionExpr>,
335 },
336 Binary {
337 op: TypedBinaryOp,
338 lhs: Box<ExecutionExpr>,
339 rhs: Box<ExecutionExpr>,
340 },
341 Call {
342 callee: ExecutionCall,
343 args: Vec<ExecutionExpr>,
344 },
345}
346
347#[derive(Debug, Clone, PartialEq)]
348pub enum ExecutionLoad {
349 Parameter(usize),
350 Covariate(usize),
351 State(ExecutionStateRef),
352 Derived(usize),
353 Local(usize),
354 RouteInput { route: SymbolId, index: usize },
355}
356
357#[derive(Debug, Clone, PartialEq)]
358pub enum ExecutionCall {
359 Math(MathIntrinsic),
360}
361
362#[derive(Clone, PartialEq, Eq)]
363pub struct LoweringError {
364 diagnostic: Box<Diagnostic>,
365 source: Option<Arc<str>>,
366}
367
368impl LoweringError {
369 fn new(message: impl Into<String>, span: Span) -> Self {
370 Self {
371 diagnostic: Box::new(Diagnostic::error(
372 DSL_LOWERING_GENERIC,
373 DiagnosticPhase::Lowering,
374 message,
375 span,
376 )),
377 source: None,
378 }
379 }
380
381 fn with_note(mut self, note: impl Into<String>) -> Self {
382 self.diagnostic.notes.push(note.into());
383 self
384 }
385
386 pub fn diagnostic(&self) -> &Diagnostic {
387 self.diagnostic.as_ref()
388 }
389
390 pub fn into_diagnostic(self) -> Diagnostic {
391 *self.diagnostic
392 }
393
394 pub fn render(&self, src: &str) -> String {
395 self.diagnostic.render(src)
396 }
397
398 pub fn diagnostic_report(&self, source_name: impl Into<String>) -> DiagnosticReport {
399 DiagnosticReport::from_diagnostics(
400 source_name,
401 self.source(),
402 std::slice::from_ref(self.diagnostic.as_ref()),
403 )
404 }
405
406 pub fn with_source(mut self, source: impl Into<Arc<str>>) -> Self {
407 self.source = Some(source.into());
408 self
409 }
410
411 pub fn source(&self) -> Option<&str> {
412 self.source.as_deref()
413 }
414}
415
416impl fmt::Debug for LoweringError {
417 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
418 fmt::Display::fmt(self, f)
419 }
420}
421
422impl fmt::Display for LoweringError {
423 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
424 if let Some(source) = self.source() {
425 return f.write_str(&self.render(source));
426 }
427 let span = self.diagnostic.primary_span();
428 write!(
429 f,
430 "{} at bytes {}..{}",
431 self.diagnostic.message, span.start, span.end
432 )
433 }
434}
435
436impl std::error::Error for LoweringError {}
437
438struct ExecutionLowerer<'a> {
439 model: &'a TypedModel,
440 metadata: ExecutionMetadata,
441 symbol_map: BTreeMap<SymbolId, &'a Symbol>,
442 parameter_slots: BTreeMap<SymbolId, usize>,
443 covariate_slots: BTreeMap<SymbolId, usize>,
444 state_slots: BTreeMap<SymbolId, StateLayout>,
445 route_slots: BTreeMap<SymbolId, usize>,
446 derived_slots: BTreeMap<SymbolId, usize>,
447 output_slots: BTreeMap<SymbolId, usize>,
448}
449
450#[derive(Debug, Clone, Copy)]
451struct StateLayout {
452 offset: usize,
453 len: usize,
454}
455
456impl<'a> ExecutionLowerer<'a> {
457 fn new(model: &'a TypedModel) -> Result<Self, LoweringError> {
458 let symbol_map: BTreeMap<SymbolId, &Symbol> = model
459 .symbols
460 .iter()
461 .map(|symbol| (symbol.id, symbol))
462 .collect();
463
464 let constants = model
465 .constants
466 .iter()
467 .map(|constant| {
468 let symbol = lookup_symbol(&symbol_map, constant.symbol, constant.span)?;
469 Ok(ExecutionConstant {
470 symbol: constant.symbol,
471 name: symbol.name.clone(),
472 value: constant.value.clone(),
473 span: constant.span,
474 })
475 })
476 .collect::<Result<Vec<_>, LoweringError>>()?;
477
478 let mut parameter_slots = BTreeMap::new();
479 let parameters = model
480 .parameters
481 .iter()
482 .enumerate()
483 .map(|(index, symbol_id)| {
484 let symbol = lookup_symbol(&symbol_map, *symbol_id, model.span)?;
485 parameter_slots.insert(*symbol_id, index);
486 Ok(ExecutionSlot {
487 symbol: *symbol_id,
488 name: symbol.name.clone(),
489 index,
490 span: symbol.span,
491 })
492 })
493 .collect::<Result<Vec<_>, LoweringError>>()?;
494
495 let mut covariate_slots = BTreeMap::new();
496 let covariates = model
497 .covariates
498 .iter()
499 .enumerate()
500 .map(|(index, covariate)| {
501 let symbol = lookup_symbol(&symbol_map, covariate.symbol, covariate.span)?;
502 covariate_slots.insert(covariate.symbol, index);
503 Ok(ExecutionCovariate {
504 symbol: covariate.symbol,
505 name: symbol.name.clone(),
506 index,
507 interpolation: covariate.interpolation,
508 span: covariate.span,
509 })
510 })
511 .collect::<Result<Vec<_>, LoweringError>>()?;
512
513 let mut state_slots = BTreeMap::new();
514 let mut states = Vec::with_capacity(model.states.len());
515 let mut next_state_offset = 0usize;
516 for state in &model.states {
517 let symbol = lookup_symbol(&symbol_map, state.symbol, state.span)?;
518 let len = state.size.unwrap_or(1);
519 state_slots.insert(
520 state.symbol,
521 StateLayout {
522 offset: next_state_offset,
523 len,
524 },
525 );
526 states.push(ExecutionState {
527 symbol: state.symbol,
528 name: symbol.name.clone(),
529 offset: next_state_offset,
530 len,
531 span: state.span,
532 });
533 next_state_offset += len;
534 }
535
536 let uses_authoring_route_kinds =
537 !model.routes.is_empty() && model.routes.iter().all(|route| route.kind.is_some());
538 let mut route_slots = BTreeMap::new();
539 let mut routes = Vec::with_capacity(model.routes.len());
540 let mut next_bolus_index = 0usize;
541 let mut next_infusion_index = 0usize;
542 for (declaration_index, route) in model.routes.iter().enumerate() {
543 let symbol = lookup_symbol(&symbol_map, route.symbol, route.span)?;
544 if route.kind == Some(RouteKind::Infusion) {
545 if let Some(property) = route.properties.first() {
546 let label = match property.kind {
547 RoutePropertyKind::Lag => "lag",
548 RoutePropertyKind::Bioavailability => "bioavailability",
549 };
550 return Err(LoweringError::new(
551 format!(
552 "DSL authoring does not allow `{label}` on infusion route `{}`",
553 symbol.name
554 ),
555 property.span,
556 )
557 .with_note("lag and bioavailability are bolus-only route properties"));
558 }
559 }
560 let index = if uses_authoring_route_kinds {
561 match route.kind.expect("authoring routes must preserve kind") {
562 RouteKind::Bolus => {
563 let index = next_bolus_index;
564 next_bolus_index += 1;
565 index
566 }
567 RouteKind::Infusion => {
568 let index = next_infusion_index;
569 next_infusion_index += 1;
570 index
571 }
572 }
573 } else {
574 declaration_index
575 };
576 route_slots.insert(route.symbol, index);
577 let destination =
578 lower_route_destination(&symbol_map, &state_slots, &route.destination)?;
579 routes.push(ExecutionRoute {
580 symbol: route.symbol,
581 name: symbol.name.clone(),
582 declaration_index,
583 index,
584 kind: route.kind,
585 destination,
586 has_lag: route
587 .properties
588 .iter()
589 .any(|property| property.kind == RoutePropertyKind::Lag),
590 has_bioavailability: route
591 .properties
592 .iter()
593 .any(|property| property.kind == RoutePropertyKind::Bioavailability),
594 span: route.span,
595 });
596 }
597
598 let mut derived_slots = BTreeMap::new();
599 let derived = model
600 .derived
601 .iter()
602 .enumerate()
603 .map(|(index, symbol_id)| {
604 let symbol = lookup_symbol(&symbol_map, *symbol_id, model.span)?;
605 derived_slots.insert(*symbol_id, index);
606 Ok(ExecutionSlot {
607 symbol: *symbol_id,
608 name: symbol.name.clone(),
609 index,
610 span: symbol.span,
611 })
612 })
613 .collect::<Result<Vec<_>, LoweringError>>()?;
614
615 let mut output_slots = BTreeMap::new();
616 let outputs = model
617 .outputs
618 .iter()
619 .enumerate()
620 .map(|(index, symbol_id)| {
621 let symbol = lookup_symbol(&symbol_map, *symbol_id, model.span)?;
622 output_slots.insert(*symbol_id, index);
623 Ok(ExecutionSlot {
624 symbol: *symbol_id,
625 name: symbol.name.clone(),
626 index,
627 span: symbol.span,
628 })
629 })
630 .collect::<Result<Vec<_>, LoweringError>>()?;
631
632 Ok(Self {
633 model,
634 metadata: ExecutionMetadata {
635 constants,
636 parameters,
637 covariates,
638 states,
639 routes,
640 derived,
641 outputs,
642 particles: model.particles,
643 analytical: model
644 .analytical
645 .as_ref()
646 .map(|analytical| analytical.structure),
647 },
648 symbol_map,
649 parameter_slots,
650 covariate_slots,
651 state_slots,
652 route_slots,
653 derived_slots,
654 output_slots,
655 })
656 }
657
658 fn lower(self) -> Result<ExecutionModel, LoweringError> {
659 let abi = self.build_abi();
660 let mut kernels = Vec::new();
661
662 if let Some(block) = &self.model.derive {
663 kernels.push(self.lower_statement_kernel(KernelRole::Derive, block)?);
664 }
665 if let Some(block) = &self.model.init {
666 kernels.push(self.lower_init_kernel(block)?);
667 }
668 if let Some(block) = &self.model.dynamics {
669 kernels.push(self.lower_statement_kernel(KernelRole::Dynamics, block)?);
670 }
671 if let Some(block) = &self.model.drift {
672 kernels.push(self.lower_statement_kernel(KernelRole::Drift, block)?);
673 }
674 if let Some(block) = &self.model.diffusion {
675 kernels.push(self.lower_statement_kernel(KernelRole::Diffusion, block)?);
676 }
677 if let Some(kernel) =
678 self.lower_route_property_kernel(RoutePropertyKind::Lag, KernelRole::RouteLag)?
679 {
680 kernels.push(kernel);
681 }
682 if let Some(kernel) = self.lower_route_property_kernel(
683 RoutePropertyKind::Bioavailability,
684 KernelRole::RouteBioavailability,
685 )? {
686 kernels.push(kernel);
687 }
688 if let Some(analytical) = &self.model.analytical {
689 kernels.push(ExecutionKernel {
690 role: KernelRole::Analytical,
691 signature: signature_for(KernelRole::Analytical),
692 implementation: KernelImplementation::AnalyticalBuiltin(analytical.structure),
693 span: analytical.span,
694 });
695 }
696 kernels.push(self.lower_statement_kernel(KernelRole::Outputs, &self.model.outputs_block)?);
697
698 Ok(ExecutionModel {
699 name: self.model.name.clone(),
700 kind: self.model.kind,
701 metadata: self.metadata,
702 abi,
703 kernels,
704 span: self.model.span,
705 })
706 }
707
708 fn build_abi(&self) -> ExecutionAbi {
709 ExecutionAbi {
710 scalar: ScalarAbi::F64,
711 calling_convention: CallingConvention::DenseF64Buffers,
712 parameter_buffer: DenseBufferLayout {
713 kind: BufferKind::Parameters,
714 len: self.metadata.parameters.len(),
715 slots: self
716 .metadata
717 .parameters
718 .iter()
719 .map(|slot| BufferSlot {
720 name: slot.name.clone(),
721 offset: slot.index,
722 len: 1,
723 })
724 .collect(),
725 },
726 covariate_buffer: DenseBufferLayout {
727 kind: BufferKind::Covariates,
728 len: self.metadata.covariates.len(),
729 slots: self
730 .metadata
731 .covariates
732 .iter()
733 .map(|slot| BufferSlot {
734 name: slot.name.clone(),
735 offset: slot.index,
736 len: 1,
737 })
738 .collect(),
739 },
740 state_buffer: DenseBufferLayout {
741 kind: BufferKind::States,
742 len: self.metadata.states.iter().map(|state| state.len).sum(),
743 slots: self
744 .metadata
745 .states
746 .iter()
747 .map(|state| BufferSlot {
748 name: state.name.clone(),
749 offset: state.offset,
750 len: state.len,
751 })
752 .collect(),
753 },
754 derived_buffer: DenseBufferLayout {
755 kind: BufferKind::Derived,
756 len: self.metadata.derived.len(),
757 slots: self
758 .metadata
759 .derived
760 .iter()
761 .map(|slot| BufferSlot {
762 name: slot.name.clone(),
763 offset: slot.index,
764 len: 1,
765 })
766 .collect(),
767 },
768 output_buffer: DenseBufferLayout {
769 kind: BufferKind::Outputs,
770 len: self.metadata.outputs.len(),
771 slots: self
772 .metadata
773 .outputs
774 .iter()
775 .map(|slot| BufferSlot {
776 name: slot.name.clone(),
777 offset: slot.index,
778 len: 1,
779 })
780 .collect(),
781 },
782 route_buffer: DenseBufferLayout {
783 kind: BufferKind::Routes,
784 len: self
785 .metadata
786 .routes
787 .iter()
788 .map(|route| route.index + 1)
789 .max()
790 .unwrap_or(0),
791 slots: self
792 .metadata
793 .routes
794 .iter()
795 .map(|route| BufferSlot {
796 name: route.name.clone(),
797 offset: route.index,
798 len: 1,
799 })
800 .collect(),
801 },
802 }
803 }
804
805 fn lower_statement_kernel(
806 &self,
807 role: KernelRole,
808 block: &TypedStatementBlock,
809 ) -> Result<ExecutionKernel, LoweringError> {
810 let mut locals = LocalLowering::default();
811 let statements = block
812 .statements
813 .iter()
814 .map(|stmt| self.lower_stmt(stmt, &mut locals))
815 .collect::<Result<Vec<_>, LoweringError>>()?;
816
817 Ok(ExecutionKernel {
818 role,
819 signature: signature_for(role),
820 implementation: KernelImplementation::Statements(ExecutionProgram {
821 locals: locals.locals,
822 body: ExecutionBlock {
823 statements,
824 span: block.span,
825 },
826 }),
827 span: block.span,
828 })
829 }
830
831 fn lower_init_kernel(
832 &self,
833 block: &TypedStatementBlock,
834 ) -> Result<ExecutionKernel, LoweringError> {
835 let mut locals = LocalLowering::default();
836 let mut statements = self
837 .metadata
838 .states
839 .iter()
840 .flat_map(|state| {
841 let base = (0..state.len).map(|component| ExecutionStmt {
842 kind: ExecutionStmtKind::Assign(ExecutionAssignStmt {
843 target: ExecutionTarget {
844 kind: ExecutionTargetKind::StateInit(ExecutionStateRef {
845 symbol: state.symbol,
846 base_offset: state.offset,
847 len: state.len,
848 index: if state.len == 1 {
849 None
850 } else {
851 Some(Box::new(literal_int(component as i64, state.span)))
852 },
853 span: state.span,
854 }),
855 span: state.span,
856 },
857 value: literal_real(0.0, state.span),
858 }),
859 span: state.span,
860 });
861 base.collect::<Vec<_>>()
862 })
863 .collect::<Vec<_>>();
864
865 statements.extend(
866 block
867 .statements
868 .iter()
869 .map(|stmt| self.lower_stmt(stmt, &mut locals))
870 .collect::<Result<Vec<_>, LoweringError>>()?,
871 );
872
873 Ok(ExecutionKernel {
874 role: KernelRole::Init,
875 signature: signature_for(KernelRole::Init),
876 implementation: KernelImplementation::Statements(ExecutionProgram {
877 locals: locals.locals,
878 body: ExecutionBlock {
879 statements,
880 span: block.span,
881 },
882 }),
883 span: block.span,
884 })
885 }
886
887 fn lower_route_property_kernel(
888 &self,
889 property_kind: RoutePropertyKind,
890 role: KernelRole,
891 ) -> Result<Option<ExecutionKernel>, LoweringError> {
892 if !self.model.routes.iter().any(|route| {
893 route
894 .properties
895 .iter()
896 .any(|property| property.kind == property_kind)
897 }) {
898 return Ok(None);
899 }
900
901 let mut statements = Vec::with_capacity(self.model.routes.len());
902 let mut locals = LocalLowering::default();
903 let default_value = match property_kind {
904 RoutePropertyKind::Lag => literal_real(0.0, self.model.span),
905 RoutePropertyKind::Bioavailability => literal_real(1.0, self.model.span),
906 };
907 let route_len = self
908 .metadata
909 .routes
910 .iter()
911 .map(|route| route.index + 1)
912 .max()
913 .unwrap_or(0);
914 for route_index in 0..route_len {
915 let target_kind = match property_kind {
916 RoutePropertyKind::Lag => ExecutionTargetKind::RouteLag(route_index),
917 RoutePropertyKind::Bioavailability => {
918 ExecutionTargetKind::RouteBioavailability(route_index)
919 }
920 };
921 statements.push(ExecutionStmt {
922 kind: ExecutionStmtKind::Assign(ExecutionAssignStmt {
923 target: ExecutionTarget {
924 kind: target_kind,
925 span: self.model.span,
926 },
927 value: default_value.clone(),
928 }),
929 span: self.model.span,
930 });
931 }
932 for route in &self.model.routes {
933 if route.kind == Some(RouteKind::Infusion) {
934 continue;
935 }
936 let route_name = self.symbol_name(route.symbol)?.to_string();
937 let route_index = *self.route_slots.get(&route.symbol).ok_or_else(|| {
938 LoweringError::new(
939 format!("route `{}` has no execution slot", route_name),
940 route.span,
941 )
942 })?;
943 let expression = match route
944 .properties
945 .iter()
946 .find(|property| property.kind == property_kind)
947 {
948 Some(property) => self.lower_expr(&property.value, &mut locals)?,
949 None => continue,
950 };
951 let target_kind = match property_kind {
952 RoutePropertyKind::Lag => ExecutionTargetKind::RouteLag(route_index),
953 RoutePropertyKind::Bioavailability => {
954 ExecutionTargetKind::RouteBioavailability(route_index)
955 }
956 };
957 statements.push(ExecutionStmt {
958 kind: ExecutionStmtKind::Assign(ExecutionAssignStmt {
959 target: ExecutionTarget {
960 kind: target_kind,
961 span: route.span,
962 },
963 value: expression,
964 }),
965 span: route.span,
966 });
967 }
968
969 Ok(Some(ExecutionKernel {
970 role,
971 signature: signature_for(role),
972 implementation: KernelImplementation::Statements(ExecutionProgram {
973 locals: locals.locals,
974 body: ExecutionBlock {
975 statements,
976 span: self.model.span,
977 },
978 }),
979 span: self.model.span,
980 }))
981 }
982
983 fn lower_stmt(
984 &self,
985 stmt: &TypedStmt,
986 locals: &mut LocalLowering,
987 ) -> Result<ExecutionStmt, LoweringError> {
988 let kind = match &stmt.kind {
989 TypedStmtKind::Let(let_stmt) => {
990 let local = locals.local_slot(let_stmt.symbol, self)?;
991 ExecutionStmtKind::Let(ExecutionLetStmt {
992 local,
993 value: self.lower_expr(&let_stmt.value, locals)?,
994 })
995 }
996 TypedStmtKind::Assign(assign) => ExecutionStmtKind::Assign(ExecutionAssignStmt {
997 target: self.lower_target(&assign.target.kind, assign.target.span, locals)?,
998 value: self.lower_expr(&assign.value, locals)?,
999 }),
1000 TypedStmtKind::If(if_stmt) => ExecutionStmtKind::If(ExecutionIfStmt {
1001 condition: self.lower_expr(&if_stmt.condition, locals)?,
1002 then_branch: if_stmt
1003 .then_branch
1004 .iter()
1005 .map(|stmt| self.lower_stmt(stmt, locals))
1006 .collect::<Result<Vec<_>, _>>()?,
1007 else_branch: if_stmt
1008 .else_branch
1009 .as_ref()
1010 .map(|branch| {
1011 branch
1012 .iter()
1013 .map(|stmt| self.lower_stmt(stmt, locals))
1014 .collect::<Result<Vec<_>, LoweringError>>()
1015 })
1016 .transpose()?,
1017 }),
1018 TypedStmtKind::For(for_stmt) => {
1019 let local = locals.local_slot(for_stmt.binding, self)?;
1020 ExecutionStmtKind::For(ExecutionForStmt {
1021 local,
1022 range: self.lower_range(&for_stmt.range, locals)?,
1023 body: for_stmt
1024 .body
1025 .iter()
1026 .map(|stmt| self.lower_stmt(stmt, locals))
1027 .collect::<Result<Vec<_>, _>>()?,
1028 })
1029 }
1030 };
1031
1032 Ok(ExecutionStmt {
1033 kind,
1034 span: stmt.span,
1035 })
1036 }
1037
1038 fn lower_range(
1039 &self,
1040 range: &TypedRangeExpr,
1041 locals: &mut LocalLowering,
1042 ) -> Result<ExecutionRange, LoweringError> {
1043 Ok(ExecutionRange {
1044 start: self.lower_expr(&range.start, locals)?,
1045 end: self.lower_expr(&range.end, locals)?,
1046 span: range.span,
1047 })
1048 }
1049
1050 fn lower_target(
1051 &self,
1052 target: &TypedAssignTargetKind,
1053 span: Span,
1054 locals: &mut LocalLowering,
1055 ) -> Result<ExecutionTarget, LoweringError> {
1056 let kind = match target {
1057 TypedAssignTargetKind::Derived(symbol) => {
1058 ExecutionTargetKind::Derived(self.slot_for_derived(*symbol, span)?)
1059 }
1060 TypedAssignTargetKind::Output(symbol) => {
1061 ExecutionTargetKind::Output(self.slot_for_output(*symbol, span)?)
1062 }
1063 TypedAssignTargetKind::StateInit(place) => {
1064 ExecutionTargetKind::StateInit(self.lower_state_ref(place, locals)?)
1065 }
1066 TypedAssignTargetKind::Derivative(place) => {
1067 ExecutionTargetKind::StateDerivative(self.lower_state_ref(place, locals)?)
1068 }
1069 TypedAssignTargetKind::Noise(place) => {
1070 ExecutionTargetKind::StateNoise(self.lower_state_ref(place, locals)?)
1071 }
1072 };
1073 Ok(ExecutionTarget { kind, span })
1074 }
1075
1076 fn lower_expr(
1077 &self,
1078 expr: &TypedExpr,
1079 locals: &mut LocalLowering,
1080 ) -> Result<ExecutionExpr, LoweringError> {
1081 if let Some(constant) = &expr.constant {
1082 return Ok(ExecutionExpr {
1083 kind: ExecutionExprKind::Literal(constant.clone()),
1084 ty: expr.ty,
1085 constant: Some(constant.clone()),
1086 span: expr.span,
1087 });
1088 }
1089
1090 let kind = match &expr.kind {
1091 TypedExprKind::Literal(constant) => ExecutionExprKind::Literal(constant.clone()),
1092 TypedExprKind::Symbol(symbol) => {
1093 let symbol_info = lookup_symbol(&self.symbol_map, *symbol, expr.span)?;
1094 match symbol_info.kind {
1095 SymbolKind::Parameter => ExecutionExprKind::Load(ExecutionLoad::Parameter(
1096 self.slot_for_parameter(*symbol, expr.span)?,
1097 )),
1098 SymbolKind::Covariate => ExecutionExprKind::Load(ExecutionLoad::Covariate(
1099 self.slot_for_covariate(*symbol, expr.span)?,
1100 )),
1101 SymbolKind::Derived => ExecutionExprKind::Load(ExecutionLoad::Derived(
1102 self.slot_for_derived(*symbol, expr.span)?,
1103 )),
1104 SymbolKind::Local | SymbolKind::LoopBinding => ExecutionExprKind::Load(
1105 ExecutionLoad::Local(locals.local_slot(*symbol, self)?),
1106 ),
1107 SymbolKind::Constant => {
1108 return Err(LoweringError::new(
1109 format!(
1110 "constant `{}` should have been folded before execution lowering",
1111 symbol_info.name
1112 ),
1113 expr.span,
1114 ));
1115 }
1116 SymbolKind::State => {
1117 return Err(LoweringError::new(
1118 format!(
1119 "state `{}` should lower through a state reference",
1120 symbol_info.name
1121 ),
1122 expr.span,
1123 ));
1124 }
1125 SymbolKind::Route => {
1126 return Err(LoweringError::new(
1127 format!(
1128 "route `{}` is not a scalar execution input",
1129 symbol_info.name
1130 ),
1131 expr.span,
1132 )
1133 .with_note("routes must lower through `rate(route)` or route metadata"));
1134 }
1135 SymbolKind::Output => {
1136 return Err(LoweringError::new(
1137 format!(
1138 "output `{}` cannot be read inside execution kernels",
1139 symbol_info.name
1140 ),
1141 expr.span,
1142 ));
1143 }
1144 }
1145 }
1146 TypedExprKind::StateValue(place) => {
1147 ExecutionExprKind::Load(ExecutionLoad::State(self.lower_state_ref(place, locals)?))
1148 }
1149 TypedExprKind::Unary { op, expr } => ExecutionExprKind::Unary {
1150 op: *op,
1151 expr: Box::new(self.lower_expr(expr, locals)?),
1152 },
1153 TypedExprKind::Binary { op, lhs, rhs } => ExecutionExprKind::Binary {
1154 op: *op,
1155 lhs: Box::new(self.lower_expr(lhs, locals)?),
1156 rhs: Box::new(self.lower_expr(rhs, locals)?),
1157 },
1158 TypedExprKind::Call { callee, args } => match callee {
1159 TypedCall::Math(intrinsic) => ExecutionExprKind::Call {
1160 callee: ExecutionCall::Math(*intrinsic),
1161 args: args
1162 .iter()
1163 .map(|arg| self.lower_expr(arg, locals))
1164 .collect::<Result<Vec<_>, _>>()?,
1165 },
1166 TypedCall::Rate(route) => {
1167 let route_name = self.symbol_name(*route)?.to_string();
1168 let route_index = *self.route_slots.get(route).ok_or_else(|| {
1169 LoweringError::new(
1170 format!("route `{}` has no execution slot", route_name),
1171 expr.span,
1172 )
1173 })?;
1174 ExecutionExprKind::Load(ExecutionLoad::RouteInput {
1175 route: *route,
1176 index: route_index,
1177 })
1178 }
1179 },
1180 };
1181
1182 Ok(ExecutionExpr {
1183 kind,
1184 ty: expr.ty,
1185 constant: None,
1186 span: expr.span,
1187 })
1188 }
1189
1190 fn lower_state_ref(
1191 &self,
1192 place: &TypedStatePlace,
1193 locals: &mut LocalLowering,
1194 ) -> Result<ExecutionStateRef, LoweringError> {
1195 let state_name = self.symbol_name(place.state)?.to_string();
1196 let layout = self.state_slots.get(&place.state).copied().ok_or_else(|| {
1197 LoweringError::new(
1198 format!("state `{}` has no execution layout", state_name),
1199 place.span,
1200 )
1201 })?;
1202 let index = place
1203 .index
1204 .as_ref()
1205 .map(|index| self.lower_expr(index, locals))
1206 .transpose()?
1207 .map(Box::new);
1208 Ok(ExecutionStateRef {
1209 symbol: place.state,
1210 base_offset: layout.offset,
1211 len: layout.len,
1212 index,
1213 span: place.span,
1214 })
1215 }
1216
1217 fn slot_for_parameter(&self, symbol: SymbolId, span: Span) -> Result<usize, LoweringError> {
1218 self.parameter_slots.get(&symbol).copied().ok_or_else(|| {
1219 LoweringError::new(
1220 format!(
1221 "parameter `{}` has no ABI slot",
1222 self.symbol_name(symbol).unwrap_or("<unknown>")
1223 ),
1224 span,
1225 )
1226 })
1227 }
1228
1229 fn slot_for_covariate(&self, symbol: SymbolId, span: Span) -> Result<usize, LoweringError> {
1230 self.covariate_slots.get(&symbol).copied().ok_or_else(|| {
1231 LoweringError::new(
1232 format!(
1233 "covariate `{}` has no ABI slot",
1234 self.symbol_name(symbol).unwrap_or("<unknown>")
1235 ),
1236 span,
1237 )
1238 })
1239 }
1240
1241 fn slot_for_derived(&self, symbol: SymbolId, span: Span) -> Result<usize, LoweringError> {
1242 self.derived_slots.get(&symbol).copied().ok_or_else(|| {
1243 LoweringError::new(
1244 format!(
1245 "derived value `{}` has no ABI slot",
1246 self.symbol_name(symbol).unwrap_or("<unknown>")
1247 ),
1248 span,
1249 )
1250 })
1251 }
1252
1253 fn slot_for_output(&self, symbol: SymbolId, span: Span) -> Result<usize, LoweringError> {
1254 self.output_slots.get(&symbol).copied().ok_or_else(|| {
1255 LoweringError::new(
1256 format!(
1257 "output `{}` has no ABI slot",
1258 self.symbol_name(symbol).unwrap_or("<unknown>")
1259 ),
1260 span,
1261 )
1262 })
1263 }
1264
1265 fn symbol_name(&self, symbol: SymbolId) -> Result<&str, LoweringError> {
1266 Ok(&lookup_symbol(&self.symbol_map, symbol, self.model.span)?.name)
1267 }
1268}
1269
1270#[derive(Default)]
1271struct LocalLowering {
1272 locals: Vec<ExecutionLocal>,
1273 slots: BTreeMap<SymbolId, usize>,
1274}
1275
1276impl LocalLowering {
1277 fn local_slot(
1278 &mut self,
1279 symbol: SymbolId,
1280 lowerer: &ExecutionLowerer<'_>,
1281 ) -> Result<usize, LoweringError> {
1282 if let Some(slot) = self.slots.get(&symbol).copied() {
1283 return Ok(slot);
1284 }
1285 let symbol_info = lookup_symbol(&lowerer.symbol_map, symbol, lowerer.model.span)?;
1286 let ty = match symbol_info.ty {
1287 SymbolType::Scalar(ty) => ty,
1288 SymbolType::Array { .. } => {
1289 return Err(LoweringError::new(
1290 format!("local `{}` must be scalar", symbol_info.name),
1291 symbol_info.span,
1292 ));
1293 }
1294 SymbolType::Route => {
1295 return Err(LoweringError::new(
1296 format!("local `{}` cannot be a route handle", symbol_info.name),
1297 symbol_info.span,
1298 ));
1299 }
1300 };
1301 let slot = self.locals.len();
1302 self.locals.push(ExecutionLocal {
1303 symbol,
1304 name: symbol_info.name.clone(),
1305 index: slot,
1306 ty,
1307 kind: symbol_info.kind,
1308 span: symbol_info.span,
1309 });
1310 self.slots.insert(symbol, slot);
1311 Ok(slot)
1312 }
1313}
1314
1315fn signature_for(role: KernelRole) -> KernelSignature {
1316 let args = match role {
1317 KernelRole::Derive => vec![
1318 arg(KernelArgumentKind::Time, KernelAccess::Input),
1319 arg(KernelArgumentKind::Parameters, KernelAccess::Input),
1320 arg(KernelArgumentKind::Covariates, KernelAccess::Input),
1321 arg(KernelArgumentKind::RouteInputs, KernelAccess::Input),
1322 arg(KernelArgumentKind::States, KernelAccess::Input),
1323 arg(KernelArgumentKind::Derived, KernelAccess::Output),
1324 ],
1325 KernelRole::Dynamics => vec![
1326 arg(KernelArgumentKind::Time, KernelAccess::Input),
1327 arg(KernelArgumentKind::States, KernelAccess::Input),
1328 arg(KernelArgumentKind::Parameters, KernelAccess::Input),
1329 arg(KernelArgumentKind::Covariates, KernelAccess::Input),
1330 arg(KernelArgumentKind::RouteInputs, KernelAccess::Input),
1331 arg(KernelArgumentKind::Derived, KernelAccess::Input),
1332 arg(KernelArgumentKind::StateDerivatives, KernelAccess::Output),
1333 ],
1334 KernelRole::Outputs => vec![
1335 arg(KernelArgumentKind::Time, KernelAccess::Input),
1336 arg(KernelArgumentKind::States, KernelAccess::Input),
1337 arg(KernelArgumentKind::Parameters, KernelAccess::Input),
1338 arg(KernelArgumentKind::Covariates, KernelAccess::Input),
1339 arg(KernelArgumentKind::RouteInputs, KernelAccess::Input),
1340 arg(KernelArgumentKind::Derived, KernelAccess::Input),
1341 arg(KernelArgumentKind::Outputs, KernelAccess::Output),
1342 ],
1343 KernelRole::Init => vec![
1344 arg(KernelArgumentKind::Time, KernelAccess::Input),
1345 arg(KernelArgumentKind::Parameters, KernelAccess::Input),
1346 arg(KernelArgumentKind::Covariates, KernelAccess::Input),
1347 arg(KernelArgumentKind::RouteInputs, KernelAccess::Input),
1348 arg(KernelArgumentKind::Derived, KernelAccess::Input),
1349 arg(KernelArgumentKind::InitialState, KernelAccess::Output),
1350 ],
1351 KernelRole::Drift => vec![
1352 arg(KernelArgumentKind::Time, KernelAccess::Input),
1353 arg(KernelArgumentKind::States, KernelAccess::Input),
1354 arg(KernelArgumentKind::Parameters, KernelAccess::Input),
1355 arg(KernelArgumentKind::Covariates, KernelAccess::Input),
1356 arg(KernelArgumentKind::RouteInputs, KernelAccess::Input),
1357 arg(KernelArgumentKind::Derived, KernelAccess::Input),
1358 arg(KernelArgumentKind::StateDerivatives, KernelAccess::Output),
1359 ],
1360 KernelRole::Diffusion => vec![
1361 arg(KernelArgumentKind::Time, KernelAccess::Input),
1362 arg(KernelArgumentKind::States, KernelAccess::Input),
1363 arg(KernelArgumentKind::Parameters, KernelAccess::Input),
1364 arg(KernelArgumentKind::Covariates, KernelAccess::Input),
1365 arg(KernelArgumentKind::RouteInputs, KernelAccess::Input),
1366 arg(KernelArgumentKind::Derived, KernelAccess::Input),
1367 arg(KernelArgumentKind::StateNoise, KernelAccess::Output),
1368 ],
1369 KernelRole::RouteLag => vec![
1370 arg(KernelArgumentKind::Time, KernelAccess::Input),
1371 arg(KernelArgumentKind::Parameters, KernelAccess::Input),
1372 arg(KernelArgumentKind::Covariates, KernelAccess::Input),
1373 arg(KernelArgumentKind::RouteInputs, KernelAccess::Input),
1374 arg(KernelArgumentKind::Derived, KernelAccess::Input),
1375 arg(KernelArgumentKind::RouteLag, KernelAccess::Output),
1376 ],
1377 KernelRole::RouteBioavailability => vec![
1378 arg(KernelArgumentKind::Time, KernelAccess::Input),
1379 arg(KernelArgumentKind::Parameters, KernelAccess::Input),
1380 arg(KernelArgumentKind::Covariates, KernelAccess::Input),
1381 arg(KernelArgumentKind::RouteInputs, KernelAccess::Input),
1382 arg(KernelArgumentKind::Derived, KernelAccess::Input),
1383 arg(
1384 KernelArgumentKind::RouteBioavailability,
1385 KernelAccess::Output,
1386 ),
1387 ],
1388 KernelRole::Analytical => vec![
1389 arg(KernelArgumentKind::Time, KernelAccess::Input),
1390 arg(KernelArgumentKind::States, KernelAccess::Input),
1391 arg(KernelArgumentKind::Parameters, KernelAccess::Input),
1392 arg(KernelArgumentKind::Covariates, KernelAccess::Input),
1393 arg(KernelArgumentKind::RouteInputs, KernelAccess::Input),
1394 arg(KernelArgumentKind::Derived, KernelAccess::Input),
1395 arg(KernelArgumentKind::AnalyticalState, KernelAccess::Output),
1396 ],
1397 };
1398 KernelSignature { args }
1399}
1400
1401fn arg(kind: KernelArgumentKind, access: KernelAccess) -> KernelArgument {
1402 KernelArgument { kind, access }
1403}
1404
1405fn lookup_symbol<'a>(
1406 symbols: &'a BTreeMap<SymbolId, &'a Symbol>,
1407 symbol: SymbolId,
1408 span: Span,
1409) -> Result<&'a Symbol, LoweringError> {
1410 symbols.get(&symbol).copied().ok_or_else(|| {
1411 LoweringError::new(
1412 format!("symbol id {symbol} is missing from the typed model symbol table"),
1413 span,
1414 )
1415 })
1416}
1417
1418fn lower_route_destination(
1419 symbols: &BTreeMap<SymbolId, &Symbol>,
1420 state_slots: &BTreeMap<SymbolId, StateLayout>,
1421 destination: &TypedStatePlace,
1422) -> Result<RouteDestination, LoweringError> {
1423 let symbol = lookup_symbol(symbols, destination.state, destination.span)?;
1424 let layout = state_slots
1425 .get(&destination.state)
1426 .copied()
1427 .ok_or_else(|| {
1428 LoweringError::new(
1429 format!("state `{}` has no execution layout", symbol.name),
1430 destination.span,
1431 )
1432 })?;
1433 let element = match &destination.index {
1434 None => 0,
1435 Some(index) => constant_index(index, destination.span)?,
1436 };
1437 if element >= layout.len {
1438 return Err(LoweringError::new(
1439 format!(
1440 "route destination for `{}` indexes element {}, but state length is {}",
1441 symbol.name, element, layout.len
1442 ),
1443 destination.span,
1444 ));
1445 }
1446 Ok(RouteDestination {
1447 state: destination.state,
1448 state_name: symbol.name.clone(),
1449 state_offset: layout.offset + element,
1450 span: destination.span,
1451 })
1452}
1453
1454fn constant_index(expr: &TypedExpr, span: Span) -> Result<usize, LoweringError> {
1455 let value = expr
1456 .constant
1457 .as_ref()
1458 .and_then(ConstValue::as_i64)
1459 .ok_or_else(|| LoweringError::new("expected a compile-time integer index", span))?;
1460 if value < 0 {
1461 return Err(LoweringError::new(
1462 "expected a non-negative compile-time index",
1463 span,
1464 ));
1465 }
1466 Ok(value as usize)
1467}
1468
1469fn literal_real(value: f64, span: Span) -> ExecutionExpr {
1470 ExecutionExpr {
1471 kind: ExecutionExprKind::Literal(ConstValue::Real(value)),
1472 ty: ValueType::Real,
1473 constant: Some(ConstValue::Real(value)),
1474 span,
1475 }
1476}
1477
1478fn literal_int(value: i64, span: Span) -> ExecutionExpr {
1479 ExecutionExpr {
1480 kind: ExecutionExprKind::Literal(ConstValue::Int(value)),
1481 ty: ValueType::Int,
1482 constant: Some(ConstValue::Int(value)),
1483 span,
1484 }
1485}
1486
1487#[cfg(test)]
1488mod tests {
1489 use super::*;
1490 use crate::test_fixtures::STRUCTURED_BLOCK_CORPUS;
1491 use crate::{analyze_module, parse_module};
1492
1493 #[test]
1494 fn lowers_structured_block_corpus_into_execution_models() {
1495 let execution = structured_block_execution();
1496 assert_eq!(execution.models.len(), 4);
1497
1498 let ode = find_model(&execution, "one_cmt_oral_iv");
1499 assert_eq!(ode.abi.parameter_buffer.len, 5);
1500 assert_eq!(ode.abi.covariate_buffer.len, 1);
1501 assert_eq!(ode.abi.state_buffer.len, 2);
1502 assert_eq!(ode.abi.derived_buffer.len, 3);
1503 assert_eq!(ode.abi.output_buffer.len, 1);
1504 assert_eq!(ode.abi.route_buffer.len, 2);
1505 assert_eq!(ode.metadata.routes[0].destination.state_offset, 0);
1506 assert_eq!(
1507 kernel_roles(ode),
1508 vec![
1509 KernelRole::Derive,
1510 KernelRole::Dynamics,
1511 KernelRole::RouteLag,
1512 KernelRole::RouteBioavailability,
1513 KernelRole::Outputs,
1514 ]
1515 );
1516 }
1517
1518 #[test]
1519 fn authoring_routes_share_input_indices_by_kind_local_ordinal() {
1520 let src = r#"name = shared_authoring
1521kind = ode
1522
1523params = ka, ke, v, tlag, f_oral
1524states = depot, central
1525outputs = cp
1526
1527bolus(oral) -> depot
1528infusion(iv) -> central
1529lag(oral) = tlag
1530fa(oral) = f_oral
1531
1532dx(depot) = -ka * depot
1533dx(central) = ka * depot - ke * central
1534
1535out(cp) = central / v ~ continuous()
1536"#;
1537
1538 let model = crate::parse_model(src).expect("authoring model parses");
1539 let typed = crate::analyze_model(&model).expect("authoring model analyzes");
1540 let lowered = crate::lower_typed_model(&typed).expect("authoring model lowers");
1541
1542 assert_eq!(lowered.abi.route_buffer.len, 1);
1543 assert_eq!(lowered.metadata.routes.len(), 2);
1544 assert_eq!(lowered.metadata.routes[0].kind, Some(RouteKind::Bolus));
1545 assert_eq!(lowered.metadata.routes[1].kind, Some(RouteKind::Infusion));
1546 assert_eq!(lowered.metadata.routes[0].declaration_index, 0);
1547 assert_eq!(lowered.metadata.routes[1].declaration_index, 1);
1548 assert_eq!(lowered.metadata.routes[0].index, 0);
1549 assert_eq!(lowered.metadata.routes[1].index, 0);
1550 assert!(lowered.metadata.routes[0].has_lag);
1551 assert!(lowered.metadata.routes[0].has_bioavailability);
1552 assert!(!lowered.metadata.routes[1].has_lag);
1553 assert!(!lowered.metadata.routes[1].has_bioavailability);
1554 }
1555
1556 #[test]
1557 fn canonical_numeric_channel_names_flow_into_execution_metadata_and_abi() {
1558 let src = r#"name = canonical_numeric_channels
1559kind = ode
1560
1561params = ke, v
1562states = depot, central
1563outputs = cp, outeq_2
1564
1565bolus(input_10) -> depot
1566infusion(iv) -> central
1567
1568dx(depot) = -ke * depot
1569dx(central) = rate(input_10) - ke * central
1570
1571out(cp) = central / v
1572out(outeq_2) = depot / v
1573"#;
1574
1575 let model = crate::parse_model(src).expect("authoring model parses");
1576 let typed = crate::analyze_model(&model).expect("authoring model analyzes");
1577 let lowered = crate::lower_typed_model(&typed).expect("authoring model lowers");
1578
1579 assert_eq!(
1580 lowered
1581 .metadata
1582 .routes
1583 .iter()
1584 .map(|route| route.name.as_str())
1585 .collect::<Vec<_>>(),
1586 vec!["input_10", "iv"]
1587 );
1588 assert_eq!(
1589 lowered
1590 .metadata
1591 .outputs
1592 .iter()
1593 .map(|output| output.name.as_str())
1594 .collect::<Vec<_>>(),
1595 vec!["cp", "outeq_2"]
1596 );
1597 assert_eq!(
1598 lowered
1599 .abi
1600 .route_buffer
1601 .slots
1602 .iter()
1603 .map(|slot| slot.name.as_str())
1604 .collect::<Vec<_>>(),
1605 vec!["input_10", "iv"]
1606 );
1607 assert_eq!(
1608 lowered
1609 .abi
1610 .output_buffer
1611 .slots
1612 .iter()
1613 .map(|slot| slot.name.as_str())
1614 .collect::<Vec<_>>(),
1615 vec!["cp", "outeq_2"]
1616 );
1617 }
1618
1619 #[test]
1620 fn authoring_routes_reject_infusion_lag_properties() {
1621 let src = r#"name = invalid_infusion_lag
1622kind = ode
1623
1624params = ke, v, tlag
1625states = central
1626outputs = cp
1627
1628infusion(iv) -> central
1629lag(iv) = tlag
1630
1631dx(central) = -ke * central
1632
1633out(cp) = central / v ~ continuous()
1634"#;
1635
1636 let model = crate::parse_model(src).expect("authoring model parses");
1637 let typed = crate::analyze_model(&model).expect("authoring model analyzes");
1638 let error = crate::lower_typed_model(&typed)
1639 .err()
1640 .expect("infusion lag should fail during lowering");
1641
1642 assert!(error
1643 .to_string()
1644 .contains("DSL authoring does not allow `lag` on infusion route `iv`"));
1645 }
1646
1647 #[test]
1648 fn authoring_routes_reject_infusion_bioavailability_properties() {
1649 let src = r#"name = invalid_infusion_fa
1650kind = ode
1651
1652params = ke, v, f_iv
1653states = central
1654outputs = cp
1655
1656infusion(iv) -> central
1657fa(iv) = f_iv
1658
1659dx(central) = -ke * central
1660
1661out(cp) = central / v ~ continuous()
1662"#;
1663
1664 let model = crate::parse_model(src).expect("authoring model parses");
1665 let typed = crate::analyze_model(&model).expect("authoring model analyzes");
1666 let error = crate::lower_typed_model(&typed)
1667 .err()
1668 .expect("infusion bioavailability should fail during lowering");
1669
1670 assert!(error
1671 .to_string()
1672 .contains("DSL authoring does not allow `bioavailability` on infusion route `iv`"));
1673 }
1674
1675 #[test]
1676 fn flattens_array_states_and_preserves_loop_structure() {
1677 let execution = structured_block_execution();
1678 let transit = find_model(&execution, "transit_absorption");
1679 assert_eq!(transit.abi.state_buffer.len, 5);
1680 assert_eq!(transit.metadata.states[0].name, "transit");
1681 assert_eq!(transit.metadata.states[0].offset, 0);
1682 assert_eq!(transit.metadata.states[0].len, 4);
1683 assert_eq!(transit.metadata.states[1].name, "central");
1684 assert_eq!(transit.metadata.states[1].offset, 4);
1685 assert!(transit.kernel(KernelRole::RouteLag).is_none());
1686 assert!(transit.kernel(KernelRole::RouteBioavailability).is_none());
1687
1688 let dynamics = transit
1689 .kernel(KernelRole::Dynamics)
1690 .expect("dynamics kernel");
1691 let KernelImplementation::Statements(program) = &dynamics.implementation else {
1692 panic!("expected statement-based dynamics kernel");
1693 };
1694 assert!(program
1695 .body
1696 .statements
1697 .iter()
1698 .any(|stmt| matches!(stmt.kind, ExecutionStmtKind::For(_))));
1699 }
1700
1701 #[test]
1702 fn analytical_models_lower_to_builtin_execution_kernels() {
1703 let execution = structured_block_execution();
1704 let analytical = find_model(&execution, "one_cmt_abs");
1705 let kernel = analytical
1706 .kernel(KernelRole::Analytical)
1707 .expect("analytical kernel");
1708 assert_eq!(
1709 kernel.signature.args,
1710 vec![
1711 arg(KernelArgumentKind::Time, KernelAccess::Input),
1712 arg(KernelArgumentKind::States, KernelAccess::Input),
1713 arg(KernelArgumentKind::Parameters, KernelAccess::Input),
1714 arg(KernelArgumentKind::Covariates, KernelAccess::Input),
1715 arg(KernelArgumentKind::RouteInputs, KernelAccess::Input),
1716 arg(KernelArgumentKind::Derived, KernelAccess::Input),
1717 arg(KernelArgumentKind::AnalyticalState, KernelAccess::Output),
1718 ]
1719 );
1720 assert!(matches!(
1721 kernel.implementation,
1722 KernelImplementation::AnalyticalBuiltin(AnalyticalKernel::OneCompartmentWithAbsorption)
1723 ));
1724 }
1725
1726 #[test]
1727 fn sde_models_emit_runtime_kernels_and_zero_filled_init() {
1728 let execution = structured_block_execution();
1729 let sde = find_model(&execution, "vanco_sde");
1730 assert_eq!(sde.metadata.particles, Some(1000));
1731 assert_eq!(
1732 kernel_roles(sde),
1733 vec![
1734 KernelRole::Init,
1735 KernelRole::Drift,
1736 KernelRole::Diffusion,
1737 KernelRole::Outputs,
1738 ]
1739 );
1740
1741 let init = sde.kernel(KernelRole::Init).expect("init kernel");
1742 let KernelImplementation::Statements(program) = &init.implementation else {
1743 panic!("expected statement init kernel");
1744 };
1745 assert!(program.body.statements.len() > sde.metadata.states.len());
1746 assert!(matches!(
1747 program.body.statements[0].kind,
1748 ExecutionStmtKind::Assign(ExecutionAssignStmt {
1749 target: ExecutionTarget {
1750 kind: ExecutionTargetKind::StateInit(_),
1751 ..
1752 },
1753 ..
1754 })
1755 ));
1756 }
1757
1758 #[test]
1759 fn route_property_kernels_fill_defaults_for_unconfigured_routes() {
1760 let execution = structured_block_execution();
1761 let ode = find_model(&execution, "one_cmt_oral_iv");
1762 let lag = ode.kernel(KernelRole::RouteLag).expect("lag kernel");
1763 let bio = ode
1764 .kernel(KernelRole::RouteBioavailability)
1765 .expect("bioavailability kernel");
1766
1767 let KernelImplementation::Statements(lag_program) = &lag.implementation else {
1768 panic!("expected statement lag kernel");
1769 };
1770 let KernelImplementation::Statements(bio_program) = &bio.implementation else {
1771 panic!("expected statement bioavailability kernel");
1772 };
1773
1774 assert_eq!(lag_program.body.statements.len(), 3);
1775 assert_eq!(bio_program.body.statements.len(), 3);
1776 assert!(matches!(
1777 lag_program.body.statements[1].kind,
1778 ExecutionStmtKind::Assign(ExecutionAssignStmt {
1779 value: ExecutionExpr {
1780 kind: ExecutionExprKind::Literal(ConstValue::Real(value)),
1781 ..
1782 },
1783 ..
1784 }) if value == 0.0
1785 ));
1786 assert!(matches!(
1787 bio_program.body.statements[1].kind,
1788 ExecutionStmtKind::Assign(ExecutionAssignStmt {
1789 value: ExecutionExpr {
1790 kind: ExecutionExprKind::Literal(ConstValue::Real(value)),
1791 ..
1792 },
1793 ..
1794 }) if value == 1.0
1795 ));
1796 }
1797
1798 fn structured_block_execution() -> ExecutionModule {
1799 let src = STRUCTURED_BLOCK_CORPUS;
1800 let module = parse_module(src).expect("structured-block fixture parses");
1801 let typed = analyze_module(&module).expect("structured-block fixture analyzes");
1802 lower_typed_module(&typed).expect("execution lowering succeeds")
1803 }
1804
1805 fn find_model<'a>(module: &'a ExecutionModule, name: &str) -> &'a ExecutionModel {
1806 module
1807 .models
1808 .iter()
1809 .find(|model| model.name == name)
1810 .unwrap_or_else(|| panic!("missing model {name}"))
1811 }
1812
1813 fn kernel_roles(model: &ExecutionModel) -> Vec<KernelRole> {
1814 model.kernels.iter().map(|kernel| kernel.role).collect()
1815 }
1816}