use crate::{
completion::CompletionModel,
message::{Message, ToolChoice},
wasm_compat::{WasmCompatSend, WasmCompatSync},
};
#[derive(Debug, Clone)]
pub struct InvalidToolCallContext {
pub tool_name: String,
pub tool_call_id: Option<String>,
pub internal_call_id: Option<String>,
pub args: Option<String>,
pub available_tools: Vec<String>,
pub allowed_tools: Vec<String>,
pub tool_choice: Option<ToolChoice>,
pub chat_history: Vec<Message>,
pub is_streaming: bool,
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum InvalidToolCallHookAction {
Fail,
Retry { feedback: String },
Repair { tool_name: String },
Skip { reason: String },
}
impl InvalidToolCallHookAction {
pub fn fail() -> Self {
Self::Fail
}
pub fn retry(feedback: impl Into<String>) -> Self {
Self::Retry {
feedback: feedback.into(),
}
}
pub fn repair(tool_name: impl Into<String>) -> Self {
Self::Repair {
tool_name: tool_name.into(),
}
}
pub fn skip(reason: impl Into<String>) -> Self {
Self::Skip {
reason: reason.into(),
}
}
}
pub trait PromptHook<M>: Clone + WasmCompatSend + WasmCompatSync
where
M: CompletionModel,
{
fn on_completion_call(
&self,
_prompt: &Message,
_history: &[Message],
) -> impl Future<Output = HookAction> + WasmCompatSend {
async { HookAction::cont() }
}
fn on_completion_response(
&self,
_prompt: &Message,
_response: &crate::completion::CompletionResponse<M::Response>,
) -> impl Future<Output = HookAction> + WasmCompatSend {
async { HookAction::cont() }
}
fn on_invalid_tool_call(
&self,
_context: &InvalidToolCallContext,
) -> impl Future<Output = InvalidToolCallHookAction> + WasmCompatSend {
async { InvalidToolCallHookAction::fail() }
}
fn on_tool_call(
&self,
_tool_name: &str,
_tool_call_id: Option<String>,
_internal_call_id: &str,
_args: &str,
) -> impl Future<Output = ToolCallHookAction> + WasmCompatSend {
async { ToolCallHookAction::cont() }
}
fn on_tool_result(
&self,
_tool_name: &str,
_tool_call_id: Option<String>,
_internal_call_id: &str,
_args: &str,
_result: &str,
) -> impl Future<Output = HookAction> + WasmCompatSend {
async { HookAction::cont() }
}
fn on_text_delta(
&self,
_text_delta: &str,
_aggregated_text: &str,
) -> impl Future<Output = HookAction> + Send {
async { HookAction::cont() }
}
fn on_tool_call_delta(
&self,
_tool_call_id: &str,
_internal_call_id: &str,
_tool_name: Option<&str>,
_tool_call_delta: &str,
) -> impl Future<Output = HookAction> + Send {
async { HookAction::cont() }
}
fn on_stream_completion_response_finish(
&self,
_prompt: &Message,
_response: &<M as CompletionModel>::StreamingResponse,
) -> impl Future<Output = HookAction> + Send {
async { HookAction::cont() }
}
}
impl<M> PromptHook<M> for () where M: CompletionModel {}
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum ToolCallHookAction {
Continue,
Skip { reason: String },
Terminate { reason: String },
}
impl ToolCallHookAction {
pub fn cont() -> Self {
Self::Continue
}
pub fn skip(reason: impl Into<String>) -> Self {
Self::Skip {
reason: reason.into(),
}
}
pub fn terminate(reason: impl Into<String>) -> Self {
Self::Terminate {
reason: reason.into(),
}
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum HookAction {
Continue,
Terminate { reason: String },
}
impl HookAction {
pub fn cont() -> Self {
Self::Continue
}
pub fn terminate(reason: impl Into<String>) -> Self {
Self::Terminate {
reason: reason.into(),
}
}
}