use serde::{Deserialize, Serialize};
use std::collections::{HashMap, HashSet};
use std::sync::Arc;
use weaver_lang::Value;
use weaver_lang::{CompiledTemplate, EvalContext, EvalError, EvalErrorKind, Registry};
use crate::lorebook::LorebookConfig;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "lowercase")]
pub enum NamespaceAccess {
ReadOnly,
ReadWrite,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct NamespaceConfig {
pub access: NamespaceAccess,
#[serde(default)]
pub description: String,
}
pub struct WeaverHost {
namespace_access: HashMap<String, NamespaceAccess>,
variables: HashMap<String, HashMap<String, Value>>,
persistent_state: HashMap<String, Value>,
eval_stack: Vec<String>,
max_recursion_depth: usize,
active_entries: HashSet<String>,
current_entry: Option<String>,
entry_templates: HashMap<String, Arc<CompiledTemplate>>,
triggered_entries: Vec<String>,
}
impl WeaverHost {
pub fn from_lorebook_config(config: &LorebookConfig) -> Self {
let namespace_access: HashMap<String, NamespaceAccess> = config
.namespaces
.iter()
.map(|(k, v)| (k.clone(), v.access))
.collect();
Self {
namespace_access,
variables: HashMap::new(),
persistent_state: HashMap::new(),
eval_stack: Vec::new(),
max_recursion_depth: 10,
active_entries: HashSet::new(),
current_entry: None,
entry_templates: HashMap::new(),
triggered_entries: Vec::new(),
}
}
pub fn set_host_variable(&mut self, scope: &str, name: &str, value: Value) {
if scope == "state" {
self.persistent_state
.insert(name.to_string(), value.clone());
}
self.variables
.entry(scope.to_string())
.or_default()
.insert(name.to_string(), value);
}
pub fn set_namespace(&mut self, scope: &str, vars: HashMap<String, Value>) {
self.variables.insert(scope.to_string(), vars);
}
pub fn set_active_entries(&mut self, ids: HashSet<String>) {
self.variables.remove("_active");
let mut active_ns = HashMap::new();
for id in &ids {
active_ns.insert(id.clone(), Value::Bool(true));
}
self.variables.insert("_active".to_string(), active_ns);
self.active_entries = ids;
}
pub fn is_entry_active(&self, id: &str) -> bool {
self.active_entries.contains(id)
}
pub fn set_entry_templates(&mut self, templates: HashMap<String, Arc<CompiledTemplate>>) {
self.entry_templates = templates;
}
pub fn drain_triggered_entries(&mut self) -> Vec<String> {
std::mem::take(&mut self.triggered_entries)
}
pub fn begin_entry(&mut self, entry_id: &str) {
self.current_entry = Some(entry_id.to_string());
self.eval_stack.push(entry_id.to_string());
}
pub fn end_entry(&mut self) {
self.eval_stack.pop();
self.current_entry = self.eval_stack.last().cloned();
}
pub fn persistent_state(&self) -> &HashMap<String, Value> {
&self.persistent_state
}
pub fn restore_persistent_state(&mut self, state: HashMap<String, Value>) {
self.persistent_state = state;
self.variables
.insert("state".to_string(), self.persistent_state.clone());
}
pub fn clear_transient(&mut self) {
self.variables.remove("local");
self.variables.remove("_active");
self.triggered_entries.clear();
}
pub fn set_max_recursion_depth(&mut self, depth: usize) {
self.max_recursion_depth = depth;
}
}
impl EvalContext for WeaverHost {
fn resolve_variable(&self, scope: &str, name: &str) -> Result<Option<Value>, EvalError> {
if let Some(ns) = self.variables.get(scope) {
if let Some(val) = ns.get(name) {
return Ok(Some(val.clone()));
}
}
if scope == "state" {
if let Some(val) = self.persistent_state.get(name) {
return Ok(Some(val.clone()));
}
}
Ok(None)
}
fn set_variable(&mut self, scope: &str, name: &str, value: Value) -> Result<(), EvalError> {
if let Some(access) = self.namespace_access.get(scope) {
if *access == NamespaceAccess::ReadOnly {
return Err(EvalError::new(
EvalErrorKind::HostError,
format!("namespace '{scope}' is read-only (cannot set {scope}:{name})"),
));
}
}
if scope == "state" {
self.persistent_state
.insert(name.to_string(), value.clone());
}
self.variables
.entry(scope.to_string())
.or_default()
.insert(name.to_string(), value);
Ok(())
}
fn fire_trigger(&mut self, entry_id: &str, _registry: &Registry) -> Result<String, EvalError> {
if !self.active_entries.contains(entry_id)
&& !self.triggered_entries.contains(&entry_id.to_string())
{
self.triggered_entries.push(entry_id.to_string());
}
Ok(String::new())
}
fn resolve_document(
&mut self,
document_id: &str,
registry: &Registry,
) -> Result<String, EvalError> {
if self.eval_stack.contains(&document_id.to_string()) {
return Err(EvalError::new(
EvalErrorKind::HostError,
format!(
"document cycle detected: {} → {document_id}",
self.eval_stack.join(" → ")
),
));
}
if self.eval_stack.len() >= self.max_recursion_depth {
return Err(EvalError::new(
EvalErrorKind::RecursionLimit,
format!(
"recursion limit ({}) reached resolving [[{document_id}]]",
self.max_recursion_depth
),
));
}
let template = self
.entry_templates
.get(document_id)
.ok_or_else(|| {
EvalError::new(
EvalErrorKind::DocumentNotFound,
format!("unknown document: {document_id}"),
)
})?
.clone();
self.eval_stack.push(document_id.to_string());
let result = weaver_lang::evaluate(template.ast(), self, registry);
self.eval_stack.pop();
result
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::lorebook::LorebookConfig;
fn make_host() -> WeaverHost {
let config = LorebookConfig::default();
let mut host = WeaverHost::from_lorebook_config(&config);
host.set_host_variable("char", "name", Value::String("Aria".into()));
host.set_host_variable("char", "class", Value::String("Mage".into()));
host.set_host_variable("user", "name", Value::String("Player".into()));
host
}
#[test]
fn test_read_host_variable() {
let host = make_host();
let val = host.resolve_variable("char", "name").unwrap();
assert_eq!(val, Some(Value::String("Aria".into())));
}
#[test]
fn test_readonly_namespace_blocks_writes() {
let mut host = make_host();
let result = host.set_variable("char", "name", Value::String("Hacked".into()));
assert!(result.is_err());
let err = result.unwrap_err();
assert!(err.message.contains("read-only"));
}
#[test]
fn test_writable_namespace_allows_writes() {
let mut host = make_host();
let result = host.set_variable("state", "visited", Value::Bool(true));
assert!(result.is_ok());
let val = host.resolve_variable("state", "visited").unwrap();
assert_eq!(val, Some(Value::Bool(true)));
}
#[test]
fn test_state_persists() {
let mut host = make_host();
host.set_variable("state", "counter", Value::Number(42.0))
.unwrap();
assert_eq!(
host.persistent_state().get("counter"),
Some(&Value::Number(42.0))
);
}
#[test]
fn test_set_host_variable_state_persists() {
let mut host = make_host();
host.set_host_variable("state", "weapon", Value::String("longbow".into()));
let val = host.resolve_variable("state", "weapon").unwrap();
assert_eq!(val, Some(Value::String("longbow".into())));
assert_eq!(
host.persistent_state().get("weapon"),
Some(&Value::String("longbow".into()))
);
}
#[test]
fn test_clear_transient_preserves_state() {
let mut host = make_host();
host.set_variable("local", "temp", Value::String("gone".into()))
.unwrap();
host.set_variable("state", "kept", Value::Bool(true))
.unwrap();
host.clear_transient();
assert_eq!(host.resolve_variable("local", "temp").unwrap(), None);
assert_eq!(
host.resolve_variable("state", "kept").unwrap(),
Some(Value::Bool(true))
);
}
#[test]
fn test_unknown_namespace_is_writable() {
let mut host = make_host();
let result = host.set_variable("custom", "foo", Value::String("bar".into()));
assert!(result.is_ok());
}
#[test]
fn test_active_entries_populate_namespace() {
let mut host = make_host();
host.set_active_entries(HashSet::from([
"entry_a".to_string(),
"entry_b".to_string(),
]));
let val = host.resolve_variable("_active", "entry_a").unwrap();
assert_eq!(val, Some(Value::Bool(true)));
let val = host.resolve_variable("_active", "entry_c").unwrap();
assert_eq!(val, None);
}
#[test]
fn test_trigger_produces_no_output() {
let mut host = make_host();
let registry = Registry::new();
let result = host.fire_trigger("some_entry", ®istry).unwrap();
assert_eq!(result, "");
}
#[test]
fn test_trigger_records_entry_id() {
let mut host = make_host();
let registry = Registry::new();
host.fire_trigger("entry_a", ®istry).unwrap();
host.fire_trigger("entry_b", ®istry).unwrap();
let triggered = host.drain_triggered_entries();
assert_eq!(triggered, vec!["entry_a", "entry_b"]);
}
#[test]
fn test_trigger_deduplicates() {
let mut host = make_host();
let registry = Registry::new();
host.fire_trigger("entry_a", ®istry).unwrap();
host.fire_trigger("entry_a", ®istry).unwrap();
let triggered = host.drain_triggered_entries();
assert_eq!(triggered, vec!["entry_a"]);
}
#[test]
fn test_trigger_skips_already_active() {
let mut host = make_host();
host.set_active_entries(HashSet::from(["entry_a".to_string()]));
let registry = Registry::new();
host.fire_trigger("entry_a", ®istry).unwrap();
let triggered = host.drain_triggered_entries();
assert!(triggered.is_empty());
}
#[test]
fn test_drain_clears_triggered() {
let mut host = make_host();
let registry = Registry::new();
host.fire_trigger("entry_a", ®istry).unwrap();
let first = host.drain_triggered_entries();
assert_eq!(first.len(), 1);
let second = host.drain_triggered_entries();
assert!(second.is_empty());
}
#[test]
fn test_document_resolves_template() {
let mut host = make_host();
let registry = Registry::new();
let template = Arc::new(CompiledTemplate::compile("Hello from document!").unwrap());
host.set_entry_templates(HashMap::from([("my_doc".to_string(), template)]));
let result = host.resolve_document("my_doc", ®istry).unwrap();
assert_eq!(result, "Hello from document!");
}
#[test]
fn test_document_resolves_variables() {
let mut host = make_host();
let registry = Registry::new();
let template = Arc::new(CompiledTemplate::compile("Name: {{char:name}}").unwrap());
host.set_entry_templates(HashMap::from([("char_doc".to_string(), template)]));
let result = host.resolve_document("char_doc", ®istry).unwrap();
assert_eq!(result, "Name: Aria");
}
#[test]
fn test_document_not_found() {
let mut host = make_host();
let registry = Registry::new();
let result = host.resolve_document("nonexistent", ®istry);
assert!(result.is_err());
assert_eq!(result.unwrap_err().kind, EvalErrorKind::DocumentNotFound);
}
#[test]
fn test_document_cycle_detection() {
let mut host = make_host();
let registry = Registry::new();
let template = Arc::new(CompiledTemplate::compile("self-reference").unwrap());
host.set_entry_templates(HashMap::from([("entry_a".to_string(), template)]));
host.eval_stack.push("entry_a".to_string());
let result = host.resolve_document("entry_a", ®istry);
assert!(result.is_err());
assert!(result.unwrap_err().message.contains("cycle"));
}
#[test]
fn test_document_depth_limit() {
let mut host = make_host();
host.set_max_recursion_depth(2);
let registry = Registry::new();
host.eval_stack.push("a".to_string());
host.eval_stack.push("b".to_string());
let template = Arc::new(CompiledTemplate::compile("deep").unwrap());
host.set_entry_templates(HashMap::from([("c".to_string(), template)]));
let result = host.resolve_document("c", ®istry);
assert!(result.is_err());
assert_eq!(result.unwrap_err().kind, EvalErrorKind::RecursionLimit);
}
#[test]
fn test_document_chains() {
let mut host = make_host();
let registry = Registry::new();
let template_b = Arc::new(CompiledTemplate::compile("world").unwrap());
let template_a = Arc::new(CompiledTemplate::compile("Hello, [[doc_b]]!").unwrap());
host.set_entry_templates(HashMap::from([
("doc_a".to_string(), template_a),
("doc_b".to_string(), template_b),
]));
let result = host.resolve_document("doc_a", ®istry).unwrap();
assert_eq!(result, "Hello, world!");
}
}