use std::collections::hash_map::DefaultHasher;
use std::collections::HashMap;
use std::hash::{Hash, Hasher};
use std::sync::{Arc, RwLock};
use crate::error::Result;
use crate::fieldgraph::compiler::{CompileOptions, FieldGraphCompiler};
use crate::fieldgraph::FieldProgram;
use crate::prelude::{FieldGraphSpec, FieldSemantics, NodeSpec, TextureChannel};
use crate::scatter::{Kind, KindId};
struct ProgramEntry {
program: Arc<FieldProgram>,
fingerprint: u64,
}
pub struct FieldProgramCache {
entries: RwLock<HashMap<KindId, ProgramEntry>>,
}
impl FieldProgramCache {
pub fn new() -> Self {
Self {
entries: RwLock::new(HashMap::new()),
}
}
pub fn get_for_kind(&self, kind_id: &KindId) -> Option<Arc<FieldProgram>> {
let entries = self
.entries
.read()
.expect("FieldProgramCache lock poisoned");
entries.get(kind_id).map(|e| e.program.clone())
}
pub fn insert(&self, kind_id: KindId, fingerprint: u64, program: FieldProgram) {
self.insert_arc(kind_id, fingerprint, Arc::new(program));
}
pub fn insert_arc(&self, kind_id: KindId, fingerprint: u64, program: Arc<FieldProgram>) {
let mut entries = self
.entries
.write()
.expect("FieldProgramCache lock poisoned");
entries.insert(
kind_id,
ProgramEntry {
fingerprint,
program,
},
);
}
pub fn remove(&self, kind_id: &KindId) -> Option<Arc<FieldProgram>> {
let mut entries = self
.entries
.write()
.expect("FieldProgramCache lock poisoned");
entries.remove(kind_id).map(|e| e.program)
}
pub fn clear(&self) {
let mut entries = self
.entries
.write()
.expect("FieldProgramCache lock poisoned");
entries.clear();
}
pub fn get_or_compile(&self, kind: &Kind, opts: &CompileOptions) -> Result<Arc<FieldProgram>> {
let key = &kind.id;
let fp = fingerprint(&kind.spec, opts);
{
let entries = self
.entries
.read()
.expect("FieldProgramCache lock poisoned");
if let Some(entry) = entries.get(key) {
if entry.fingerprint == fp {
return Ok(entry.program.clone());
}
}
}
let program = Arc::new(FieldGraphCompiler::compile(&kind.spec, opts)?);
let mut entries = self
.entries
.write()
.expect("FieldProgramCache lock poisoned");
match entries.get(key) {
Some(entry) if entry.fingerprint == fp => Ok(entry.program.clone()),
_ => {
entries.insert(
key.clone(),
ProgramEntry {
program: program.clone(),
fingerprint: fp,
},
);
Ok(program)
}
}
}
}
impl Default for FieldProgramCache {
fn default() -> Self {
Self::new()
}
}
fn fingerprint(spec: &FieldGraphSpec, opts: &CompileOptions) -> u64 {
let mut hasher = DefaultHasher::new();
let mut ids: Vec<&String> = spec.nodes.keys().collect();
ids.sort();
for id in ids {
id.hash(&mut hasher);
let node = &spec.nodes[id];
let kind_tag: u8 = match node {
NodeSpec::Constant { .. } => 1,
NodeSpec::Texture { .. } => 2,
NodeSpec::Add { .. } => 3,
NodeSpec::Mul { .. } => 4,
NodeSpec::Min { .. } => 5,
NodeSpec::Max { .. } => 6,
NodeSpec::Invert { .. } => 7,
NodeSpec::Clamp { .. } => 8,
NodeSpec::SmoothStep { .. } => 9,
NodeSpec::Pow { .. } => 10,
NodeSpec::EdtNormalize { .. } => 11,
NodeSpec::Sub { .. } => 12,
NodeSpec::Scale { .. } => 13,
};
kind_tag.hash(&mut hasher);
let semantics_tag: u8 = match spec.semantics.get(id) {
Some(s) => match s {
FieldSemantics::Gate => 0,
FieldSemantics::Probability => 1,
},
None => 255,
};
semantics_tag.hash(&mut hasher);
for input in node.inputs() {
input.hash(&mut hasher);
}
match node {
NodeSpec::Constant { params } => {
params.value.to_bits().hash(&mut hasher);
}
NodeSpec::Texture { params } => {
params.texture_id.hash(&mut hasher);
let channel_tag: u8 = match params.channel {
TextureChannel::R => 0,
TextureChannel::G => 1,
TextureChannel::B => 2,
TextureChannel::A => 3,
};
channel_tag.hash(&mut hasher);
}
NodeSpec::Scale { params, .. } => {
params.factor.to_bits().hash(&mut hasher);
}
NodeSpec::Clamp { params, .. } => {
params.min.to_bits().hash(&mut hasher);
params.max.to_bits().hash(&mut hasher);
}
NodeSpec::SmoothStep { params, .. } => {
params.edge0.to_bits().hash(&mut hasher);
params.edge1.to_bits().hash(&mut hasher);
}
NodeSpec::Pow { params, .. } => {
params.exp.to_bits().hash(&mut hasher);
}
NodeSpec::EdtNormalize { params, .. } => {
params.threshold.to_bits().hash(&mut hasher);
params.d_max.to_bits().hash(&mut hasher);
}
_ => {}
}
}
if !opts.force_bake.is_empty() {
let mut bake_ids: Vec<&str> = opts.force_bake.iter().map(|s| s.as_str()).collect();
bake_ids.sort_unstable();
for id in bake_ids {
id.hash(&mut hasher);
}
}
hasher.finish()
}
#[cfg(test)]
mod tests {
use super::*;
fn kind_with_constant(id: &str, value: f32) -> Kind {
let mut spec = FieldGraphSpec::default();
spec.add_with_semantics(
"prob",
NodeSpec::constant(value),
FieldSemantics::Probability,
);
Kind::new(id, spec)
}
fn constant_from_program(program: &FieldProgram) -> f32 {
if let Some(meta) = program.nodes.get("prob") {
if let NodeSpec::Constant { params } = &meta.spec {
return params.value;
}
}
panic!("expected constant node");
}
#[test]
fn caches_and_returns_compiled_programs() {
let cache = FieldProgramCache::new();
let kind = kind_with_constant("tree", 0.5);
let program = cache
.get_or_compile(&kind, &CompileOptions::default())
.expect("compile succeeds");
assert_eq!(constant_from_program(program.as_ref()), 0.5);
assert!(cache.get_for_kind(&kind.id).is_some());
let removed = cache.remove(&kind.id);
assert!(removed.is_some());
assert!(cache.get_for_kind(&kind.id).is_none());
let opts = CompileOptions::default();
let program = FieldGraphCompiler::compile(&kind.spec, &opts).unwrap();
cache.insert(kind.id.clone(), fingerprint(&kind.spec, &opts), program);
assert!(cache.get_for_kind(&kind.id).is_some());
}
#[test]
fn recompiles_when_spec_fingerprint_changes() {
let cache = FieldProgramCache::new();
let kind_v1 = kind_with_constant("rock", 0.3);
let program_v1 = cache
.get_or_compile(&kind_v1, &CompileOptions::default())
.expect("first compile succeeds");
assert_eq!(constant_from_program(program_v1.as_ref()), 0.3);
let kind_v2 = kind_with_constant("rock", 0.9);
let program_v2 = cache
.get_or_compile(&kind_v2, &CompileOptions::default())
.expect("second compile succeeds");
assert_eq!(constant_from_program(program_v2.as_ref()), 0.9);
}
#[test]
fn clear_removes_all_entries() {
let cache = FieldProgramCache::new();
let kind = kind_with_constant("bush", 0.2);
cache
.get_or_compile(&kind, &CompileOptions::default())
.expect("compile succeeds");
assert!(cache.get_for_kind(&kind.id).is_some());
cache.clear();
assert!(cache.get_for_kind(&kind.id).is_none());
}
#[test]
fn recompiles_when_compile_options_change() {
let cache = FieldProgramCache::new();
let kind = kind_with_constant("grass", 0.5);
let opts_a = CompileOptions::default();
let program_a = cache
.get_or_compile(&kind, &opts_a)
.expect("initial compile succeeds");
assert!(
!program_a
.as_ref()
.nodes
.get("prob")
.expect("node exists")
.force_bake
);
let mut opts_b = CompileOptions::default();
opts_b.force_bake.insert("prob".into());
let program_b = cache
.get_or_compile(&kind, &opts_b)
.expect("force bake compile succeeds");
assert!(
program_b
.as_ref()
.nodes
.get("prob")
.expect("node exists")
.force_bake
);
}
}