use std::future::Future;
use std::pin::Pin;
use crate::types::budget::Budget;
use crate::types::error::CruxErr;
use crate::types::recovery::Recovery;
type ConfidenceHandler = Box<
dyn Fn(f32) -> Pin<Box<dyn Future<Output = Recovery<serde_json::Value>> + Send>> + Send + Sync,
>;
type FailureHandler = Box<
dyn Fn(CruxErr) -> Pin<Box<dyn Future<Output = Recovery<serde_json::Value>> + Send>>
+ Send
+ Sync,
>;
type BudgetHandler = Box<
dyn Fn(Budget) -> Pin<Box<dyn Future<Output = Recovery<serde_json::Value>> + Send>>
+ Send
+ Sync,
>;
pub struct HookRegistry {
pub(crate) confidence_threshold: Option<f32>,
confidence_handler: Option<ConfidenceHandler>,
failure_handler: Option<FailureHandler>,
budget_handler: Option<BudgetHandler>,
}
impl HookRegistry {
pub fn new() -> Self {
Self {
confidence_threshold: None,
confidence_handler: None,
failure_handler: None,
budget_handler: None,
}
}
pub fn on_low_confidence<F, Fut>(&mut self, threshold: f32, handler: F)
where
F: Fn(f32) -> Fut + Send + Sync + 'static,
Fut: Future<Output = Recovery<serde_json::Value>> + Send + 'static,
{
self.confidence_threshold = Some(threshold);
self.confidence_handler = Some(Box::new(move |score| Box::pin(handler(score))));
}
pub fn on_step_failure<F, Fut>(&mut self, handler: F)
where
F: Fn(CruxErr) -> Fut + Send + Sync + 'static,
Fut: Future<Output = Recovery<serde_json::Value>> + Send + 'static,
{
self.failure_handler = Some(Box::new(move |err| Box::pin(handler(err))));
}
pub fn on_budget_exceeded<F, Fut>(&mut self, handler: F)
where
F: Fn(Budget) -> Fut + Send + Sync + 'static,
Fut: Future<Output = Recovery<serde_json::Value>> + Send + 'static,
{
self.budget_handler = Some(Box::new(move |budget| Box::pin(handler(budget))));
}
pub async fn check_confidence(&self, confidence: f32) -> Option<Recovery<serde_json::Value>> {
if let (Some(threshold), Some(handler)) =
(self.confidence_threshold, &self.confidence_handler)
{
if confidence < threshold {
return Some(handler(confidence).await);
}
}
None
}
pub async fn check_failure(&self, err: CruxErr) -> Option<Recovery<serde_json::Value>> {
if let Some(handler) = &self.failure_handler {
Some(handler(err).await)
} else {
None
}
}
pub async fn check_budget(&self, budget: Budget) -> Option<Recovery<serde_json::Value>> {
if let Some(handler) = &self.budget_handler {
Some(handler(budget).await)
} else {
None
}
}
pub(crate) fn on_low_confidence_boxed(&mut self, threshold: f32, handler: ConfidenceHandler) {
self.confidence_threshold = Some(threshold);
self.confidence_handler = Some(handler);
}
pub(crate) fn on_step_failure_boxed(&mut self, handler: FailureHandler) {
self.failure_handler = Some(handler);
}
pub fn has_failure_handler(&self) -> bool {
self.failure_handler.is_some()
}
}
impl Default for HookRegistry {
fn default() -> Self {
Self::new()
}
}
impl std::fmt::Debug for HookRegistry {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("HookRegistry")
.field("confidence_threshold", &self.confidence_threshold)
.field("has_confidence_handler", &self.confidence_handler.is_some())
.field("has_failure_handler", &self.failure_handler.is_some())
.field("has_budget_handler", &self.budget_handler.is_some())
.finish()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn confidence_fires_below_threshold() {
let mut hooks = HookRegistry::new();
hooks.on_low_confidence(0.8, |_| async { Recovery::Continue });
assert!(hooks.check_confidence(0.5).await.is_some());
}
#[tokio::test]
async fn confidence_skips_above_threshold() {
let mut hooks = HookRegistry::new();
hooks.on_low_confidence(0.8, |_| async { Recovery::Continue });
assert!(hooks.check_confidence(0.9).await.is_none());
}
#[tokio::test]
async fn failure_returns_none_without_handler() {
let hooks = HookRegistry::new();
let err = CruxErr::step_failed("x", "y");
assert!(hooks.check_failure(err).await.is_none());
}
#[tokio::test]
async fn failure_invokes_handler() {
let mut hooks = HookRegistry::new();
hooks.on_step_failure(|_| async { Recovery::Propagate });
let err = CruxErr::step_failed("x", "y");
assert!(hooks.check_failure(err).await.is_some());
}
#[tokio::test]
async fn budget_returns_none_without_handler() {
let hooks = HookRegistry::new();
assert!(hooks.check_budget(Budget::tokens(10)).await.is_none());
}
}