Skip to main content

neuron_runtime/
guardrail.rs

1//! Input and output guardrails with tripwire support.
2//!
3//! Guardrails check input before it reaches the LLM and output before it
4//! reaches the user. A [`GuardrailResult::Tripwire`] halts execution
5//! immediately, while [`GuardrailResult::Warn`] allows execution to continue
6//! with a logged warning.
7
8use std::future::Future;
9
10use neuron_types::{WasmCompatSend, WasmCompatSync};
11
12/// Result of a guardrail check.
13#[derive(Debug, Clone)]
14pub enum GuardrailResult {
15    /// Input/output is acceptable.
16    Pass,
17    /// Immediately halt execution. The string explains why.
18    Tripwire(String),
19    /// Allow execution but log a warning. The string is the warning message.
20    Warn(String),
21}
22
23impl GuardrailResult {
24    /// Returns `true` if the result is [`GuardrailResult::Pass`].
25    #[must_use]
26    pub fn is_pass(&self) -> bool {
27        matches!(self, Self::Pass)
28    }
29
30    /// Returns `true` if the result is [`GuardrailResult::Tripwire`].
31    #[must_use]
32    pub fn is_tripwire(&self) -> bool {
33        matches!(self, Self::Tripwire(_))
34    }
35
36    /// Returns `true` if the result is [`GuardrailResult::Warn`].
37    #[must_use]
38    pub fn is_warn(&self) -> bool {
39        matches!(self, Self::Warn(_))
40    }
41}
42
43/// Guardrail that checks input before it reaches the LLM.
44///
45/// # Example
46///
47/// ```ignore
48/// use neuron_runtime::*;
49///
50/// struct NoSecrets;
51/// impl InputGuardrail for NoSecrets {
52///     fn check(&self, input: &str) -> impl Future<Output = GuardrailResult> + Send {
53///         async move {
54///             if input.contains("API_KEY") {
55///                 GuardrailResult::Tripwire("Input contains API key".to_string())
56///             } else {
57///                 GuardrailResult::Pass
58///             }
59///         }
60///     }
61/// }
62/// ```
63pub trait InputGuardrail: WasmCompatSend + WasmCompatSync {
64    /// Check the input text and return a guardrail result.
65    fn check(&self, input: &str) -> impl Future<Output = GuardrailResult> + WasmCompatSend;
66}
67
68/// Guardrail that checks output before it reaches the user.
69///
70/// # Example
71///
72/// ```ignore
73/// use neuron_runtime::*;
74///
75/// struct NoLeakedSecrets;
76/// impl OutputGuardrail for NoLeakedSecrets {
77///     fn check(&self, output: &str) -> impl Future<Output = GuardrailResult> + Send {
78///         async move {
79///             if output.contains("sk-") {
80///                 GuardrailResult::Tripwire("Output contains secret key".to_string())
81///             } else {
82///                 GuardrailResult::Pass
83///             }
84///         }
85///     }
86/// }
87/// ```
88pub trait OutputGuardrail: WasmCompatSend + WasmCompatSync {
89    /// Check the output text and return a guardrail result.
90    fn check(&self, output: &str) -> impl Future<Output = GuardrailResult> + WasmCompatSend;
91}
92
93/// Run a sequence of input guardrails, returning the first non-Pass result.
94///
95/// Returns [`GuardrailResult::Pass`] if all guardrails pass.
96pub async fn run_input_guardrails(
97    guardrails: &[&dyn ErasedInputGuardrail],
98    input: &str,
99) -> GuardrailResult {
100    for guardrail in guardrails {
101        let result = guardrail.check_dyn(input).await;
102        if !result.is_pass() {
103            return result;
104        }
105    }
106    GuardrailResult::Pass
107}
108
109/// Run a sequence of output guardrails, returning the first non-Pass result.
110///
111/// Returns [`GuardrailResult::Pass`] if all guardrails pass.
112pub async fn run_output_guardrails(
113    guardrails: &[&dyn ErasedOutputGuardrail],
114    output: &str,
115) -> GuardrailResult {
116    for guardrail in guardrails {
117        let result = guardrail.check_dyn(output).await;
118        if !result.is_pass() {
119            return result;
120        }
121    }
122    GuardrailResult::Pass
123}
124
125// --- Type erasure for guardrails (RPITIT is not dyn-compatible) ---
126
127/// Dyn-compatible wrapper for [`InputGuardrail`].
128pub trait ErasedInputGuardrail: WasmCompatSend + WasmCompatSync {
129    /// Check input, returning a boxed future.
130    fn check_dyn<'a>(
131        &'a self,
132        input: &'a str,
133    ) -> std::pin::Pin<Box<dyn Future<Output = GuardrailResult> + Send + 'a>>;
134}
135
136impl<T: InputGuardrail> ErasedInputGuardrail for T {
137    fn check_dyn<'a>(
138        &'a self,
139        input: &'a str,
140    ) -> std::pin::Pin<Box<dyn Future<Output = GuardrailResult> + Send + 'a>> {
141        Box::pin(self.check(input))
142    }
143}
144
145/// Dyn-compatible wrapper for [`OutputGuardrail`].
146pub trait ErasedOutputGuardrail: WasmCompatSend + WasmCompatSync {
147    /// Check output, returning a boxed future.
148    fn check_dyn<'a>(
149        &'a self,
150        output: &'a str,
151    ) -> std::pin::Pin<Box<dyn Future<Output = GuardrailResult> + Send + 'a>>;
152}
153
154impl<T: OutputGuardrail> ErasedOutputGuardrail for T {
155    fn check_dyn<'a>(
156        &'a self,
157        output: &'a str,
158    ) -> std::pin::Pin<Box<dyn Future<Output = GuardrailResult> + Send + 'a>> {
159        Box::pin(self.check(output))
160    }
161}