diffsl 0.9.1

A compiler for a domain-specific language for ordinary differential equations (ODE).
Documentation
// RUN: %eopt %s | FileCheck %s

module {
  func.func private @normal(%seed : tensor<2xui64>, %mean : tensor<f64>, %stddev : tensor<f64>) -> (tensor<2xui64>, tensor<f64>)
  func.func private @logpdf(%x : tensor<f64>, %mean : tensor<f64>, %stddev : tensor<f64>) -> tensor<f64>

  // CHECK:   func.func @sample(%[[seed:.+]]: tensor<2xui64>, %[[mean:.+]]: tensor<f64>, %[[stddev:.+]]: tensor<f64>) -> (tensor<2xui64>, tensor<f64>) {
  // CHECK-NEXT:    %[[res:.+]]:2 = enzyme.sample @normal(%[[seed]], %[[mean]], %[[stddev]]) {logpdf = @logpdf, name = "r", symbol = #enzyme.symbol<3>} : (tensor<2xui64>, tensor<f64>, tensor<f64>) -> (tensor<2xui64>, tensor<f64>)
  // CHECK-NEXT:    return %[[res]]#0, %[[res]]#1 : tensor<2xui64>, tensor<f64>
  // CHECK-NEXT:   }
  func.func @sample(%seed : tensor<2xui64>, %mean : tensor<f64>, %stddev : tensor<f64>) -> (tensor<2xui64>, tensor<f64>) {
    %r:2 = enzyme.sample @normal(%seed, %mean, %stddev) { logpdf = @logpdf, name="r", symbol = #enzyme.symbol<3> } : (tensor<2xui64>, tensor<f64>, tensor<f64>) -> (tensor<2xui64>, tensor<f64>)
    return %r#0, %r#1 : tensor<2xui64>, tensor<f64>
  }

  // CHECK:   func.func @simulate(%[[seed:.+]]: tensor<2xui64>, %[[mean:.+]]: tensor<f64>, %[[stddev:.+]]: tensor<f64>) -> (!enzyme.Trace, tensor<f64>, tensor<2xui64>, tensor<f64>) {
  // CHECK-NEXT:    %[[trace:.+]], %[[weight:.+]], %[[outputs:.+]]:2 = enzyme.simulate @sample(%[[seed]], %[[mean]], %[[stddev]]) {name = "test"} : (tensor<2xui64>, tensor<f64>, tensor<f64>) -> (!enzyme.Trace, tensor<f64>, tensor<2xui64>, tensor<f64>)
  // CHECK-NEXT:    return %[[trace]], %[[weight]], %[[outputs]]#0, %[[outputs]]#1 : !enzyme.Trace, tensor<f64>, tensor<2xui64>, tensor<f64>
  // CHECK-NEXT:   }
  func.func @simulate(%seed : tensor<2xui64>, %mean : tensor<f64>, %stddev : tensor<f64>) -> (!enzyme.Trace, tensor<f64>, tensor<2xui64>, tensor<f64>) {
    %res:4 = enzyme.simulate @sample(%seed, %mean, %stddev) { name = "test" } : (tensor<2xui64>, tensor<f64>, tensor<f64>) -> (!enzyme.Trace, tensor<f64>, tensor<2xui64>, tensor<f64>)
    return %res#0, %res#1, %res#2, %res#3 : !enzyme.Trace, tensor<f64>, tensor<2xui64>, tensor<f64>
  }

  // CHECK:   func.func @generate(%[[seed:.+]]: tensor<2xui64>, %[[mean:.+]]: tensor<f64>, %[[stddev:.+]]: tensor<f64>) -> (!enzyme.Trace, tensor<f64>, tensor<2xui64>, tensor<f64>) {
  // CHECK-NEXT:    %[[cst:.+]] = arith.constant dense<42> : tensor<ui64>
  // CHECK-NEXT:    %[[constraint:.+]] = builtin.unrealized_conversion_cast %[[cst]] : tensor<ui64> to !enzyme.Constraint
  // CHECK-NEXT:    %[[trace:.+]], %[[weight:.+]], %[[outputs:.+]]:2 = enzyme.generate @sample(%[[seed]], %[[mean]], %[[stddev]]) given %[[constraint]] {constrained_addresses = {{\[}}[#enzyme.symbol<2>], [#enzyme.symbol<3>]{{\]}}, name = "test"} : (tensor<2xui64>, tensor<f64>, tensor<f64>) -> (!enzyme.Trace, tensor<f64>, tensor<2xui64>, tensor<f64>)
  // CHECK-NEXT:    return %[[trace]], %[[weight]], %[[outputs]]#0, %[[outputs]]#1 : !enzyme.Trace, tensor<f64>, tensor<2xui64>, tensor<f64>
  // CHECK-NEXT:   }
  func.func @generate(%seed : tensor<2xui64>, %mean : tensor<f64>, %stddev : tensor<f64>) -> (!enzyme.Trace, tensor<f64>, tensor<2xui64>, tensor<f64>) {
    %cst = arith.constant dense<42> : tensor<ui64>
    %0 = builtin.unrealized_conversion_cast %cst : tensor<ui64> to !enzyme.Constraint
    %res:4 = enzyme.generate @sample(%seed, %mean, %stddev) given %0 { constrained_addresses = [[#enzyme.symbol<2>], [#enzyme.symbol<3>]], name = "test" } : (tensor<2xui64>, tensor<f64>, tensor<f64>) -> (!enzyme.Trace, tensor<f64>, tensor<2xui64>, tensor<f64>)
    return %res#0, %res#1, %res#2, %res#3 : !enzyme.Trace, tensor<f64>, tensor<2xui64>, tensor<f64>
  }
}