use async_trait::async_trait;
use std::sync::Arc;
use super::run_tree::RunTree;
use crate::schema::Message;
#[async_trait]
pub trait CallbackHandler: Send + Sync {
async fn on_run_start(&self, run: &RunTree);
async fn on_run_end(&self, run: &RunTree);
async fn on_run_error(&self, run: &RunTree, error: &str);
async fn on_llm_start(&self, run: &RunTree, _messages: &[Message]) {
self.on_run_start(run).await;
}
async fn on_llm_end(&self, run: &RunTree, _response: &str) {
self.on_run_end(run).await;
}
async fn on_llm_new_token(&self, _run: &RunTree, _token: &str) {
}
async fn on_llm_error(&self, run: &RunTree, error: &str) {
self.on_run_error(run, error).await;
}
async fn on_chain_start(&self, run: &RunTree, _inputs: &serde_json::Value) {
self.on_run_start(run).await;
}
async fn on_chain_end(&self, run: &RunTree, _outputs: &serde_json::Value) {
self.on_run_end(run).await;
}
async fn on_chain_error(&self, run: &RunTree, error: &str) {
self.on_run_error(run, error).await;
}
async fn on_tool_start(&self, run: &RunTree, _tool_name: &str, _input: &str) {
self.on_run_start(run).await;
}
async fn on_tool_end(&self, run: &RunTree, _output: &str) {
self.on_run_end(run).await;
}
async fn on_tool_error(&self, run: &RunTree, error: &str) {
self.on_run_error(run, error).await;
}
async fn on_retriever_start(&self, run: &RunTree, _query: &str) {
self.on_run_start(run).await;
}
async fn on_retriever_end(&self, run: &RunTree, _documents: &[serde_json::Value]) {
self.on_run_end(run).await;
}
async fn on_retriever_error(&self, run: &RunTree, error: &str) {
self.on_run_error(run, error).await;
}
}
pub struct CallbackManager {
handlers: Vec<Arc<dyn CallbackHandler>>,
}
impl std::fmt::Debug for CallbackManager {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("CallbackManager")
.field("handlers_count", &self.handlers.len())
.finish()
}
}
impl CallbackManager {
pub fn new() -> Self {
Self { handlers: Vec::new() }
}
pub fn add_handler(mut self, handler: Arc<dyn CallbackHandler>) -> Self {
self.handlers.push(handler);
self
}
pub fn handlers(&self) -> &[Arc<dyn CallbackHandler>] {
&self.handlers
}
pub fn is_empty(&self) -> bool {
self.handlers.is_empty()
}
}
impl Default for CallbackManager {
fn default() -> Self {
Self::new()
}
}
impl Clone for CallbackManager {
fn clone(&self) -> Self {
Self {
handlers: self.handlers.clone(),
}
}
}