use serde::{Deserialize, Serialize};
use std::collections::{HashMap, HashSet};
use crate::lorebook::LorebookConfig;
use crate::EvaluatedEntry;
#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize, Default)]
#[serde(rename_all = "snake_case")]
pub enum Slot {
Preamble,
Foundation,
#[default]
Context,
Reference,
Framing,
Guidance,
Emphasis,
Immediate,
Aftermath,
#[serde(rename = "at_depth")]
AtDepth(usize),
}
impl Slot {
fn order_index(&self) -> usize {
match self {
Self::Preamble => 0,
Self::Foundation => 1,
Self::Context => 2,
Self::Reference => 3,
Self::Framing => 4,
Self::Guidance => 5,
Self::Emphasis => 6,
Self::Immediate => 7,
Self::Aftermath => 8,
Self::AtDepth(_) => 9,
}
}
pub fn standard_slots() -> Vec<Slot> {
vec![
Self::Preamble,
Self::Foundation,
Self::Context,
Self::Reference,
Self::Framing,
Self::Guidance,
Self::Emphasis,
Self::Immediate,
Self::Aftermath,
]
}
}
impl Ord for Slot {
fn cmp(&self, other: &Self) -> std::cmp::Ordering {
match (self, other) {
(Self::AtDepth(a), Self::AtDepth(b)) => a.cmp(b),
_ => self.order_index().cmp(&other.order_index()),
}
}
}
impl PartialOrd for Slot {
fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
Some(self.cmp(other))
}
}
impl std::fmt::Display for Slot {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::Preamble => write!(f, "preamble"),
Self::Foundation => write!(f, "foundation"),
Self::Context => write!(f, "context"),
Self::Reference => write!(f, "reference"),
Self::Framing => write!(f, "framing"),
Self::Guidance => write!(f, "guidance"),
Self::Emphasis => write!(f, "emphasis"),
Self::Immediate => write!(f, "immediate"),
Self::Aftermath => write!(f, "aftermath"),
Self::AtDepth(n) => write!(f, "at_depth({n})"),
}
}
}
pub trait Tokenizer: Send + Sync {
fn estimate_tokens(&self, text: &str) -> usize;
}
pub struct GuesstimationTokenizer;
impl Tokenizer for GuesstimationTokenizer {
fn estimate_tokens(&self, text: &str) -> usize {
text.len().div_ceil(4)
}
}
#[derive(Debug, Clone)]
pub struct TokenBudget {
pub total: Option<usize>,
pub groups: HashMap<String, usize>,
consumed_total: usize,
consumed_groups: HashMap<String, usize>,
}
impl TokenBudget {
pub fn from_config(config: &LorebookConfig) -> Self {
Self {
total: config.token_budget,
groups: config.group_budgets.clone(),
consumed_total: 0,
consumed_groups: HashMap::new(),
}
}
pub fn can_fit(&self, tokens: usize, group: Option<&str>) -> bool {
if let Some(total) = self.total {
if self.consumed_total + tokens > total {
return false;
}
}
if let Some(group_name) = group {
if let Some(group_budget) = self.groups.get(group_name) {
let consumed = self.consumed_groups.get(group_name).copied().unwrap_or(0);
if consumed + tokens > *group_budget {
return false;
}
}
}
true
}
pub fn consume(&mut self, tokens: usize, group: Option<&str>) {
self.consumed_total += tokens;
if let Some(group_name) = group {
*self
.consumed_groups
.entry(group_name.to_string())
.or_default() += tokens;
}
}
}
#[derive(Debug, Clone)]
pub struct AssembledBlock {
pub entry_id: String,
pub slot: Slot,
pub content: String,
pub priority: i32,
pub insertion_order: i32,
pub group: Option<String>,
pub estimated_tokens: usize,
}
pub struct ContextAssembler;
impl ContextAssembler {
pub fn assemble(
entries: Vec<EvaluatedEntry>,
config: &LorebookConfig,
tokenizer: &dyn Tokenizer,
available_slots: &HashSet<Slot>,
) -> Vec<AssembledBlock> {
let mut budget = TokenBudget::from_config(config);
let mut blocks: Vec<AssembledBlock> = Vec::new();
let mut candidates: Vec<AssembledBlock> = entries
.into_iter()
.filter(|e| !e.content.trim().is_empty())
.filter_map(|e| {
let resolved_slot = resolve_slot(&e.meta.slot, &e.meta.fallback, available_slots)?;
let estimated_tokens = tokenizer.estimate_tokens(&e.content);
Some(AssembledBlock {
entry_id: e.id,
slot: resolved_slot,
content: e.content,
priority: e.meta.priority,
insertion_order: e.meta.insertion_order,
group: e.meta.group,
estimated_tokens,
})
})
.collect();
candidates.sort_by(|a, b| {
b.priority
.cmp(&a.priority)
.then(a.insertion_order.cmp(&b.insertion_order))
});
for block in candidates {
let group = &block.group;
if budget.can_fit(block.estimated_tokens, group.as_deref()) {
budget.consume(block.estimated_tokens, group.as_deref());
blocks.push(block);
}
}
blocks.sort_by(|a, b| {
a.slot
.cmp(&b.slot)
.then(b.priority.cmp(&a.priority))
.then(a.insertion_order.cmp(&b.insertion_order))
});
blocks
}
}
fn resolve_slot(primary: &Slot, fallback: &[Slot], available: &HashSet<Slot>) -> Option<Slot> {
if matches!(primary, Slot::AtDepth(_)) {
return Some(primary.clone());
}
if available.contains(primary) {
return Some(primary.clone());
}
for slot in fallback {
if matches!(slot, Slot::AtDepth(_)) {
return Some(slot.clone());
}
if available.contains(slot) {
return Some(slot.clone());
}
}
None
}
#[cfg(test)]
mod tests {
use super::*;
use crate::entry::EntryMeta;
fn all_slots() -> HashSet<Slot> {
Slot::standard_slots().into_iter().collect()
}
#[test]
fn test_budget_tracking() {
let mut budget = TokenBudget {
total: Some(100),
groups: HashMap::from([("combat".into(), 50)]),
consumed_total: 0,
consumed_groups: HashMap::new(),
};
assert!(budget.can_fit(30, None));
budget.consume(30, None);
assert!(budget.can_fit(70, None));
assert!(!budget.can_fit(71, None));
assert!(budget.can_fit(50, Some("combat")));
budget.consume(50, Some("combat"));
assert!(!budget.can_fit(1, Some("combat")));
assert!(!budget.can_fit(21, None));
}
#[test]
fn test_empty_entries_filtered() {
let config = LorebookConfig::default();
let tokenizer = GuesstimationTokenizer;
let entries = vec![
EvaluatedEntry {
id: "empty".into(),
meta: make_meta("empty", 100),
content: " \n ".into(),
},
EvaluatedEntry {
id: "has_content".into(),
meta: make_meta("has_content", 100),
content: "Hello!".into(),
},
];
let blocks = ContextAssembler::assemble(entries, &config, &tokenizer, &all_slots());
assert_eq!(blocks.len(), 1);
assert_eq!(blocks[0].entry_id, "has_content");
}
#[test]
fn test_slot_resolution_primary() {
let available = HashSet::from([Slot::Foundation, Slot::Context]);
assert_eq!(
resolve_slot(&Slot::Foundation, &[], &available),
Some(Slot::Foundation)
);
}
#[test]
fn test_slot_resolution_fallback() {
let available = HashSet::from([Slot::Context]);
assert_eq!(
resolve_slot(
&Slot::Foundation,
&[Slot::Context, Slot::Preamble],
&available,
),
Some(Slot::Context)
);
}
#[test]
fn test_slot_resolution_none_available() {
let available = HashSet::from([Slot::Preamble]);
assert_eq!(
resolve_slot(&Slot::Foundation, &[Slot::Context], &available),
None
);
}
#[test]
fn test_at_depth_always_resolves() {
let available = HashSet::new(); assert_eq!(
resolve_slot(&Slot::AtDepth(3), &[], &available),
Some(Slot::AtDepth(3))
);
}
#[test]
fn test_slot_ordering() {
let mut slots = vec![
Slot::Aftermath,
Slot::Preamble,
Slot::Emphasis,
Slot::Foundation,
];
slots.sort();
assert_eq!(
slots,
vec![
Slot::Preamble,
Slot::Foundation,
Slot::Emphasis,
Slot::Aftermath
]
);
}
#[test]
fn test_guesstimation_tokenizer() {
let tok = GuesstimationTokenizer;
assert_eq!(tok.estimate_tokens(""), 0);
assert_eq!(tok.estimate_tokens("abcd"), 1);
assert_eq!(tok.estimate_tokens("Hello, world!"), 4); }
fn make_meta(id: &str, priority: i32) -> EntryMeta {
EntryMeta {
id: id.to_string(),
name: id.to_string(),
keywords: vec![],
regex: vec![],
condition: None,
scan_depth: None,
constant: false,
priority,
slot: Slot::default(),
fallback: vec![],
insertion_order: 50,
enabled: true,
sticky_turns: 0,
cooldown: 0,
group: None,
tags: vec![],
extensions: Default::default(),
}
}
}