Skip to main content

entelix_policy/
layer.rs

1//! `PolicyLayer` — `tower::Layer<S>` middleware that fires the
2//! per-tenant policy stack (PII redactor + quota gate + cost meter)
3//! around every model and tool call.
4//!
5//! One layer struct, two `Service` impls — one for
6//! [`ModelInvocation`] and one for [`ToolInvocation`]. Compose via
7//! `ChatModel::layer(PolicyLayer::new(manager))` for model calls
8//! and `ToolRegistry::layer(PolicyLayer::new(manager))` for tool
9//! dispatch. The same `PolicyRegistry` backs both.
10//!
11//! ## Lifecycle (model calls)
12//!
13//! - **before inner.call**: PII `redact_request`, then quota
14//!   pre-check (rate + budget). A pre-check refusal short-circuits
15//!   before encode, surfacing as `Error::Provider { status: 429 |
16//!   402, ... }` per [`PolicyError`]'s `From` impl.
17//! - **after inner.call**: PII `redact_response`, then transactional
18//!   cost `charge`. F4 — charge fires only when the inner call
19//!   succeeded.
20//!
21//! ## Lifecycle (tool calls)
22//!
23//! - **before inner.call**: PII `redact_tool_input` walks the JSON
24//!   `input` and scrubs string leaves. Quota / cost meter don't
25//!   apply to tool calls — those are model-call concepts.
26//! - **after inner.call**: PII `redact_tool_output` on the JSON
27//!   response.
28
29use std::sync::Arc;
30use std::task::{Context, Poll};
31
32use futures::future::BoxFuture;
33use serde_json::Value;
34use tower::{Layer, Service, ServiceExt};
35
36use entelix_core::error::{Error, Result};
37use entelix_core::ir::ModelResponse;
38use entelix_core::service::{
39    ModelInvocation, ModelStream, StreamingModelInvocation, ToolInvocation,
40};
41
42use crate::error::PolicyError;
43use crate::tenant::PolicyRegistry;
44
45/// Layer that wraps an inner service with per-tenant policy
46/// enforcement.
47#[derive(Clone)]
48pub struct PolicyLayer {
49    manager: Arc<PolicyRegistry>,
50    /// Rate-limit tokens to acquire on each request. Most callers
51    /// use `1` — one model call = one bucket draw.
52    rate_tokens_per_request: u32,
53}
54
55impl PolicyLayer {
56    /// Patch-version-stable identifier surfaced through
57    /// [`entelix_core::ChatModel::layer_names`] /
58    /// `ToolRegistry::layer_names`. Renaming this constant is a
59    /// breaking change for dashboards keyed off the value.
60    pub const NAME: &'static str = "policy";
61
62    /// Build with the supplied manager; one rate-limit token per
63    /// request.
64    #[must_use]
65    pub fn new(manager: Arc<PolicyRegistry>) -> Self {
66        Self {
67            manager,
68            rate_tokens_per_request: 1,
69        }
70    }
71
72    /// Override how many rate-limit tokens each request costs.
73    #[must_use]
74    pub const fn with_rate_tokens(mut self, tokens: u32) -> Self {
75        self.rate_tokens_per_request = tokens;
76        self
77    }
78}
79
80impl entelix_core::NamedLayer for PolicyLayer {
81    fn layer_name(&self) -> &'static str {
82        Self::NAME
83    }
84}
85
86impl std::fmt::Debug for PolicyLayer {
87    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
88        f.debug_struct("PolicyLayer")
89            .field("tenants", &self.manager.tenant_count())
90            .field("rate_tokens_per_request", &self.rate_tokens_per_request)
91            .finish()
92    }
93}
94
95impl<S> Layer<S> for PolicyLayer {
96    type Service = PolicyService<S>;
97
98    fn layer(&self, inner: S) -> Self::Service {
99        PolicyService {
100            inner,
101            manager: Arc::clone(&self.manager),
102            rate_tokens_per_request: self.rate_tokens_per_request,
103        }
104    }
105}
106
107/// `Service` produced by [`PolicyLayer`]. Generic over the inner
108/// service type; specialised `Service<ModelInvocation>` and
109/// `Service<ToolInvocation>` impls below.
110#[derive(Clone)]
111pub struct PolicyService<S> {
112    inner: S,
113    manager: Arc<PolicyRegistry>,
114    rate_tokens_per_request: u32,
115}
116
117impl<S> Service<ModelInvocation> for PolicyService<S>
118where
119    S: Service<ModelInvocation, Response = ModelResponse, Error = Error> + Clone + Send + 'static,
120    S::Future: Send + 'static,
121{
122    type Response = ModelResponse;
123    type Error = Error;
124    type Future = BoxFuture<'static, Result<ModelResponse>>;
125
126    fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<()>> {
127        self.inner.poll_ready(cx)
128    }
129
130    fn call(&mut self, mut invocation: ModelInvocation) -> Self::Future {
131        let manager = Arc::clone(&self.manager);
132        let inner = self.inner.clone();
133        let tokens = self.rate_tokens_per_request;
134        Box::pin(async move {
135            let policy = manager.policy_for(invocation.ctx.tenant_id());
136
137            // Pre — redact then quota.
138            if let Some(redactor) = &policy.redactor {
139                redactor
140                    .redact_request(&mut invocation.request)
141                    .await
142                    .map_err(Error::from)?;
143            }
144            if let Some(quota) = &policy.quota {
145                quota
146                    .check_pre_request(invocation.ctx.tenant_id(), tokens)
147                    .await
148                    .map_err(Error::from)?;
149            }
150            // Pre-call cost gate — projects the worst-case charge
151            // against `RunBudget::cost_usd_limit` before the wire
152            // roundtrip fires. Skips silently when no budget is
153            // attached, no cost meter is configured, or the meter
154            // has no tariff for `request.model`.
155            if let (Some(meter), Some(budget)) = (&policy.cost_meter, invocation.ctx.run_budget())
156                && let Some(estimate) = entelix_core::BudgetCostEstimator::estimate_pre_call(
157                    meter.as_ref(),
158                    &invocation.request,
159                    &invocation.ctx,
160                )
161                .await
162            {
163                budget.check_pre_request_cost(estimate)?;
164            }
165
166            let tenant = invocation.ctx.tenant_id().to_owned();
167            let ctx_for_post = invocation.ctx.clone();
168            let request_snapshot = invocation.request.clone();
169            let mut response = inner.oneshot(invocation).await?;
170
171            // Post — redact then charge (tenant ledger) and observe
172            // (RunBudget axis). Both fire only on the `Ok` branch
173            // (invariant 12). The two paths share `PricingTable` via
174            // `CostMeter`'s `BudgetCostEstimator::calculate_actual`
175            // and `charge` — no rounding drift between ledger and
176            // budget axis because the source decimal is the same.
177            if let Some(redactor) = &policy.redactor {
178                redactor
179                    .redact_response(&mut response)
180                    .await
181                    .map_err(Error::from)?;
182            }
183            if let Some(meter) = &policy.cost_meter {
184                match meter.charge(&tenant, &response.model, &response.usage) {
185                    Ok(_) => {}
186                    Err(PolicyError::UnknownModel(model)) => {
187                        tracing::warn!(
188                            target: "entelix_policy::layer",
189                            tenant = %tenant,
190                            %model,
191                            "no pricing configured; skipping cost charge"
192                        );
193                    }
194                    Err(e) => return Err(Error::from(e)),
195                }
196                if let Some(budget) = ctx_for_post.run_budget()
197                    && let Some(actual) = entelix_core::BudgetCostEstimator::calculate_actual(
198                        meter.as_ref(),
199                        &request_snapshot,
200                        &response.usage,
201                        &ctx_for_post,
202                    )
203                    .await
204                {
205                    budget.observe_cost(actual)?;
206                }
207            }
208            Ok(response)
209        })
210    }
211}
212
213impl<S> Service<StreamingModelInvocation> for PolicyService<S>
214where
215    S: Service<StreamingModelInvocation, Response = ModelStream, Error = Error>
216        + Clone
217        + Send
218        + 'static,
219    S::Future: Send + 'static,
220{
221    type Response = ModelStream;
222    type Error = Error;
223    type Future = BoxFuture<'static, Result<ModelStream>>;
224
225    fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<()>> {
226        self.inner.poll_ready(cx)
227    }
228
229    fn call(&mut self, mut invocation: StreamingModelInvocation) -> Self::Future {
230        let manager = Arc::clone(&self.manager);
231        let inner = self.inner.clone();
232        let tokens = self.rate_tokens_per_request;
233        Box::pin(async move {
234            let policy = manager.policy_for(invocation.ctx().tenant_id());
235
236            // Pre — request-side redactor + quota gate. Streaming
237            // PII redaction on individual deltas is intentionally
238            // left out at 1.0 (chunk boundaries can split a PII
239            // pattern; streaming-aware redactors are post-1.0
240            // work). The request-side redactor still applies —
241            // operator-supplied input is fully formed before the
242            // first byte streams.
243            if let Some(redactor) = &policy.redactor {
244                redactor
245                    .redact_request(&mut invocation.inner.request)
246                    .await
247                    .map_err(Error::from)?;
248            }
249            if let Some(quota) = &policy.quota {
250                quota
251                    .check_pre_request(invocation.ctx().tenant_id(), tokens)
252                    .await
253                    .map_err(Error::from)?;
254            }
255            // Pre-call cost gate — same shape as `ModelInvocation`.
256            // The estimate operates on the request as the consumer
257            // submitted it; streaming-side accumulated `Usage`
258            // post-completion drives the post-call `observe_cost`.
259            if let (Some(meter), Some(budget)) = (&policy.cost_meter, invocation.ctx().run_budget())
260                && let Some(estimate) = entelix_core::BudgetCostEstimator::estimate_pre_call(
261                    meter.as_ref(),
262                    &invocation.inner.request,
263                    invocation.ctx(),
264                )
265                .await
266            {
267                budget.check_pre_request_cost(estimate)?;
268            }
269
270            let tenant = invocation.ctx().tenant_id().clone();
271            let ctx_for_post = invocation.ctx().clone();
272            let request_snapshot = invocation.inner.request.clone();
273            let model_stream = inner.oneshot(invocation).await?;
274            let ModelStream { stream, completion } = model_stream;
275
276            // Wrap completion: charge cost on `Ok` branch only —
277            // mirrors the one-shot `ModelInvocation` post-call
278            // path. A stream that errors mid-flight resolves
279            // `completion` to `Err` and skips the charge entirely
280            // (invariant 12 — no phantom cost on partial streams).
281            let cost_meter = policy.cost_meter.clone();
282            let user_facing = async move {
283                let result = completion.await;
284                if let Ok(response) = &result
285                    && let Some(meter) = &cost_meter
286                {
287                    match meter.charge(&tenant, &response.model, &response.usage) {
288                        Ok(_) => {}
289                        Err(PolicyError::UnknownModel(model)) => {
290                            tracing::warn!(
291                                target: "entelix_policy::layer",
292                                tenant = %tenant,
293                                %model,
294                                "no pricing configured; skipping cost charge"
295                            );
296                        }
297                        Err(e) => return Err(Error::from(e)),
298                    }
299                    if let Some(budget) = ctx_for_post.run_budget()
300                        && let Some(actual) = entelix_core::BudgetCostEstimator::calculate_actual(
301                            meter.as_ref(),
302                            &request_snapshot,
303                            &response.usage,
304                            &ctx_for_post,
305                        )
306                        .await
307                    {
308                        budget.observe_cost(actual)?;
309                    }
310                }
311                result
312            };
313            Ok(ModelStream {
314                stream,
315                completion: Box::pin(user_facing),
316            })
317        })
318    }
319}
320
321impl<S> Service<ToolInvocation> for PolicyService<S>
322where
323    S: Service<ToolInvocation, Response = Value, Error = Error> + Clone + Send + 'static,
324    S::Future: Send + 'static,
325{
326    type Response = Value;
327    type Error = Error;
328    type Future = BoxFuture<'static, Result<Value>>;
329
330    fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<()>> {
331        self.inner.poll_ready(cx)
332    }
333
334    fn call(&mut self, mut invocation: ToolInvocation) -> Self::Future {
335        let manager = Arc::clone(&self.manager);
336        let inner = self.inner.clone();
337        Box::pin(async move {
338            let policy = manager.policy_for(invocation.ctx.tenant_id());
339            if let Some(redactor) = &policy.redactor {
340                redactor
341                    .redact_json(&mut invocation.input)
342                    .await
343                    .map_err(Error::from)?;
344            }
345            let mut output = inner.oneshot(invocation).await?;
346            if let Some(redactor) = &policy.redactor {
347                redactor
348                    .redact_json(&mut output)
349                    .await
350                    .map_err(Error::from)?;
351            }
352            Ok(output)
353        })
354    }
355}
356
357#[cfg(test)]
358#[allow(clippy::unwrap_used, clippy::indexing_slicing)]
359mod tests {
360    use entelix_core::TenantId;
361    use std::sync::Arc;
362    use std::sync::atomic::{AtomicU32, Ordering};
363    use std::task::Context as TaskContext;
364
365    use entelix_core::context::ExecutionContext;
366    use entelix_core::ir::{ContentPart, Message, ModelRequest, StopReason, Usage};
367    use rust_decimal::Decimal;
368    use serde_json::json;
369
370    use super::*;
371    use crate::cost::{CostMeter, ModelPricing, PricingTable};
372    use crate::pii::RegexRedactor;
373    use crate::quota::{Budget, QuotaLimiter};
374    use crate::rate_limit::TokenBucketLimiter;
375    use crate::tenant::TenantPolicy;
376    use std::str::FromStr;
377
378    fn d(s: &str) -> Decimal {
379        Decimal::from_str(s).unwrap()
380    }
381
382    /// Trivial leaf service that returns a fixed response and counts calls.
383    #[derive(Clone)]
384    struct FakeModelService {
385        calls: Arc<AtomicU32>,
386        canned: ModelResponse,
387    }
388
389    impl FakeModelService {
390        fn new(canned: ModelResponse) -> Self {
391            Self {
392                calls: Arc::new(AtomicU32::new(0)),
393                canned,
394            }
395        }
396    }
397
398    impl Service<ModelInvocation> for FakeModelService {
399        type Response = ModelResponse;
400        type Error = Error;
401        type Future = BoxFuture<'static, Result<ModelResponse>>;
402
403        fn poll_ready(&mut self, _: &mut TaskContext<'_>) -> Poll<Result<()>> {
404            Poll::Ready(Ok(()))
405        }
406        fn call(&mut self, _inv: ModelInvocation) -> Self::Future {
407            self.calls.fetch_add(1, Ordering::SeqCst);
408            let canned = self.canned.clone();
409            Box::pin(async move { Ok(canned) })
410        }
411    }
412
413    fn make_request() -> ModelRequest {
414        ModelRequest {
415            model: "claude-opus-4-7".into(),
416            messages: vec![Message::user("contact user@acme.io for help")],
417            ..ModelRequest::default()
418        }
419    }
420
421    fn make_response() -> ModelResponse {
422        ModelResponse {
423            id: "r1".into(),
424            model: "claude-opus-4-7".into(),
425            stop_reason: StopReason::EndTurn,
426            content: vec![ContentPart::text("ack")],
427            usage: Usage::new(1000, 1000),
428            rate_limit: None,
429            warnings: Vec::new(),
430            provider_echoes: Vec::new(),
431        }
432    }
433
434    fn pricing() -> PricingTable {
435        PricingTable::new().add_model_pricing(
436            "claude-opus-4-7",
437            ModelPricing::new(d("15"), d("75"), d("1.5"), d("18.75")),
438        )
439    }
440
441    #[tokio::test]
442    async fn model_layer_redacts_request_then_charges_on_success() {
443        let meter = Arc::new(CostMeter::new(pricing()));
444        let mgr = Arc::new(
445            PolicyRegistry::new().with_tenant(
446                TenantId::new("acme"),
447                TenantPolicy::new()
448                    .with_redactor(Arc::new(RegexRedactor::with_defaults()))
449                    .with_cost_meter(meter.clone()),
450            ),
451        );
452        let leaf = FakeModelService::new(make_response());
453        let calls = leaf.calls.clone();
454        let layer = PolicyLayer::new(mgr);
455        let service = layer.layer(leaf);
456
457        let invocation = ModelInvocation::new(
458            make_request(),
459            ExecutionContext::new().with_tenant_id(TenantId::new("acme")),
460        );
461        let resp = tower::ServiceExt::oneshot(service, invocation)
462            .await
463            .unwrap();
464        assert_eq!(calls.load(Ordering::SeqCst), 1);
465        // 1000*15/1000 + 1000*75/1000 = 90
466        assert_eq!(meter.spent_by(&TenantId::new("acme")), d("90"));
467        assert_eq!(resp.id, "r1");
468    }
469
470    #[tokio::test]
471    async fn rate_refusal_returns_provider_429_and_skips_inner() {
472        let mgr = Arc::new(PolicyRegistry::new().with_tenant(
473            TenantId::new("acme"),
474            TenantPolicy::new().with_quota(Arc::new(QuotaLimiter::new(
475                Some(Arc::new(TokenBucketLimiter::new(1, 1.0).unwrap())),
476                None,
477                Budget::unlimited(),
478            ))),
479        ));
480        let leaf = FakeModelService::new(make_response());
481        let calls = leaf.calls.clone();
482        let layer = PolicyLayer::new(mgr);
483
484        // First call drains the single token.
485        let svc1 = layer.layer(leaf.clone());
486        let _ = tower::ServiceExt::oneshot(
487            svc1,
488            ModelInvocation::new(
489                make_request(),
490                ExecutionContext::new().with_tenant_id(TenantId::new("acme")),
491            ),
492        )
493        .await
494        .unwrap();
495        // Second call refused.
496        let svc2 = layer.layer(leaf);
497        let err = tower::ServiceExt::oneshot(
498            svc2,
499            ModelInvocation::new(
500                make_request(),
501                ExecutionContext::new().with_tenant_id(TenantId::new("acme")),
502            ),
503        )
504        .await
505        .unwrap_err();
506        match err {
507            Error::Provider { kind, .. } => {
508                assert_eq!(kind, entelix_core::ProviderErrorKind::Http(429));
509            }
510            other => panic!("expected Provider 429, got {other:?}"),
511        }
512        assert_eq!(
513            calls.load(Ordering::SeqCst),
514            1,
515            "inner must not run on refusal"
516        );
517    }
518
519    /// Tool-side leaf service: echoes its input.
520    #[derive(Clone)]
521    struct EchoToolService;
522
523    impl Service<ToolInvocation> for EchoToolService {
524        type Response = serde_json::Value;
525        type Error = Error;
526        type Future = BoxFuture<'static, Result<serde_json::Value>>;
527
528        fn poll_ready(&mut self, _: &mut TaskContext<'_>) -> Poll<Result<()>> {
529            Poll::Ready(Ok(()))
530        }
531        fn call(&mut self, inv: ToolInvocation) -> Self::Future {
532            Box::pin(async move { Ok(inv.input) })
533        }
534    }
535
536    #[tokio::test]
537    async fn pre_call_cost_gate_blocks_when_projection_exceeds_budget() {
538        // Worst-case projection from the `make_request` body + the
539        // unbounded-output fallback (8192 tokens) under the pricing
540        // table dominates a $0.10 ceiling; the pre-call gate must
541        // refuse the dispatch before the inner service runs.
542        let meter = Arc::new(CostMeter::new(pricing()));
543        let mgr = Arc::new(PolicyRegistry::new().with_tenant(
544            TenantId::new("acme"),
545            TenantPolicy::new().with_cost_meter(meter.clone()),
546        ));
547        let leaf = FakeModelService::new(make_response());
548        let calls = leaf.calls.clone();
549        let service = PolicyLayer::new(mgr).layer(leaf);
550
551        let budget = entelix_core::RunBudget::unlimited().with_cost_limit_usd(d("0.10"));
552        let ctx = ExecutionContext::new()
553            .with_tenant_id(TenantId::new("acme"))
554            .with_run_budget(budget);
555        let err = tower::ServiceExt::oneshot(service, ModelInvocation::new(make_request(), ctx))
556            .await
557            .unwrap_err();
558        assert!(
559            matches!(
560                err,
561                Error::UsageLimitExceeded(entelix_core::UsageLimitBreach::CostUsd { .. })
562            ),
563            "got: {err:?}"
564        );
565        assert_eq!(
566            calls.load(Ordering::SeqCst),
567            0,
568            "inner dispatch must not fire when pre-call gate refuses"
569        );
570        assert_eq!(
571            meter.spent_by(&TenantId::new("acme")),
572            Decimal::ZERO,
573            "no ledger charge on refused dispatch"
574        );
575    }
576
577    #[tokio::test]
578    async fn cost_observation_populates_run_budget_after_ok() {
579        let meter = Arc::new(CostMeter::new(pricing()));
580        let mgr = Arc::new(PolicyRegistry::new().with_tenant(
581            TenantId::new("acme"),
582            TenantPolicy::new().with_cost_meter(meter.clone()),
583        ));
584        let leaf = FakeModelService::new(make_response());
585        let service = PolicyLayer::new(mgr).layer(leaf);
586
587        // Budget is high enough that the worst-case pre-call estimate
588        // passes; the post-call `observe_cost` then records the actual
589        // charge against the RunBudget axis.
590        let budget = entelix_core::RunBudget::unlimited().with_cost_limit_usd(d("1000"));
591        let budget_for_assertion = budget.clone();
592        let ctx = ExecutionContext::new()
593            .with_tenant_id(TenantId::new("acme"))
594            .with_run_budget(budget);
595        let _ = tower::ServiceExt::oneshot(service, ModelInvocation::new(make_request(), ctx))
596            .await
597            .unwrap();
598        // 1000 in / 1000 out @ $15 / $75 per 1k = $90.
599        assert_eq!(budget_for_assertion.snapshot().cost_usd, d("90"));
600        assert_eq!(meter.spent_by(&TenantId::new("acme")), d("90"));
601    }
602
603    #[tokio::test]
604    async fn tool_layer_redacts_input_and_output() {
605        let mgr = Arc::new(PolicyRegistry::new().with_tenant(
606            TenantId::new("acme"),
607            TenantPolicy::new().with_redactor(Arc::new(RegexRedactor::with_defaults())),
608        ));
609        let layer = PolicyLayer::new(mgr);
610        let svc = layer.layer(EchoToolService);
611        let inv = ToolInvocation::new(
612            "tool_use_1".into(),
613            std::sync::Arc::new(entelix_core::tools::ToolMetadata::function(
614                "lookup",
615                "look up a record",
616                json!({"type": "object"}),
617            )),
618            json!({"email": "user@acme.io"}),
619            ExecutionContext::new().with_tenant_id(TenantId::new("acme")),
620        );
621        let out = tower::ServiceExt::oneshot(svc, inv).await.unwrap();
622        // Both directions redacted: input redacted before echo;
623        // echoed output redacted again on the way back.
624        let txt = out["email"].as_str().unwrap();
625        assert!(txt.contains("[REDACTED:email]"), "{txt}");
626    }
627}