Skip to main content

entelix_memory/
metered.rs

1//! `MeteredEmbedder<E>` — wraps any `E: Embedder` and emits
2//! `gen_ai.embedding.cost` (plus `usage`, `duration_ms`) per call.
3//!
4//! Cost calculation flows through the [`EmbeddingCostCalculator`]
5//! trait so deployments can plug in any pricing source — typically
6//! `entelix_policy::CostMeter` for unified billing alongside model
7//! and tool costs.
8//!
9//! ## F4 transactional discipline
10//!
11//! Cost is computed and emitted **only after** `inner.embed` /
12//! `embed_batch` returns Ok. A failed embedder call never produces a
13//! phantom charge in telemetry — same rule the model and tool paths
14//! enforce for `gen_ai.usage.cost` and `gen_ai.tool.cost`.
15//!
16//! ## Provider-supplied usage required
17//!
18//! When the inner embedder returns `Embedding::usage = None` (stub
19//! embedders, hash-based encoders) the wrapper still emits a
20//! `gen_ai.embedding.start`/`.end` pair for visibility but skips the
21//! cost attribute — without a token count there is nothing to charge.
22
23use std::sync::Arc;
24use std::time::Instant;
25
26use async_trait::async_trait;
27use entelix_core::context::ExecutionContext;
28use entelix_core::cost::CostCalculator;
29use entelix_core::error::Result;
30use entelix_core::ir::Usage;
31
32use crate::traits::{Embedder, Embedding, EmbeddingUsage};
33
34/// Compute a monetary cost for one embedder call.
35///
36/// Implementors are pure with respect to the caller's request — they
37/// may consult internal caches (a pricing table) but must not mutate
38/// caller state. Implementations are typically shared across many
39/// calls, so they must be `Send + Sync + 'static`.
40///
41/// `ctx` lets multi-tenant calculators select per-tenant pricing
42/// rows via [`ExecutionContext::tenant_id`]. Single-tenant
43/// calculators ignore it. Returns `None` when no pricing applies —
44/// telemetry consumers omit the cost attribute (silent zero would
45/// hide a missing-pricing-row deployment bug).
46#[async_trait]
47pub trait EmbeddingCostCalculator: Send + Sync + 'static {
48    /// Compute the cost of one embedder call given the request
49    /// context, the embedder model name (operator-supplied at
50    /// `MeteredEmbedder` construction), and the embedder's
51    /// reported usage record.
52    async fn compute_cost(
53        &self,
54        model: &str,
55        usage: &EmbeddingUsage,
56        ctx: &ExecutionContext,
57    ) -> Option<f64>;
58}
59
60/// `Embedder` decorator that emits OTel-compatible telemetry per
61/// call (and optional cost via [`EmbeddingCostCalculator`]).
62///
63/// Wraps any inner `E: Embedder` and itself implements `Embedder`,
64/// so the wrapper drops in transparently anywhere the bare type
65/// was used.
66pub struct MeteredEmbedder<E>
67where
68    E: Embedder,
69{
70    inner: Arc<E>,
71    model: Arc<str>,
72    cost_calculator: Option<Arc<dyn EmbeddingCostCalculator>>,
73}
74
75impl<E> MeteredEmbedder<E>
76where
77    E: Embedder,
78{
79    /// Wrap `inner` with a metered surface. `model` is the wire-name
80    /// the operator wants surfaced in telemetry (`gen_ai.embedding.model`)
81    /// and used as the lookup key in the cost calculator's pricing
82    /// table.
83    pub fn new(inner: E, model: impl Into<Arc<str>>) -> Self {
84        Self {
85            inner: Arc::new(inner),
86            model: model.into(),
87            cost_calculator: None,
88        }
89    }
90
91    /// Variant for callers that already hold an `Arc<E>` (typical
92    /// when the embedder is shared across multiple memory backends).
93    pub fn from_arc(inner: Arc<E>, model: impl Into<Arc<str>>) -> Self {
94        Self {
95            inner,
96            model: model.into(),
97            cost_calculator: None,
98        }
99    }
100
101    /// Attach an [`EmbeddingCostCalculator`]. When set, the wrapper
102    /// emits `gen_ai.embedding.cost` on the success branch of every
103    /// embed / embed_batch call whose `(tenant, model)` resolves to
104    /// a pricing row.
105    #[must_use]
106    pub fn with_cost_calculator(mut self, calculator: Arc<dyn EmbeddingCostCalculator>) -> Self {
107        self.cost_calculator = Some(calculator);
108        self
109    }
110
111    /// Borrow the operator-supplied model name surfaced in
112    /// telemetry — useful for tests and for dashboards that
113    /// label rows by model.
114    pub fn model(&self) -> &str {
115        &self.model
116    }
117
118    /// Helper: emit a single `gen_ai.embedding.end` event with the
119    /// computed cost (when calculator + usage both present).
120    async fn emit_end(
121        &self,
122        ctx: &ExecutionContext,
123        usage: Option<&EmbeddingUsage>,
124        duration_ms: u64,
125        batch_size: usize,
126    ) {
127        let cost = match (usage, &self.cost_calculator) {
128            (Some(u), Some(calc)) => calc.compute_cost(&self.model, u, ctx).await,
129            _ => None,
130        };
131        let input_tokens = usage.map(|u| u.input_tokens);
132        tracing::event!(
133            target: "gen_ai",
134            tracing::Level::INFO,
135            gen_ai.system = "embedder",
136            gen_ai.operation.name = "embed",
137            gen_ai.embedding.model = %self.model,
138            gen_ai.embedding.batch_size = batch_size,
139            gen_ai.usage.input_tokens = input_tokens,
140            gen_ai.embedding.cost = cost,
141            duration_ms,
142            entelix.tenant_id = %ctx.tenant_id(),
143            entelix.run_id = ctx.run_id(),
144            "gen_ai.embedding.end"
145        );
146    }
147}
148
149#[async_trait]
150impl<E> Embedder for MeteredEmbedder<E>
151where
152    E: Embedder,
153{
154    fn dimension(&self) -> usize {
155        self.inner.dimension()
156    }
157
158    async fn embed(&self, text: &str, ctx: &ExecutionContext) -> Result<Embedding> {
159        let started_at = Instant::now();
160        let result = self.inner.embed(text, ctx).await?;
161        let duration_ms = u64::try_from(started_at.elapsed().as_millis()).unwrap_or(u64::MAX);
162        self.emit_end(ctx, result.usage.as_ref(), duration_ms, 1)
163            .await;
164        Ok(result)
165    }
166
167    async fn embed_batch(
168        &self,
169        texts: &[String],
170        ctx: &ExecutionContext,
171    ) -> Result<Vec<Embedding>> {
172        let started_at = Instant::now();
173        let result = self.inner.embed_batch(texts, ctx).await?;
174        let duration_ms = u64::try_from(started_at.elapsed().as_millis()).unwrap_or(u64::MAX);
175        // Sum input_tokens across the batch for one combined event.
176        // Per-element emission would flood telemetry on large
177        // batches; aggregated count gives dashboards the same total
178        // at one event per call.
179        let aggregated = aggregate_usage(&result);
180        self.emit_end(ctx, aggregated.as_ref(), duration_ms, texts.len())
181            .await;
182        Ok(result)
183    }
184}
185
186fn aggregate_usage(embeddings: &[Embedding]) -> Option<EmbeddingUsage> {
187    let mut total: u32 = 0;
188    let mut any = false;
189    for e in embeddings {
190        if let Some(u) = e.usage {
191            total = total.saturating_add(u.input_tokens);
192            any = true;
193        }
194    }
195    any.then_some(EmbeddingUsage::new(total))
196}
197
198/// Adapter that bridges any [`CostCalculator`] (`ChatModel` pricing
199/// source) into the [`EmbeddingCostCalculator`] surface.
200///
201/// Embeddings only consume input tokens, so the adapter constructs
202/// a synthetic [`Usage`] with `input_tokens` populated from the
203/// embedder's [`EmbeddingUsage`] and delegates to the wrapped
204/// calculator. Operators with a single shared `entelix_policy::CostMeter`
205/// pricing table use this to charge embedding calls from the same
206/// source as model and tool calls — one pricing source, three cost
207/// surfaces, no drift.
208pub struct CostCalculatorAdapter {
209    inner: Arc<dyn CostCalculator>,
210}
211
212impl CostCalculatorAdapter {
213    /// Wrap the supplied calculator. The adapter forwards every
214    /// embed call as a synthetic `Usage` and lets the inner
215    /// calculator's pricing table do the lookup.
216    #[must_use]
217    pub const fn new(inner: Arc<dyn CostCalculator>) -> Self {
218        Self { inner }
219    }
220}
221
222#[async_trait]
223impl EmbeddingCostCalculator for CostCalculatorAdapter {
224    async fn compute_cost(
225        &self,
226        model: &str,
227        usage: &EmbeddingUsage,
228        ctx: &ExecutionContext,
229    ) -> Option<f64> {
230        let usage = Usage::new(usage.input_tokens, 0);
231        self.inner.compute_cost(model, &usage, ctx).await
232    }
233}
234
235#[cfg(test)]
236#[allow(clippy::unwrap_used, clippy::indexing_slicing)]
237mod tests {
238    use super::*;
239    use std::sync::atomic::{AtomicUsize, Ordering};
240
241    /// Stub embedder that returns deterministic vectors and reports
242    /// usage proportional to input length.
243    struct StubEmbedder {
244        dim: usize,
245    }
246
247    #[async_trait]
248    impl Embedder for StubEmbedder {
249        fn dimension(&self) -> usize {
250            self.dim
251        }
252        async fn embed(&self, text: &str, _ctx: &ExecutionContext) -> Result<Embedding> {
253            #[allow(clippy::cast_possible_truncation)]
254            let tokens = text.len() as u32;
255            Ok(Embedding::new(vec![0.0; self.dim]).with_usage(EmbeddingUsage::new(tokens)))
256        }
257    }
258
259    /// Stub embedder that fails — used to verify the metered
260    /// wrapper does NOT emit cost on the error branch.
261    struct FailingEmbedder;
262
263    #[async_trait]
264    impl Embedder for FailingEmbedder {
265        fn dimension(&self) -> usize {
266            4
267        }
268        async fn embed(&self, _text: &str, _ctx: &ExecutionContext) -> Result<Embedding> {
269            Err(entelix_core::Error::config("embedder down"))
270        }
271    }
272
273    /// Counting calculator: returns a fixed cost and tracks calls.
274    struct CountingCalculator {
275        cost: f64,
276        calls: Arc<AtomicUsize>,
277    }
278
279    #[async_trait]
280    impl EmbeddingCostCalculator for CountingCalculator {
281        async fn compute_cost(
282            &self,
283            _model: &str,
284            _usage: &EmbeddingUsage,
285            _ctx: &ExecutionContext,
286        ) -> Option<f64> {
287            self.calls.fetch_add(1, Ordering::SeqCst);
288            Some(self.cost)
289        }
290    }
291
292    #[tokio::test]
293    async fn metered_embed_passes_through_inner_embedding() {
294        let metered = MeteredEmbedder::new(StubEmbedder { dim: 8 }, "stub-model");
295        let ctx = ExecutionContext::new();
296        let out = metered.embed("hello", &ctx).await.unwrap();
297        assert_eq!(out.vector.len(), 8);
298        assert_eq!(out.usage.unwrap().input_tokens, 5);
299    }
300
301    #[tokio::test]
302    async fn metered_embed_invokes_calculator_on_success() {
303        let calls = Arc::new(AtomicUsize::new(0));
304        let calc = Arc::new(CountingCalculator {
305            cost: 0.0001,
306            calls: calls.clone(),
307        });
308        let metered =
309            MeteredEmbedder::new(StubEmbedder { dim: 4 }, "stub-model").with_cost_calculator(calc);
310        let _ = metered
311            .embed("hello", &ExecutionContext::new())
312            .await
313            .unwrap();
314        assert_eq!(calls.load(Ordering::SeqCst), 1);
315    }
316
317    #[tokio::test]
318    async fn metered_embed_skips_calculator_on_failure() {
319        // F4: a failed inner call must NEVER trigger cost
320        // computation — phantom charges break billing audits.
321        let calls = Arc::new(AtomicUsize::new(0));
322        let calc = Arc::new(CountingCalculator {
323            cost: 0.99,
324            calls: calls.clone(),
325        });
326        let metered =
327            MeteredEmbedder::new(FailingEmbedder, "stub-model").with_cost_calculator(calc);
328        let err = metered
329            .embed("hi", &ExecutionContext::new())
330            .await
331            .unwrap_err();
332        assert!(matches!(err, entelix_core::Error::Config(_)));
333        assert_eq!(
334            calls.load(Ordering::SeqCst),
335            0,
336            "cost calculator must not fire on the error branch"
337        );
338    }
339
340    #[tokio::test]
341    async fn metered_embed_batch_aggregates_usage_into_one_event() {
342        let calls = Arc::new(AtomicUsize::new(0));
343        let calc = Arc::new(CountingCalculator {
344            cost: 0.0,
345            calls: calls.clone(),
346        });
347        let metered =
348            MeteredEmbedder::new(StubEmbedder { dim: 2 }, "stub-model").with_cost_calculator(calc);
349        let texts = vec!["a".to_owned(), "bb".to_owned(), "ccc".to_owned()];
350        let out = metered
351            .embed_batch(&texts, &ExecutionContext::new())
352            .await
353            .unwrap();
354        assert_eq!(out.len(), 3);
355        // Calculator should be called exactly once for the aggregate.
356        assert_eq!(calls.load(Ordering::SeqCst), 1);
357    }
358
359    /// Stub `CostCalculator` (`ChatModel` surface) that records the
360    /// `input_tokens` it sees so the adapter test can confirm the
361    /// embedding usage round-tripped into the synthetic `Usage`.
362    struct ChatStyleCalculator {
363        rate_per_token: f64,
364        observed_input_tokens: Arc<std::sync::Mutex<Vec<u32>>>,
365    }
366
367    #[async_trait]
368    impl entelix_core::CostCalculator for ChatStyleCalculator {
369        async fn compute_cost(
370            &self,
371            _model: &str,
372            usage: &entelix_core::ir::Usage,
373            _ctx: &ExecutionContext,
374        ) -> Option<f64> {
375            self.observed_input_tokens
376                .lock()
377                .unwrap()
378                .push(usage.input_tokens);
379            Some(self.rate_per_token * f64::from(usage.input_tokens))
380        }
381    }
382
383    #[tokio::test]
384    async fn cost_calculator_adapter_forwards_embedding_usage_as_synthetic_usage() {
385        // The adapter is the bridge that lets one shared
386        // PricingTable charge model, tool, AND embedding surfaces
387        // without per-surface duplication. Verify the
388        // `EmbeddingUsage::input_tokens` arrives intact at the
389        // wrapped `ChatModel` calculator.
390        let observed = Arc::new(std::sync::Mutex::new(Vec::<u32>::new()));
391        let chat_calc = Arc::new(ChatStyleCalculator {
392            rate_per_token: 0.0001,
393            observed_input_tokens: Arc::clone(&observed),
394        });
395        let adapter = Arc::new(CostCalculatorAdapter::new(chat_calc));
396        let metered = MeteredEmbedder::new(StubEmbedder { dim: 4 }, "text-embedding-3-small")
397            .with_cost_calculator(adapter);
398        let _ = metered
399            .embed("hello world", &ExecutionContext::new())
400            .await
401            .unwrap();
402        let saw = observed.lock().unwrap();
403        assert_eq!(saw.len(), 1);
404        assert_eq!(saw[0], 11, "stub embedder reports text len as input_tokens");
405    }
406
407    #[tokio::test]
408    async fn metered_embed_skips_calculator_when_no_usage() {
409        struct NoUsageEmbedder;
410        #[async_trait]
411        impl Embedder for NoUsageEmbedder {
412            fn dimension(&self) -> usize {
413                4
414            }
415            async fn embed(&self, _text: &str, _ctx: &ExecutionContext) -> Result<Embedding> {
416                // No usage attached — local stub embedders.
417                Ok(Embedding::new(vec![0.0; 4]))
418            }
419        }
420        let calls = Arc::new(AtomicUsize::new(0));
421        let calc = Arc::new(CountingCalculator {
422            cost: 1.0,
423            calls: calls.clone(),
424        });
425        let metered =
426            MeteredEmbedder::new(NoUsageEmbedder, "no-usage-model").with_cost_calculator(calc);
427        let _ = metered
428            .embed("anything", &ExecutionContext::new())
429            .await
430            .unwrap();
431        assert_eq!(
432            calls.load(Ordering::SeqCst),
433            0,
434            "no usage → no cost computation (silent zero would mislead dashboards)"
435        );
436    }
437}