use rig_core::agent::{HookAction, PromptHook, ToolCallHookAction};
use rig_core::completion::CompletionModel;
use rig_core::message::Message;
#[derive(Debug, Clone)]
pub struct HookPair<A, B> {
first: A,
second: B,
}
impl<A, B> HookPair<A, B> {
pub fn new(first: A, second: B) -> Self {
Self { first, second }
}
pub fn then<C>(self, third: C) -> HookPair<Self, C> {
HookPair::new(self, third)
}
pub fn first(&self) -> &A {
&self.first
}
pub fn second(&self) -> &B {
&self.second
}
}
impl<A, B, M> PromptHook<M> for HookPair<A, B>
where
M: CompletionModel,
A: PromptHook<M>,
B: PromptHook<M>,
{
async fn on_completion_call(&self, prompt: &Message, history: &[Message]) -> HookAction {
match self.first.on_completion_call(prompt, history).await {
HookAction::Continue => self.second.on_completion_call(prompt, history).await,
terminate => terminate,
}
}
async fn on_completion_response(
&self,
prompt: &Message,
response: &rig_core::completion::CompletionResponse<M::Response>,
) -> HookAction {
match self.first.on_completion_response(prompt, response).await {
HookAction::Continue => self.second.on_completion_response(prompt, response).await,
terminate => terminate,
}
}
async fn on_tool_call(
&self,
tool_name: &str,
tool_call_id: Option<String>,
internal_call_id: &str,
args: &str,
) -> ToolCallHookAction {
match self
.first
.on_tool_call(tool_name, tool_call_id.clone(), internal_call_id, args)
.await
{
ToolCallHookAction::Continue => {
self.second
.on_tool_call(tool_name, tool_call_id, internal_call_id, args)
.await
}
other => other,
}
}
async fn on_tool_result(
&self,
tool_name: &str,
tool_call_id: Option<String>,
internal_call_id: &str,
args: &str,
result: &str,
) -> HookAction {
match self
.first
.on_tool_result(
tool_name,
tool_call_id.clone(),
internal_call_id,
args,
result,
)
.await
{
HookAction::Continue => {
self.second
.on_tool_result(tool_name, tool_call_id, internal_call_id, args, result)
.await
}
terminate => terminate,
}
}
async fn on_text_delta(&self, text_delta: &str, aggregated_text: &str) -> HookAction {
match self.first.on_text_delta(text_delta, aggregated_text).await {
HookAction::Continue => self.second.on_text_delta(text_delta, aggregated_text).await,
terminate => terminate,
}
}
async fn on_tool_call_delta(
&self,
tool_call_id: &str,
internal_call_id: &str,
tool_name: Option<&str>,
tool_call_delta: &str,
) -> HookAction {
match self
.first
.on_tool_call_delta(tool_call_id, internal_call_id, tool_name, tool_call_delta)
.await
{
HookAction::Continue => {
self.second
.on_tool_call_delta(tool_call_id, internal_call_id, tool_name, tool_call_delta)
.await
}
terminate => terminate,
}
}
async fn on_stream_completion_response_finish(
&self,
prompt: &Message,
response: &<M as CompletionModel>::StreamingResponse,
) -> HookAction {
match self
.first
.on_stream_completion_response_finish(prompt, response)
.await
{
HookAction::Continue => {
self.second
.on_stream_completion_response_finish(prompt, response)
.await
}
terminate => terminate,
}
}
}
#[cfg(test)]
#[allow(clippy::unwrap_used, clippy::panic, clippy::indexing_slicing)]
mod tests {
use super::*;
#[test]
fn hook_pair_constructs_and_chains() {
let pair = HookPair::new((), ());
let chained = pair.then(());
let _ = chained.first();
let _ = chained.second();
}
}