Skip to main content

mlua_batteries/policy/
llm_policy.rs

1//! LLM request access policy.
2
3use std::collections::HashSet;
4
5use super::{PolicyError, Unrestricted};
6
7/// Policy that decides whether a given LLM request may be sent.
8///
9/// Called by the `llm` module before dispatching to a
10/// [`LlmProvider`](crate::llm::LlmProvider).
11///
12/// # Relationship to `HttpPolicy`
13///
14/// LLM requests are external HTTP calls, so they pass through **both**
15/// [`HttpPolicy`](super::HttpPolicy) and `LlmPolicy`:
16///
17/// - [`HttpPolicy`](super::HttpPolicy) — network-level: "is this URL reachable?"
18///   Checked first against the resolved base URL.
19/// - `LlmPolicy` — AI-specific: "should data be sent to this provider?"
20///   Addresses concerns that do not apply to general HTTP: data may be
21///   used for model training, subject to provider-specific retention
22///   policies, or expose sensitive context to a third-party AI system.
23///
24/// Both policies must allow the request for it to proceed.
25///
26/// # Built-in implementations
27///
28/// | Type | Behaviour |
29/// |------|-----------|
30/// | [`Unrestricted`] | No checks (default) |
31/// | [`LlmAllowList`] | Allow only listed providers |
32///
33/// # Custom implementations
34///
35/// ```rust,no_run
36/// use mlua_batteries::policy::{LlmPolicy, PolicyError};
37///
38/// struct OnlyLocal;
39///
40/// impl LlmPolicy for OnlyLocal {
41///     fn check_request(&self, _provider: &str, _model: &str, base_url: &str) -> Result<(), PolicyError> {
42///         if base_url.contains("localhost") || base_url.contains("127.0.0.1") {
43///             Ok(())
44///         } else {
45///             Err(PolicyError::new(format!("LLM denied: only local endpoints allowed, got '{base_url}'")))
46///         }
47///     }
48/// }
49/// ```
50pub trait LlmPolicy: Send + Sync + 'static {
51    /// Human-readable name for this policy, used in `Debug` output.
52    ///
53    /// The default implementation returns [`std::any::type_name`] of the
54    /// concrete type, which works correctly even through trait objects
55    /// because the vtable dispatches to the concrete implementation.
56    fn policy_name(&self) -> &'static str {
57        std::any::type_name::<Self>()
58    }
59
60    /// Validate an LLM request before it is sent.
61    ///
62    /// `provider` is the provider name (e.g. `"openai"`), `model` is the
63    /// model identifier, `base_url` is the resolved API base URL.
64    ///
65    /// Return `Ok(())` to allow, `Err(reason)` to deny.
66    fn check_request(&self, provider: &str, model: &str, base_url: &str)
67        -> Result<(), PolicyError>;
68}
69
70impl LlmPolicy for Unrestricted {
71    fn check_request(
72        &self,
73        _provider: &str,
74        _model: &str,
75        _base_url: &str,
76    ) -> Result<(), PolicyError> {
77        Ok(())
78    }
79}
80
81/// Allow only requests to listed LLM providers.
82///
83/// ```rust,no_run
84/// use mlua_batteries::policy::LlmAllowList;
85///
86/// let policy = LlmAllowList::new(["ollama", "openai"]);
87/// ```
88#[derive(Debug)]
89pub struct LlmAllowList {
90    allowed_providers: HashSet<String>,
91}
92
93impl LlmAllowList {
94    /// Create an allow-list from provider names.
95    pub fn new<I, S>(providers: I) -> Self
96    where
97        I: IntoIterator<Item = S>,
98        S: Into<String>,
99    {
100        Self {
101            allowed_providers: providers
102                .into_iter()
103                .map(Into::into)
104                .collect::<HashSet<_>>(),
105        }
106    }
107}
108
109impl LlmPolicy for LlmAllowList {
110    fn check_request(
111        &self,
112        provider: &str,
113        _model: &str,
114        _base_url: &str,
115    ) -> Result<(), PolicyError> {
116        if self.allowed_providers.contains(provider) {
117            Ok(())
118        } else {
119            Err(PolicyError::new(format!(
120                "LLM denied: provider '{provider}' is not in the allow list"
121            )))
122        }
123    }
124}