1use pharmsol_dsl::{
7 AnalyticalKernel as ResolverAnalyticalKernel, AnalyticalStructureInputKind,
8 AnalyticalStructureInputPlan, AnalyticalStructureInputSource,
9};
10use proc_macro::TokenStream;
11use proc_macro2::{Span, TokenStream as TokenStream2};
12use quote::{quote, ToTokens};
13use std::collections::{HashMap, HashSet};
14use syn::{
15 parse::{Parse, ParseStream, Parser},
16 punctuated::Punctuated,
17 token,
18 visit::Visit,
19 visit_mut::VisitMut,
20 Expr, ExprClosure, Ident, Lit, LitInt, LitStr, Pat, Stmt, Token,
21};
22
23struct OdeInput {
28 name: LitStr,
29 params: Vec<Ident>,
30 covariates: Vec<Ident>,
31 states: Vec<Ident>,
32 outputs: Vec<SymbolicIndex>,
33 routes: Vec<OdeRouteDecl>,
34 diffeq: ExprClosure,
35 lag: Option<ExprClosure>,
36 fa: Option<ExprClosure>,
37 init: Option<ExprClosure>,
38 out: ExprClosure,
39}
40
41struct AnalyticalInput {
42 name: LitStr,
43 params: Vec<Ident>,
44 derived: Vec<Ident>,
45 covariates: Vec<Ident>,
46 states: Vec<Ident>,
47 outputs: Vec<SymbolicIndex>,
48 routes: Vec<OdeRouteDecl>,
49 structure: Ident,
50 derive: Option<ExprClosure>,
51 lag: Option<ExprClosure>,
52 fa: Option<ExprClosure>,
53 init: Option<ExprClosure>,
54 out: ExprClosure,
55}
56
57struct SdeInput {
58 name: LitStr,
59 params: Vec<Ident>,
60 covariates: Vec<Ident>,
61 states: Vec<Ident>,
62 outputs: Vec<SymbolicIndex>,
63 routes: Vec<OdeRouteDecl>,
64 particles: Expr,
65 drift: ExprClosure,
66 diffusion: ExprClosure,
67 lag: Option<ExprClosure>,
68 fa: Option<ExprClosure>,
69 init: Option<ExprClosure>,
70 out: ExprClosure,
71}
72
73struct OdeRouteDecl {
74 kind: OdeRouteKind,
75 input: SymbolicIndex,
76 destination: Ident,
77}
78
79#[derive(Clone, Copy)]
80enum OdeRouteKind {
81 Bolus,
82 Infusion,
83}
84
85struct AnalyticalKernelSpec {
86 kernel: ResolverAnalyticalKernel,
87 runtime_path: TokenStream2,
88 metadata_kernel: TokenStream2,
89 state_count: usize,
90}
91
92struct RoutePropertyEntry {
93 route: SymbolicIndex,
94 value: Expr,
95}
96
97#[derive(Clone)]
98enum SymbolicIndex {
99 Ident(Ident),
100 Int(LitInt),
101}
102
103impl SymbolicIndex {
104 fn name(&self) -> String {
105 match self {
106 Self::Ident(ident) => ident.to_string(),
107 Self::Int(lit) => lit.base10_digits().to_string(),
108 }
109 }
110
111 fn ident(&self) -> Option<&Ident> {
112 match self {
113 Self::Ident(ident) => Some(ident),
114 Self::Int(_) => None,
115 }
116 }
117
118 fn numeric_value(&self) -> Option<usize> {
119 match self {
120 Self::Ident(_) => None,
121 Self::Int(lit) => Some(
122 lit.base10_parse::<usize>()
123 .expect("validated numeric label should fit usize"),
124 ),
125 }
126 }
127
128 fn numeric(value: usize) -> Self {
129 Self::Int(LitInt::new(&value.to_string(), Span::call_site()))
130 }
131}
132
133impl Parse for SymbolicIndex {
134 fn parse(input: ParseStream) -> syn::Result<Self> {
135 if input.peek(LitInt) {
136 let lit: LitInt = input.parse()?;
137 lit.base10_parse::<usize>().map_err(|_| {
138 syn::Error::new_spanned(
139 &lit,
140 "numeric declaration-first labels must be non-negative base-10 integers that fit in usize",
141 )
142 })?;
143 Ok(Self::Int(lit))
144 } else {
145 Ok(Self::Ident(input.parse()?))
146 }
147 }
148}
149
150impl ToTokens for SymbolicIndex {
151 fn to_tokens(&self, tokens: &mut TokenStream2) {
152 match self {
153 Self::Ident(ident) => ident.to_tokens(tokens),
154 Self::Int(lit) => lit.to_tokens(tokens),
155 }
156 }
157}
158
159impl std::fmt::Display for SymbolicIndex {
160 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
161 f.write_str(&self.name())
162 }
163}
164
165impl Parse for OdeRouteDecl {
166 fn parse(input: ParseStream) -> syn::Result<Self> {
167 let kind_ident: Ident = input.parse()?;
168 let kind = match kind_ident.to_string().as_str() {
169 "bolus" => OdeRouteKind::Bolus,
170 "infusion" => OdeRouteKind::Infusion,
171 other => {
172 return Err(syn::Error::new_spanned(
173 &kind_ident,
174 format!("unknown route kind `{other}`, expected `bolus` or `infusion`"),
175 ));
176 }
177 };
178
179 let content;
180 syn::parenthesized!(content in input);
181 let route_input: SymbolicIndex = content.parse()?;
182 if !content.is_empty() {
183 return Err(content.error("expected a single route input name inside `(...)`"));
184 }
185
186 if !input.peek(Token![->]) {
187 return Err(
188 input.error("expected `->` followed by a destination state in route declaration")
189 );
190 }
191 input.parse::<Token![->]>()?;
192 let destination: Ident = input.parse()?;
193
194 if input.peek(token::Brace) {
195 return Err(
196 input.error("route properties are not supported in declaration-first `ode!` yet")
197 );
198 }
199
200 Ok(Self {
201 kind,
202 input: route_input,
203 destination,
204 })
205 }
206}
207
208impl Parse for OdeInput {
209 fn parse(input: ParseStream) -> syn::Result<Self> {
210 let mut name = None;
211 let mut params = None;
212 let mut covariates = None;
213 let mut states = None;
214 let mut outputs = None;
215 let mut routes = None;
216 let mut diffeq = None;
217 let mut lag = None;
218 let mut fa = None;
219 let mut init = None;
220 let mut out = None;
221
222 while !input.is_empty() {
223 let key: Ident = input.parse()?;
224 input.parse::<Token![:]>()?;
225
226 match key.to_string().as_str() {
227 "name" => set_once_ode(&mut name, input.parse()?, &key, "name")?,
228 "params" => set_once_ode(&mut params, parse_ident_list(input)?, &key, "params")?,
229 "covariates" => set_once_ode(
230 &mut covariates,
231 parse_ident_list(input)?,
232 &key,
233 "covariates",
234 )?,
235 "states" => set_once_ode(&mut states, parse_ident_list(input)?, &key, "states")?,
236 "outputs" => set_once_ode(
237 &mut outputs,
238 parse_symbolic_index_list(input)?,
239 &key,
240 "outputs",
241 )?,
242 "routes" => set_once_ode(&mut routes, parse_route_list(input)?, &key, "routes")?,
243 "diffeq" => set_once_ode(&mut diffeq, input.parse()?, &key, "diffeq")?,
244 "lag" => set_once_ode(&mut lag, input.parse()?, &key, "lag")?,
245 "fa" => set_once_ode(&mut fa, input.parse()?, &key, "fa")?,
246 "init" => set_once_ode(&mut init, input.parse()?, &key, "init")?,
247 "out" => set_once_ode(&mut out, input.parse()?, &key, "out")?,
248 other => {
249 return Err(syn::Error::new_spanned(
250 &key,
251 format!(
252 "unknown field `{other}`, expected one of: name, params, covariates, states, outputs, routes, diffeq, lag, fa, init, out"
253 ),
254 ));
255 }
256 }
257
258 if !input.is_empty() {
259 input.parse::<Token![,]>()?;
260 }
261 }
262
263 let name = name.ok_or_else(|| {
264 syn::Error::new(
265 Span::call_site(),
266 "declaration-first `ode!` requires `name`, `params`, `states`, `outputs`, and `routes`; the old inferred-dimensions form has been removed",
267 )
268 })?;
269 let params = params.ok_or_else(|| missing_required_ode_field("params"))?;
270 let covariates = covariates.unwrap_or_default();
271 let states = states.ok_or_else(|| missing_required_ode_field("states"))?;
272 let outputs = outputs.ok_or_else(|| missing_required_ode_field("outputs"))?;
273 let routes = routes.ok_or_else(|| missing_required_ode_field("routes"))?;
274 let diffeq = diffeq.ok_or_else(|| missing_required_ode_field("diffeq"))?;
275 let out = out.ok_or_else(|| missing_required_ode_field("out"))?;
276 validate_ode_diffeq_uses_automatic_injection(&diffeq, &routes)?;
277
278 validate_unique_idents("parameter", ¶ms, "ode!")?;
279 validate_unique_idents("covariate", &covariates, "ode!")?;
280 validate_unique_idents("state", &states, "ode!")?;
281 let output_idents = symbolic_index_idents(&outputs);
282
283 validate_unique_symbolic_indices("output", &outputs, "ode!")?;
284 validate_routes(&routes, &states, "ode!")?;
285 validate_named_binding_compatibility(
286 NamedBindingSets {
287 params: ¶ms,
288 derived: &[],
289 covariates: &covariates,
290 states: &states,
291 outputs: &output_idents,
292 routes: &routes,
293 },
294 OdeBindingClosures {
295 diffeq: &diffeq,
296 common: CommonBindingClosures {
297 lag: lag.as_ref(),
298 fa: fa.as_ref(),
299 init: init.as_ref(),
300 out: &out,
301 },
302 },
303 )?;
304
305 Ok(Self {
306 name,
307 params,
308 covariates,
309 states,
310 outputs,
311 routes,
312 diffeq,
313 lag,
314 fa,
315 init,
316 out,
317 })
318 }
319}
320
321impl Parse for RoutePropertyEntry {
322 fn parse(input: ParseStream) -> syn::Result<Self> {
323 let route: SymbolicIndex = input.parse()?;
324 input.parse::<Token![=>]>()?;
325 let value: Expr = input.parse()?;
326 Ok(Self { route, value })
327 }
328}
329
330impl Parse for AnalyticalInput {
331 fn parse(input: ParseStream) -> syn::Result<Self> {
332 let mut name = None;
333 let mut params = None;
334 let mut derived = None;
335 let mut covariates = None;
336 let mut states = None;
337 let mut outputs = None;
338 let mut routes = None;
339 let mut structure = None;
340 let mut derive = None;
341 let mut lag = None;
342 let mut fa = None;
343 let mut init = None;
344 let mut out = None;
345
346 while !input.is_empty() {
347 let key: Ident = input.parse()?;
348 input.parse::<Token![:]>()?;
349
350 match key.to_string().as_str() {
351 "name" => set_once_analytical(&mut name, input.parse()?, &key, "name")?,
352 "params" => {
353 set_once_analytical(&mut params, parse_ident_list(input)?, &key, "params")?
354 }
355 "derived" => {
356 set_once_analytical(&mut derived, parse_ident_list(input)?, &key, "derived")?
357 }
358 "covariates" => set_once_analytical(
359 &mut covariates,
360 parse_ident_list(input)?,
361 &key,
362 "covariates",
363 )?,
364 "states" => {
365 set_once_analytical(&mut states, parse_ident_list(input)?, &key, "states")?
366 }
367 "outputs" => set_once_analytical(
368 &mut outputs,
369 parse_symbolic_index_list(input)?,
370 &key,
371 "outputs",
372 )?,
373 "routes" => {
374 set_once_analytical(&mut routes, parse_route_list(input)?, &key, "routes")?
375 }
376 "structure" => {
377 set_once_analytical(&mut structure, input.parse()?, &key, "structure")?
378 }
379 "derive" => set_once_analytical(&mut derive, input.parse()?, &key, "derive")?,
380 "sec" => {
381 return Err(syn::Error::new_spanned(
382 &key,
383 "built-in `analytical!` no longer supports `sec`; use `derived: [...]` plus `derive: ...`",
384 ));
385 }
386 "lag" => set_once_analytical(&mut lag, input.parse()?, &key, "lag")?,
387 "fa" => set_once_analytical(&mut fa, input.parse()?, &key, "fa")?,
388 "init" => set_once_analytical(&mut init, input.parse()?, &key, "init")?,
389 "out" => set_once_analytical(&mut out, input.parse()?, &key, "out")?,
390 other => {
391 return Err(syn::Error::new_spanned(
392 &key,
393 format!(
394 "unknown field `{other}`, expected one of: name, params, derived, covariates, states, outputs, routes, structure, derive, lag, fa, init, out"
395 ),
396 ));
397 }
398 }
399
400 if !input.is_empty() {
401 input.parse::<Token![,]>()?;
402 }
403 }
404
405 let name = name.ok_or_else(|| missing_required_analytical_field("name"))?;
406 let params = params.ok_or_else(|| missing_required_analytical_field("params"))?;
407 let derived = derived.unwrap_or_default();
408 let covariates = covariates.unwrap_or_default();
409 let states = states.ok_or_else(|| missing_required_analytical_field("states"))?;
410 let outputs = outputs.ok_or_else(|| missing_required_analytical_field("outputs"))?;
411 let routes = routes.ok_or_else(|| missing_required_analytical_field("routes"))?;
412 let structure = structure.ok_or_else(|| missing_required_analytical_field("structure"))?;
413 let out = out.ok_or_else(|| missing_required_analytical_field("out"))?;
414
415 validate_unique_idents("covariate", &covariates, "analytical!")?;
416 validate_unique_idents("state", &states, "analytical!")?;
417 let output_idents = symbolic_index_idents(&outputs);
418
419 validate_unique_symbolic_indices("output", &outputs, "analytical!")?;
420 validate_routes(&routes, &states, "analytical!")?;
421
422 let kernel_spec = resolve_analytical_structure(&structure)?;
423 validate_analytical_structure_inputs(&structure, kernel_spec.kernel, ¶ms, &derived)?;
424 if states.len() != kernel_spec.state_count {
425 return Err(syn::Error::new_spanned(
426 &structure,
427 format!(
428 "analytical structure `{}` expects {} state value(s), but `states` declares {}",
429 structure,
430 kernel_spec.state_count,
431 states.len()
432 ),
433 ));
434 }
435
436 validate_analytical_named_binding_compatibility(
437 NamedBindingSets {
438 params: ¶ms,
439 derived: &derived,
440 covariates: &covariates,
441 states: &states,
442 outputs: &output_idents,
443 routes: &routes,
444 },
445 AnalyticalBindingClosures {
446 derive: derive.as_ref(),
447 common: CommonBindingClosures {
448 lag: lag.as_ref(),
449 fa: fa.as_ref(),
450 init: init.as_ref(),
451 out: &out,
452 },
453 },
454 )?;
455
456 validate_analytical_derive_contract(
457 kernel_spec.kernel,
458 ¶ms,
459 &derived,
460 &covariates,
461 derive.as_ref(),
462 )?;
463
464 if let Some(lag) = lag.as_ref() {
465 let lag_routes =
466 extract_route_property_routes("built-in `analytical!`", "lag", lag, &routes)?;
467 validate_route_property_kinds("built-in `analytical!`", "lag", &routes, &lag_routes)?;
468 }
469
470 if let Some(fa) = fa.as_ref() {
471 let fa_routes =
472 extract_route_property_routes("built-in `analytical!`", "fa", fa, &routes)?;
473 validate_route_property_kinds("built-in `analytical!`", "fa", &routes, &fa_routes)?;
474 }
475
476 Ok(Self {
477 name,
478 params,
479 derived,
480 covariates,
481 states,
482 outputs,
483 routes,
484 structure,
485 derive,
486 lag,
487 fa,
488 init,
489 out,
490 })
491 }
492}
493
494impl Parse for SdeInput {
495 fn parse(input: ParseStream) -> syn::Result<Self> {
496 let mut name = None;
497 let mut params = None;
498 let mut covariates = None;
499 let mut states = None;
500 let mut outputs = None;
501 let mut routes = None;
502 let mut particles = None;
503 let mut drift = None;
504 let mut diffusion = None;
505 let mut lag = None;
506 let mut fa = None;
507 let mut init = None;
508 let mut out = None;
509
510 while !input.is_empty() {
511 let key: Ident = input.parse()?;
512 input.parse::<Token![:]>()?;
513
514 match key.to_string().as_str() {
515 "name" => set_once_sde(&mut name, input.parse()?, &key, "name")?,
516 "params" => set_once_sde(&mut params, parse_ident_list(input)?, &key, "params")?,
517 "covariates" => set_once_sde(
518 &mut covariates,
519 parse_ident_list(input)?,
520 &key,
521 "covariates",
522 )?,
523 "states" => set_once_sde(&mut states, parse_ident_list(input)?, &key, "states")?,
524 "outputs" => set_once_sde(
525 &mut outputs,
526 parse_symbolic_index_list(input)?,
527 &key,
528 "outputs",
529 )?,
530 "routes" => set_once_sde(&mut routes, parse_route_list(input)?, &key, "routes")?,
531 "particles" => set_once_sde(&mut particles, input.parse()?, &key, "particles")?,
532 "drift" => set_once_sde(&mut drift, input.parse()?, &key, "drift")?,
533 "diffusion" => set_once_sde(&mut diffusion, input.parse()?, &key, "diffusion")?,
534 "lag" => set_once_sde(&mut lag, input.parse()?, &key, "lag")?,
535 "fa" => set_once_sde(&mut fa, input.parse()?, &key, "fa")?,
536 "init" => set_once_sde(&mut init, input.parse()?, &key, "init")?,
537 "out" => set_once_sde(&mut out, input.parse()?, &key, "out")?,
538 other => {
539 return Err(syn::Error::new_spanned(
540 &key,
541 format!(
542 "unknown field `{other}`, expected one of: name, params, covariates, states, outputs, routes, particles, drift, diffusion, lag, fa, init, out"
543 ),
544 ));
545 }
546 }
547
548 if !input.is_empty() {
549 input.parse::<Token![,]>()?;
550 }
551 }
552
553 let name = name.ok_or_else(|| missing_required_sde_field("name"))?;
554 let params = params.ok_or_else(|| missing_required_sde_field("params"))?;
555 let covariates = covariates.unwrap_or_default();
556 let states = states.ok_or_else(|| missing_required_sde_field("states"))?;
557 let outputs = outputs.ok_or_else(|| missing_required_sde_field("outputs"))?;
558 let routes = routes.ok_or_else(|| missing_required_sde_field("routes"))?;
559 let particles = particles.ok_or_else(|| missing_required_sde_field("particles"))?;
560 let drift = drift.ok_or_else(|| missing_required_sde_field("drift"))?;
561 let diffusion = diffusion.ok_or_else(|| missing_required_sde_field("diffusion"))?;
562 let out = out.ok_or_else(|| missing_required_sde_field("out"))?;
563
564 validate_unique_idents("parameter", ¶ms, "sde!")?;
565 validate_unique_idents("covariate", &covariates, "sde!")?;
566 validate_unique_idents("state", &states, "sde!")?;
567 let output_idents = symbolic_index_idents(&outputs);
568
569 validate_unique_symbolic_indices("output", &outputs, "sde!")?;
570 validate_routes(&routes, &states, "sde!")?;
571 validate_sde_named_binding_compatibility(
572 NamedBindingSets {
573 params: ¶ms,
574 derived: &[],
575 covariates: &covariates,
576 states: &states,
577 outputs: &output_idents,
578 routes: &routes,
579 },
580 SdeBindingClosures {
581 drift: &drift,
582 diffusion: &diffusion,
583 common: CommonBindingClosures {
584 lag: lag.as_ref(),
585 fa: fa.as_ref(),
586 init: init.as_ref(),
587 out: &out,
588 },
589 },
590 )?;
591
592 if let Some(lag) = lag.as_ref() {
593 let lag_routes =
594 extract_route_property_routes("declaration-first `sde!`", "lag", lag, &routes)?;
595 validate_route_property_kinds("declaration-first `sde!`", "lag", &routes, &lag_routes)?;
596 }
597
598 if let Some(fa) = fa.as_ref() {
599 let fa_routes =
600 extract_route_property_routes("declaration-first `sde!`", "fa", fa, &routes)?;
601 validate_route_property_kinds("declaration-first `sde!`", "fa", &routes, &fa_routes)?;
602 }
603
604 Ok(Self {
605 name,
606 params,
607 covariates,
608 states,
609 outputs,
610 routes,
611 particles,
612 drift,
613 diffusion,
614 lag,
615 fa,
616 init,
617 out,
618 })
619 }
620}
621
622fn missing_required_ode_field(name: &str) -> syn::Error {
627 syn::Error::new(
628 Span::call_site(),
629 format!("missing required field `{name}` in declaration-first `ode!`"),
630 )
631}
632
633fn missing_required_analytical_field(name: &str) -> syn::Error {
634 syn::Error::new(
635 Span::call_site(),
636 format!("missing required field `{name}` in built-in `analytical!`"),
637 )
638}
639
640fn missing_required_sde_field(name: &str) -> syn::Error {
641 syn::Error::new(
642 Span::call_site(),
643 format!("missing required field `{name}` in declaration-first `sde!`"),
644 )
645}
646
647fn set_once_ode<T>(slot: &mut Option<T>, value: T, key: &Ident, name: &str) -> syn::Result<()> {
648 if slot.is_some() {
649 Err(syn::Error::new_spanned(
650 key,
651 format!("duplicate field `{name}` in `ode!`"),
652 ))
653 } else {
654 *slot = Some(value);
655 Ok(())
656 }
657}
658
659fn set_once_analytical<T>(
660 slot: &mut Option<T>,
661 value: T,
662 key: &Ident,
663 name: &str,
664) -> syn::Result<()> {
665 if slot.is_some() {
666 Err(syn::Error::new_spanned(
667 key,
668 format!("duplicate field `{name}` in `analytical!`"),
669 ))
670 } else {
671 *slot = Some(value);
672 Ok(())
673 }
674}
675
676fn set_once_sde<T>(slot: &mut Option<T>, value: T, key: &Ident, name: &str) -> syn::Result<()> {
677 if slot.is_some() {
678 Err(syn::Error::new_spanned(
679 key,
680 format!("duplicate field `{name}` in `sde!`"),
681 ))
682 } else {
683 *slot = Some(value);
684 Ok(())
685 }
686}
687
688fn parse_ident_list(input: ParseStream) -> syn::Result<Vec<Ident>> {
689 let content;
690 syn::bracketed!(content in input);
691 Ok(Punctuated::<Ident, Token![,]>::parse_terminated(&content)?
692 .into_iter()
693 .collect())
694}
695
696fn parse_symbolic_index_list(input: ParseStream) -> syn::Result<Vec<SymbolicIndex>> {
697 let content;
698 syn::bracketed!(content in input);
699 Ok(
700 Punctuated::<SymbolicIndex, Token![,]>::parse_terminated(&content)?
701 .into_iter()
702 .collect(),
703 )
704}
705
706fn parse_route_list(input: ParseStream) -> syn::Result<Vec<OdeRouteDecl>> {
707 if input.peek(token::Brace) {
708 return Err(input.error("declaration-first macro `routes` must use `[...]`, not `{...}`"));
709 }
710
711 if !input.peek(token::Bracket) {
712 return Err(
713 input.error("expected a bracketed route list like `routes: [infusion(iv) -> central]`")
714 );
715 }
716
717 let content;
718 syn::bracketed!(content in input);
719 Ok(
720 Punctuated::<OdeRouteDecl, Token![,]>::parse_terminated(&content)?
721 .into_iter()
722 .collect(),
723 )
724}
725
726fn param_name(pat: &Pat) -> String {
727 match pat {
728 Pat::Ident(p) => p.ident.to_string(),
729 _ => String::new(),
730 }
731}
732
733fn closure_param_names(c: &ExprClosure) -> Vec<String> {
734 c.inputs.iter().map(param_name).collect()
735}
736
737fn closure_param_ident(c: &ExprClosure, index: usize) -> Option<Ident> {
738 c.inputs.get(index).and_then(|pat| match pat {
739 Pat::Ident(pat_ident) => Some(pat_ident.ident.clone()),
740 _ => None,
741 })
742}
743
744fn generated_ident(name: &str) -> Ident {
745 Ident::new(name, Span::call_site())
746}
747
748fn symbolic_index_idents(labels: &[SymbolicIndex]) -> Vec<Ident> {
749 labels
750 .iter()
751 .filter_map(|label| label.ident().cloned())
752 .collect()
753}
754
755fn symbolic_index_bindings(labels: &[SymbolicIndex]) -> Vec<(SymbolicIndex, usize)> {
756 labels
757 .iter()
758 .cloned()
759 .enumerate()
760 .map(|(index, label)| (label, index))
761 .collect()
762}
763
764fn symbolic_numeric_binding_map(bindings: &[(SymbolicIndex, usize)]) -> HashMap<usize, usize> {
765 bindings
766 .iter()
767 .filter_map(|(label, index)| label.numeric_value().map(|value| (value, *index)))
768 .collect()
769}
770
771#[derive(Default)]
772struct ClosureBodyUsage {
773 idents: HashSet<String>,
774 indexed_idents: HashSet<String>,
775 assigned_indexed_idents: HashSet<String>,
776 contains_macro: bool,
777}
778
779impl ClosureBodyUsage {
780 fn analyze(expr: &Expr) -> Self {
781 let mut usage = Self::default();
782 usage.visit_expr(expr);
783 usage
784 }
785
786 fn uses(&self, ident: &Ident) -> bool {
787 self.contains_macro || self.idents.contains(&ident.to_string())
788 }
789
790 fn mentions(&self, ident: &Ident) -> bool {
791 self.idents.contains(&ident.to_string())
792 }
793
794 fn indexes(&self, ident: &Ident) -> bool {
795 self.indexed_idents.contains(&ident.to_string())
796 }
797
798 fn assigns_index(&self, ident: &Ident) -> bool {
799 self.assigned_indexed_idents.contains(&ident.to_string())
800 }
801}
802
803impl<'ast> Visit<'ast> for ClosureBodyUsage {
804 fn visit_expr_path(&mut self, expr_path: &'ast syn::ExprPath) {
805 if expr_path.qself.is_none()
806 && expr_path.path.leading_colon.is_none()
807 && expr_path.path.segments.len() == 1
808 {
809 self.idents
810 .insert(expr_path.path.segments[0].ident.to_string());
811 }
812
813 syn::visit::visit_expr_path(self, expr_path);
814 }
815
816 fn visit_expr_macro(&mut self, expr_macro: &'ast syn::ExprMacro) {
817 self.contains_macro = true;
818 syn::visit::visit_expr_macro(self, expr_macro);
819 }
820
821 fn visit_stmt_macro(&mut self, stmt_macro: &'ast syn::StmtMacro) {
822 self.contains_macro = true;
823 syn::visit::visit_stmt_macro(self, stmt_macro);
824 }
825
826 fn visit_expr_index(&mut self, expr_index: &'ast syn::ExprIndex) {
827 if let Expr::Path(expr_path) = expr_index.expr.as_ref() {
828 if expr_path.qself.is_none()
829 && expr_path.path.leading_colon.is_none()
830 && expr_path.path.segments.len() == 1
831 {
832 self.indexed_idents
833 .insert(expr_path.path.segments[0].ident.to_string());
834 }
835 }
836
837 syn::visit::visit_expr_index(self, expr_index);
838 }
839
840 fn visit_expr_assign(&mut self, expr_assign: &'ast syn::ExprAssign) {
841 if let Expr::Index(expr_index) = expr_assign.left.as_ref() {
842 if let Expr::Path(expr_path) = expr_index.expr.as_ref() {
843 if expr_path.qself.is_none()
844 && expr_path.path.leading_colon.is_none()
845 && expr_path.path.segments.len() == 1
846 {
847 self.assigned_indexed_idents
848 .insert(expr_path.path.segments[0].ident.to_string());
849 }
850 }
851 }
852
853 syn::visit::visit_expr_assign(self, expr_assign);
854 }
855}
856
857struct IndexRewriteTarget {
858 container: Ident,
859 labels: HashMap<usize, usize>,
860}
861
862impl IndexRewriteTarget {
863 fn new(container: Ident, labels: HashMap<usize, usize>) -> Self {
864 Self { container, labels }
865 }
866}
867
868struct NumericLabelRewriter {
869 index_targets: Vec<IndexRewriteTarget>,
870 route_labels: Option<HashMap<usize, usize>>,
871}
872
873impl NumericLabelRewriter {
874 fn rewrite(
875 expr: &Expr,
876 index_targets: Vec<IndexRewriteTarget>,
877 route_labels: Option<HashMap<usize, usize>>,
878 ) -> Expr {
879 let mut rewritten = expr.clone();
880 let mut rewriter = Self {
881 index_targets,
882 route_labels,
883 };
884 rewriter.visit_expr_mut(&mut rewritten);
885 rewritten
886 }
887
888 fn target_labels(&self, path: &syn::ExprPath) -> Option<&HashMap<usize, usize>> {
889 if path.qself.is_some()
890 || path.path.leading_colon.is_some()
891 || path.path.segments.len() != 1
892 {
893 return None;
894 }
895
896 let ident = &path.path.segments[0].ident;
897 self.index_targets
898 .iter()
899 .find(|target| target.container == *ident)
900 .map(|target| &target.labels)
901 }
902
903 fn rewrite_route_macro(&self, mac: &mut syn::Macro) {
904 let Some(route_labels) = self.route_labels.as_ref() else {
905 return;
906 };
907 if !(mac.path.is_ident("lag") || mac.path.is_ident("fa")) {
908 return;
909 }
910
911 let Ok(entries) = Punctuated::<RoutePropertyEntry, Token![,]>::parse_terminated
912 .parse2(mac.tokens.clone())
913 else {
914 return;
915 };
916
917 let entries = entries.into_iter().map(|mut entry| {
918 if let Some(value) = entry.route.numeric_value() {
919 if let Some(internal_index) = route_labels.get(&value) {
920 entry.route = SymbolicIndex::numeric(*internal_index);
921 }
922 }
923 entry
924 });
925
926 let tokens = entries.map(|entry| {
927 let route = entry.route;
928 let value = entry.value;
929 quote! { #route => #value }
930 });
931 mac.tokens = quote! { #(#tokens),* };
932 }
933}
934
935impl VisitMut for NumericLabelRewriter {
936 fn visit_expr_index_mut(&mut self, expr_index: &mut syn::ExprIndex) {
937 syn::visit_mut::visit_expr_index_mut(self, expr_index);
938
939 let Expr::Path(expr_path) = expr_index.expr.as_ref() else {
940 return;
941 };
942 let Some(labels) = self.target_labels(expr_path) else {
943 return;
944 };
945 let Expr::Lit(expr_lit) = expr_index.index.as_ref() else {
946 return;
947 };
948 let Lit::Int(lit) = &expr_lit.lit else {
949 return;
950 };
951 let Ok(external_index) = lit.base10_parse::<usize>() else {
952 return;
953 };
954 let Some(internal_index) = labels.get(&external_index) else {
955 return;
956 };
957
958 *expr_index.index = Expr::Lit(syn::ExprLit {
959 attrs: Vec::new(),
960 lit: Lit::Int(LitInt::new(&internal_index.to_string(), lit.span())),
961 });
962 }
963
964 fn visit_expr_macro_mut(&mut self, expr_macro: &mut syn::ExprMacro) {
965 self.rewrite_route_macro(&mut expr_macro.mac);
966 syn::visit_mut::visit_expr_macro_mut(self, expr_macro);
967 }
968
969 fn visit_stmt_macro_mut(&mut self, stmt_macro: &mut syn::StmtMacro) {
970 self.rewrite_route_macro(&mut stmt_macro.mac);
971 syn::visit_mut::visit_stmt_macro_mut(self, stmt_macro);
972 }
973}
974
975fn generate_closure_input_aliases(
976 closure: &ExprClosure,
977 internal_names: &[Ident],
978) -> syn::Result<TokenStream2> {
979 if closure.inputs.len() != internal_names.len() {
980 return Err(syn::Error::new_spanned(
981 closure,
982 "internal named binding generation error: closure arity mismatch",
983 ));
984 }
985
986 let aliases =
987 closure
988 .inputs
989 .iter()
990 .zip(internal_names.iter())
991 .map(|(pattern, internal_name)| {
992 quote! {
993 let #pattern = #internal_name;
994 }
995 });
996
997 Ok(quote! {
998 #(#aliases)*
999 })
1000}
1001
1002fn generate_supported_input_aliases(
1003 closure: &ExprClosure,
1004 supported_internal_names: &[&[Ident]],
1005 error_message: &str,
1006) -> syn::Result<TokenStream2> {
1007 for internal_names in supported_internal_names {
1008 if closure.inputs.len() == internal_names.len() {
1009 return generate_closure_input_aliases(closure, internal_names);
1010 }
1011 }
1012
1013 Err(syn::Error::new_spanned(closure, error_message))
1014}
1015
1016fn generate_parameter_bindings(
1017 params: &[Ident],
1018 closure: &ExprClosure,
1019 parameter_vector: &Ident,
1020) -> TokenStream2 {
1021 let usage = ClosureBodyUsage::analyze(closure.body.as_ref());
1022 let bindings = params
1023 .iter()
1024 .enumerate()
1025 .filter(|(_, ident)| usage.uses(ident))
1026 .map(|(index, ident)| {
1027 quote! {
1028 #[allow(unused_variables)]
1029 let #ident = #parameter_vector[#index];
1030 }
1031 });
1032
1033 quote! {
1034 #(#bindings)*
1035 }
1036}
1037
1038fn generate_derived_bindings(
1039 derived: &[Ident],
1040 closure: &ExprClosure,
1041 derived_values: &Ident,
1042) -> TokenStream2 {
1043 let usage = ClosureBodyUsage::analyze(closure.body.as_ref());
1044 let bindings = derived
1045 .iter()
1046 .enumerate()
1047 .filter(|(_, ident)| usage.uses(ident))
1048 .map(|(index, ident)| {
1049 quote! {
1050 #[allow(unused_variables)]
1051 let #ident = #derived_values[#index];
1052 }
1053 });
1054
1055 quote! {
1056 #(#bindings)*
1057 }
1058}
1059
1060fn generate_covariate_bindings(
1061 covariates: &[Ident],
1062 closure: &ExprClosure,
1063 covariate_map: &Ident,
1064 time: &Ident,
1065) -> TokenStream2 {
1066 let usage = ClosureBodyUsage::analyze(closure.body.as_ref());
1067 let used_covariates = covariates
1068 .iter()
1069 .filter(|ident| usage.uses(ident))
1070 .collect::<Vec<_>>();
1071
1072 if used_covariates.is_empty() {
1073 quote! {}
1074 } else {
1075 quote! {
1076 ::pharmsol::fetch_cov!(#covariate_map, #time, #(#used_covariates),*);
1077 }
1078 }
1079}
1080
1081fn analytical_error_span<'a>(names: &'a [Ident], target: &str) -> Option<&'a Ident> {
1082 names.iter().find(|ident| *ident == target)
1083}
1084
1085fn validate_analytical_structure_inputs(
1086 structure: &Ident,
1087 kernel: ResolverAnalyticalKernel,
1088 params: &[Ident],
1089 derived: &[Ident],
1090) -> syn::Result<AnalyticalStructureInputPlan> {
1091 let primary_names = params.iter().map(Ident::to_string).collect::<Vec<_>>();
1092 let derived_names = derived.iter().map(Ident::to_string).collect::<Vec<_>>();
1093 AnalyticalStructureInputPlan::for_kernel(kernel, &primary_names, &derived_names).map_err(
1094 |error| match error {
1095 pharmsol_dsl::AnalyticalStructureInputError::DuplicatePrimary { name } => {
1096 let span = analytical_error_span(params, &name).unwrap_or(structure);
1097 syn::Error::new_spanned(span, format!("duplicate primary parameter `{name}`"))
1098 }
1099 pharmsol_dsl::AnalyticalStructureInputError::DuplicateDerived { name } => {
1100 let span = analytical_error_span(derived, &name).unwrap_or(structure);
1101 syn::Error::new_spanned(span, format!("duplicate derived parameter `{name}`"))
1102 }
1103 pharmsol_dsl::AnalyticalStructureInputError::ConflictingName { name } => {
1104 let span = analytical_error_span(derived, &name)
1105 .or_else(|| analytical_error_span(params, &name))
1106 .unwrap_or(structure);
1107 syn::Error::new_spanned(
1108 span,
1109 format!("`{name}` is declared in both `params` and `derived`"),
1110 )
1111 }
1112 pharmsol_dsl::AnalyticalStructureInputError::MissingRequiredName {
1113 structure,
1114 name,
1115 suggestion,
1116 } => {
1117 let message = if let Some(candidate) = suggestion {
1118 format!(
1119 "analytical structure `{structure}` requires `{name}`; did you mean `{candidate}`? declare it in `params: [...]` or `derived: [...]`"
1120 )
1121 } else {
1122 format!(
1123 "analytical structure `{structure}` requires `{name}`; declare it in `params: [...]` or `derived: [...]`"
1124 )
1125 };
1126 syn::Error::new_spanned(structure, message)
1127 }
1128 },
1129 )
1130}
1131
1132#[derive(Clone)]
1133struct DeriveValidationContext {
1134 params: HashSet<String>,
1135 covariates: HashSet<String>,
1136 derived: HashSet<String>,
1137}
1138
1139impl DeriveValidationContext {
1140 fn new(params: &[Ident], covariates: &[Ident], derived: &[Ident]) -> Self {
1141 Self {
1142 params: params.iter().map(Ident::to_string).collect(),
1143 covariates: covariates.iter().map(Ident::to_string).collect(),
1144 derived: derived.iter().map(Ident::to_string).collect(),
1145 }
1146 }
1147
1148 fn invalid_target_error(&self, ident: &Ident) -> syn::Error {
1149 let name = ident.to_string();
1150 let message = if self.params.contains(&name) {
1151 format!(
1152 "`derive` cannot assign to `{name}`; only names declared in `derived: [...]` are valid derive targets"
1153 )
1154 } else if self.covariates.contains(&name) {
1155 format!(
1156 "`derive` cannot assign to covariate `{name}`; only names declared in `derived: [...]` are valid derive targets"
1157 )
1158 } else {
1159 format!(
1160 "`derive` cannot assign to `{name}`; declare it in `derived: [...]` before assigning to it"
1161 )
1162 };
1163 syn::Error::new_spanned(ident, message)
1164 }
1165}
1166
1167fn bound_local_names(pat: &Pat) -> Vec<String> {
1168 struct BoundNames {
1169 names: Vec<String>,
1170 }
1171
1172 impl<'ast> Visit<'ast> for BoundNames {
1173 fn visit_pat_ident(&mut self, pat_ident: &'ast syn::PatIdent) {
1174 self.names.push(pat_ident.ident.to_string());
1175 }
1176 }
1177
1178 let mut bound = BoundNames { names: Vec::new() };
1179 bound.visit_pat(pat);
1180 bound.names
1181}
1182
1183fn analyze_derive_block(
1184 block: &syn::Block,
1185 context: &DeriveValidationContext,
1186 locals: &mut HashSet<String>,
1187 assigned: &HashSet<String>,
1188) -> syn::Result<HashSet<String>> {
1189 let mut assigned_now = assigned.clone();
1190 for stmt in &block.stmts {
1191 assigned_now = analyze_derive_stmt(stmt, context, locals, &assigned_now)?;
1192 }
1193 Ok(assigned_now)
1194}
1195
1196fn analyze_derive_stmt(
1197 stmt: &Stmt,
1198 context: &DeriveValidationContext,
1199 locals: &mut HashSet<String>,
1200 assigned: &HashSet<String>,
1201) -> syn::Result<HashSet<String>> {
1202 match stmt {
1203 Stmt::Local(local) => {
1204 if let Some(init) = &local.init {
1205 let _ = analyze_derive_expr(&init.expr, context, &mut locals.clone(), assigned)?;
1206 }
1207 for name in bound_local_names(&local.pat) {
1208 locals.insert(name);
1209 }
1210 Ok(assigned.clone())
1211 }
1212 Stmt::Expr(expr, _) => analyze_derive_expr(expr, context, locals, assigned),
1213 Stmt::Macro(stmt_macro) => Err(syn::Error::new_spanned(
1214 stmt_macro,
1215 "`derive` only supports assignments, `if`, `if` / `else`, `for`, and local `let` bindings",
1216 )),
1217 _ => Ok(assigned.clone()),
1218 }
1219}
1220
1221fn analyze_derive_expr(
1222 expr: &Expr,
1223 context: &DeriveValidationContext,
1224 locals: &mut HashSet<String>,
1225 assigned: &HashSet<String>,
1226) -> syn::Result<HashSet<String>> {
1227 match expr {
1228 Expr::Assign(assign) => {
1229 if let Expr::Path(path) = assign.left.as_ref() {
1230 if path.qself.is_none()
1231 && path.path.leading_colon.is_none()
1232 && path.path.segments.len() == 1
1233 {
1234 let ident = &path.path.segments[0].ident;
1235 let name = ident.to_string();
1236 if context.derived.contains(&name) {
1237 let mut next = assigned.clone();
1238 next.insert(name);
1239 return Ok(next);
1240 }
1241 if locals.contains(&name) {
1242 return Ok(assigned.clone());
1243 }
1244 return Err(context.invalid_target_error(ident));
1245 }
1246 }
1247 Err(syn::Error::new_spanned(
1248 &assign.left,
1249 "`derive` assignments must target a name declared in `derived: [...]`",
1250 ))
1251 }
1252 Expr::If(expr_if) => {
1253 let mut then_locals = locals.clone();
1254 let then_assigned = analyze_derive_block(
1255 &expr_if.then_branch,
1256 context,
1257 &mut then_locals,
1258 assigned,
1259 )?;
1260
1261 if let Some((_, else_branch)) = &expr_if.else_branch {
1262 let mut else_locals = locals.clone();
1263 let else_assigned = analyze_derive_expr(
1264 else_branch,
1265 context,
1266 &mut else_locals,
1267 assigned,
1268 )?;
1269 Ok(then_assigned
1270 .intersection(&else_assigned)
1271 .cloned()
1272 .collect::<HashSet<_>>())
1273 } else {
1274 Ok(assigned.clone())
1275 }
1276 }
1277 Expr::ForLoop(expr_for) => {
1278 let mut loop_locals = locals.clone();
1279 for name in bound_local_names(&expr_for.pat) {
1280 loop_locals.insert(name);
1281 }
1282 let _ = analyze_derive_block(&expr_for.body, context, &mut loop_locals, assigned)?;
1283 Ok(assigned.clone())
1284 }
1285 Expr::Block(expr_block) => analyze_derive_block(&expr_block.block, context, locals, assigned),
1286 Expr::While(expr_while) => Err(syn::Error::new_spanned(
1287 expr_while,
1288 "`derive` does not support `while`; use straight-line code, `if`, `if` / `else`, or `for`",
1289 )),
1290 Expr::Loop(expr_loop) => Err(syn::Error::new_spanned(
1291 expr_loop,
1292 "`derive` does not support `loop`; use straight-line code, `if`, `if` / `else`, or `for`",
1293 )),
1294 Expr::Match(expr_match) => Err(syn::Error::new_spanned(
1295 expr_match,
1296 "`derive` does not support `match`; use straight-line code, `if`, `if` / `else`, or `for`",
1297 )),
1298 _ => Ok(assigned.clone()),
1299 }
1300}
1301
1302fn validate_analytical_derive_contract(
1303 kernel: ResolverAnalyticalKernel,
1304 params: &[Ident],
1305 derived: &[Ident],
1306 covariates: &[Ident],
1307 derive: Option<&ExprClosure>,
1308) -> syn::Result<()> {
1309 if derived.is_empty() {
1310 if let Some(derive) = derive {
1311 return Err(syn::Error::new_spanned(
1312 derive,
1313 "built-in `analytical!` `derive` requires `derived: [...]`",
1314 ));
1315 }
1316 return Ok(());
1317 }
1318
1319 let derive = derive.ok_or_else(|| {
1320 syn::Error::new_spanned(
1321 &derived[0],
1322 "built-in `analytical!` declares `derived: [...]` but is missing `derive: ...`",
1323 )
1324 })?;
1325
1326 let p = generated_ident("__pharmsol_p");
1327 let t = generated_ident("__pharmsol_t");
1328 let cov = generated_ident("__pharmsol_cov");
1329 let full_inputs = [p, t.clone(), cov];
1330 let reduced_inputs = [t];
1331 generate_supported_input_aliases(
1332 derive,
1333 &[&full_inputs, &reduced_inputs],
1334 "built-in `analytical!` requires `derive` to have either 3 parameters: |p, t, cov| or 1 parameter: |t|",
1335 )?;
1336
1337 let context = DeriveValidationContext::new(params, covariates, derived);
1338 let mut locals = HashSet::new();
1339 let assigned = match derive.body.as_ref() {
1340 Expr::Block(expr_block) => {
1341 analyze_derive_block(&expr_block.block, &context, &mut locals, &HashSet::new())?
1342 }
1343 expr => analyze_derive_expr(expr, &context, &mut locals, &HashSet::new())?,
1344 };
1345
1346 let required_derived = match validate_analytical_structure_inputs(
1347 &Ident::new(kernel.name(), Span::call_site()),
1348 kernel,
1349 params,
1350 derived,
1351 ) {
1352 Ok(plan) => match plan.kind() {
1353 AnalyticalStructureInputKind::AllPrimary { .. } => HashSet::new(),
1354 AnalyticalStructureInputKind::AllDerived { indices, .. } => indices
1355 .iter()
1356 .map(|index| derived[*index].to_string())
1357 .collect::<HashSet<_>>(),
1358 AnalyticalStructureInputKind::Mixed { bindings } => bindings
1359 .iter()
1360 .filter_map(|binding| match binding.source {
1361 AnalyticalStructureInputSource::Primary => None,
1362 AnalyticalStructureInputSource::Derived => {
1363 Some(derived[binding.index].to_string())
1364 }
1365 })
1366 .collect::<HashSet<_>>(),
1367 },
1368 Err(_) => HashSet::new(),
1369 };
1370
1371 for ident in derived {
1372 let name = ident.to_string();
1373 if !assigned.contains(&name) {
1374 let message = if required_derived.contains(&name) {
1375 format!(
1376 "derived parameter `{name}` is not definitely assigned on every path before analytical structure `{}` uses it",
1377 kernel.name()
1378 )
1379 } else {
1380 format!(
1381 "derived parameter `{name}` is declared in `derived: [...]` but is not definitely assigned in `derive`"
1382 )
1383 };
1384 return Err(syn::Error::new_spanned(ident, message));
1385 }
1386 }
1387
1388 Ok(())
1389}
1390
1391fn validate_ode_diffeq_uses_automatic_injection(
1392 diffeq: &ExprClosure,
1393 routes: &[OdeRouteDecl],
1394) -> syn::Result<()> {
1395 match closure_param_names(diffeq).len() {
1396 3 => Ok(()),
1397 5 => {
1398 let usage = ClosureBodyUsage::analyze(diffeq.body.as_ref());
1399 let route_inputs = route_input_idents(routes);
1400 let fourth_param = closure_param_ident(diffeq, 3);
1401 let fifth_param = closure_param_ident(diffeq, 4);
1402 let mentions_route_inputs = route_inputs.iter().any(|route| usage.mentions(route));
1403 let indexes_fifth_param = fifth_param.as_ref().is_some_and(|ident| usage.indexes(ident));
1404 let reads_fourth_param_as_input = fourth_param
1405 .as_ref()
1406 .is_some_and(|ident| usage.indexes(ident) && !usage.assigns_index(ident));
1407
1408 if mentions_route_inputs || indexes_fifth_param || reads_fourth_param_as_input {
1409 Err(syn::Error::new_spanned(
1410 diffeq,
1411 "declaration-first `ode!` only supports automatic route injection in `diffeq`; use either 5 parameters: |x, p, t, dx, cov| or 3 parameters: |x, t, dx| and remove manual `bolus[...]` / `rateiv[...]` terms",
1412 ))
1413 } else {
1414 Ok(())
1415 }
1416 }
1417 _ => Err(syn::Error::new_spanned(
1418 diffeq,
1419 "declaration-first `ode!` only supports automatic route injection in `diffeq`; use either 5 parameters: |x, p, t, dx, cov| or 3 parameters: |x, t, dx|",
1420 )),
1421 }
1422}
1423
1424fn route_input_idents(routes: &[OdeRouteDecl]) -> Vec<Ident> {
1425 routes
1426 .iter()
1427 .filter_map(|route| route.input.ident().cloned())
1428 .collect()
1429}
1430
1431fn route_input_names(routes: &[OdeRouteDecl]) -> Vec<String> {
1432 routes.iter().map(|route| route.input.name()).collect()
1433}
1434
1435fn ode_route_input_bindings(routes: &[OdeRouteDecl]) -> Vec<(SymbolicIndex, usize)> {
1436 let mut next_bolus_index = 0usize;
1437 let mut next_infusion_index = 0usize;
1438
1439 routes
1440 .iter()
1441 .map(|route| {
1442 let index = match route.kind {
1443 OdeRouteKind::Bolus => {
1444 let index = next_bolus_index;
1445 next_bolus_index += 1;
1446 index
1447 }
1448 OdeRouteKind::Infusion => {
1449 let index = next_infusion_index;
1450 next_infusion_index += 1;
1451 index
1452 }
1453 };
1454 (route.input.clone(), index)
1455 })
1456 .collect()
1457}
1458
1459fn dense_index_len(bindings: &[(SymbolicIndex, usize)]) -> usize {
1460 bindings
1461 .iter()
1462 .map(|(_, index)| index + 1)
1463 .max()
1464 .unwrap_or(0)
1465}
1466
1467fn validate_binding_conflicts(
1468 left_label: &str,
1469 left: &[Ident],
1470 right_label: &str,
1471 right: &[Ident],
1472 context: &str,
1473) -> syn::Result<()> {
1474 let right_names = right.iter().map(Ident::to_string).collect::<HashSet<_>>();
1475
1476 for ident in left {
1477 let name = ident.to_string();
1478 if right_names.contains(&name) {
1479 return Err(syn::Error::new_spanned(
1480 ident,
1481 format!(
1482 "named {left_label} binding `{name}` conflicts with named {right_label} binding in {context}"
1483 ),
1484 ));
1485 }
1486 }
1487
1488 Ok(())
1489}
1490
1491fn validate_closure_param_conflicts(
1492 closure_label: &str,
1493 closure: &ExprClosure,
1494 bindings: &[Ident],
1495 binding_label: &str,
1496) -> syn::Result<()> {
1497 let parameter_names = closure_param_names(closure)
1498 .into_iter()
1499 .filter(|name| !name.is_empty())
1500 .collect::<HashSet<_>>();
1501
1502 for ident in bindings {
1503 let name = ident.to_string();
1504 if parameter_names.contains(&name) {
1505 return Err(syn::Error::new_spanned(
1506 ident,
1507 format!(
1508 "named {binding_label} binding `{name}` conflicts with `{closure_label}` closure parameter `{name}`"
1509 ),
1510 ));
1511 }
1512 }
1513
1514 Ok(())
1515}
1516
1517#[derive(Clone, Copy)]
1518struct NamedBindingSets<'a> {
1519 params: &'a [Ident],
1520 derived: &'a [Ident],
1521 covariates: &'a [Ident],
1522 states: &'a [Ident],
1523 outputs: &'a [Ident],
1524 routes: &'a [OdeRouteDecl],
1525}
1526
1527#[derive(Clone, Copy)]
1528struct CommonBindingClosures<'a> {
1529 lag: Option<&'a ExprClosure>,
1530 fa: Option<&'a ExprClosure>,
1531 init: Option<&'a ExprClosure>,
1532 out: &'a ExprClosure,
1533}
1534
1535#[derive(Clone, Copy)]
1536struct AnalyticalBindingClosures<'a> {
1537 derive: Option<&'a ExprClosure>,
1538 common: CommonBindingClosures<'a>,
1539}
1540
1541#[derive(Clone, Copy)]
1542struct OdeBindingClosures<'a> {
1543 diffeq: &'a ExprClosure,
1544 common: CommonBindingClosures<'a>,
1545}
1546
1547#[derive(Clone, Copy)]
1548struct SdeBindingClosures<'a> {
1549 drift: &'a ExprClosure,
1550 diffusion: &'a ExprClosure,
1551 common: CommonBindingClosures<'a>,
1552}
1553
1554fn validate_named_binding_compatibility(
1555 bindings: NamedBindingSets<'_>,
1556 closures: OdeBindingClosures<'_>,
1557) -> syn::Result<()> {
1558 let NamedBindingSets {
1559 params,
1560 derived: _,
1561 covariates,
1562 states,
1563 outputs,
1564 routes,
1565 } = bindings;
1566 let OdeBindingClosures {
1567 diffeq,
1568 common: CommonBindingClosures { lag, fa, init, out },
1569 } = closures;
1570 let route_inputs = route_input_idents(routes);
1571
1572 validate_binding_conflicts(
1573 "parameter",
1574 params,
1575 "covariate",
1576 covariates,
1577 "declaration-first `ode!` named binding generation",
1578 )?;
1579 validate_binding_conflicts(
1580 "parameter",
1581 params,
1582 "state",
1583 states,
1584 "`diffeq` and `out` named binding generation",
1585 )?;
1586 validate_binding_conflicts(
1587 "parameter",
1588 params,
1589 "output",
1590 outputs,
1591 "`out` named binding generation",
1592 )?;
1593 validate_binding_conflicts(
1594 "state",
1595 states,
1596 "output",
1597 outputs,
1598 "`out` named binding generation",
1599 )?;
1600 validate_binding_conflicts(
1601 "covariate",
1602 covariates,
1603 "state",
1604 states,
1605 "declaration-first `ode!` named binding generation",
1606 )?;
1607 validate_binding_conflicts(
1608 "covariate",
1609 covariates,
1610 "output",
1611 outputs,
1612 "declaration-first `ode!` named binding generation",
1613 )?;
1614
1615 validate_closure_param_conflicts("diffeq", diffeq, params, "parameter")?;
1616 validate_closure_param_conflicts("diffeq", diffeq, covariates, "covariate")?;
1617 validate_closure_param_conflicts("diffeq", diffeq, states, "state")?;
1618
1619 if let Some(lag) = lag {
1620 validate_binding_conflicts(
1621 "covariate",
1622 covariates,
1623 "route",
1624 &route_inputs,
1625 "`lag` named binding generation",
1626 )?;
1627 validate_closure_param_conflicts("lag", lag, params, "parameter")?;
1628 validate_closure_param_conflicts("lag", lag, covariates, "covariate")?;
1629 validate_closure_param_conflicts("lag", lag, &route_inputs, "route")?;
1630 }
1631
1632 if let Some(fa) = fa {
1633 validate_binding_conflicts(
1634 "covariate",
1635 covariates,
1636 "route",
1637 &route_inputs,
1638 "`fa` named binding generation",
1639 )?;
1640 validate_closure_param_conflicts("fa", fa, params, "parameter")?;
1641 validate_closure_param_conflicts("fa", fa, covariates, "covariate")?;
1642 validate_closure_param_conflicts("fa", fa, &route_inputs, "route")?;
1643 }
1644
1645 if let Some(init) = init {
1646 validate_closure_param_conflicts("init", init, params, "parameter")?;
1647 validate_closure_param_conflicts("init", init, covariates, "covariate")?;
1648 validate_closure_param_conflicts("init", init, states, "state")?;
1649 }
1650
1651 validate_closure_param_conflicts("out", out, params, "parameter")?;
1652 validate_closure_param_conflicts("out", out, covariates, "covariate")?;
1653 validate_closure_param_conflicts("out", out, states, "state")?;
1654 validate_closure_param_conflicts("out", out, outputs, "output")?;
1655
1656 Ok(())
1657}
1658
1659fn validate_analytical_named_binding_compatibility(
1660 bindings: NamedBindingSets<'_>,
1661 closures: AnalyticalBindingClosures<'_>,
1662) -> syn::Result<()> {
1663 let NamedBindingSets {
1664 params,
1665 derived,
1666 covariates,
1667 states,
1668 outputs,
1669 routes,
1670 } = bindings;
1671 let AnalyticalBindingClosures {
1672 derive,
1673 common: CommonBindingClosures { lag, fa, init, out },
1674 } = closures;
1675 let route_inputs = route_input_idents(routes);
1676
1677 validate_binding_conflicts(
1678 "parameter",
1679 params,
1680 "covariate",
1681 covariates,
1682 "`analytical!` named binding generation",
1683 )?;
1684 validate_binding_conflicts(
1685 "derived parameter",
1686 derived,
1687 "covariate",
1688 covariates,
1689 "`analytical!` named binding generation",
1690 )?;
1691 validate_binding_conflicts(
1692 "parameter",
1693 params,
1694 "state",
1695 states,
1696 "`analytical!` named binding generation",
1697 )?;
1698 validate_binding_conflicts(
1699 "derived parameter",
1700 derived,
1701 "state",
1702 states,
1703 "`analytical!` named binding generation",
1704 )?;
1705 validate_binding_conflicts(
1706 "parameter",
1707 params,
1708 "output",
1709 outputs,
1710 "`analytical!` named binding generation",
1711 )?;
1712 validate_binding_conflicts(
1713 "derived parameter",
1714 derived,
1715 "output",
1716 outputs,
1717 "`analytical!` named binding generation",
1718 )?;
1719 validate_binding_conflicts(
1720 "covariate",
1721 covariates,
1722 "state",
1723 states,
1724 "`analytical!` named binding generation",
1725 )?;
1726 validate_binding_conflicts(
1727 "covariate",
1728 covariates,
1729 "output",
1730 outputs,
1731 "`analytical!` named binding generation",
1732 )?;
1733 validate_binding_conflicts(
1734 "covariate",
1735 covariates,
1736 "route",
1737 &route_inputs,
1738 "`analytical!` named binding generation",
1739 )?;
1740 validate_binding_conflicts(
1741 "parameter",
1742 params,
1743 "route",
1744 &route_inputs,
1745 "`analytical!` named binding generation",
1746 )?;
1747 validate_binding_conflicts(
1748 "derived parameter",
1749 derived,
1750 "route",
1751 &route_inputs,
1752 "`analytical!` named binding generation",
1753 )?;
1754 validate_binding_conflicts(
1755 "state",
1756 states,
1757 "output",
1758 outputs,
1759 "`analytical!` named binding generation",
1760 )?;
1761 validate_binding_conflicts(
1762 "state",
1763 states,
1764 "route",
1765 &route_inputs,
1766 "`analytical!` named binding generation",
1767 )?;
1768 validate_binding_conflicts(
1769 "output",
1770 outputs,
1771 "route",
1772 &route_inputs,
1773 "`analytical!` named binding generation",
1774 )?;
1775
1776 if let Some(derive) = derive {
1777 validate_closure_param_conflicts("derive", derive, params, "parameter")?;
1778 validate_closure_param_conflicts("derive", derive, derived, "derived parameter")?;
1779 validate_closure_param_conflicts("derive", derive, covariates, "covariate")?;
1780 }
1781
1782 if let Some(lag) = lag {
1783 validate_closure_param_conflicts("lag", lag, params, "parameter")?;
1784 validate_closure_param_conflicts("lag", lag, derived, "derived parameter")?;
1785 validate_closure_param_conflicts("lag", lag, covariates, "covariate")?;
1786 validate_closure_param_conflicts("lag", lag, &route_inputs, "route")?;
1787 }
1788
1789 if let Some(fa) = fa {
1790 validate_closure_param_conflicts("fa", fa, params, "parameter")?;
1791 validate_closure_param_conflicts("fa", fa, derived, "derived parameter")?;
1792 validate_closure_param_conflicts("fa", fa, covariates, "covariate")?;
1793 validate_closure_param_conflicts("fa", fa, &route_inputs, "route")?;
1794 }
1795
1796 if let Some(init) = init {
1797 validate_closure_param_conflicts("init", init, params, "parameter")?;
1798 validate_closure_param_conflicts("init", init, derived, "derived parameter")?;
1799 validate_closure_param_conflicts("init", init, covariates, "covariate")?;
1800 validate_closure_param_conflicts("init", init, states, "state")?;
1801 }
1802
1803 validate_closure_param_conflicts("out", out, params, "parameter")?;
1804 validate_closure_param_conflicts("out", out, derived, "derived parameter")?;
1805 validate_closure_param_conflicts("out", out, covariates, "covariate")?;
1806 validate_closure_param_conflicts("out", out, states, "state")?;
1807 validate_closure_param_conflicts("out", out, outputs, "output")?;
1808
1809 Ok(())
1810}
1811
1812fn validate_sde_named_binding_compatibility(
1813 bindings: NamedBindingSets<'_>,
1814 closures: SdeBindingClosures<'_>,
1815) -> syn::Result<()> {
1816 let NamedBindingSets {
1817 params,
1818 derived: _,
1819 covariates,
1820 states,
1821 outputs,
1822 routes,
1823 } = bindings;
1824 let SdeBindingClosures {
1825 drift,
1826 diffusion,
1827 common: CommonBindingClosures { lag, fa, init, out },
1828 } = closures;
1829 let route_inputs = route_input_idents(routes);
1830
1831 validate_binding_conflicts(
1832 "parameter",
1833 params,
1834 "covariate",
1835 covariates,
1836 "`sde!` named binding generation",
1837 )?;
1838 validate_binding_conflicts(
1839 "parameter",
1840 params,
1841 "state",
1842 states,
1843 "`sde!` named binding generation",
1844 )?;
1845 validate_binding_conflicts(
1846 "parameter",
1847 params,
1848 "output",
1849 outputs,
1850 "`sde!` named binding generation",
1851 )?;
1852 validate_binding_conflicts(
1853 "covariate",
1854 covariates,
1855 "state",
1856 states,
1857 "`sde!` named binding generation",
1858 )?;
1859 validate_binding_conflicts(
1860 "covariate",
1861 covariates,
1862 "output",
1863 outputs,
1864 "`sde!` named binding generation",
1865 )?;
1866 validate_binding_conflicts(
1867 "covariate",
1868 covariates,
1869 "route",
1870 &route_inputs,
1871 "`sde!` named binding generation",
1872 )?;
1873 validate_binding_conflicts(
1874 "parameter",
1875 params,
1876 "route",
1877 &route_inputs,
1878 "`sde!` named binding generation",
1879 )?;
1880 validate_binding_conflicts(
1881 "state",
1882 states,
1883 "output",
1884 outputs,
1885 "`sde!` named binding generation",
1886 )?;
1887 validate_binding_conflicts(
1888 "state",
1889 states,
1890 "route",
1891 &route_inputs,
1892 "`sde!` named binding generation",
1893 )?;
1894 validate_binding_conflicts(
1895 "output",
1896 outputs,
1897 "route",
1898 &route_inputs,
1899 "`sde!` named binding generation",
1900 )?;
1901
1902 validate_closure_param_conflicts("drift", drift, params, "parameter")?;
1903 validate_closure_param_conflicts("drift", drift, covariates, "covariate")?;
1904 validate_closure_param_conflicts("drift", drift, states, "state")?;
1905 validate_closure_param_conflicts("diffusion", diffusion, params, "parameter")?;
1906 validate_closure_param_conflicts("diffusion", diffusion, states, "state")?;
1907
1908 if let Some(lag) = lag {
1909 validate_closure_param_conflicts("lag", lag, params, "parameter")?;
1910 validate_closure_param_conflicts("lag", lag, covariates, "covariate")?;
1911 validate_closure_param_conflicts("lag", lag, &route_inputs, "route")?;
1912 }
1913
1914 if let Some(fa) = fa {
1915 validate_closure_param_conflicts("fa", fa, params, "parameter")?;
1916 validate_closure_param_conflicts("fa", fa, covariates, "covariate")?;
1917 validate_closure_param_conflicts("fa", fa, &route_inputs, "route")?;
1918 }
1919
1920 if let Some(init) = init {
1921 validate_closure_param_conflicts("init", init, params, "parameter")?;
1922 validate_closure_param_conflicts("init", init, covariates, "covariate")?;
1923 validate_closure_param_conflicts("init", init, states, "state")?;
1924 }
1925
1926 validate_closure_param_conflicts("out", out, params, "parameter")?;
1927 validate_closure_param_conflicts("out", out, covariates, "covariate")?;
1928 validate_closure_param_conflicts("out", out, states, "state")?;
1929 validate_closure_param_conflicts("out", out, outputs, "output")?;
1930
1931 Ok(())
1932}
1933
1934fn generate_index_consts(idents: &[Ident]) -> TokenStream2 {
1935 let bindings = idents.iter().enumerate().map(|(index, ident)| {
1936 quote! {
1937 #[allow(non_upper_case_globals, dead_code)]
1938 const #ident: usize = #index;
1939 }
1940 });
1941
1942 quote! {
1943 #(#bindings)*
1944 }
1945}
1946
1947fn generate_mapped_index_consts(bindings: &[(SymbolicIndex, usize)]) -> TokenStream2 {
1948 let bindings = bindings.iter().filter_map(|(label, index)| {
1949 label.ident().map(|ident| {
1950 quote! {
1951 #[allow(non_upper_case_globals, dead_code)]
1952 const #ident: usize = #index;
1953 }
1954 })
1955 });
1956
1957 quote! {
1958 #(#bindings)*
1959 }
1960}
1961
1962fn expand_out(
1963 out: &ExprClosure,
1964 params: &[Ident],
1965 covariates: &[Ident],
1966 states: &[Ident],
1967 outputs: &[SymbolicIndex],
1968) -> syn::Result<TokenStream2> {
1969 let state_consts = generate_index_consts(states);
1970 let output_bindings = symbolic_index_bindings(outputs);
1971 let output_consts = generate_mapped_index_consts(&output_bindings);
1972 let x = generated_ident("__pharmsol_x");
1973 let p = generated_ident("__pharmsol_p");
1974 let t = generated_ident("__pharmsol_t");
1975 let cov = generated_ident("__pharmsol_cov");
1976 let y = generated_ident("__pharmsol_y");
1977 let full_inputs = [x.clone(), p.clone(), t.clone(), cov.clone(), y.clone()];
1978 let reduced_inputs = [x.clone(), t.clone(), y.clone()];
1979 let input_aliases = generate_supported_input_aliases(
1980 out,
1981 &[&full_inputs, &reduced_inputs],
1982 "declaration-first `ode!` requires `out` to have either 5 parameters: |x, p, t, cov, y| or 3 parameters: |x, t, y|",
1983 )?;
1984 let parameter_bindings = generate_parameter_bindings(params, out, &p);
1985 let covariate_bindings = generate_covariate_bindings(covariates, out, &cov, &t);
1986 let y_binding = if out.inputs.len() == full_inputs.len() {
1987 closure_param_ident(out, 4).unwrap_or_else(|| y.clone())
1988 } else {
1989 closure_param_ident(out, 2).unwrap_or_else(|| y.clone())
1990 };
1991 let body = NumericLabelRewriter::rewrite(
1992 out.body.as_ref(),
1993 vec![IndexRewriteTarget::new(
1994 y_binding,
1995 symbolic_numeric_binding_map(&output_bindings),
1996 )],
1997 None,
1998 );
1999
2000 Ok(quote! {{
2001 let __pharmsol_out: fn(
2002 &::pharmsol::simulator::V,
2003 &::pharmsol::simulator::V,
2004 f64,
2005 &::pharmsol::data::Covariates,
2006 &mut ::pharmsol::simulator::V,
2007 ) = |#x: &::pharmsol::simulator::V,
2008 #p: &::pharmsol::simulator::V,
2009 #t: f64,
2010 #cov: &::pharmsol::data::Covariates,
2011 #y: &mut ::pharmsol::simulator::V| {
2012 #input_aliases
2013 #state_consts
2014 #output_consts
2015 #parameter_bindings
2016 #covariate_bindings
2017 #body
2018 };
2019 __pharmsol_out
2020 }})
2021}
2022
2023fn route_property_error<T: ToTokens>(macro_name: &str, label: &str, node: T) -> syn::Error {
2024 syn::Error::new_spanned(
2025 node,
2026 format!(
2027 "{macro_name} requires `{label}` to return `{label}! {{ ... }}` so route-property metadata can be synthesized"
2028 ),
2029 )
2030}
2031
2032fn find_terminal_macro_invocation(
2033 macro_name: &str,
2034 label: &str,
2035 closure: &ExprClosure,
2036) -> syn::Result<syn::Macro> {
2037 match closure.body.as_ref() {
2038 Expr::Macro(expr_macro) if expr_macro.mac.path.is_ident(label) => {
2039 Ok(expr_macro.mac.clone())
2040 }
2041 Expr::Macro(expr_macro) => Err(route_property_error(macro_name, label, expr_macro)),
2042 Expr::Block(expr_block) => {
2043 for stmt in expr_block.block.stmts.iter().rev() {
2044 match stmt {
2045 Stmt::Expr(Expr::Macro(expr_macro), _)
2046 if expr_macro.mac.path.is_ident(label) =>
2047 {
2048 return Ok(expr_macro.mac.clone());
2049 }
2050 Stmt::Expr(Expr::Macro(expr_macro), _) => {
2051 return Err(route_property_error(macro_name, label, expr_macro));
2052 }
2053 Stmt::Expr(other, _) => {
2054 return Err(route_property_error(macro_name, label, other));
2055 }
2056 Stmt::Macro(stmt_macro) if stmt_macro.mac.path.is_ident(label) => {
2057 return Ok(stmt_macro.mac.clone());
2058 }
2059 Stmt::Macro(stmt_macro) => {
2060 return Err(route_property_error(macro_name, label, stmt_macro));
2061 }
2062 _ => continue,
2063 }
2064 }
2065
2066 Err(route_property_error(macro_name, label, expr_block))
2067 }
2068 other => Err(route_property_error(macro_name, label, other)),
2069 }
2070}
2071
2072fn extract_route_property_routes(
2073 macro_name: &str,
2074 label: &str,
2075 closure: &ExprClosure,
2076 routes: &[OdeRouteDecl],
2077) -> syn::Result<HashSet<String>> {
2078 let macro_expr = find_terminal_macro_invocation(macro_name, label, closure)?;
2079 let entries = Punctuated::<RoutePropertyEntry, Token![,]>::parse_terminated
2080 .parse2(macro_expr.tokens.clone())?;
2081 let known_routes = route_input_names(routes)
2082 .into_iter()
2083 .collect::<HashSet<_>>();
2084 let mut seen = HashSet::new();
2085
2086 for entry in entries {
2087 let route_name = entry.route.name();
2088 if !known_routes.contains(&route_name) {
2089 return Err(syn::Error::new_spanned(
2090 &entry.route,
2091 format!(
2092 "route `{route_name}` in `{label}!` is not declared in the `routes` section"
2093 ),
2094 ));
2095 }
2096 if !seen.insert(route_name.clone()) {
2097 return Err(syn::Error::new_spanned(
2098 &entry.route,
2099 format!("duplicate route `{route_name}` in `{label}!`"),
2100 ));
2101 }
2102 let _ = entry.value;
2103 }
2104
2105 Ok(seen)
2106}
2107
2108fn validate_route_property_kinds(
2109 macro_name: &str,
2110 label: &str,
2111 routes: &[OdeRouteDecl],
2112 property_routes: &HashSet<String>,
2113) -> syn::Result<()> {
2114 for route in routes {
2115 if property_routes.contains(&route.input.name())
2116 && matches!(route.kind, OdeRouteKind::Infusion)
2117 {
2118 return Err(syn::Error::new_spanned(
2119 &route.input,
2120 format!(
2121 "{macro_name} does not allow `{label}` on infusion route `{}`",
2122 route.input
2123 ),
2124 ));
2125 }
2126 }
2127
2128 Ok(())
2129}
2130
2131fn expand_ode_route_map(
2132 label: &str,
2133 closure: &ExprClosure,
2134 params: &[Ident],
2135 covariates: &[Ident],
2136 route_bindings: &[(SymbolicIndex, usize)],
2137) -> syn::Result<TokenStream2> {
2138 let route_consts = generate_mapped_index_consts(route_bindings);
2139 let p = generated_ident("__pharmsol_p");
2140 let t = generated_ident("__pharmsol_t");
2141 let cov = generated_ident("__pharmsol_cov");
2142 let full_inputs = [p.clone(), t.clone(), cov.clone()];
2143 let reduced_inputs = [t.clone()];
2144 let input_aliases = generate_supported_input_aliases(
2145 closure,
2146 &[&full_inputs, &reduced_inputs],
2147 &format!(
2148 "declaration-first `ode!` requires `{label}` to have either 3 parameters: |p, t, cov| or 1 parameter: |t|"
2149 ),
2150 )?;
2151 let parameter_bindings = generate_parameter_bindings(params, closure, &p);
2152 let covariate_bindings = generate_covariate_bindings(covariates, closure, &cov, &t);
2153 let body = NumericLabelRewriter::rewrite(
2154 closure.body.as_ref(),
2155 Vec::new(),
2156 Some(symbolic_numeric_binding_map(route_bindings)),
2157 );
2158
2159 Ok(quote! {{
2160 let __pharmsol_route_map: fn(
2161 &::pharmsol::simulator::V,
2162 f64,
2163 &::pharmsol::data::Covariates,
2164 ) -> ::std::collections::HashMap<usize, f64> = |#p: &::pharmsol::simulator::V,
2165 #t: f64,
2166 #cov: &::pharmsol::data::Covariates| {
2167 #input_aliases
2168 #route_consts
2169 #parameter_bindings
2170 #covariate_bindings
2171 #body
2172 };
2173 __pharmsol_route_map
2174 }})
2175}
2176
2177fn expand_ode_init(
2178 init: &ExprClosure,
2179 params: &[Ident],
2180 covariates: &[Ident],
2181 states: &[Ident],
2182) -> syn::Result<TokenStream2> {
2183 let state_consts = generate_index_consts(states);
2184 let p = generated_ident("__pharmsol_p");
2185 let t = generated_ident("__pharmsol_t");
2186 let cov = generated_ident("__pharmsol_cov");
2187 let x = generated_ident("__pharmsol_x");
2188 let full_inputs = [p.clone(), t.clone(), cov.clone(), x.clone()];
2189 let reduced_inputs = [t.clone(), x.clone()];
2190 let input_aliases = generate_supported_input_aliases(
2191 init,
2192 &[&full_inputs, &reduced_inputs],
2193 "declaration-first `ode!` requires `init` to have either 4 parameters: |p, t, cov, x| or 2 parameters: |t, x|",
2194 )?;
2195 let parameter_bindings = generate_parameter_bindings(params, init, &p);
2196 let covariate_bindings = generate_covariate_bindings(covariates, init, &cov, &t);
2197 let body = &init.body;
2198
2199 Ok(quote! {{
2200 let __pharmsol_init: fn(
2201 &::pharmsol::simulator::V,
2202 f64,
2203 &::pharmsol::data::Covariates,
2204 &mut ::pharmsol::simulator::V,
2205 ) = |#p: &::pharmsol::simulator::V,
2206 #t: f64,
2207 #cov: &::pharmsol::data::Covariates,
2208 #x: &mut ::pharmsol::simulator::V| {
2209 #input_aliases
2210 #state_consts
2211 #parameter_bindings
2212 #covariate_bindings
2213 #body
2214 };
2215 __pharmsol_init
2216 }})
2217}
2218
2219fn expand_route_metadata(
2220 routes: &[OdeRouteDecl],
2221 lag_routes: &HashSet<String>,
2222 fa_routes: &HashSet<String>,
2223) -> Vec<TokenStream2> {
2224 routes
2225 .iter()
2226 .map(|route| {
2227 let input = &route.input;
2228 let destination = &route.destination;
2229 let route_name = route.input.name();
2230 let route_builder = match route.kind {
2231 OdeRouteKind::Bolus => {
2232 quote! { ::pharmsol::equation::Route::bolus(stringify!(#input)) }
2233 }
2234 OdeRouteKind::Infusion => {
2235 quote! { ::pharmsol::equation::Route::infusion(stringify!(#input)) }
2236 }
2237 };
2238 let lag_flag = if lag_routes.contains(&route_name) {
2239 quote! { .with_lag() }
2240 } else {
2241 quote! {}
2242 };
2243 let fa_flag = if fa_routes.contains(&route_name) {
2244 quote! { .with_bioavailability() }
2245 } else {
2246 quote! {}
2247 };
2248
2249 quote! {
2250 #route_builder
2251 .to_state(stringify!(#destination))
2252 #lag_flag
2253 #fa_flag
2254 .inject_input_to_destination()
2255 }
2256 })
2257 .collect()
2258}
2259
2260fn expand_analytical_route_metadata(
2261 routes: &[OdeRouteDecl],
2262 lag_routes: &HashSet<String>,
2263 fa_routes: &HashSet<String>,
2264) -> Vec<TokenStream2> {
2265 routes
2266 .iter()
2267 .map(|route| {
2268 let input = &route.input;
2269 let destination = &route.destination;
2270 let route_name = route.input.name();
2271 let route_builder = match route.kind {
2272 OdeRouteKind::Bolus => {
2273 quote! { ::pharmsol::equation::Route::bolus(stringify!(#input)) }
2274 }
2275 OdeRouteKind::Infusion => {
2276 quote! { ::pharmsol::equation::Route::infusion(stringify!(#input)) }
2277 }
2278 };
2279 let lag_flag = if lag_routes.contains(&route_name) {
2280 quote! { .with_lag() }
2281 } else {
2282 quote! {}
2283 };
2284 let fa_flag = if fa_routes.contains(&route_name) {
2285 quote! { .with_bioavailability() }
2286 } else {
2287 quote! {}
2288 };
2289
2290 quote! {
2291 #route_builder
2292 .to_state(stringify!(#destination))
2293 #lag_flag
2294 #fa_flag
2295 }
2296 })
2297 .collect()
2298}
2299
2300fn expand_sde_route_metadata(
2301 routes: &[OdeRouteDecl],
2302 lag_routes: &HashSet<String>,
2303 fa_routes: &HashSet<String>,
2304) -> Vec<TokenStream2> {
2305 routes
2306 .iter()
2307 .map(|route| {
2308 let input = &route.input;
2309 let destination = &route.destination;
2310 let route_name = route.input.name();
2311 let route_builder = match route.kind {
2312 OdeRouteKind::Bolus => {
2313 quote! { ::pharmsol::equation::Route::bolus(stringify!(#input)) }
2314 }
2315 OdeRouteKind::Infusion => {
2316 quote! { ::pharmsol::equation::Route::infusion(stringify!(#input)) }
2317 }
2318 };
2319 let lag_flag = if lag_routes.contains(&route_name) {
2320 quote! { .with_lag() }
2321 } else {
2322 quote! {}
2323 };
2324 let fa_flag = if fa_routes.contains(&route_name) {
2325 quote! { .with_bioavailability() }
2326 } else {
2327 quote! {}
2328 };
2329
2330 quote! {
2331 #route_builder
2332 .to_state(stringify!(#destination))
2333 .inject_input_to_destination()
2334 #lag_flag
2335 #fa_flag
2336 }
2337 })
2338 .collect()
2339}
2340
2341fn route_destination_index(route: &OdeRouteDecl, states: &[Ident]) -> usize {
2342 states
2343 .iter()
2344 .position(|state| state == &route.destination)
2345 .expect("validated route destination should exist")
2346}
2347
2348fn expand_injected_ode_route_terms(
2349 routes: &[OdeRouteDecl],
2350 states: &[Ident],
2351 route_bindings: &[(SymbolicIndex, usize)],
2352 dx: &Ident,
2353 bolus: &Ident,
2354 rateiv: &Ident,
2355) -> TokenStream2 {
2356 let terms = routes
2357 .iter()
2358 .zip(route_bindings.iter())
2359 .map(|(route, (_, input_index))| {
2360 let destination = route_destination_index(route, states);
2361 match route.kind {
2362 OdeRouteKind::Bolus => quote! {
2363 #dx[#destination] += #bolus[#input_index];
2364 },
2365 OdeRouteKind::Infusion => quote! {
2366 #dx[#destination] += #rateiv[#input_index];
2367 },
2368 }
2369 });
2370
2371 quote! {
2372 #(#terms)*
2373 }
2374}
2375
2376fn expand_injected_sde_rate_terms(
2377 routes: &[OdeRouteDecl],
2378 states: &[Ident],
2379 route_bindings: &[(SymbolicIndex, usize)],
2380 dx: &Ident,
2381 rateiv: &Ident,
2382) -> TokenStream2 {
2383 let terms = routes
2384 .iter()
2385 .zip(route_bindings.iter())
2386 .filter_map(|(route, (_, input_index))| match route.kind {
2387 OdeRouteKind::Bolus => None,
2388 OdeRouteKind::Infusion => {
2389 let destination = route_destination_index(route, states);
2390 Some(quote! {
2391 #dx[#destination] += #rateiv[#input_index];
2392 })
2393 }
2394 });
2395
2396 quote! {
2397 #(#terms)*
2398 }
2399}
2400
2401fn expand_injected_sde_bolus_mappings(
2402 routes: &[OdeRouteDecl],
2403 states: &[Ident],
2404 route_bindings: &[(SymbolicIndex, usize)],
2405) -> TokenStream2 {
2406 let mut destinations = vec![quote! { None }; dense_index_len(route_bindings)];
2407
2408 for (route, (_, input_index)) in routes.iter().zip(route_bindings.iter()) {
2409 if let OdeRouteKind::Bolus = route.kind {
2410 let destination = route_destination_index(route, states);
2411 destinations[*input_index] = quote! { Some(#destination) };
2412 }
2413 }
2414
2415 quote! {
2416 .with_injected_bolus_inputs(&[#(#destinations),*])
2417 }
2418}
2419
2420fn validate_unique_idents(kind: &str, idents: &[Ident], macro_name: &str) -> syn::Result<()> {
2421 let mut seen = HashSet::new();
2422 for ident in idents {
2423 let name = ident.to_string();
2424 if !seen.insert(name.clone()) {
2425 return Err(syn::Error::new_spanned(
2426 ident,
2427 format!("duplicate {kind} `{name}` in declaration-first `{macro_name}`"),
2428 ));
2429 }
2430 }
2431 Ok(())
2432}
2433
2434fn validate_unique_symbolic_indices(
2435 kind: &str,
2436 labels: &[SymbolicIndex],
2437 macro_name: &str,
2438) -> syn::Result<()> {
2439 let mut seen = HashSet::new();
2440 for label in labels {
2441 let name = label.name();
2442 if !seen.insert(name.clone()) {
2443 return Err(syn::Error::new_spanned(
2444 label,
2445 format!("duplicate {kind} `{name}` in declaration-first `{macro_name}`"),
2446 ));
2447 }
2448 }
2449 Ok(())
2450}
2451
2452fn validate_routes(routes: &[OdeRouteDecl], states: &[Ident], macro_name: &str) -> syn::Result<()> {
2453 let known_states = states.iter().map(Ident::to_string).collect::<HashSet<_>>();
2454 let mut seen_routes = HashSet::new();
2455
2456 for route in routes {
2457 let route_name = route.input.name();
2458 if !seen_routes.insert(route_name.clone()) {
2459 return Err(syn::Error::new_spanned(
2460 &route.input,
2461 format!("duplicate route `{route_name}` in declaration-first `{macro_name}`"),
2462 ));
2463 }
2464
2465 if !known_states.contains(&route.destination.to_string()) {
2466 return Err(syn::Error::new_spanned(
2467 &route.destination,
2468 format!(
2469 "route destination `{}` is not declared in the `states` section",
2470 route.destination
2471 ),
2472 ));
2473 }
2474 }
2475
2476 Ok(())
2477}
2478
2479fn expand_diffeq(
2480 diffeq: &ExprClosure,
2481 params: &[Ident],
2482 covariates: &[Ident],
2483 states: &[Ident],
2484 routes: &[OdeRouteDecl],
2485 route_bindings: &[(SymbolicIndex, usize)],
2486) -> syn::Result<TokenStream2> {
2487 let state_consts = generate_index_consts(states);
2488 let x = generated_ident("__pharmsol_x");
2489 let p = generated_ident("__pharmsol_p");
2490 let t = generated_ident("__pharmsol_t");
2491 let dx = generated_ident("__pharmsol_dx");
2492 let bolus = generated_ident("__pharmsol_bolus");
2493 let rateiv = generated_ident("__pharmsol_rateiv");
2494 let cov = generated_ident("__pharmsol_cov");
2495 let full_inputs = [x.clone(), p.clone(), t.clone(), dx.clone(), cov.clone()];
2496 let reduced_inputs = [x.clone(), t.clone(), dx.clone()];
2497 let input_aliases = generate_supported_input_aliases(
2498 diffeq,
2499 &[&full_inputs, &reduced_inputs],
2500 "declaration-first `ode!` injected-route `diffeq` requires either 5 parameters: |x, p, t, dx, cov| or 3 parameters: |x, t, dx|",
2501 )?;
2502 let parameter_bindings = generate_parameter_bindings(params, diffeq, &p);
2503 let covariate_bindings = generate_covariate_bindings(covariates, diffeq, &cov, &t);
2504 let body = &diffeq.body;
2505 let dx_binding = if diffeq.inputs.len() == full_inputs.len() {
2506 closure_param_ident(diffeq, 3).unwrap_or_else(|| dx.clone())
2507 } else {
2508 closure_param_ident(diffeq, 2).unwrap_or_else(|| dx.clone())
2509 };
2510 let route_terms = expand_injected_ode_route_terms(
2511 routes,
2512 states,
2513 route_bindings,
2514 &dx_binding,
2515 &bolus,
2516 &rateiv,
2517 );
2518
2519 Ok(quote! {{
2520 let __pharmsol_diffeq: fn(
2521 &::pharmsol::simulator::V,
2522 &::pharmsol::simulator::V,
2523 f64,
2524 &mut ::pharmsol::simulator::V,
2525 &::pharmsol::simulator::V,
2526 &::pharmsol::simulator::V,
2527 &::pharmsol::data::Covariates,
2528 ) = |#x: &::pharmsol::simulator::V,
2529 #p: &::pharmsol::simulator::V,
2530 #t: f64,
2531 #dx: &mut ::pharmsol::simulator::V,
2532 #bolus: &::pharmsol::simulator::V,
2533 #rateiv: &::pharmsol::simulator::V,
2534 #cov: &::pharmsol::data::Covariates| {
2535 #input_aliases
2536 #state_consts
2537 #parameter_bindings
2538 #covariate_bindings
2539 #body
2540 #route_terms
2541 };
2542 __pharmsol_diffeq
2543 }})
2544}
2545
2546fn resolve_analytical_structure(structure: &Ident) -> syn::Result<AnalyticalKernelSpec> {
2547 let structure_name = structure.to_string();
2548 let (kernel, runtime_path, metadata_kernel, state_count) = match structure_name.as_str() {
2549 "one_compartment" => (
2550 ResolverAnalyticalKernel::OneCompartment,
2551 quote! { ::pharmsol::equation::one_compartment },
2552 quote! { ::pharmsol::equation::AnalyticalKernel::OneCompartment },
2553 1,
2554 ),
2555 "one_compartment_cl" => (
2556 ResolverAnalyticalKernel::OneCompartmentCl,
2557 quote! { ::pharmsol::equation::one_compartment_cl },
2558 quote! { ::pharmsol::equation::AnalyticalKernel::OneCompartmentCl },
2559 1,
2560 ),
2561 "one_compartment_cl_with_absorption" => (
2562 ResolverAnalyticalKernel::OneCompartmentClWithAbsorption,
2563 quote! { ::pharmsol::equation::one_compartment_cl_with_absorption },
2564 quote! { ::pharmsol::equation::AnalyticalKernel::OneCompartmentClWithAbsorption },
2565 2,
2566 ),
2567 "one_compartment_with_absorption" => (
2568 ResolverAnalyticalKernel::OneCompartmentWithAbsorption,
2569 quote! { ::pharmsol::equation::one_compartment_with_absorption },
2570 quote! { ::pharmsol::equation::AnalyticalKernel::OneCompartmentWithAbsorption },
2571 2,
2572 ),
2573 "two_compartments" => (
2574 ResolverAnalyticalKernel::TwoCompartments,
2575 quote! { ::pharmsol::equation::two_compartments },
2576 quote! { ::pharmsol::equation::AnalyticalKernel::TwoCompartments },
2577 2,
2578 ),
2579 "two_compartments_cl" => (
2580 ResolverAnalyticalKernel::TwoCompartmentsCl,
2581 quote! { ::pharmsol::equation::two_compartments_cl },
2582 quote! { ::pharmsol::equation::AnalyticalKernel::TwoCompartmentsCl },
2583 2,
2584 ),
2585 "two_compartments_cl_with_absorption" => (
2586 ResolverAnalyticalKernel::TwoCompartmentsClWithAbsorption,
2587 quote! { ::pharmsol::equation::two_compartments_cl_with_absorption },
2588 quote! { ::pharmsol::equation::AnalyticalKernel::TwoCompartmentsClWithAbsorption },
2589 3,
2590 ),
2591 "two_compartments_with_absorption" => (
2592 ResolverAnalyticalKernel::TwoCompartmentsWithAbsorption,
2593 quote! { ::pharmsol::equation::two_compartments_with_absorption },
2594 quote! { ::pharmsol::equation::AnalyticalKernel::TwoCompartmentsWithAbsorption },
2595 3,
2596 ),
2597 "three_compartments" => (
2598 ResolverAnalyticalKernel::ThreeCompartments,
2599 quote! { ::pharmsol::equation::three_compartments },
2600 quote! { ::pharmsol::equation::AnalyticalKernel::ThreeCompartments },
2601 3,
2602 ),
2603 "three_compartments_cl" => (
2604 ResolverAnalyticalKernel::ThreeCompartmentsCl,
2605 quote! { ::pharmsol::equation::three_compartments_cl },
2606 quote! { ::pharmsol::equation::AnalyticalKernel::ThreeCompartmentsCl },
2607 3,
2608 ),
2609 "three_compartments_cl_with_absorption" => (
2610 ResolverAnalyticalKernel::ThreeCompartmentsClWithAbsorption,
2611 quote! { ::pharmsol::equation::three_compartments_cl_with_absorption },
2612 quote! { ::pharmsol::equation::AnalyticalKernel::ThreeCompartmentsClWithAbsorption },
2613 4,
2614 ),
2615 "three_compartments_with_absorption" => (
2616 ResolverAnalyticalKernel::ThreeCompartmentsWithAbsorption,
2617 quote! { ::pharmsol::equation::three_compartments_with_absorption },
2618 quote! { ::pharmsol::equation::AnalyticalKernel::ThreeCompartmentsWithAbsorption },
2619 4,
2620 ),
2621 _ => {
2622 return Err(syn::Error::new_spanned(
2623 structure,
2624 format!("unknown analytical structure `{structure_name}`"),
2625 ));
2626 }
2627 };
2628
2629 Ok(AnalyticalKernelSpec {
2630 kernel,
2631 runtime_path,
2632 metadata_kernel,
2633 state_count,
2634 })
2635}
2636
2637fn expand_analytical_route_map(
2638 label: &str,
2639 closure: &ExprClosure,
2640 params: &[Ident],
2641 derived: &[Ident],
2642 covariates: &[Ident],
2643 route_bindings: &[(SymbolicIndex, usize)],
2644) -> syn::Result<TokenStream2> {
2645 let route_consts = generate_mapped_index_consts(route_bindings);
2646 let p = generated_ident("__pharmsol_p");
2647 let t = generated_ident("__pharmsol_t");
2648 let cov = generated_ident("__pharmsol_cov");
2649 let full_inputs = [p.clone(), t.clone(), cov.clone()];
2650 let reduced_inputs = [t.clone()];
2651 let input_aliases = generate_supported_input_aliases(
2652 closure,
2653 &[&full_inputs, &reduced_inputs],
2654 &format!(
2655 "built-in `analytical!` requires `{label}` to have either 3 parameters: |p, t, cov| or 1 parameter: |t|"
2656 ),
2657 )?;
2658 let parameter_bindings = generate_parameter_bindings(params, closure, &p);
2659 let derived_values = generated_ident("__pharmsol_derived");
2660 let derived_bindings = generate_derived_bindings(derived, closure, &derived_values);
2661 let derive_values = if derived_bindings.is_empty() {
2662 quote! {}
2663 } else {
2664 quote! {
2665 let #derived_values = __pharmsol_derive(#p, #t, #cov);
2666 }
2667 };
2668 let covariate_bindings = generate_covariate_bindings(covariates, closure, &cov, &t);
2669 let body = NumericLabelRewriter::rewrite(
2670 closure.body.as_ref(),
2671 Vec::new(),
2672 Some(symbolic_numeric_binding_map(route_bindings)),
2673 );
2674
2675 Ok(quote! {{
2676 let __pharmsol_route_map: fn(
2677 &::pharmsol::simulator::V,
2678 f64,
2679 &::pharmsol::data::Covariates,
2680 ) -> ::std::collections::HashMap<usize, f64> = |#p: &::pharmsol::simulator::V,
2681 #t: f64,
2682 #cov: &::pharmsol::data::Covariates| {
2683 #input_aliases
2684 #route_consts
2685 #parameter_bindings
2686 #derive_values
2687 #derived_bindings
2688 #covariate_bindings
2689 #body
2690 };
2691 __pharmsol_route_map
2692 }})
2693}
2694
2695fn expand_analytical_derive(
2696 derive: Option<&ExprClosure>,
2697 params: &[Ident],
2698 covariates: &[Ident],
2699 derived: &[Ident],
2700) -> syn::Result<TokenStream2> {
2701 let p = generated_ident("__pharmsol_p");
2702 let t = generated_ident("__pharmsol_t");
2703 let cov = generated_ident("__pharmsol_cov");
2704 let derived_len = syn::LitInt::new(&derived.len().to_string(), Span::call_site());
2705
2706 if let Some(derive) = derive {
2707 let full_inputs = [p.clone(), t.clone(), cov.clone()];
2708 let reduced_inputs = [t.clone()];
2709 let input_aliases = generate_supported_input_aliases(
2710 derive,
2711 &[&full_inputs, &reduced_inputs],
2712 "built-in `analytical!` requires `derive` to have either 3 parameters: |p, t, cov| or 1 parameter: |t|",
2713 )?;
2714 let parameter_bindings = generate_parameter_bindings(params, derive, &p);
2715 let covariate_bindings = generate_covariate_bindings(covariates, derive, &cov, &t);
2716 let derived_decls = derived.iter().map(|ident| {
2717 quote! {
2718 #[allow(unused_mut)]
2719 let mut #ident: f64;
2720 }
2721 });
2722 let body = &derive.body;
2723
2724 Ok(quote! {
2725 fn __pharmsol_derive(
2726 #p: &::pharmsol::simulator::V,
2727 #t: f64,
2728 #cov: &::pharmsol::data::Covariates,
2729 ) -> [f64; #derived_len] {
2730 #input_aliases
2731 #parameter_bindings
2732 #covariate_bindings
2733 #(#derived_decls)*
2734 #body
2735 [#(#derived),*]
2736 }
2737 })
2738 } else {
2739 let zeros = derived.iter().map(|_| quote! { 0.0 });
2740 Ok(quote! {
2741 fn __pharmsol_derive(
2742 _: &::pharmsol::simulator::V,
2743 _: f64,
2744 _: &::pharmsol::data::Covariates,
2745 ) -> [f64; #derived_len] {
2746 [#(#zeros),*]
2747 }
2748 })
2749 }
2750}
2751
2752fn expand_analytical_runtime(
2753 runtime_path: &TokenStream2,
2754 projection: &AnalyticalStructureInputKind,
2755) -> TokenStream2 {
2756 match projection {
2757 AnalyticalStructureInputKind::AllPrimary { identity: true, .. } => runtime_path.clone(),
2758 AnalyticalStructureInputKind::AllPrimary { indices, .. } => {
2759 let projected = indices.iter().map(|index| quote! { __pharmsol_p[#index] });
2760 quote! {{
2761 let __pharmsol_eq: fn(
2762 &::pharmsol::simulator::V,
2763 &::pharmsol::simulator::V,
2764 f64,
2765 &::pharmsol::simulator::V,
2766 &::pharmsol::data::Covariates,
2767 ) -> ::pharmsol::simulator::V = |
2768 __pharmsol_x: &::pharmsol::simulator::V,
2769 __pharmsol_p: &::pharmsol::simulator::V,
2770 __pharmsol_t: f64,
2771 __pharmsol_rateiv: &::pharmsol::simulator::V,
2772 __pharmsol_cov: &::pharmsol::data::Covariates,
2773 | {
2774 let __pharmsol_projected = ::pharmsol::__macro_support::vector_from_values(vec![#(#projected),*]);
2775 #runtime_path(__pharmsol_x, &__pharmsol_projected, __pharmsol_t, __pharmsol_rateiv, __pharmsol_cov)
2776 };
2777 __pharmsol_eq
2778 }}
2779 }
2780 AnalyticalStructureInputKind::AllDerived { indices, .. } => {
2781 let projected = indices
2782 .iter()
2783 .map(|index| quote! { __pharmsol_derived[#index] });
2784 quote! {{
2785 let __pharmsol_eq: fn(
2786 &::pharmsol::simulator::V,
2787 &::pharmsol::simulator::V,
2788 f64,
2789 &::pharmsol::simulator::V,
2790 &::pharmsol::data::Covariates,
2791 ) -> ::pharmsol::simulator::V = |
2792 __pharmsol_x: &::pharmsol::simulator::V,
2793 __pharmsol_p: &::pharmsol::simulator::V,
2794 __pharmsol_t: f64,
2795 __pharmsol_rateiv: &::pharmsol::simulator::V,
2796 __pharmsol_cov: &::pharmsol::data::Covariates,
2797 | {
2798 let __pharmsol_derived = __pharmsol_derive(__pharmsol_p, __pharmsol_t, __pharmsol_cov);
2799 let __pharmsol_projected = ::pharmsol::__macro_support::vector_from_values(vec![#(#projected),*]);
2800 #runtime_path(__pharmsol_x, &__pharmsol_projected, __pharmsol_t, __pharmsol_rateiv, __pharmsol_cov)
2801 };
2802 __pharmsol_eq
2803 }}
2804 }
2805 AnalyticalStructureInputKind::Mixed { bindings } => {
2806 let projected = bindings.iter().map(|binding| match binding.source {
2807 AnalyticalStructureInputSource::Primary => {
2808 let index = binding.index;
2809 quote! { __pharmsol_p[#index] }
2810 }
2811 AnalyticalStructureInputSource::Derived => {
2812 let index = binding.index;
2813 quote! { __pharmsol_derived[#index] }
2814 }
2815 });
2816 quote! {{
2817 let __pharmsol_eq: fn(
2818 &::pharmsol::simulator::V,
2819 &::pharmsol::simulator::V,
2820 f64,
2821 &::pharmsol::simulator::V,
2822 &::pharmsol::data::Covariates,
2823 ) -> ::pharmsol::simulator::V = |
2824 __pharmsol_x: &::pharmsol::simulator::V,
2825 __pharmsol_p: &::pharmsol::simulator::V,
2826 __pharmsol_t: f64,
2827 __pharmsol_rateiv: &::pharmsol::simulator::V,
2828 __pharmsol_cov: &::pharmsol::data::Covariates,
2829 | {
2830 let __pharmsol_derived = __pharmsol_derive(__pharmsol_p, __pharmsol_t, __pharmsol_cov);
2831 let __pharmsol_projected = ::pharmsol::__macro_support::vector_from_values(vec![#(#projected),*]);
2832 #runtime_path(__pharmsol_x, &__pharmsol_projected, __pharmsol_t, __pharmsol_rateiv, __pharmsol_cov)
2833 };
2834 __pharmsol_eq
2835 }}
2836 }
2837 }
2838}
2839
2840fn expand_analytical_init(
2841 init: &ExprClosure,
2842 params: &[Ident],
2843 derived: &[Ident],
2844 covariates: &[Ident],
2845 states: &[Ident],
2846) -> syn::Result<TokenStream2> {
2847 let state_consts = generate_index_consts(states);
2848 let p = generated_ident("__pharmsol_p");
2849 let t = generated_ident("__pharmsol_t");
2850 let cov = generated_ident("__pharmsol_cov");
2851 let x = generated_ident("__pharmsol_x");
2852 let full_inputs = [p.clone(), t.clone(), cov.clone(), x.clone()];
2853 let reduced_inputs = [t.clone(), x.clone()];
2854 let input_aliases = generate_supported_input_aliases(
2855 init,
2856 &[&full_inputs, &reduced_inputs],
2857 "built-in `analytical!` requires `init` to have either 4 parameters: |p, t, cov, x| or 2 parameters: |t, x|",
2858 )?;
2859 let parameter_bindings = generate_parameter_bindings(params, init, &p);
2860 let derived_values = generated_ident("__pharmsol_derived");
2861 let derived_bindings = generate_derived_bindings(derived, init, &derived_values);
2862 let derive_values = if derived_bindings.is_empty() {
2863 quote! {}
2864 } else {
2865 quote! {
2866 let #derived_values = __pharmsol_derive(#p, #t, #cov);
2867 }
2868 };
2869 let covariate_bindings = generate_covariate_bindings(covariates, init, &cov, &t);
2870 let body = &init.body;
2871
2872 Ok(quote! {{
2873 let __pharmsol_init: fn(
2874 &::pharmsol::simulator::V,
2875 f64,
2876 &::pharmsol::data::Covariates,
2877 &mut ::pharmsol::simulator::V,
2878 ) = |#p: &::pharmsol::simulator::V,
2879 #t: f64,
2880 #cov: &::pharmsol::data::Covariates,
2881 #x: &mut ::pharmsol::simulator::V| {
2882 #input_aliases
2883 #state_consts
2884 #parameter_bindings
2885 #derive_values
2886 #derived_bindings
2887 #covariate_bindings
2888 #body
2889 };
2890 __pharmsol_init
2891 }})
2892}
2893
2894fn expand_analytical_out(
2895 out: &ExprClosure,
2896 params: &[Ident],
2897 derived: &[Ident],
2898 covariates: &[Ident],
2899 states: &[Ident],
2900 outputs: &[SymbolicIndex],
2901) -> syn::Result<TokenStream2> {
2902 let state_consts = generate_index_consts(states);
2903 let output_bindings = symbolic_index_bindings(outputs);
2904 let output_consts = generate_mapped_index_consts(&output_bindings);
2905 let x = generated_ident("__pharmsol_x");
2906 let p = generated_ident("__pharmsol_p");
2907 let t = generated_ident("__pharmsol_t");
2908 let cov = generated_ident("__pharmsol_cov");
2909 let y = generated_ident("__pharmsol_y");
2910 let full_inputs = [x.clone(), p.clone(), t.clone(), cov.clone(), y.clone()];
2911 let reduced_inputs = [x.clone(), t.clone(), y.clone()];
2912 let input_aliases = generate_supported_input_aliases(
2913 out,
2914 &[&full_inputs, &reduced_inputs],
2915 "built-in `analytical!` requires `out` to have either 5 parameters: |x, p, t, cov, y| or 3 parameters: |x, t, y|",
2916 )?;
2917 let parameter_bindings = generate_parameter_bindings(params, out, &p);
2918 let derived_values = generated_ident("__pharmsol_derived");
2919 let derived_bindings = generate_derived_bindings(derived, out, &derived_values);
2920 let derive_values = if derived_bindings.is_empty() {
2921 quote! {}
2922 } else {
2923 quote! {
2924 let #derived_values = __pharmsol_derive(#p, #t, #cov);
2925 }
2926 };
2927 let covariate_bindings = generate_covariate_bindings(covariates, out, &cov, &t);
2928 let y_binding = if out.inputs.len() == full_inputs.len() {
2929 closure_param_ident(out, 4).unwrap_or_else(|| y.clone())
2930 } else {
2931 closure_param_ident(out, 2).unwrap_or_else(|| y.clone())
2932 };
2933 let body = NumericLabelRewriter::rewrite(
2934 out.body.as_ref(),
2935 vec![IndexRewriteTarget::new(
2936 y_binding,
2937 symbolic_numeric_binding_map(&output_bindings),
2938 )],
2939 None,
2940 );
2941
2942 Ok(quote! {{
2943 let __pharmsol_out: fn(
2944 &::pharmsol::simulator::V,
2945 &::pharmsol::simulator::V,
2946 f64,
2947 &::pharmsol::data::Covariates,
2948 &mut ::pharmsol::simulator::V,
2949 ) = |#x: &::pharmsol::simulator::V,
2950 #p: &::pharmsol::simulator::V,
2951 #t: f64,
2952 #cov: &::pharmsol::data::Covariates,
2953 #y: &mut ::pharmsol::simulator::V| {
2954 #input_aliases
2955 #state_consts
2956 #output_consts
2957 #parameter_bindings
2958 #derive_values
2959 #derived_bindings
2960 #covariate_bindings
2961 #body
2962 };
2963 __pharmsol_out
2964 }})
2965}
2966
2967fn expand_sde_drift(
2968 drift: &ExprClosure,
2969 params: &[Ident],
2970 covariates: &[Ident],
2971 states: &[Ident],
2972 routes: &[OdeRouteDecl],
2973 route_bindings: &[(SymbolicIndex, usize)],
2974) -> syn::Result<TokenStream2> {
2975 let state_consts = generate_index_consts(states);
2976 let x = generated_ident("__pharmsol_x");
2977 let p = generated_ident("__pharmsol_p");
2978 let t = generated_ident("__pharmsol_t");
2979 let dx = generated_ident("__pharmsol_dx");
2980 let rateiv = generated_ident("__pharmsol_rateiv");
2981 let cov = generated_ident("__pharmsol_cov");
2982 let full_inputs = [x.clone(), p.clone(), t.clone(), dx.clone(), cov.clone()];
2983 let reduced_inputs = [x.clone(), t.clone(), dx.clone()];
2984 let input_aliases = generate_supported_input_aliases(
2985 drift,
2986 &[&full_inputs, &reduced_inputs],
2987 "declaration-first `sde!` requires `drift` to have either 5 parameters: |x, p, t, dx, cov| or 3 parameters: |x, t, dx|",
2988 )?;
2989 let parameter_bindings = generate_parameter_bindings(params, drift, &p);
2990 let covariate_bindings = generate_covariate_bindings(covariates, drift, &cov, &t);
2991 let body = &drift.body;
2992 let dx_binding = if drift.inputs.len() == full_inputs.len() {
2993 closure_param_ident(drift, 3).unwrap_or_else(|| dx.clone())
2994 } else {
2995 closure_param_ident(drift, 2).unwrap_or_else(|| dx.clone())
2996 };
2997 let rate_terms =
2998 expand_injected_sde_rate_terms(routes, states, route_bindings, &dx_binding, &rateiv);
2999
3000 Ok(quote! {{
3001 let __pharmsol_drift: fn(
3002 &::pharmsol::simulator::V,
3003 &::pharmsol::simulator::V,
3004 f64,
3005 &mut ::pharmsol::simulator::V,
3006 &::pharmsol::simulator::V,
3007 &::pharmsol::data::Covariates,
3008 ) = |#x: &::pharmsol::simulator::V,
3009 #p: &::pharmsol::simulator::V,
3010 #t: f64,
3011 #dx: &mut ::pharmsol::simulator::V,
3012 #rateiv: &::pharmsol::simulator::V,
3013 #cov: &::pharmsol::data::Covariates| {
3014 #input_aliases
3015 #state_consts
3016 #parameter_bindings
3017 #covariate_bindings
3018 #body
3019 #rate_terms
3020 };
3021 __pharmsol_drift
3022 }})
3023}
3024
3025fn expand_sde_diffusion(
3026 diffusion: &ExprClosure,
3027 params: &[Ident],
3028 states: &[Ident],
3029) -> syn::Result<TokenStream2> {
3030 let state_consts = generate_index_consts(states);
3031 let p = generated_ident("__pharmsol_p");
3032 let sigma = generated_ident("__pharmsol_sigma");
3033 let full_inputs = [p.clone(), sigma.clone()];
3034 let reduced_inputs = [sigma.clone()];
3035 let input_aliases = generate_supported_input_aliases(
3036 diffusion,
3037 &[&full_inputs, &reduced_inputs],
3038 "declaration-first `sde!` requires `diffusion` to have either 2 parameters: |p, sigma| or 1 parameter: |sigma|",
3039 )?;
3040 let parameter_bindings = generate_parameter_bindings(params, diffusion, &p);
3041 let body = &diffusion.body;
3042
3043 Ok(quote! {{
3044 let __pharmsol_diffusion: fn(
3045 &::pharmsol::simulator::V,
3046 &mut ::pharmsol::simulator::V,
3047 ) = |#p: &::pharmsol::simulator::V,
3048 #sigma: &mut ::pharmsol::simulator::V| {
3049 #input_aliases
3050 #state_consts
3051 #parameter_bindings
3052 #body
3053 };
3054 __pharmsol_diffusion
3055 }})
3056}
3057
3058fn expand_sde_route_map(
3059 label: &str,
3060 closure: &ExprClosure,
3061 params: &[Ident],
3062 covariates: &[Ident],
3063 route_bindings: &[(SymbolicIndex, usize)],
3064) -> syn::Result<TokenStream2> {
3065 let route_consts = generate_mapped_index_consts(route_bindings);
3066 let p = generated_ident("__pharmsol_p");
3067 let t = generated_ident("__pharmsol_t");
3068 let cov = generated_ident("__pharmsol_cov");
3069 let full_inputs = [p.clone(), t.clone(), cov.clone()];
3070 let reduced_inputs = [t.clone()];
3071 let input_aliases = generate_supported_input_aliases(
3072 closure,
3073 &[&full_inputs, &reduced_inputs],
3074 &format!(
3075 "declaration-first `sde!` requires `{label}` to have either 3 parameters: |p, t, cov| or 1 parameter: |t|"
3076 ),
3077 )?;
3078 let parameter_bindings = generate_parameter_bindings(params, closure, &p);
3079 let covariate_bindings = generate_covariate_bindings(covariates, closure, &cov, &t);
3080 let body = NumericLabelRewriter::rewrite(
3081 closure.body.as_ref(),
3082 Vec::new(),
3083 Some(symbolic_numeric_binding_map(route_bindings)),
3084 );
3085
3086 Ok(quote! {{
3087 let __pharmsol_route_map: fn(
3088 &::pharmsol::simulator::V,
3089 f64,
3090 &::pharmsol::data::Covariates,
3091 ) -> ::std::collections::HashMap<usize, f64> = |#p: &::pharmsol::simulator::V,
3092 #t: f64,
3093 #cov: &::pharmsol::data::Covariates| {
3094 #input_aliases
3095 #route_consts
3096 #parameter_bindings
3097 #covariate_bindings
3098 #body
3099 };
3100 __pharmsol_route_map
3101 }})
3102}
3103
3104fn expand_sde_init(
3105 init: &ExprClosure,
3106 params: &[Ident],
3107 covariates: &[Ident],
3108 states: &[Ident],
3109) -> syn::Result<TokenStream2> {
3110 let state_consts = generate_index_consts(states);
3111 let p = generated_ident("__pharmsol_p");
3112 let t = generated_ident("__pharmsol_t");
3113 let cov = generated_ident("__pharmsol_cov");
3114 let x = generated_ident("__pharmsol_x");
3115 let full_inputs = [p.clone(), t.clone(), cov.clone(), x.clone()];
3116 let reduced_inputs = [t.clone(), x.clone()];
3117 let input_aliases = generate_supported_input_aliases(
3118 init,
3119 &[&full_inputs, &reduced_inputs],
3120 "declaration-first `sde!` requires `init` to have either 4 parameters: |p, t, cov, x| or 2 parameters: |t, x|",
3121 )?;
3122 let parameter_bindings = generate_parameter_bindings(params, init, &p);
3123 let covariate_bindings = generate_covariate_bindings(covariates, init, &cov, &t);
3124 let body = &init.body;
3125
3126 Ok(quote! {{
3127 let __pharmsol_init: fn(
3128 &::pharmsol::simulator::V,
3129 f64,
3130 &::pharmsol::data::Covariates,
3131 &mut ::pharmsol::simulator::V,
3132 ) = |#p: &::pharmsol::simulator::V,
3133 #t: f64,
3134 #cov: &::pharmsol::data::Covariates,
3135 #x: &mut ::pharmsol::simulator::V| {
3136 #input_aliases
3137 #state_consts
3138 #parameter_bindings
3139 #covariate_bindings
3140 #body
3141 };
3142 __pharmsol_init
3143 }})
3144}
3145
3146fn expand_sde_out(
3147 out: &ExprClosure,
3148 params: &[Ident],
3149 covariates: &[Ident],
3150 states: &[Ident],
3151 outputs: &[SymbolicIndex],
3152) -> syn::Result<TokenStream2> {
3153 let state_consts = generate_index_consts(states);
3154 let output_bindings = symbolic_index_bindings(outputs);
3155 let output_consts = generate_mapped_index_consts(&output_bindings);
3156 let x = generated_ident("__pharmsol_x");
3157 let p = generated_ident("__pharmsol_p");
3158 let t = generated_ident("__pharmsol_t");
3159 let cov = generated_ident("__pharmsol_cov");
3160 let y = generated_ident("__pharmsol_y");
3161 let full_inputs = [x.clone(), p.clone(), t.clone(), cov.clone(), y.clone()];
3162 let reduced_inputs = [x.clone(), t.clone(), y.clone()];
3163 let input_aliases = generate_supported_input_aliases(
3164 out,
3165 &[&full_inputs, &reduced_inputs],
3166 "declaration-first `sde!` requires `out` to have either 5 parameters: |x, p, t, cov, y| or 3 parameters: |x, t, y|",
3167 )?;
3168 let parameter_bindings = generate_parameter_bindings(params, out, &p);
3169 let covariate_bindings = generate_covariate_bindings(covariates, out, &cov, &t);
3170 let y_binding = if out.inputs.len() == full_inputs.len() {
3171 closure_param_ident(out, 4).unwrap_or_else(|| y.clone())
3172 } else {
3173 closure_param_ident(out, 2).unwrap_or_else(|| y.clone())
3174 };
3175 let body = NumericLabelRewriter::rewrite(
3176 out.body.as_ref(),
3177 vec![IndexRewriteTarget::new(
3178 y_binding,
3179 symbolic_numeric_binding_map(&output_bindings),
3180 )],
3181 None,
3182 );
3183
3184 Ok(quote! {{
3185 let __pharmsol_out: fn(
3186 &::pharmsol::simulator::V,
3187 &::pharmsol::simulator::V,
3188 f64,
3189 &::pharmsol::data::Covariates,
3190 &mut ::pharmsol::simulator::V,
3191 ) = |#x: &::pharmsol::simulator::V,
3192 #p: &::pharmsol::simulator::V,
3193 #t: f64,
3194 #cov: &::pharmsol::data::Covariates,
3195 #y: &mut ::pharmsol::simulator::V| {
3196 #input_aliases
3197 #state_consts
3198 #output_consts
3199 #parameter_bindings
3200 #covariate_bindings
3201 #body
3202 };
3203 __pharmsol_out
3204 }})
3205}
3206
3207#[proc_macro]
3212pub fn ode(input: TokenStream) -> TokenStream {
3213 let input = syn::parse_macro_input!(input as OdeInput);
3214
3215 let route_bindings = ode_route_input_bindings(&input.routes);
3216
3217 let lag_routes = match input.lag.as_ref() {
3218 Some(closure) => match extract_route_property_routes(
3219 "declaration-first `ode!`",
3220 "lag",
3221 closure,
3222 &input.routes,
3223 ) {
3224 Ok(routes) => {
3225 if let Err(error) = validate_route_property_kinds(
3226 "declaration-first `ode!`",
3227 "lag",
3228 &input.routes,
3229 &routes,
3230 ) {
3231 return error.to_compile_error().into();
3232 }
3233 routes
3234 }
3235 Err(error) => return error.to_compile_error().into(),
3236 },
3237 None => HashSet::new(),
3238 };
3239
3240 let fa_routes = match input.fa.as_ref() {
3241 Some(closure) => match extract_route_property_routes(
3242 "declaration-first `ode!`",
3243 "fa",
3244 closure,
3245 &input.routes,
3246 ) {
3247 Ok(routes) => {
3248 if let Err(error) = validate_route_property_kinds(
3249 "declaration-first `ode!`",
3250 "fa",
3251 &input.routes,
3252 &routes,
3253 ) {
3254 return error.to_compile_error().into();
3255 }
3256 routes
3257 }
3258 Err(error) => return error.to_compile_error().into(),
3259 },
3260 None => HashSet::new(),
3261 };
3262
3263 let diffeq = match expand_diffeq(
3264 &input.diffeq,
3265 &input.params,
3266 &input.covariates,
3267 &input.states,
3268 &input.routes,
3269 &route_bindings,
3270 ) {
3271 Ok(diffeq) => diffeq,
3272 Err(error) => return error.to_compile_error().into(),
3273 };
3274
3275 let out = match expand_out(
3276 &input.out,
3277 &input.params,
3278 &input.covariates,
3279 &input.states,
3280 &input.outputs,
3281 ) {
3282 Ok(out) => out,
3283 Err(error) => return error.to_compile_error().into(),
3284 };
3285
3286 let nstates = input.states.len();
3287 let ndrugs = dense_index_len(&route_bindings);
3288 let nout = input.outputs.len();
3289
3290 let name = &input.name;
3291 let params = &input.params;
3292 let covariates = &input.covariates;
3293 let states = &input.states;
3294 let outputs = &input.outputs;
3295 let routes = expand_route_metadata(&input.routes, &lag_routes, &fa_routes);
3296 let covariate_metadata = if covariates.is_empty() {
3297 quote! {}
3298 } else {
3299 quote! {
3300 .covariates([#(::pharmsol::equation::Covariate::continuous(stringify!(#covariates))),*])
3301 }
3302 };
3303
3304 let lag = match input.lag.as_ref() {
3305 Some(closure) => match expand_ode_route_map(
3306 "lag",
3307 closure,
3308 &input.params,
3309 &input.covariates,
3310 &route_bindings,
3311 ) {
3312 Ok(lag) => lag,
3313 Err(error) => return error.to_compile_error().into(),
3314 },
3315 None => quote! { |_, _, _| ::std::collections::HashMap::new() },
3316 };
3317
3318 let fa = match input.fa.as_ref() {
3319 Some(closure) => {
3320 match expand_ode_route_map(
3321 "fa",
3322 closure,
3323 &input.params,
3324 &input.covariates,
3325 &route_bindings,
3326 ) {
3327 Ok(fa) => fa,
3328 Err(error) => return error.to_compile_error().into(),
3329 }
3330 }
3331 None => quote! { |_, _, _| ::std::collections::HashMap::new() },
3332 };
3333
3334 let init = match input.init.as_ref() {
3335 Some(closure) => {
3336 match expand_ode_init(closure, &input.params, &input.covariates, &input.states) {
3337 Ok(init) => init,
3338 Err(error) => return error.to_compile_error().into(),
3339 }
3340 }
3341 None => quote! { |_, _, _, _| {} },
3342 };
3343
3344 quote! {{
3345 let __pharmsol_metadata = ::pharmsol::equation::metadata::new(#name)
3346 .parameters([#(stringify!(#params)),*])
3347 #covariate_metadata
3348 .states([#(stringify!(#states)),*])
3349 .outputs([#(stringify!(#outputs)),*])
3350 #(.route(#routes))*;
3351
3352 ::pharmsol::equation::ODE::new(
3353 #diffeq,
3354 #lag,
3355 #fa,
3356 #init,
3357 #out,
3358 )
3359 .with_nstates(#nstates)
3360 .with_ndrugs(#ndrugs)
3361 .with_nout(#nout)
3362 .with_metadata(__pharmsol_metadata)
3363 .expect("declaration-first `ode!` generated invalid metadata")
3364 }}
3365 .into()
3366}
3367
3368#[proc_macro]
3369pub fn analytical(input: TokenStream) -> TokenStream {
3370 let input = syn::parse_macro_input!(input as AnalyticalInput);
3371 let route_bindings = ode_route_input_bindings(&input.routes);
3372
3373 let kernel_spec = match resolve_analytical_structure(&input.structure) {
3374 Ok(spec) => spec,
3375 Err(error) => return error.to_compile_error().into(),
3376 };
3377 let projection = match validate_analytical_structure_inputs(
3378 &input.structure,
3379 kernel_spec.kernel,
3380 &input.params,
3381 &input.derived,
3382 ) {
3383 Ok(plan) => plan,
3384 Err(error) => return error.to_compile_error().into(),
3385 };
3386
3387 let lag_routes = match input.lag.as_ref() {
3388 Some(closure) => match extract_route_property_routes(
3389 "built-in `analytical!`",
3390 "lag",
3391 closure,
3392 &input.routes,
3393 ) {
3394 Ok(routes) => {
3395 if let Err(error) = validate_route_property_kinds(
3396 "built-in `analytical!`",
3397 "lag",
3398 &input.routes,
3399 &routes,
3400 ) {
3401 return error.to_compile_error().into();
3402 }
3403 routes
3404 }
3405 Err(error) => return error.to_compile_error().into(),
3406 },
3407 None => HashSet::new(),
3408 };
3409
3410 let fa_routes = match input.fa.as_ref() {
3411 Some(closure) => match extract_route_property_routes(
3412 "built-in `analytical!`",
3413 "fa",
3414 closure,
3415 &input.routes,
3416 ) {
3417 Ok(routes) => {
3418 if let Err(error) = validate_route_property_kinds(
3419 "built-in `analytical!`",
3420 "fa",
3421 &input.routes,
3422 &routes,
3423 ) {
3424 return error.to_compile_error().into();
3425 }
3426 routes
3427 }
3428 Err(error) => return error.to_compile_error().into(),
3429 },
3430 None => HashSet::new(),
3431 };
3432
3433 let derive = match expand_analytical_derive(
3434 input.derive.as_ref(),
3435 &input.params,
3436 &input.covariates,
3437 &input.derived,
3438 ) {
3439 Ok(derive) => derive,
3440 Err(error) => return error.to_compile_error().into(),
3441 };
3442 let eq = expand_analytical_runtime(&kernel_spec.runtime_path, projection.kind());
3443
3444 let out = match expand_analytical_out(
3445 &input.out,
3446 &input.params,
3447 &input.derived,
3448 &input.covariates,
3449 &input.states,
3450 &input.outputs,
3451 ) {
3452 Ok(out) => out,
3453 Err(error) => return error.to_compile_error().into(),
3454 };
3455
3456 let lag = match input.lag.as_ref() {
3457 Some(closure) => {
3458 match expand_analytical_route_map(
3459 "lag",
3460 closure,
3461 &input.params,
3462 &input.derived,
3463 &input.covariates,
3464 &route_bindings,
3465 ) {
3466 Ok(lag) => lag,
3467 Err(error) => return error.to_compile_error().into(),
3468 }
3469 }
3470 None => quote! { |_, _, _| ::std::collections::HashMap::new() },
3471 };
3472
3473 let fa = match input.fa.as_ref() {
3474 Some(closure) => {
3475 match expand_analytical_route_map(
3476 "fa",
3477 closure,
3478 &input.params,
3479 &input.derived,
3480 &input.covariates,
3481 &route_bindings,
3482 ) {
3483 Ok(fa) => fa,
3484 Err(error) => return error.to_compile_error().into(),
3485 }
3486 }
3487 None => quote! { |_, _, _| ::std::collections::HashMap::new() },
3488 };
3489
3490 let init = match input.init.as_ref() {
3491 Some(closure) => {
3492 match expand_analytical_init(
3493 closure,
3494 &input.params,
3495 &input.derived,
3496 &input.covariates,
3497 &input.states,
3498 ) {
3499 Ok(init) => init,
3500 Err(error) => return error.to_compile_error().into(),
3501 }
3502 }
3503 None => quote! { |_, _, _, _| {} },
3504 };
3505
3506 let nstates = input.states.len();
3507 let ndrugs = dense_index_len(&route_bindings);
3508 let nout = input.outputs.len();
3509
3510 let name = &input.name;
3511 let params = &input.params;
3512 let covariates = &input.covariates;
3513 let states = &input.states;
3514 let outputs = &input.outputs;
3515 let routes = expand_analytical_route_metadata(&input.routes, &lag_routes, &fa_routes);
3516 let metadata_kernel = kernel_spec.metadata_kernel;
3517 let covariate_metadata = if covariates.is_empty() {
3518 quote! {}
3519 } else {
3520 quote! {
3521 .covariates([#(::pharmsol::equation::Covariate::continuous(stringify!(#covariates))),*])
3522 }
3523 };
3524
3525 quote! {{
3526 #derive
3527 let __pharmsol_metadata = ::pharmsol::equation::metadata::new(#name)
3528 .kind(::pharmsol::equation::ModelKind::Analytical)
3529 .parameters([#(stringify!(#params)),*])
3530 #covariate_metadata
3531 .states([#(stringify!(#states)),*])
3532 .outputs([#(stringify!(#outputs)),*])
3533 #(.route(#routes))*
3534 .analytical_kernel(#metadata_kernel);
3535
3536 ::pharmsol::equation::Analytical::new(
3537 #eq,
3538 |_, _, _| {},
3539 #lag,
3540 #fa,
3541 #init,
3542 #out,
3543 )
3544 .with_nstates(#nstates)
3545 .with_ndrugs(#ndrugs)
3546 .with_nout(#nout)
3547 .with_metadata(__pharmsol_metadata)
3548 .expect("built-in `analytical!` generated invalid metadata")
3549 }}
3550 .into()
3551}
3552
3553#[proc_macro]
3554pub fn sde(input: TokenStream) -> TokenStream {
3555 let input = syn::parse_macro_input!(input as SdeInput);
3556 let route_bindings = ode_route_input_bindings(&input.routes);
3557
3558 let lag_routes = match input.lag.as_ref() {
3559 Some(closure) => match extract_route_property_routes(
3560 "declaration-first `sde!`",
3561 "lag",
3562 closure,
3563 &input.routes,
3564 ) {
3565 Ok(routes) => {
3566 if let Err(error) = validate_route_property_kinds(
3567 "declaration-first `sde!`",
3568 "lag",
3569 &input.routes,
3570 &routes,
3571 ) {
3572 return error.to_compile_error().into();
3573 }
3574 routes
3575 }
3576 Err(error) => return error.to_compile_error().into(),
3577 },
3578 None => HashSet::new(),
3579 };
3580
3581 let fa_routes = match input.fa.as_ref() {
3582 Some(closure) => match extract_route_property_routes(
3583 "declaration-first `sde!`",
3584 "fa",
3585 closure,
3586 &input.routes,
3587 ) {
3588 Ok(routes) => {
3589 if let Err(error) = validate_route_property_kinds(
3590 "declaration-first `sde!`",
3591 "fa",
3592 &input.routes,
3593 &routes,
3594 ) {
3595 return error.to_compile_error().into();
3596 }
3597 routes
3598 }
3599 Err(error) => return error.to_compile_error().into(),
3600 },
3601 None => HashSet::new(),
3602 };
3603
3604 let drift = match expand_sde_drift(
3605 &input.drift,
3606 &input.params,
3607 &input.covariates,
3608 &input.states,
3609 &input.routes,
3610 &route_bindings,
3611 ) {
3612 Ok(drift) => drift,
3613 Err(error) => return error.to_compile_error().into(),
3614 };
3615
3616 let diffusion = match expand_sde_diffusion(&input.diffusion, &input.params, &input.states) {
3617 Ok(diffusion) => diffusion,
3618 Err(error) => return error.to_compile_error().into(),
3619 };
3620
3621 let lag = match input.lag.as_ref() {
3622 Some(closure) => match expand_sde_route_map(
3623 "lag",
3624 closure,
3625 &input.params,
3626 &input.covariates,
3627 &route_bindings,
3628 ) {
3629 Ok(lag) => lag,
3630 Err(error) => return error.to_compile_error().into(),
3631 },
3632 None => quote! { |_, _, _| ::std::collections::HashMap::new() },
3633 };
3634
3635 let fa = match input.fa.as_ref() {
3636 Some(closure) => {
3637 match expand_sde_route_map(
3638 "fa",
3639 closure,
3640 &input.params,
3641 &input.covariates,
3642 &route_bindings,
3643 ) {
3644 Ok(fa) => fa,
3645 Err(error) => return error.to_compile_error().into(),
3646 }
3647 }
3648 None => quote! { |_, _, _| ::std::collections::HashMap::new() },
3649 };
3650
3651 let init = match input.init.as_ref() {
3652 Some(closure) => {
3653 match expand_sde_init(closure, &input.params, &input.covariates, &input.states) {
3654 Ok(init) => init,
3655 Err(error) => return error.to_compile_error().into(),
3656 }
3657 }
3658 None => quote! { |_, _, _, _| {} },
3659 };
3660
3661 let out = match expand_sde_out(
3662 &input.out,
3663 &input.params,
3664 &input.covariates,
3665 &input.states,
3666 &input.outputs,
3667 ) {
3668 Ok(out) => out,
3669 Err(error) => return error.to_compile_error().into(),
3670 };
3671
3672 let nstates = input.states.len();
3673 let ndrugs = dense_index_len(&route_bindings);
3674 let nout = input.outputs.len();
3675
3676 let name = &input.name;
3677 let params = &input.params;
3678 let covariates = &input.covariates;
3679 let states = &input.states;
3680 let outputs = &input.outputs;
3681 let particles = &input.particles;
3682 let routes = expand_sde_route_metadata(&input.routes, &lag_routes, &fa_routes);
3683 let bolus_mappings =
3684 expand_injected_sde_bolus_mappings(&input.routes, &input.states, &route_bindings);
3685 let covariate_metadata = if covariates.is_empty() {
3686 quote! {}
3687 } else {
3688 quote! {
3689 .covariates([#(::pharmsol::equation::Covariate::continuous(stringify!(#covariates))),*])
3690 }
3691 };
3692
3693 quote! {{
3694 let __pharmsol_particles: usize = #particles;
3695 let __pharmsol_metadata = ::pharmsol::equation::metadata::new(#name)
3696 .kind(::pharmsol::equation::ModelKind::Sde)
3697 .parameters([#(stringify!(#params)),*])
3698 #covariate_metadata
3699 .states([#(stringify!(#states)),*])
3700 .outputs([#(stringify!(#outputs)),*])
3701 #(.route(#routes))*
3702 .particles(__pharmsol_particles);
3703
3704 ::pharmsol::equation::SDE::new(
3705 #drift,
3706 #diffusion,
3707 #lag,
3708 #fa,
3709 #init,
3710 #out,
3711 __pharmsol_particles,
3712 )
3713 .with_nstates(#nstates)
3714 .with_ndrugs(#ndrugs)
3715 .with_nout(#nout)
3716 #bolus_mappings
3717 .with_metadata(__pharmsol_metadata)
3718 .expect("declaration-first `sde!` generated invalid metadata")
3719 }}
3720 .into()
3721}
3722
3723#[cfg(test)]
3724mod tests {
3725 use super::*;
3726
3727 #[test]
3728 fn rejects_removed_legacy_form() {
3729 let error = syn::parse_str::<OdeInput>(
3730 "diffeq: |x, p, t, dx, b, rateiv, cov| {}, out: |x, p, t, cov, y| {}",
3731 )
3732 .err()
3733 .expect("legacy macro form must fail");
3734
3735 assert!(error
3736 .to_string()
3737 .contains("requires `name`, `params`, `states`, `outputs`, and `routes`"));
3738 assert!(error
3739 .to_string()
3740 .contains("old inferred-dimensions form has been removed"));
3741 }
3742
3743 #[test]
3744 fn validates_route_destinations() {
3745 let error = syn::parse_str::<OdeInput>(
3746 "name: \"demo\", params: [ke], states: [central], outputs: [cp], routes: [infusion(iv) -> peripheral], diffeq: |x, p, t, dx, cov| {}, out: |x, p, t, cov, y| {}",
3747 )
3748 .err()
3749 .expect("unknown route destination must fail");
3750
3751 assert!(error
3752 .to_string()
3753 .contains("route destination `peripheral` is not declared in the `states` section"));
3754 }
3755
3756 #[test]
3757 fn rejects_named_binding_collisions() {
3758 let error = syn::parse_str::<OdeInput>(
3759 "name: \"demo\", params: [central, v], states: [central], outputs: [cp], routes: [infusion(iv) -> central], diffeq: |x, p, t, dx, cov| {}, out: |x, p, t, cov, y| {}",
3760 )
3761 .err()
3762 .expect("parameter/state binding collisions must fail");
3763
3764 assert!(error
3765 .to_string()
3766 .contains("named parameter binding `central` conflicts with named state binding"));
3767 }
3768
3769 #[test]
3770 fn ode_route_bindings_share_inputs_by_kind_local_ordinal() {
3771 let input = syn::parse_str::<OdeInput>(
3772 "name: \"demo\", params: [ka, ke, v], states: [depot, central], outputs: [cp], routes: [bolus(oral) -> depot, infusion(iv) -> central, bolus(sc) -> depot], diffeq: |x, p, t, dx, b, rateiv, cov| {}, out: |x, p, t, cov, y| {}",
3773 )
3774 .expect("declaration-first ode input should parse");
3775
3776 let bindings = ode_route_input_bindings(&input.routes);
3777
3778 assert_eq!(dense_index_len(&bindings), 2);
3779 assert_eq!(bindings[0].0.name(), "oral");
3780 assert_eq!(bindings[0].1, 0);
3781 assert_eq!(bindings[1].0.name(), "iv");
3782 assert_eq!(bindings[1].1, 0);
3783 assert_eq!(bindings[2].0.name(), "sc");
3784 assert_eq!(bindings[2].1, 1);
3785 }
3786
3787 #[test]
3788 fn generated_parameter_bindings_only_include_referenced_locals_in_hot_closures() {
3789 let params = vec![generated_ident("ke"), generated_ident("v")];
3790 let closure = syn::parse_str::<ExprClosure>(
3791 "|x, _p, _t, dx, _cov| { dx[central] = -ke * x[central]; }",
3792 )
3793 .expect("closure should parse");
3794
3795 let bindings =
3796 generate_parameter_bindings(¶ms, &closure, &generated_ident("__pharmsol_p"))
3797 .to_string();
3798
3799 assert!(
3800 bindings.contains("let ke = __pharmsol_p [0usize] ;")
3801 || bindings.contains("let ke = __pharmsol_p [ 0 ] ;")
3802 );
3803 assert!(!bindings.contains("let v ="));
3804 }
3805
3806 #[test]
3807 fn generated_parameter_bindings_fall_back_to_all_params_for_stmt_macros() {
3808 let params = vec![generated_ident("ka"), generated_ident("tlag")];
3809 let closure = syn::parse_str::<ExprClosure>("|_p, _t, _cov| { lag! { oral => tlag } }")
3810 .expect("closure should parse");
3811
3812 let bindings =
3813 generate_parameter_bindings(¶ms, &closure, &generated_ident("__pharmsol_p"))
3814 .to_string();
3815
3816 assert!(bindings.contains("let ka ="));
3817 assert!(bindings.contains("let tlag ="));
3818 }
3819
3820 #[test]
3821 fn analytical_accepts_extra_parameters_beyond_kernel_arity() {
3822 let input = syn::parse_str::<AnalyticalInput>(
3823 "name: \"demo\", params: [ka, ke0, v, tlag, tvke], derived: [ke], covariates: [wt, renal], states: [gut, central], outputs: [cp], routes: [bolus(oral) -> gut], structure: one_compartment_with_absorption, derive: |_t| { ke = tvke; }, out: |x, p, t, cov, y| {}",
3824 )
3825 .expect("extra declared parameters should be allowed");
3826
3827 assert_eq!(input.params.len(), 5);
3828 assert_eq!(input.derived.len(), 1);
3829 assert_eq!(input.covariates.len(), 2);
3830 assert!(input.derive.is_some());
3831 assert_eq!(input.states.len(), 2);
3832 }
3833
3834 #[test]
3835 fn analytical_rejects_legacy_sec_with_migration_message() {
3836 let error = syn::parse_str::<AnalyticalInput>(
3837 "name: \"demo\", params: [ka, ke, v], states: [gut, central], outputs: [cp], routes: [bolus(oral) -> gut], structure: one_compartment_with_absorption, sec: |_t| { ke = 1.0; }, out: |x, p, t, cov, y| {}",
3838 )
3839 .err()
3840 .expect("legacy sec must fail");
3841
3842 assert!(error
3843 .to_string()
3844 .contains("no longer supports `sec`; use `derived: [...]` plus `derive: ...`"));
3845 }
3846
3847 #[test]
3848 fn analytical_rejects_unknown_structure() {
3849 let error = syn::parse_str::<AnalyticalInput>(
3850 "name: \"demo\", params: [ke], states: [central], outputs: [cp], routes: [infusion(iv) -> central], structure: mystery, out: |x, p, t, cov, y| {}",
3851 )
3852 .err()
3853 .expect("unknown analytical structure must fail");
3854
3855 assert!(error
3856 .to_string()
3857 .contains("unknown analytical structure `mystery`"));
3858 }
3859
3860 #[test]
3861 fn analytical_rejects_missing_required_structure_name() {
3862 let error = syn::parse_str::<AnalyticalInput>(
3863 "name: \"demo\", params: [ke], states: [gut, central], outputs: [cp], routes: [bolus(oral) -> gut], structure: one_compartment_with_absorption, out: |x, p, t, cov, y| {}",
3864 )
3865 .err()
3866 .expect("missing required structure name must fail");
3867
3868 assert!(error.to_string().contains("requires `ka`"));
3869 }
3870
3871 #[test]
3872 fn analytical_rejects_overlap_between_params_and_derived() {
3873 let error = syn::parse_str::<AnalyticalInput>(
3874 "name: \"demo\", params: [ka, ke, v], derived: [ke], states: [gut, central], outputs: [cp], routes: [bolus(oral) -> gut], structure: one_compartment_with_absorption, derive: |_t| { ke = 1.0; }, out: |x, p, t, cov, y| {}",
3875 )
3876 .err()
3877 .expect("overlap must fail");
3878
3879 assert!(error
3880 .to_string()
3881 .contains("`ke` is declared in both `params` and `derived`"));
3882 }
3883
3884 #[test]
3885 fn analytical_rejects_invalid_derive_target() {
3886 let error = syn::parse_str::<AnalyticalInput>(
3887 "name: \"demo\", params: [ka, ke0, v], derived: [ke], states: [gut, central], outputs: [cp], routes: [bolus(oral) -> gut], structure: one_compartment_with_absorption, derive: |_t| { ke0 = 1.0; ke = 0.1; }, out: |x, p, t, cov, y| {}",
3888 )
3889 .err()
3890 .expect("invalid derive target must fail");
3891
3892 assert!(error
3893 .to_string()
3894 .contains("`derive` cannot assign to `ke0`"));
3895 }
3896
3897 #[test]
3898 fn analytical_rejects_if_only_assignment_for_required_derived_name() {
3899 let error = syn::parse_str::<AnalyticalInput>(
3900 "name: \"demo\", params: [ka, ke0, v], derived: [ke], covariates: [wt], states: [gut, central], outputs: [cp], routes: [bolus(oral) -> gut], structure: one_compartment_with_absorption, derive: |_t| { if wt > 70.0 { ke = ke0; } }, out: |x, p, t, cov, y| {}",
3901 )
3902 .err()
3903 .expect("bare if must fail");
3904
3905 assert!(error
3906 .to_string()
3907 .contains("not definitely assigned on every path"));
3908 }
3909
3910 #[test]
3911 fn analytical_accepts_if_else_assignment_for_required_derived_name() {
3912 syn::parse_str::<AnalyticalInput>(
3913 "name: \"demo\", params: [ka, ke0, v], derived: [ke], covariates: [wt], states: [gut, central], outputs: [cp], routes: [bolus(oral) -> gut], structure: one_compartment_with_absorption, derive: |_t| { if wt > 70.0 { ke = ke0; } else { ke = ke0 * 0.5; } }, out: |x, p, t, cov, y| {}",
3914 )
3915 .expect("if / else should establish derived assignment");
3916 }
3917
3918 #[test]
3919 fn analytical_rejects_loop_only_assignment_for_required_derived_name() {
3920 let error = syn::parse_str::<AnalyticalInput>(
3921 "name: \"demo\", params: [ka, ke0, v], derived: [ke], states: [gut, central], outputs: [cp], routes: [bolus(oral) -> gut], structure: one_compartment_with_absorption, derive: |_t| { for i in 0..1 { let _ = i; ke = ke0; } }, out: |x, p, t, cov, y| {}",
3922 )
3923 .err()
3924 .expect("loop-only assignment must fail");
3925
3926 assert!(error
3927 .to_string()
3928 .contains("not definitely assigned on every path"));
3929 }
3930
3931 #[test]
3932 fn analytical_accepts_initial_assignment_followed_by_loop_updates() {
3933 syn::parse_str::<AnalyticalInput>(
3934 "name: \"demo\", params: [ka, ke0, v], derived: [ke], states: [gut, central], outputs: [cp], routes: [bolus(oral) -> gut], structure: one_compartment_with_absorption, derive: |_t| { ke = ke0; for i in 0..2 { let _ = i; ke = ke + 1.0; } }, out: |x, p, t, cov, y| {}",
3935 )
3936 .expect("initial assignment plus loop updates should pass");
3937 }
3938
3939 #[test]
3940 fn analytical_rejects_unknown_route_property_binding() {
3941 let error = syn::parse_str::<AnalyticalInput>(
3942 "name: \"demo\", params: [ka, ke, v], states: [gut, central], outputs: [cp], routes: [bolus(oral) -> gut], structure: one_compartment_with_absorption, lag: |_p, _t, _cov| { lag! { iv => 1.0 } }, out: |x, p, t, cov, y| {}",
3943 )
3944 .err()
3945 .expect("unknown lag route must fail");
3946
3947 assert!(error
3948 .to_string()
3949 .contains("route `iv` in `lag!` is not declared in the `routes` section"));
3950 }
3951
3952 #[test]
3953 fn analytical_rejects_infusion_lag_binding() {
3954 let error = syn::parse_str::<AnalyticalInput>(
3955 "name: \"demo\", params: [ke, v, tlag], states: [central], outputs: [cp], routes: [infusion(iv) -> central], structure: one_compartment, lag: |_p, _t, _cov| { lag! { iv => tlag } }, out: |x, p, t, cov, y| {}",
3956 )
3957 .err()
3958 .expect("infusion lag must fail");
3959
3960 assert!(error
3961 .to_string()
3962 .contains("built-in `analytical!` does not allow `lag` on infusion route `iv`"));
3963 }
3964
3965 #[test]
3966 fn sde_requires_particles() {
3967 let error = syn::parse_str::<SdeInput>(
3968 "name: \"demo\", params: [ke, theta], states: [central], outputs: [cp], routes: [infusion(iv) -> central], drift: |x, p, t, dx, cov| {}, diffusion: |p, sigma| {}, out: |x, p, t, cov, y| {}",
3969 )
3970 .err()
3971 .expect("missing particles must fail");
3972
3973 assert!(error
3974 .to_string()
3975 .contains("missing required field `particles` in declaration-first `sde!`"));
3976 }
3977
3978 #[test]
3979 fn sde_rejects_unknown_route_property_binding() {
3980 let error = syn::parse_str::<SdeInput>(
3981 "name: \"demo\", params: [ke, sigma_ke], states: [central], outputs: [cp], routes: [infusion(iv) -> central], particles: 16, drift: |x, p, t, dx, cov| {}, diffusion: |p, sigma| {}, lag: |_p, _t, _cov| { lag! { oral => 1.0 } }, out: |x, p, t, cov, y| {}",
3982 )
3983 .err()
3984 .expect("unknown lag route must fail");
3985
3986 assert!(error
3987 .to_string()
3988 .contains("route `oral` in `lag!` is not declared in the `routes` section"));
3989 }
3990
3991 #[test]
3992 fn sde_rejects_infusion_lag_binding() {
3993 let error = syn::parse_str::<SdeInput>(
3994 "name: \"demo\", params: [ke, sigma_ke, tlag], states: [central], outputs: [cp], routes: [infusion(iv) -> central], particles: 16, drift: |x, p, t, dx, cov| {}, diffusion: |p, sigma| {}, lag: |_p, _t, _cov| { lag! { iv => tlag } }, out: |x, p, t, cov, y| {}",
3995 )
3996 .err()
3997 .expect("infusion lag must fail");
3998
3999 assert!(error
4000 .to_string()
4001 .contains("declaration-first `sde!` does not allow `lag` on infusion route `iv`"));
4002 }
4003
4004 #[test]
4005 fn rejects_braced_route_lists() {
4006 let error = syn::parse_str::<OdeInput>(
4007 "name: \"demo\", params: [ke], states: [central], outputs: [cp], routes: { infusion(iv) -> central }, diffeq: |x, p, t, dx, cov| {}, out: |x, p, t, cov, y| {}",
4008 )
4009 .err()
4010 .expect("braced route lists must fail");
4011
4012 assert!(error
4013 .to_string()
4014 .contains("declaration-first macro `routes` must use `[...]`, not `{...}`"));
4015 }
4016}