mlua_batteries/policy/llm_policy.rs
1//! LLM request access policy.
2
3use std::collections::HashSet;
4
5use super::{PolicyError, Unrestricted};
6
7/// Policy that decides whether a given LLM request may be sent.
8///
9/// Called by the `llm` module before dispatching to a
10/// [`LlmProvider`](crate::llm::LlmProvider).
11///
12/// # Relationship to `HttpPolicy`
13///
14/// LLM requests are external HTTP calls, so they pass through **both**
15/// [`HttpPolicy`](super::HttpPolicy) and `LlmPolicy`:
16///
17/// - [`HttpPolicy`](super::HttpPolicy) — network-level: "is this URL reachable?"
18/// Checked first against the resolved base URL.
19/// - `LlmPolicy` — AI-specific: "should data be sent to this provider?"
20/// Addresses concerns that do not apply to general HTTP: data may be
21/// used for model training, subject to provider-specific retention
22/// policies, or expose sensitive context to a third-party AI system.
23///
24/// Both policies must allow the request for it to proceed.
25///
26/// # Built-in implementations
27///
28/// | Type | Behaviour |
29/// |------|-----------|
30/// | [`Unrestricted`] | No checks (default) |
31/// | [`LlmAllowList`] | Allow only listed providers |
32///
33/// # Custom implementations
34///
35/// ```rust,no_run
36/// use mlua_batteries::policy::{LlmPolicy, PolicyError};
37///
38/// struct OnlyLocal;
39///
40/// impl LlmPolicy for OnlyLocal {
41/// fn check_request(&self, _provider: &str, _model: &str, base_url: &str) -> Result<(), PolicyError> {
42/// if base_url.contains("localhost") || base_url.contains("127.0.0.1") {
43/// Ok(())
44/// } else {
45/// Err(PolicyError::new(format!("LLM denied: only local endpoints allowed, got '{base_url}'")))
46/// }
47/// }
48/// }
49/// ```
50pub trait LlmPolicy: Send + Sync + 'static {
51 /// Human-readable name for this policy, used in `Debug` output.
52 ///
53 /// The default implementation returns [`std::any::type_name`] of the
54 /// concrete type, which works correctly even through trait objects
55 /// because the vtable dispatches to the concrete implementation.
56 fn policy_name(&self) -> &'static str {
57 std::any::type_name::<Self>()
58 }
59
60 /// Validate an LLM request before it is sent.
61 ///
62 /// `provider` is the provider name (e.g. `"openai"`), `model` is the
63 /// model identifier, `base_url` is the resolved API base URL.
64 ///
65 /// Return `Ok(())` to allow, `Err(reason)` to deny.
66 fn check_request(&self, provider: &str, model: &str, base_url: &str)
67 -> Result<(), PolicyError>;
68}
69
70impl LlmPolicy for Unrestricted {
71 fn check_request(
72 &self,
73 _provider: &str,
74 _model: &str,
75 _base_url: &str,
76 ) -> Result<(), PolicyError> {
77 Ok(())
78 }
79}
80
81/// Allow only requests to listed LLM providers.
82///
83/// ```rust,no_run
84/// use mlua_batteries::policy::LlmAllowList;
85///
86/// let policy = LlmAllowList::new(["ollama", "openai"]);
87/// ```
88#[derive(Debug)]
89pub struct LlmAllowList {
90 allowed_providers: HashSet<String>,
91}
92
93impl LlmAllowList {
94 /// Create an allow-list from provider names.
95 pub fn new<I, S>(providers: I) -> Self
96 where
97 I: IntoIterator<Item = S>,
98 S: Into<String>,
99 {
100 Self {
101 allowed_providers: providers
102 .into_iter()
103 .map(Into::into)
104 .collect::<HashSet<_>>(),
105 }
106 }
107}
108
109impl LlmPolicy for LlmAllowList {
110 fn check_request(
111 &self,
112 provider: &str,
113 _model: &str,
114 _base_url: &str,
115 ) -> Result<(), PolicyError> {
116 if self.allowed_providers.contains(provider) {
117 Ok(())
118 } else {
119 Err(PolicyError::new(format!(
120 "LLM denied: provider '{provider}' is not in the allow list"
121 )))
122 }
123 }
124}