diffsl 0.11.1

A compiler for a domain-specific language for ordinary differential equations (ODE).
Documentation
// RUN: %eopt --enzyme %s | FileCheck %s
// RUN: %eopt --enzyme --canonicalize --remove-unnecessary-enzyme-ops %s | FileCheck %s --check-prefix=REM
// RUN: %eopt --enzyme --canonicalize --remove-unnecessary-enzyme-ops --canonicalize --enzyme-simplify-math --cse %s | FileCheck %s --check-prefix=FIN

module {
  func.func @square(%x: f64) -> f64 {
    %next = arith.mulf %x, %x : f64
    return %next : f64
  }

  func.func @dsquare(%x: f64, %dr: f64) -> f64 {
    %r = enzyme.autodiff @square(%x, %dr) { activity=[#enzyme<activity enzyme_active>], ret_activity=[#enzyme<activity enzyme_activenoneed>] } : (f64, f64) -> f64
    return %r : f64
  }
}


// CHECK:  func.func @dsquare(%arg0: f64, %arg1: f64) -> f64 {
// CHECK-NEXT:    %0 = call @diffesquare(%arg0, %arg1) : (f64, f64) -> f64
// CHECK-NEXT:    return %0 : f64
// CHECK-NEXT:  }

// CHECK:  func.func private @diffesquare(%arg0: f64, %arg1: f64) -> f64 {
// CHECK-NEXT:    %[[dx:.+]] = "enzyme.init"() : () -> !enzyme.Gradient<f64>
// CHECK-NEXT:    %[[c0:.+]] = arith.constant 0.000000e+00 : f64
// CHECK-NEXT:    "enzyme.set"(%[[dx]], %[[c0]]) : (!enzyme.Gradient<f64>, f64) -> ()

// CHECK-NEXT:    %[[lhscache:.+]] = "enzyme.init"() : () -> !enzyme.Cache<f64>
// CHECK-NEXT:    %[[rhscache:.+]] = "enzyme.init"() : () -> !enzyme.Cache<f64>

// CHECK-NEXT:    %[[dy:.+]] = "enzyme.init"() : () -> !enzyme.Gradient<f64>
// CHECK-NEXT:    %[[c1:.+]] = arith.constant 0.000000e+00 : f64
// CHECK-NEXT:    "enzyme.set"(%[[dy]], %[[c1]]) : (!enzyme.Gradient<f64>, f64) -> ()
// CHECK-NEXT:    "enzyme.push"(%[[rhscache]], %arg0) : (!enzyme.Cache<f64>, f64) -> ()
// CHECK-NEXT:    "enzyme.push"(%[[lhscache]], %arg0) : (!enzyme.Cache<f64>, f64) -> ()
// CHECK-NEXT:    %[[mul:.+]] = arith.mulf %arg0, %arg0 : f64
// CHECK-NEXT:    cf.br ^bb1

// CHECK:  ^bb1:  // pred: ^bb0
// CHECK-NEXT:    %[[prevdret0:.+]] = "enzyme.get"(%[[dy]]) : (!enzyme.Gradient<f64>) -> f64
// CHECK-NEXT:    %[[postdret0:.+]] = arith.addf %[[prevdret0]], %arg1 : f64
// CHECK-NEXT:    "enzyme.set"(%[[dy]], %[[postdret0]]) : (!enzyme.Gradient<f64>, f64) -> ()
// CHECK-NEXT:    %[[prevdret:.+]] = "enzyme.get"(%[[dy]]) : (!enzyme.Gradient<f64>) -> f64
// CHECK-NEXT:    %[[c2:.+]] = arith.constant 0.000000e+00 : f64
// CHECK-NEXT:    "enzyme.set"(%[[dy]], %[[c2]]) : (!enzyme.Gradient<f64>, f64) -> ()
// CHECK-NEXT:    %[[postlhs:.+]] = "enzyme.pop"(%[[rhscache]]) : (!enzyme.Cache<f64>) -> f64
// CHECK-NEXT:    %[[postrhs:.+]] = "enzyme.pop"(%[[lhscache]]) : (!enzyme.Cache<f64>) -> f64
// CHECK-NEXT:    %[[dlhs:.+]] = arith.mulf %[[prevdret]], %[[postrhs]] : f64
// CHECK-NEXT:    %[[prevdx1:.+]] = "enzyme.get"(%[[dx]]) : (!enzyme.Gradient<f64>) -> f64
// CHECK-NEXT:    %[[postdx1:.+]] = arith.addf %[[prevdx1]], %[[dlhs]] : f64
// CHECK-NEXT:    "enzyme.set"(%[[dx]], %[[postdx1]]) : (!enzyme.Gradient<f64>, f64) -> ()
// CHECK-NEXT:    %[[drhs:.+]] = arith.mulf %[[prevdret]], %[[postlhs]] : f64
// CHECK-NEXT:    %[[prevdx2:.+]] = "enzyme.get"(%[[dx]]) : (!enzyme.Gradient<f64>) -> f64
// CHECK-NEXT:    %[[postdx2:.+]] = arith.addf %[[prevdx2]], %[[drhs]] : f64
// CHECK-NEXT:    "enzyme.set"(%[[dx]], %[[postdx2]]) : (!enzyme.Gradient<f64>, f64) -> ()
// CHECK-NEXT:    %[[res:.+]] = "enzyme.get"(%[[dx]]) : (!enzyme.Gradient<f64>) -> f64
// CHECK-NEXT:    return %[[res]] : f64
// CHECK-NEXT:  }


// REM:  func.func private @diffesquare(%arg0: f64, %arg1: f64) -> f64 {
// REM-NEXT:    %[[cst:.+]] = arith.constant 0.000000e+00 : f64
// REM-NEXT:    %[[a1:.+]] = arith.addf %arg1, %[[cst]] : f64
// REM-NEXT:    %[[a2:.+]] = arith.mulf %[[a1]], %arg0 : f64
// REM-NEXT:    %[[a3:.+]] = arith.addf %[[a2]], %[[cst]] : f64
// REM-NEXT:    %[[a4:.+]] = arith.mulf %[[a1]], %arg0 : f64
// REM-NEXT:    %[[a5:.+]] = arith.addf %[[a3]], %[[a4]] : f64
// REM-NEXT:    return %[[a5]] : f64
// REM-NEXT:  }

// FIN:  func.func private @diffesquare(%arg0: f64, %arg1: f64) -> f64 {
// FIN-NEXT:    %0 = arith.mulf %arg1, %arg0 : f64
// FIN-NEXT:    %1 = arith.addf %0, %0 : f64
// FIN-NEXT:    return %1 : f64
// FIN-NEXT:  }