neuron_runtime/guardrail_hook.rs
1//! Adapter that wraps guardrails as an [`ObservabilityHook`].
2//!
3//! [`GuardrailHook`] runs input guardrails on [`HookEvent::PreLlmCall`] and output
4//! guardrails on [`HookEvent::PostLlmCall`], mapping [`GuardrailResult`] variants
5//! to [`HookAction`] values.
6//!
7//! # Example
8//!
9//! ```ignore
10//! use neuron_runtime::{GuardrailHook, InputGuardrail, OutputGuardrail, GuardrailResult};
11//!
12//! struct BlockSecrets;
13//! impl InputGuardrail for BlockSecrets {
14//! fn check(&self, input: &str) -> impl Future<Output = GuardrailResult> + Send {
15//! async move {
16//! if input.contains("API_KEY") {
17//! GuardrailResult::Tripwire("secret detected".to_string())
18//! } else {
19//! GuardrailResult::Pass
20//! }
21//! }
22//! }
23//! }
24//!
25//! let hook = GuardrailHook::new().input_guardrail(BlockSecrets);
26//! // Use `hook` as an ObservabilityHook in the agent loop
27//! ```
28
29use std::sync::Arc;
30
31use neuron_types::{
32 ContentBlock, HookAction, HookError, HookEvent, ObservabilityHook, WasmCompatSend,
33};
34
35use crate::guardrail::{ErasedInputGuardrail, ErasedOutputGuardrail, GuardrailResult};
36
37/// An [`ObservabilityHook`] that runs guardrails on LLM input and output.
38///
39/// Built with a builder pattern. Input guardrails fire on [`HookEvent::PreLlmCall`],
40/// output guardrails fire on [`HookEvent::PostLlmCall`]. All other events pass
41/// through with [`HookAction::Continue`].
42///
43/// # Guardrail result mapping
44///
45/// - [`GuardrailResult::Pass`] -> [`HookAction::Continue`]
46/// - [`GuardrailResult::Tripwire`] -> [`HookAction::Terminate`] with the reason
47/// - [`GuardrailResult::Warn`] -> logs a warning via [`tracing::warn!`] and continues
48///
49/// # Example
50///
51/// ```ignore
52/// use neuron_runtime::{GuardrailHook, InputGuardrail, OutputGuardrail, GuardrailResult};
53///
54/// struct NoSecrets;
55/// impl InputGuardrail for NoSecrets {
56/// fn check(&self, input: &str) -> impl Future<Output = GuardrailResult> + Send {
57/// async move {
58/// if input.contains("sk-") {
59/// GuardrailResult::Tripwire("secret in input".to_string())
60/// } else {
61/// GuardrailResult::Pass
62/// }
63/// }
64/// }
65/// }
66///
67/// let hook = GuardrailHook::new()
68/// .input_guardrail(NoSecrets);
69/// ```
70pub struct GuardrailHook {
71 input_guardrails: Vec<Arc<dyn ErasedInputGuardrail>>,
72 output_guardrails: Vec<Arc<dyn ErasedOutputGuardrail>>,
73}
74
75impl GuardrailHook {
76 /// Create an empty `GuardrailHook` with no guardrails.
77 #[must_use]
78 pub fn new() -> Self {
79 Self {
80 input_guardrails: Vec::new(),
81 output_guardrails: Vec::new(),
82 }
83 }
84
85 /// Add an input guardrail.
86 ///
87 /// Input guardrails run on [`HookEvent::PreLlmCall`], checking the last
88 /// user message text in the request.
89 #[must_use]
90 pub fn input_guardrail<G>(mut self, guardrail: G) -> Self
91 where
92 G: ErasedInputGuardrail + 'static,
93 {
94 self.input_guardrails.push(Arc::new(guardrail));
95 self
96 }
97
98 /// Add an output guardrail.
99 ///
100 /// Output guardrails run on [`HookEvent::PostLlmCall`], checking the
101 /// assistant response text from the response message.
102 #[must_use]
103 pub fn output_guardrail<G>(mut self, guardrail: G) -> Self
104 where
105 G: ErasedOutputGuardrail + 'static,
106 {
107 self.output_guardrails.push(Arc::new(guardrail));
108 self
109 }
110}
111
112impl Default for GuardrailHook {
113 fn default() -> Self {
114 Self::new()
115 }
116}
117
118/// Extract the text content from the last user message in the request's messages.
119///
120/// Returns an empty string if there are no user messages or no text blocks.
121fn extract_last_user_text(messages: &[neuron_types::Message]) -> String {
122 for message in messages.iter().rev() {
123 if message.role == neuron_types::Role::User {
124 let texts: Vec<&str> = message
125 .content
126 .iter()
127 .filter_map(|block| match block {
128 ContentBlock::Text(t) => Some(t.as_str()),
129 _ => None,
130 })
131 .collect();
132 if !texts.is_empty() {
133 return texts.join("\n");
134 }
135 }
136 }
137 String::new()
138}
139
140/// Extract text content from the assistant response message.
141///
142/// Returns an empty string if there are no text blocks.
143fn extract_response_text(message: &neuron_types::Message) -> String {
144 let texts: Vec<&str> = message
145 .content
146 .iter()
147 .filter_map(|block| match block {
148 ContentBlock::Text(t) => Some(t.as_str()),
149 _ => None,
150 })
151 .collect();
152 texts.join("\n")
153}
154
155/// Map a [`GuardrailResult`] to a [`HookAction`], logging warnings as needed.
156fn map_guardrail_result(result: GuardrailResult, direction: &str) -> HookAction {
157 match result {
158 GuardrailResult::Pass => HookAction::Continue,
159 GuardrailResult::Tripwire(reason) => HookAction::Terminate { reason },
160 GuardrailResult::Warn(reason) => {
161 tracing::warn!("{direction} guardrail warning: {reason}");
162 HookAction::Continue
163 }
164 }
165}
166
167impl ObservabilityHook for GuardrailHook {
168 fn on_event(
169 &self,
170 event: HookEvent<'_>,
171 ) -> impl Future<Output = Result<HookAction, HookError>> + WasmCompatSend {
172 // Capture references needed for the async block before moving into it.
173 let input_guardrails = &self.input_guardrails;
174 let output_guardrails = &self.output_guardrails;
175
176 async move {
177 match event {
178 HookEvent::PreLlmCall { request } => {
179 if input_guardrails.is_empty() {
180 return Ok(HookAction::Continue);
181 }
182 let text = extract_last_user_text(&request.messages);
183 if text.is_empty() {
184 return Ok(HookAction::Continue);
185 }
186 for guardrail in input_guardrails {
187 let result = guardrail.check_dyn(&text).await;
188 if !result.is_pass() {
189 return Ok(map_guardrail_result(result, "input"));
190 }
191 }
192 Ok(HookAction::Continue)
193 }
194 HookEvent::PostLlmCall { response } => {
195 if output_guardrails.is_empty() {
196 return Ok(HookAction::Continue);
197 }
198 let text = extract_response_text(&response.message);
199 if text.is_empty() {
200 return Ok(HookAction::Continue);
201 }
202 for guardrail in output_guardrails {
203 let result = guardrail.check_dyn(&text).await;
204 if !result.is_pass() {
205 return Ok(map_guardrail_result(result, "output"));
206 }
207 }
208 Ok(HookAction::Continue)
209 }
210 _ => Ok(HookAction::Continue),
211 }
212 }
213 }
214}