use crate::activation::{ActivationResult, ActivationState};
use crate::assembler::AssembledBlock;
use crate::entry::Entry;
use crate::lorebook::Lorebook;
use crate::ChatMessage;
#[derive(Debug)]
pub struct HookError {
pub message: String,
pub source: Option<Box<dyn std::error::Error + Send + Sync>>,
}
impl HookError {
pub fn new(message: impl Into<String>) -> Self {
Self {
message: message.into(),
source: None,
}
}
pub fn with_source<E>(message: impl Into<String>, source: E) -> Self
where
E: std::error::Error + Send + Sync + 'static,
{
Self {
message: message.into(),
source: Some(Box::new(source)),
}
}
}
impl std::fmt::Display for HookError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{}", self.message)?;
if let Some(src) = &self.source {
write!(f, ": {src}")?;
}
Ok(())
}
}
impl std::error::Error for HookError {
fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
self.source
.as_ref()
.map(|b| b.as_ref() as &(dyn std::error::Error + 'static))
}
}
pub struct PreActivationCtx<'a> {
pub messages: &'a mut Vec<ChatMessage>,
pub turn: usize,
}
pub struct PostActivationCtx<'a> {
pub results: &'a mut Vec<ActivationResult>,
pub lorebook: &'a Lorebook,
pub turn: usize,
}
pub struct PreEvaluateCtx<'a> {
pub entry: &'a Entry,
pub skip: &'a mut bool,
}
pub struct PostEvaluateCtx<'a> {
pub entry: &'a Entry,
pub content: &'a mut String,
}
pub struct PostAssembleCtx<'a> {
pub blocks: &'a mut Vec<AssembledBlock>,
pub lorebook: &'a Lorebook,
}
pub struct TurnAdvanceCtx<'a> {
pub state: &'a mut ActivationState,
}
pub struct TriggerCtx<'a> {
pub triggered_ids: &'a mut Vec<String>,
pub pass_number: usize,
}
pub trait LifecyclePlugin: Send + Sync {
fn name(&self) -> &str;
fn pre_activation(&mut self, _ctx: &mut PreActivationCtx<'_>) -> Result<(), HookError> {
Ok(())
}
fn post_activation(&mut self, _ctx: &mut PostActivationCtx<'_>) -> Result<(), HookError> {
Ok(())
}
fn pre_evaluate(&mut self, _ctx: &mut PreEvaluateCtx<'_>) -> Result<(), HookError> {
Ok(())
}
fn post_evaluate(&mut self, _ctx: &mut PostEvaluateCtx<'_>) -> Result<(), HookError> {
Ok(())
}
fn post_assemble(&mut self, _ctx: &mut PostAssembleCtx<'_>) -> Result<(), HookError> {
Ok(())
}
fn on_turn_advance(&mut self, _ctx: &mut TurnAdvanceCtx<'_>) -> Result<(), HookError> {
Ok(())
}
fn on_trigger_fired(&mut self, _ctx: &mut TriggerCtx<'_>) -> Result<(), HookError> {
Ok(())
}
}
type PreActivationFn =
Box<dyn FnMut(&mut PreActivationCtx<'_>) -> Result<(), HookError> + Send + Sync>;
type PostActivationFn =
Box<dyn FnMut(&mut PostActivationCtx<'_>) -> Result<(), HookError> + Send + Sync>;
type PreEvaluateFn = Box<dyn FnMut(&mut PreEvaluateCtx<'_>) -> Result<(), HookError> + Send + Sync>;
type PostEvaluateFn =
Box<dyn FnMut(&mut PostEvaluateCtx<'_>) -> Result<(), HookError> + Send + Sync>;
type PostAssembleFn =
Box<dyn FnMut(&mut PostAssembleCtx<'_>) -> Result<(), HookError> + Send + Sync>;
type TurnAdvanceFn = Box<dyn FnMut(&mut TurnAdvanceCtx<'_>) -> Result<(), HookError> + Send + Sync>;
type TriggerFiredFn = Box<dyn FnMut(&mut TriggerCtx<'_>) -> Result<(), HookError> + Send + Sync>;
pub struct FnLifecycle {
name: String,
pre_activation_fn: Option<PreActivationFn>,
post_activation_fn: Option<PostActivationFn>,
pre_evaluate_fn: Option<PreEvaluateFn>,
post_evaluate_fn: Option<PostEvaluateFn>,
post_assemble_fn: Option<PostAssembleFn>,
turn_advance_fn: Option<TurnAdvanceFn>,
trigger_fired_fn: Option<TriggerFiredFn>,
}
macro_rules! fn_lifecycle_setter {
($method:ident, $field:ident, $ctx:ident) => {
pub fn $method<F>(mut self, f: F) -> Self
where
F: FnMut(&mut $ctx<'_>) -> Result<(), HookError> + Send + Sync + 'static,
{
self.$field = Some(Box::new(f));
self
}
};
}
impl FnLifecycle {
pub fn new(name: impl Into<String>) -> Self {
Self {
name: name.into(),
pre_activation_fn: None,
post_activation_fn: None,
pre_evaluate_fn: None,
post_evaluate_fn: None,
post_assemble_fn: None,
turn_advance_fn: None,
trigger_fired_fn: None,
}
}
fn_lifecycle_setter!(on_pre_activation, pre_activation_fn, PreActivationCtx);
fn_lifecycle_setter!(on_post_activation, post_activation_fn, PostActivationCtx);
fn_lifecycle_setter!(on_pre_evaluate, pre_evaluate_fn, PreEvaluateCtx);
fn_lifecycle_setter!(on_post_evaluate, post_evaluate_fn, PostEvaluateCtx);
fn_lifecycle_setter!(on_post_assemble, post_assemble_fn, PostAssembleCtx);
fn_lifecycle_setter!(on_turn_advance, turn_advance_fn, TurnAdvanceCtx);
fn_lifecycle_setter!(on_trigger_fired, trigger_fired_fn, TriggerCtx);
}
impl LifecyclePlugin for FnLifecycle {
fn name(&self) -> &str {
&self.name
}
fn pre_activation(&mut self, ctx: &mut PreActivationCtx<'_>) -> Result<(), HookError> {
match &mut self.pre_activation_fn {
Some(f) => f(ctx),
None => Ok(()),
}
}
fn post_activation(&mut self, ctx: &mut PostActivationCtx<'_>) -> Result<(), HookError> {
match &mut self.post_activation_fn {
Some(f) => f(ctx),
None => Ok(()),
}
}
fn pre_evaluate(&mut self, ctx: &mut PreEvaluateCtx<'_>) -> Result<(), HookError> {
match &mut self.pre_evaluate_fn {
Some(f) => f(ctx),
None => Ok(()),
}
}
fn post_evaluate(&mut self, ctx: &mut PostEvaluateCtx<'_>) -> Result<(), HookError> {
match &mut self.post_evaluate_fn {
Some(f) => f(ctx),
None => Ok(()),
}
}
fn post_assemble(&mut self, ctx: &mut PostAssembleCtx<'_>) -> Result<(), HookError> {
match &mut self.post_assemble_fn {
Some(f) => f(ctx),
None => Ok(()),
}
}
fn on_turn_advance(&mut self, ctx: &mut TurnAdvanceCtx<'_>) -> Result<(), HookError> {
match &mut self.turn_advance_fn {
Some(f) => f(ctx),
None => Ok(()),
}
}
fn on_trigger_fired(&mut self, ctx: &mut TriggerCtx<'_>) -> Result<(), HookError> {
match &mut self.trigger_fired_fn {
Some(f) => f(ctx),
None => Ok(()),
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::sync::atomic::{AtomicUsize, Ordering};
use std::sync::Arc;
#[test]
fn fn_lifecycle_invokes_set_hook() {
let count = Arc::new(AtomicUsize::new(0));
let count_clone = count.clone();
let mut plugin = FnLifecycle::new("counter").on_pre_activation(move |_ctx| {
count_clone.fetch_add(1, Ordering::SeqCst);
Ok(())
});
let mut messages: Vec<ChatMessage> = vec![];
let mut ctx = PreActivationCtx {
messages: &mut messages,
turn: 0,
};
plugin.pre_activation(&mut ctx).unwrap();
plugin.pre_activation(&mut ctx).unwrap();
assert_eq!(count.load(Ordering::SeqCst), 2);
}
#[test]
fn fn_lifecycle_unset_hooks_are_noop() {
let mut plugin = FnLifecycle::new("empty");
let mut messages: Vec<ChatMessage> = vec![];
let mut ctx = PreActivationCtx {
messages: &mut messages,
turn: 0,
};
assert!(plugin.pre_activation(&mut ctx).is_ok());
let mut blocks: Vec<AssembledBlock> = vec![];
let lorebook = Lorebook::new();
let mut ctx = PostAssembleCtx {
blocks: &mut blocks,
lorebook: &lorebook,
};
assert!(plugin.post_assemble(&mut ctx).is_ok());
}
#[test]
fn hook_error_display_with_source() {
let inner = std::io::Error::new(std::io::ErrorKind::Other, "inner cause");
let err = HookError::with_source("hook failed", inner);
let display = err.to_string();
assert!(display.contains("hook failed"));
assert!(display.contains("inner cause"));
}
#[test]
fn hook_error_display_without_source() {
let err = HookError::new("simple failure");
assert_eq!(err.to_string(), "simple failure");
}
#[test]
fn fn_lifecycle_pre_activation_can_mutate_messages() {
let mut plugin = FnLifecycle::new("injector").on_pre_activation(|ctx| {
ctx.messages.push(ChatMessage::system("[injected]"));
Ok(())
});
let mut messages = vec![ChatMessage::user("hello")];
let mut ctx = PreActivationCtx {
messages: &mut messages,
turn: 0,
};
plugin.pre_activation(&mut ctx).unwrap();
assert_eq!(messages.len(), 2);
}
#[test]
fn fn_lifecycle_error_propagates() {
let mut plugin =
FnLifecycle::new("failer").on_pre_activation(|_ctx| Err(HookError::new("nope")));
let mut messages: Vec<ChatMessage> = vec![];
let mut ctx = PreActivationCtx {
messages: &mut messages,
turn: 0,
};
let err = plugin.pre_activation(&mut ctx).unwrap_err();
assert_eq!(err.message, "nope");
}
}