use crate::CompileOptions;
use crate::backend::{Backend, ExecutableGraph};
use rlx_ir::Graph;
use std::collections::HashMap;
pub struct SubgraphCache {
cache: HashMap<String, Box<dyn ExecutableGraph>>,
options: CompileOptions,
}
impl SubgraphCache {
pub fn new(options: CompileOptions) -> Self {
Self {
cache: HashMap::new(),
options,
}
}
pub fn get_or_compile<'a>(
&'a mut self,
backend: &dyn Backend,
graph: &Graph,
) -> &'a mut Box<dyn ExecutableGraph> {
let key = graph.name.clone();
self.cache
.entry(key)
.or_insert_with(|| backend.compile(graph.clone(), &self.options))
}
pub fn run(
&mut self,
backend: &dyn Backend,
graph: &Graph,
inputs: &[(&str, &[f32])],
) -> Vec<Vec<f32>> {
let exe = self.get_or_compile(backend, graph);
exe.run(inputs)
}
}
pub fn run_if(
cache: &mut SubgraphCache,
backend: &dyn Backend,
predicate: f32,
then_branch: &Graph,
else_branch: &Graph,
inputs: &[(&str, &[f32])],
) -> Vec<Vec<f32>> {
let chosen = if predicate != 0.0 {
then_branch
} else {
else_branch
};
cache.run(backend, chosen, inputs)
}
pub fn run_while(
cache: &mut SubgraphCache,
backend: &dyn Backend,
cond: &Graph,
body: &Graph,
initial: Vec<Vec<f32>>,
input_names: &[&str],
max_iterations: Option<usize>,
) -> Vec<Vec<f32>> {
let mut state = initial;
let limit = max_iterations.unwrap_or(usize::MAX);
for _ in 0..limit {
let bindings: Vec<(&str, &[f32])> = input_names
.iter()
.zip(state.iter())
.map(|(n, v)| (*n, v.as_slice()))
.collect();
let cond_out = cache.run(backend, cond, &bindings);
if cond_out
.first()
.map(|v| v.first().copied().unwrap_or(0.0))
.unwrap_or(0.0)
== 0.0
{
break;
}
state = cache.run(backend, body, &bindings);
}
state
}