1use std::sync::Arc;
30use std::task::{Context, Poll};
31
32use futures::future::BoxFuture;
33use serde_json::Value;
34use tower::{Layer, Service, ServiceExt};
35
36use entelix_core::error::{Error, Result};
37use entelix_core::ir::ModelResponse;
38use entelix_core::service::{
39 ModelInvocation, ModelStream, StreamingModelInvocation, ToolInvocation,
40};
41
42use crate::error::PolicyError;
43use crate::tenant::PolicyRegistry;
44
45#[derive(Clone)]
48pub struct PolicyLayer {
49 manager: Arc<PolicyRegistry>,
50 rate_tokens_per_request: u32,
53}
54
55impl PolicyLayer {
56 pub const NAME: &'static str = "policy";
61
62 #[must_use]
65 pub fn new(manager: Arc<PolicyRegistry>) -> Self {
66 Self {
67 manager,
68 rate_tokens_per_request: 1,
69 }
70 }
71
72 #[must_use]
74 pub const fn with_rate_tokens(mut self, tokens: u32) -> Self {
75 self.rate_tokens_per_request = tokens;
76 self
77 }
78}
79
80impl entelix_core::NamedLayer for PolicyLayer {
81 fn layer_name(&self) -> &'static str {
82 Self::NAME
83 }
84}
85
86impl std::fmt::Debug for PolicyLayer {
87 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
88 f.debug_struct("PolicyLayer")
89 .field("tenants", &self.manager.tenant_count())
90 .field("rate_tokens_per_request", &self.rate_tokens_per_request)
91 .finish()
92 }
93}
94
95impl<S> Layer<S> for PolicyLayer {
96 type Service = PolicyService<S>;
97
98 fn layer(&self, inner: S) -> Self::Service {
99 PolicyService {
100 inner,
101 manager: Arc::clone(&self.manager),
102 rate_tokens_per_request: self.rate_tokens_per_request,
103 }
104 }
105}
106
107#[derive(Clone)]
111pub struct PolicyService<S> {
112 inner: S,
113 manager: Arc<PolicyRegistry>,
114 rate_tokens_per_request: u32,
115}
116
117impl<S> Service<ModelInvocation> for PolicyService<S>
118where
119 S: Service<ModelInvocation, Response = ModelResponse, Error = Error> + Clone + Send + 'static,
120 S::Future: Send + 'static,
121{
122 type Response = ModelResponse;
123 type Error = Error;
124 type Future = BoxFuture<'static, Result<ModelResponse>>;
125
126 fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<()>> {
127 self.inner.poll_ready(cx)
128 }
129
130 fn call(&mut self, mut invocation: ModelInvocation) -> Self::Future {
131 let manager = Arc::clone(&self.manager);
132 let inner = self.inner.clone();
133 let tokens = self.rate_tokens_per_request;
134 Box::pin(async move {
135 let policy = manager.policy_for(invocation.ctx.tenant_id());
136
137 if let Some(redactor) = &policy.redactor {
139 redactor
140 .redact_request(&mut invocation.request)
141 .await
142 .map_err(Error::from)?;
143 }
144 if let Some(quota) = &policy.quota {
145 quota
146 .check_pre_request(invocation.ctx.tenant_id(), tokens)
147 .await
148 .map_err(Error::from)?;
149 }
150 if let (Some(meter), Some(budget)) = (&policy.cost_meter, invocation.ctx.run_budget())
156 && let Some(estimate) = entelix_core::BudgetCostEstimator::estimate_pre_call(
157 meter.as_ref(),
158 &invocation.request,
159 &invocation.ctx,
160 )
161 .await
162 {
163 budget.check_pre_request_cost(estimate)?;
164 }
165
166 let tenant = invocation.ctx.tenant_id().to_owned();
167 let ctx_for_post = invocation.ctx.clone();
168 let request_snapshot = invocation.request.clone();
169 let mut response = inner.oneshot(invocation).await?;
170
171 if let Some(redactor) = &policy.redactor {
178 redactor
179 .redact_response(&mut response)
180 .await
181 .map_err(Error::from)?;
182 }
183 if let Some(meter) = &policy.cost_meter {
184 match meter.charge(&tenant, &response.model, &response.usage) {
185 Ok(_) => {}
186 Err(PolicyError::UnknownModel(model)) => {
187 tracing::warn!(
188 target: "entelix_policy::layer",
189 tenant = %tenant,
190 %model,
191 "no pricing configured; skipping cost charge"
192 );
193 }
194 Err(e) => return Err(Error::from(e)),
195 }
196 if let Some(budget) = ctx_for_post.run_budget()
197 && let Some(actual) = entelix_core::BudgetCostEstimator::calculate_actual(
198 meter.as_ref(),
199 &request_snapshot,
200 &response.usage,
201 &ctx_for_post,
202 )
203 .await
204 {
205 budget.observe_cost(actual)?;
206 }
207 }
208 Ok(response)
209 })
210 }
211}
212
213impl<S> Service<StreamingModelInvocation> for PolicyService<S>
214where
215 S: Service<StreamingModelInvocation, Response = ModelStream, Error = Error>
216 + Clone
217 + Send
218 + 'static,
219 S::Future: Send + 'static,
220{
221 type Response = ModelStream;
222 type Error = Error;
223 type Future = BoxFuture<'static, Result<ModelStream>>;
224
225 fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<()>> {
226 self.inner.poll_ready(cx)
227 }
228
229 fn call(&mut self, mut invocation: StreamingModelInvocation) -> Self::Future {
230 let manager = Arc::clone(&self.manager);
231 let inner = self.inner.clone();
232 let tokens = self.rate_tokens_per_request;
233 Box::pin(async move {
234 let policy = manager.policy_for(invocation.ctx().tenant_id());
235
236 if let Some(redactor) = &policy.redactor {
244 redactor
245 .redact_request(&mut invocation.inner.request)
246 .await
247 .map_err(Error::from)?;
248 }
249 if let Some(quota) = &policy.quota {
250 quota
251 .check_pre_request(invocation.ctx().tenant_id(), tokens)
252 .await
253 .map_err(Error::from)?;
254 }
255 if let (Some(meter), Some(budget)) = (&policy.cost_meter, invocation.ctx().run_budget())
260 && let Some(estimate) = entelix_core::BudgetCostEstimator::estimate_pre_call(
261 meter.as_ref(),
262 &invocation.inner.request,
263 invocation.ctx(),
264 )
265 .await
266 {
267 budget.check_pre_request_cost(estimate)?;
268 }
269
270 let tenant = invocation.ctx().tenant_id().clone();
271 let ctx_for_post = invocation.ctx().clone();
272 let request_snapshot = invocation.inner.request.clone();
273 let model_stream = inner.oneshot(invocation).await?;
274 let ModelStream { stream, completion } = model_stream;
275
276 let cost_meter = policy.cost_meter.clone();
282 let user_facing = async move {
283 let result = completion.await;
284 if let Ok(response) = &result
285 && let Some(meter) = &cost_meter
286 {
287 match meter.charge(&tenant, &response.model, &response.usage) {
288 Ok(_) => {}
289 Err(PolicyError::UnknownModel(model)) => {
290 tracing::warn!(
291 target: "entelix_policy::layer",
292 tenant = %tenant,
293 %model,
294 "no pricing configured; skipping cost charge"
295 );
296 }
297 Err(e) => return Err(Error::from(e)),
298 }
299 if let Some(budget) = ctx_for_post.run_budget()
300 && let Some(actual) = entelix_core::BudgetCostEstimator::calculate_actual(
301 meter.as_ref(),
302 &request_snapshot,
303 &response.usage,
304 &ctx_for_post,
305 )
306 .await
307 {
308 budget.observe_cost(actual)?;
309 }
310 }
311 result
312 };
313 Ok(ModelStream {
314 stream,
315 completion: Box::pin(user_facing),
316 })
317 })
318 }
319}
320
321impl<S> Service<ToolInvocation> for PolicyService<S>
322where
323 S: Service<ToolInvocation, Response = Value, Error = Error> + Clone + Send + 'static,
324 S::Future: Send + 'static,
325{
326 type Response = Value;
327 type Error = Error;
328 type Future = BoxFuture<'static, Result<Value>>;
329
330 fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<()>> {
331 self.inner.poll_ready(cx)
332 }
333
334 fn call(&mut self, mut invocation: ToolInvocation) -> Self::Future {
335 let manager = Arc::clone(&self.manager);
336 let inner = self.inner.clone();
337 Box::pin(async move {
338 let policy = manager.policy_for(invocation.ctx.tenant_id());
339 if let Some(redactor) = &policy.redactor {
340 redactor
341 .redact_json(&mut invocation.input)
342 .await
343 .map_err(Error::from)?;
344 }
345 let mut output = inner.oneshot(invocation).await?;
346 if let Some(redactor) = &policy.redactor {
347 redactor
348 .redact_json(&mut output)
349 .await
350 .map_err(Error::from)?;
351 }
352 Ok(output)
353 })
354 }
355}
356
357#[cfg(test)]
358#[allow(clippy::unwrap_used, clippy::indexing_slicing)]
359mod tests {
360 use entelix_core::TenantId;
361 use std::sync::Arc;
362 use std::sync::atomic::{AtomicU32, Ordering};
363 use std::task::Context as TaskContext;
364
365 use entelix_core::context::ExecutionContext;
366 use entelix_core::ir::{ContentPart, Message, ModelRequest, StopReason, Usage};
367 use rust_decimal::Decimal;
368 use serde_json::json;
369
370 use super::*;
371 use crate::cost::{CostMeter, ModelPricing, PricingTable};
372 use crate::pii::RegexRedactor;
373 use crate::quota::{Budget, QuotaLimiter};
374 use crate::rate_limit::TokenBucketLimiter;
375 use crate::tenant::TenantPolicy;
376 use std::str::FromStr;
377
378 fn d(s: &str) -> Decimal {
379 Decimal::from_str(s).unwrap()
380 }
381
382 #[derive(Clone)]
384 struct FakeModelService {
385 calls: Arc<AtomicU32>,
386 canned: ModelResponse,
387 }
388
389 impl FakeModelService {
390 fn new(canned: ModelResponse) -> Self {
391 Self {
392 calls: Arc::new(AtomicU32::new(0)),
393 canned,
394 }
395 }
396 }
397
398 impl Service<ModelInvocation> for FakeModelService {
399 type Response = ModelResponse;
400 type Error = Error;
401 type Future = BoxFuture<'static, Result<ModelResponse>>;
402
403 fn poll_ready(&mut self, _: &mut TaskContext<'_>) -> Poll<Result<()>> {
404 Poll::Ready(Ok(()))
405 }
406 fn call(&mut self, _inv: ModelInvocation) -> Self::Future {
407 self.calls.fetch_add(1, Ordering::SeqCst);
408 let canned = self.canned.clone();
409 Box::pin(async move { Ok(canned) })
410 }
411 }
412
413 fn make_request() -> ModelRequest {
414 ModelRequest {
415 model: "claude-opus-4-7".into(),
416 messages: vec![Message::user("contact user@acme.io for help")],
417 ..ModelRequest::default()
418 }
419 }
420
421 fn make_response() -> ModelResponse {
422 ModelResponse {
423 id: "r1".into(),
424 model: "claude-opus-4-7".into(),
425 stop_reason: StopReason::EndTurn,
426 content: vec![ContentPart::text("ack")],
427 usage: Usage::new(1000, 1000),
428 rate_limit: None,
429 warnings: Vec::new(),
430 provider_echoes: Vec::new(),
431 }
432 }
433
434 fn pricing() -> PricingTable {
435 PricingTable::new().add_model_pricing(
436 "claude-opus-4-7",
437 ModelPricing::new(d("15"), d("75"), d("1.5"), d("18.75")),
438 )
439 }
440
441 #[tokio::test]
442 async fn model_layer_redacts_request_then_charges_on_success() {
443 let meter = Arc::new(CostMeter::new(pricing()));
444 let mgr = Arc::new(
445 PolicyRegistry::new().with_tenant(
446 TenantId::new("acme"),
447 TenantPolicy::new()
448 .with_redactor(Arc::new(RegexRedactor::with_defaults()))
449 .with_cost_meter(meter.clone()),
450 ),
451 );
452 let leaf = FakeModelService::new(make_response());
453 let calls = leaf.calls.clone();
454 let layer = PolicyLayer::new(mgr);
455 let service = layer.layer(leaf);
456
457 let invocation = ModelInvocation::new(
458 make_request(),
459 ExecutionContext::new().with_tenant_id(TenantId::new("acme")),
460 );
461 let resp = tower::ServiceExt::oneshot(service, invocation)
462 .await
463 .unwrap();
464 assert_eq!(calls.load(Ordering::SeqCst), 1);
465 assert_eq!(meter.spent_by(&TenantId::new("acme")), d("90"));
467 assert_eq!(resp.id, "r1");
468 }
469
470 #[tokio::test]
471 async fn rate_refusal_returns_provider_429_and_skips_inner() {
472 let mgr = Arc::new(PolicyRegistry::new().with_tenant(
473 TenantId::new("acme"),
474 TenantPolicy::new().with_quota(Arc::new(QuotaLimiter::new(
475 Some(Arc::new(TokenBucketLimiter::new(1, 1.0).unwrap())),
476 None,
477 Budget::unlimited(),
478 ))),
479 ));
480 let leaf = FakeModelService::new(make_response());
481 let calls = leaf.calls.clone();
482 let layer = PolicyLayer::new(mgr);
483
484 let svc1 = layer.layer(leaf.clone());
486 let _ = tower::ServiceExt::oneshot(
487 svc1,
488 ModelInvocation::new(
489 make_request(),
490 ExecutionContext::new().with_tenant_id(TenantId::new("acme")),
491 ),
492 )
493 .await
494 .unwrap();
495 let svc2 = layer.layer(leaf);
497 let err = tower::ServiceExt::oneshot(
498 svc2,
499 ModelInvocation::new(
500 make_request(),
501 ExecutionContext::new().with_tenant_id(TenantId::new("acme")),
502 ),
503 )
504 .await
505 .unwrap_err();
506 match err {
507 Error::Provider { kind, .. } => {
508 assert_eq!(kind, entelix_core::ProviderErrorKind::Http(429));
509 }
510 other => panic!("expected Provider 429, got {other:?}"),
511 }
512 assert_eq!(
513 calls.load(Ordering::SeqCst),
514 1,
515 "inner must not run on refusal"
516 );
517 }
518
519 #[derive(Clone)]
521 struct EchoToolService;
522
523 impl Service<ToolInvocation> for EchoToolService {
524 type Response = serde_json::Value;
525 type Error = Error;
526 type Future = BoxFuture<'static, Result<serde_json::Value>>;
527
528 fn poll_ready(&mut self, _: &mut TaskContext<'_>) -> Poll<Result<()>> {
529 Poll::Ready(Ok(()))
530 }
531 fn call(&mut self, inv: ToolInvocation) -> Self::Future {
532 Box::pin(async move { Ok(inv.input) })
533 }
534 }
535
536 #[tokio::test]
537 async fn pre_call_cost_gate_blocks_when_projection_exceeds_budget() {
538 let meter = Arc::new(CostMeter::new(pricing()));
543 let mgr = Arc::new(PolicyRegistry::new().with_tenant(
544 TenantId::new("acme"),
545 TenantPolicy::new().with_cost_meter(meter.clone()),
546 ));
547 let leaf = FakeModelService::new(make_response());
548 let calls = leaf.calls.clone();
549 let service = PolicyLayer::new(mgr).layer(leaf);
550
551 let budget = entelix_core::RunBudget::unlimited().with_cost_limit_usd(d("0.10"));
552 let ctx = ExecutionContext::new()
553 .with_tenant_id(TenantId::new("acme"))
554 .with_run_budget(budget);
555 let err = tower::ServiceExt::oneshot(service, ModelInvocation::new(make_request(), ctx))
556 .await
557 .unwrap_err();
558 assert!(
559 matches!(
560 err,
561 Error::UsageLimitExceeded(entelix_core::UsageLimitBreach::CostUsd { .. })
562 ),
563 "got: {err:?}"
564 );
565 assert_eq!(
566 calls.load(Ordering::SeqCst),
567 0,
568 "inner dispatch must not fire when pre-call gate refuses"
569 );
570 assert_eq!(
571 meter.spent_by(&TenantId::new("acme")),
572 Decimal::ZERO,
573 "no ledger charge on refused dispatch"
574 );
575 }
576
577 #[tokio::test]
578 async fn cost_observation_populates_run_budget_after_ok() {
579 let meter = Arc::new(CostMeter::new(pricing()));
580 let mgr = Arc::new(PolicyRegistry::new().with_tenant(
581 TenantId::new("acme"),
582 TenantPolicy::new().with_cost_meter(meter.clone()),
583 ));
584 let leaf = FakeModelService::new(make_response());
585 let service = PolicyLayer::new(mgr).layer(leaf);
586
587 let budget = entelix_core::RunBudget::unlimited().with_cost_limit_usd(d("1000"));
591 let budget_for_assertion = budget.clone();
592 let ctx = ExecutionContext::new()
593 .with_tenant_id(TenantId::new("acme"))
594 .with_run_budget(budget);
595 let _ = tower::ServiceExt::oneshot(service, ModelInvocation::new(make_request(), ctx))
596 .await
597 .unwrap();
598 assert_eq!(budget_for_assertion.snapshot().cost_usd, d("90"));
600 assert_eq!(meter.spent_by(&TenantId::new("acme")), d("90"));
601 }
602
603 #[tokio::test]
604 async fn tool_layer_redacts_input_and_output() {
605 let mgr = Arc::new(PolicyRegistry::new().with_tenant(
606 TenantId::new("acme"),
607 TenantPolicy::new().with_redactor(Arc::new(RegexRedactor::with_defaults())),
608 ));
609 let layer = PolicyLayer::new(mgr);
610 let svc = layer.layer(EchoToolService);
611 let inv = ToolInvocation::new(
612 "tool_use_1".into(),
613 std::sync::Arc::new(entelix_core::tools::ToolMetadata::function(
614 "lookup",
615 "look up a record",
616 json!({"type": "object"}),
617 )),
618 json!({"email": "user@acme.io"}),
619 ExecutionContext::new().with_tenant_id(TenantId::new("acme")),
620 );
621 let out = tower::ServiceExt::oneshot(svc, inv).await.unwrap();
622 let txt = out["email"].as_str().unwrap();
625 assert!(txt.contains("[REDACTED:email]"), "{txt}");
626 }
627}