rig/agent/prompt_request/
hooks.rs1use crate::{
6 completion::CompletionModel,
7 message::Message,
8 wasm_compat::{WasmCompatSend, WasmCompatSync},
9};
10
11pub trait PromptHook<M>: Clone + WasmCompatSend + WasmCompatSync
13where
14 M: CompletionModel,
15{
16 fn on_completion_call(
18 &self,
19 _prompt: &Message,
20 _history: &[Message],
21 ) -> impl Future<Output = HookAction> + WasmCompatSend {
22 async { HookAction::cont() }
23 }
24
25 fn on_completion_response(
27 &self,
28 _prompt: &Message,
29 _response: &crate::completion::CompletionResponse<M::Response>,
30 ) -> impl Future<Output = HookAction> + WasmCompatSend {
31 async { HookAction::cont() }
32 }
33
34 fn on_tool_call(
40 &self,
41 _tool_name: &str,
42 _tool_call_id: Option<String>,
43 _internal_call_id: &str,
44 _args: &str,
45 ) -> impl Future<Output = ToolCallHookAction> + WasmCompatSend {
46 async { ToolCallHookAction::cont() }
47 }
48
49 fn on_tool_result(
51 &self,
52 _tool_name: &str,
53 _tool_call_id: Option<String>,
54 _internal_call_id: &str,
55 _args: &str,
56 _result: &str,
57 ) -> impl Future<Output = HookAction> + WasmCompatSend {
58 async { HookAction::cont() }
59 }
60
61 fn on_text_delta(
63 &self,
64 _text_delta: &str,
65 _aggregated_text: &str,
66 ) -> impl Future<Output = HookAction> + Send {
67 async { HookAction::cont() }
68 }
69
70 fn on_tool_call_delta(
73 &self,
74 _tool_call_id: &str,
75 _internal_call_id: &str,
76 _tool_name: Option<&str>,
77 _tool_call_delta: &str,
78 ) -> impl Future<Output = HookAction> + Send {
79 async { HookAction::cont() }
80 }
81
82 fn on_stream_completion_response_finish(
84 &self,
85 _prompt: &Message,
86 _response: &<M as CompletionModel>::StreamingResponse,
87 ) -> impl Future<Output = HookAction> + Send {
88 async { HookAction::cont() }
89 }
90}
91
92impl<M> PromptHook<M> for () where M: CompletionModel {}
93
94#[derive(Debug, Clone, PartialEq, Eq)]
96pub enum ToolCallHookAction {
97 Continue,
99 Skip { reason: String },
101 Terminate { reason: String },
103}
104
105impl ToolCallHookAction {
106 pub fn cont() -> Self {
108 Self::Continue
109 }
110
111 pub fn skip(reason: impl Into<String>) -> Self {
113 Self::Skip {
114 reason: reason.into(),
115 }
116 }
117
118 pub fn terminate(reason: impl Into<String>) -> Self {
120 Self::Terminate {
121 reason: reason.into(),
122 }
123 }
124}
125
126#[derive(Debug, Clone, PartialEq, Eq)]
128pub enum HookAction {
129 Continue,
131 Terminate { reason: String },
133}
134
135impl HookAction {
136 pub fn cont() -> Self {
138 Self::Continue
139 }
140
141 pub fn terminate(reason: impl Into<String>) -> Self {
143 Self::Terminate {
144 reason: reason.into(),
145 }
146 }
147}