1use std::sync::Arc;
24use std::time::Instant;
25
26use async_trait::async_trait;
27use entelix_core::context::ExecutionContext;
28use entelix_core::cost::CostCalculator;
29use entelix_core::error::Result;
30use entelix_core::ir::Usage;
31
32use crate::traits::{Embedder, Embedding, EmbeddingUsage};
33
34#[async_trait]
47pub trait EmbeddingCostCalculator: Send + Sync + 'static {
48 async fn compute_cost(
53 &self,
54 model: &str,
55 usage: &EmbeddingUsage,
56 ctx: &ExecutionContext,
57 ) -> Option<f64>;
58}
59
60pub struct MeteredEmbedder<E>
67where
68 E: Embedder,
69{
70 inner: Arc<E>,
71 model: Arc<str>,
72 cost_calculator: Option<Arc<dyn EmbeddingCostCalculator>>,
73}
74
75impl<E> MeteredEmbedder<E>
76where
77 E: Embedder,
78{
79 pub fn new(inner: E, model: impl Into<Arc<str>>) -> Self {
84 Self {
85 inner: Arc::new(inner),
86 model: model.into(),
87 cost_calculator: None,
88 }
89 }
90
91 pub fn from_arc(inner: Arc<E>, model: impl Into<Arc<str>>) -> Self {
94 Self {
95 inner,
96 model: model.into(),
97 cost_calculator: None,
98 }
99 }
100
101 #[must_use]
106 pub fn with_cost_calculator(mut self, calculator: Arc<dyn EmbeddingCostCalculator>) -> Self {
107 self.cost_calculator = Some(calculator);
108 self
109 }
110
111 pub fn model(&self) -> &str {
115 &self.model
116 }
117
118 async fn emit_end(
121 &self,
122 ctx: &ExecutionContext,
123 usage: Option<&EmbeddingUsage>,
124 duration_ms: u64,
125 batch_size: usize,
126 ) {
127 let cost = match (usage, &self.cost_calculator) {
128 (Some(u), Some(calc)) => calc.compute_cost(&self.model, u, ctx).await,
129 _ => None,
130 };
131 let input_tokens = usage.map(|u| u.input_tokens);
132 tracing::event!(
133 target: "gen_ai",
134 tracing::Level::INFO,
135 gen_ai.system = "embedder",
136 gen_ai.operation.name = "embed",
137 gen_ai.embedding.model = %self.model,
138 gen_ai.embedding.batch_size = batch_size,
139 gen_ai.usage.input_tokens = input_tokens,
140 gen_ai.embedding.cost = cost,
141 duration_ms,
142 entelix.tenant_id = %ctx.tenant_id(),
143 entelix.run_id = ctx.run_id(),
144 "gen_ai.embedding.end"
145 );
146 }
147}
148
149#[async_trait]
150impl<E> Embedder for MeteredEmbedder<E>
151where
152 E: Embedder,
153{
154 fn dimension(&self) -> usize {
155 self.inner.dimension()
156 }
157
158 async fn embed(&self, text: &str, ctx: &ExecutionContext) -> Result<Embedding> {
159 let started_at = Instant::now();
160 let result = self.inner.embed(text, ctx).await?;
161 let duration_ms = u64::try_from(started_at.elapsed().as_millis()).unwrap_or(u64::MAX);
162 self.emit_end(ctx, result.usage.as_ref(), duration_ms, 1)
163 .await;
164 Ok(result)
165 }
166
167 async fn embed_batch(
168 &self,
169 texts: &[String],
170 ctx: &ExecutionContext,
171 ) -> Result<Vec<Embedding>> {
172 let started_at = Instant::now();
173 let result = self.inner.embed_batch(texts, ctx).await?;
174 let duration_ms = u64::try_from(started_at.elapsed().as_millis()).unwrap_or(u64::MAX);
175 let aggregated = aggregate_usage(&result);
180 self.emit_end(ctx, aggregated.as_ref(), duration_ms, texts.len())
181 .await;
182 Ok(result)
183 }
184}
185
186fn aggregate_usage(embeddings: &[Embedding]) -> Option<EmbeddingUsage> {
187 let mut total: u32 = 0;
188 let mut any = false;
189 for e in embeddings {
190 if let Some(u) = e.usage {
191 total = total.saturating_add(u.input_tokens);
192 any = true;
193 }
194 }
195 any.then_some(EmbeddingUsage::new(total))
196}
197
198pub struct CostCalculatorAdapter {
209 inner: Arc<dyn CostCalculator>,
210}
211
212impl CostCalculatorAdapter {
213 #[must_use]
217 pub const fn new(inner: Arc<dyn CostCalculator>) -> Self {
218 Self { inner }
219 }
220}
221
222#[async_trait]
223impl EmbeddingCostCalculator for CostCalculatorAdapter {
224 async fn compute_cost(
225 &self,
226 model: &str,
227 usage: &EmbeddingUsage,
228 ctx: &ExecutionContext,
229 ) -> Option<f64> {
230 let usage = Usage::new(usage.input_tokens, 0);
231 self.inner.compute_cost(model, &usage, ctx).await
232 }
233}
234
235#[cfg(test)]
236#[allow(clippy::unwrap_used, clippy::indexing_slicing)]
237mod tests {
238 use super::*;
239 use std::sync::atomic::{AtomicUsize, Ordering};
240
241 struct StubEmbedder {
244 dim: usize,
245 }
246
247 #[async_trait]
248 impl Embedder for StubEmbedder {
249 fn dimension(&self) -> usize {
250 self.dim
251 }
252 async fn embed(&self, text: &str, _ctx: &ExecutionContext) -> Result<Embedding> {
253 #[allow(clippy::cast_possible_truncation)]
254 let tokens = text.len() as u32;
255 Ok(Embedding::new(vec![0.0; self.dim]).with_usage(EmbeddingUsage::new(tokens)))
256 }
257 }
258
259 struct FailingEmbedder;
262
263 #[async_trait]
264 impl Embedder for FailingEmbedder {
265 fn dimension(&self) -> usize {
266 4
267 }
268 async fn embed(&self, _text: &str, _ctx: &ExecutionContext) -> Result<Embedding> {
269 Err(entelix_core::Error::config("embedder down"))
270 }
271 }
272
273 struct CountingCalculator {
275 cost: f64,
276 calls: Arc<AtomicUsize>,
277 }
278
279 #[async_trait]
280 impl EmbeddingCostCalculator for CountingCalculator {
281 async fn compute_cost(
282 &self,
283 _model: &str,
284 _usage: &EmbeddingUsage,
285 _ctx: &ExecutionContext,
286 ) -> Option<f64> {
287 self.calls.fetch_add(1, Ordering::SeqCst);
288 Some(self.cost)
289 }
290 }
291
292 #[tokio::test]
293 async fn metered_embed_passes_through_inner_embedding() {
294 let metered = MeteredEmbedder::new(StubEmbedder { dim: 8 }, "stub-model");
295 let ctx = ExecutionContext::new();
296 let out = metered.embed("hello", &ctx).await.unwrap();
297 assert_eq!(out.vector.len(), 8);
298 assert_eq!(out.usage.unwrap().input_tokens, 5);
299 }
300
301 #[tokio::test]
302 async fn metered_embed_invokes_calculator_on_success() {
303 let calls = Arc::new(AtomicUsize::new(0));
304 let calc = Arc::new(CountingCalculator {
305 cost: 0.0001,
306 calls: calls.clone(),
307 });
308 let metered =
309 MeteredEmbedder::new(StubEmbedder { dim: 4 }, "stub-model").with_cost_calculator(calc);
310 let _ = metered
311 .embed("hello", &ExecutionContext::new())
312 .await
313 .unwrap();
314 assert_eq!(calls.load(Ordering::SeqCst), 1);
315 }
316
317 #[tokio::test]
318 async fn metered_embed_skips_calculator_on_failure() {
319 let calls = Arc::new(AtomicUsize::new(0));
322 let calc = Arc::new(CountingCalculator {
323 cost: 0.99,
324 calls: calls.clone(),
325 });
326 let metered =
327 MeteredEmbedder::new(FailingEmbedder, "stub-model").with_cost_calculator(calc);
328 let err = metered
329 .embed("hi", &ExecutionContext::new())
330 .await
331 .unwrap_err();
332 assert!(matches!(err, entelix_core::Error::Config(_)));
333 assert_eq!(
334 calls.load(Ordering::SeqCst),
335 0,
336 "cost calculator must not fire on the error branch"
337 );
338 }
339
340 #[tokio::test]
341 async fn metered_embed_batch_aggregates_usage_into_one_event() {
342 let calls = Arc::new(AtomicUsize::new(0));
343 let calc = Arc::new(CountingCalculator {
344 cost: 0.0,
345 calls: calls.clone(),
346 });
347 let metered =
348 MeteredEmbedder::new(StubEmbedder { dim: 2 }, "stub-model").with_cost_calculator(calc);
349 let texts = vec!["a".to_owned(), "bb".to_owned(), "ccc".to_owned()];
350 let out = metered
351 .embed_batch(&texts, &ExecutionContext::new())
352 .await
353 .unwrap();
354 assert_eq!(out.len(), 3);
355 assert_eq!(calls.load(Ordering::SeqCst), 1);
357 }
358
359 struct ChatStyleCalculator {
363 rate_per_token: f64,
364 observed_input_tokens: Arc<std::sync::Mutex<Vec<u32>>>,
365 }
366
367 #[async_trait]
368 impl entelix_core::CostCalculator for ChatStyleCalculator {
369 async fn compute_cost(
370 &self,
371 _model: &str,
372 usage: &entelix_core::ir::Usage,
373 _ctx: &ExecutionContext,
374 ) -> Option<f64> {
375 self.observed_input_tokens
376 .lock()
377 .unwrap()
378 .push(usage.input_tokens);
379 Some(self.rate_per_token * f64::from(usage.input_tokens))
380 }
381 }
382
383 #[tokio::test]
384 async fn cost_calculator_adapter_forwards_embedding_usage_as_synthetic_usage() {
385 let observed = Arc::new(std::sync::Mutex::new(Vec::<u32>::new()));
391 let chat_calc = Arc::new(ChatStyleCalculator {
392 rate_per_token: 0.0001,
393 observed_input_tokens: Arc::clone(&observed),
394 });
395 let adapter = Arc::new(CostCalculatorAdapter::new(chat_calc));
396 let metered = MeteredEmbedder::new(StubEmbedder { dim: 4 }, "text-embedding-3-small")
397 .with_cost_calculator(adapter);
398 let _ = metered
399 .embed("hello world", &ExecutionContext::new())
400 .await
401 .unwrap();
402 let saw = observed.lock().unwrap();
403 assert_eq!(saw.len(), 1);
404 assert_eq!(saw[0], 11, "stub embedder reports text len as input_tokens");
405 }
406
407 #[tokio::test]
408 async fn metered_embed_skips_calculator_when_no_usage() {
409 struct NoUsageEmbedder;
410 #[async_trait]
411 impl Embedder for NoUsageEmbedder {
412 fn dimension(&self) -> usize {
413 4
414 }
415 async fn embed(&self, _text: &str, _ctx: &ExecutionContext) -> Result<Embedding> {
416 Ok(Embedding::new(vec![0.0; 4]))
418 }
419 }
420 let calls = Arc::new(AtomicUsize::new(0));
421 let calc = Arc::new(CountingCalculator {
422 cost: 1.0,
423 calls: calls.clone(),
424 });
425 let metered =
426 MeteredEmbedder::new(NoUsageEmbedder, "no-usage-model").with_cost_calculator(calc);
427 let _ = metered
428 .embed("anything", &ExecutionContext::new())
429 .await
430 .unwrap();
431 assert_eq!(
432 calls.load(Ordering::SeqCst),
433 0,
434 "no usage → no cost computation (silent zero would mislead dashboards)"
435 );
436 }
437}