Skip to main content

neuron_runtime/
guardrail_hook.rs

1//! Adapter that wraps guardrails as an [`ObservabilityHook`].
2//!
3//! [`GuardrailHook`] runs input guardrails on [`HookEvent::PreLlmCall`] and output
4//! guardrails on [`HookEvent::PostLlmCall`], mapping [`GuardrailResult`] variants
5//! to [`HookAction`] values.
6//!
7//! # Example
8//!
9//! ```ignore
10//! use neuron_runtime::{GuardrailHook, InputGuardrail, OutputGuardrail, GuardrailResult};
11//!
12//! struct BlockSecrets;
13//! impl InputGuardrail for BlockSecrets {
14//!     fn check(&self, input: &str) -> impl Future<Output = GuardrailResult> + Send {
15//!         async move {
16//!             if input.contains("API_KEY") {
17//!                 GuardrailResult::Tripwire("secret detected".to_string())
18//!             } else {
19//!                 GuardrailResult::Pass
20//!             }
21//!         }
22//!     }
23//! }
24//!
25//! let hook = GuardrailHook::new().input_guardrail(BlockSecrets);
26//! // Use `hook` as an ObservabilityHook in the agent loop
27//! ```
28
29use std::sync::Arc;
30
31use neuron_types::{
32    ContentBlock, HookAction, HookError, HookEvent, ObservabilityHook, WasmCompatSend,
33};
34
35use crate::guardrail::{ErasedInputGuardrail, ErasedOutputGuardrail, GuardrailResult};
36
37/// An [`ObservabilityHook`] that runs guardrails on LLM input and output.
38///
39/// Built with a builder pattern. Input guardrails fire on [`HookEvent::PreLlmCall`],
40/// output guardrails fire on [`HookEvent::PostLlmCall`]. All other events pass
41/// through with [`HookAction::Continue`].
42///
43/// # Guardrail result mapping
44///
45/// - [`GuardrailResult::Pass`] -> [`HookAction::Continue`]
46/// - [`GuardrailResult::Tripwire`] -> [`HookAction::Terminate`] with the reason
47/// - [`GuardrailResult::Warn`] -> logs a warning via [`tracing::warn!`] and continues
48///
49/// # Example
50///
51/// ```ignore
52/// use neuron_runtime::{GuardrailHook, InputGuardrail, OutputGuardrail, GuardrailResult};
53///
54/// struct NoSecrets;
55/// impl InputGuardrail for NoSecrets {
56///     fn check(&self, input: &str) -> impl Future<Output = GuardrailResult> + Send {
57///         async move {
58///             if input.contains("sk-") {
59///                 GuardrailResult::Tripwire("secret in input".to_string())
60///             } else {
61///                 GuardrailResult::Pass
62///             }
63///         }
64///     }
65/// }
66///
67/// let hook = GuardrailHook::new()
68///     .input_guardrail(NoSecrets);
69/// ```
70pub struct GuardrailHook {
71    input_guardrails: Vec<Arc<dyn ErasedInputGuardrail>>,
72    output_guardrails: Vec<Arc<dyn ErasedOutputGuardrail>>,
73}
74
75impl GuardrailHook {
76    /// Create an empty `GuardrailHook` with no guardrails.
77    #[must_use]
78    pub fn new() -> Self {
79        Self {
80            input_guardrails: Vec::new(),
81            output_guardrails: Vec::new(),
82        }
83    }
84
85    /// Add an input guardrail.
86    ///
87    /// Input guardrails run on [`HookEvent::PreLlmCall`], checking the last
88    /// user message text in the request.
89    #[must_use]
90    pub fn input_guardrail<G>(mut self, guardrail: G) -> Self
91    where
92        G: ErasedInputGuardrail + 'static,
93    {
94        self.input_guardrails.push(Arc::new(guardrail));
95        self
96    }
97
98    /// Add an output guardrail.
99    ///
100    /// Output guardrails run on [`HookEvent::PostLlmCall`], checking the
101    /// assistant response text from the response message.
102    #[must_use]
103    pub fn output_guardrail<G>(mut self, guardrail: G) -> Self
104    where
105        G: ErasedOutputGuardrail + 'static,
106    {
107        self.output_guardrails.push(Arc::new(guardrail));
108        self
109    }
110}
111
112impl Default for GuardrailHook {
113    fn default() -> Self {
114        Self::new()
115    }
116}
117
118/// Extract the text content from the last user message in the request's messages.
119///
120/// Returns an empty string if there are no user messages or no text blocks.
121fn extract_last_user_text(messages: &[neuron_types::Message]) -> String {
122    for message in messages.iter().rev() {
123        if message.role == neuron_types::Role::User {
124            let texts: Vec<&str> = message
125                .content
126                .iter()
127                .filter_map(|block| match block {
128                    ContentBlock::Text(t) => Some(t.as_str()),
129                    _ => None,
130                })
131                .collect();
132            if !texts.is_empty() {
133                return texts.join("\n");
134            }
135        }
136    }
137    String::new()
138}
139
140/// Extract text content from the assistant response message.
141///
142/// Returns an empty string if there are no text blocks.
143fn extract_response_text(message: &neuron_types::Message) -> String {
144    let texts: Vec<&str> = message
145        .content
146        .iter()
147        .filter_map(|block| match block {
148            ContentBlock::Text(t) => Some(t.as_str()),
149            _ => None,
150        })
151        .collect();
152    texts.join("\n")
153}
154
155/// Map a [`GuardrailResult`] to a [`HookAction`], logging warnings as needed.
156fn map_guardrail_result(result: GuardrailResult, direction: &str) -> HookAction {
157    match result {
158        GuardrailResult::Pass => HookAction::Continue,
159        GuardrailResult::Tripwire(reason) => HookAction::Terminate { reason },
160        GuardrailResult::Warn(reason) => {
161            tracing::warn!("{direction} guardrail warning: {reason}");
162            HookAction::Continue
163        }
164    }
165}
166
167impl ObservabilityHook for GuardrailHook {
168    fn on_event(
169        &self,
170        event: HookEvent<'_>,
171    ) -> impl Future<Output = Result<HookAction, HookError>> + WasmCompatSend {
172        // Capture references needed for the async block before moving into it.
173        let input_guardrails = &self.input_guardrails;
174        let output_guardrails = &self.output_guardrails;
175
176        async move {
177            match event {
178                HookEvent::PreLlmCall { request } => {
179                    if input_guardrails.is_empty() {
180                        return Ok(HookAction::Continue);
181                    }
182                    let text = extract_last_user_text(&request.messages);
183                    if text.is_empty() {
184                        return Ok(HookAction::Continue);
185                    }
186                    for guardrail in input_guardrails {
187                        let result = guardrail.check_dyn(&text).await;
188                        if !result.is_pass() {
189                            return Ok(map_guardrail_result(result, "input"));
190                        }
191                    }
192                    Ok(HookAction::Continue)
193                }
194                HookEvent::PostLlmCall { response } => {
195                    if output_guardrails.is_empty() {
196                        return Ok(HookAction::Continue);
197                    }
198                    let text = extract_response_text(&response.message);
199                    if text.is_empty() {
200                        return Ok(HookAction::Continue);
201                    }
202                    for guardrail in output_guardrails {
203                        let result = guardrail.check_dyn(&text).await;
204                        if !result.is_pass() {
205                            return Ok(map_guardrail_result(result, "output"));
206                        }
207                    }
208                    Ok(HookAction::Continue)
209                }
210                _ => Ok(HookAction::Continue),
211            }
212        }
213    }
214}