#![allow(dead_code)]
use std::future::Future;
use std::sync::{Arc, Mutex};
use crate::types::Message;
pub use crate::utils::hooks::api_query_hook_helper::{ReplHookContext, SystemPrompt};
pub type PostSamplingHook = Arc<
dyn Fn(ReplHookContext) -> std::pin::Pin<Box<dyn Future<Output = ()> + Send>> + Send + Sync,
>;
lazy_static::lazy_static! {
static ref POST_SAMPLING_HOOKS: Arc<Mutex<Vec<PostSamplingHook>>> = Arc::new(Mutex::new(Vec::new()));
}
pub fn register_post_sampling_hook(hook: PostSamplingHook) {
let mut hooks = POST_SAMPLING_HOOKS.lock().unwrap();
hooks.push(hook);
}
pub fn clear_post_sampling_hooks() {
let mut hooks = POST_SAMPLING_HOOKS.lock().unwrap();
hooks.clear();
}
pub async fn execute_post_sampling_hooks(
messages: Vec<Message>,
system_prompt: SystemPrompt,
user_context: std::collections::HashMap<String, String>,
system_context: std::collections::HashMap<String, String>,
tool_use_context: Arc<crate::utils::hooks::can_use_tool::ToolUseContext>,
query_source: Option<String>,
) {
let context = ReplHookContext {
messages,
system_prompt,
user_context,
system_context,
tool_use_context,
query_source,
query_message_count: None,
};
let hooks: Vec<PostSamplingHook> = {
let hooks = POST_SAMPLING_HOOKS.lock().unwrap();
hooks.clone() };
for hook in hooks {
let ctx = context.clone();
let future = hook(ctx);
future.await;
}
}
pub fn get_post_sampling_hook_count() -> usize {
POST_SAMPLING_HOOKS.lock().unwrap().len()
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_register_and_clear_hooks() {
clear_post_sampling_hooks();
assert_eq!(get_post_sampling_hook_count(), 0);
clear_post_sampling_hooks();
assert_eq!(get_post_sampling_hook_count(), 0);
}
#[test]
fn test_repl_hook_context_clone() {
let ctx = ReplHookContext {
messages: Vec::new(),
system_prompt: SystemPrompt::default(),
user_context: std::collections::HashMap::new(),
system_context: std::collections::HashMap::new(),
tool_use_context: Arc::new(crate::utils::hooks::can_use_tool::ToolUseContext {
session_id: "test".to_string(),
cwd: None,
is_non_interactive_session: false,
options: None,
}),
query_source: None,
query_message_count: None,
};
let _ctx2 = ctx.clone();
}
}