use crate::generator::GeneratorInfo;
use crate::provider::CostInfo;
use std::future::Future;
use std::pin::Pin;
use std::sync::Arc;
pub type AsyncCostCallback = Arc<
dyn Fn(CostInfo, serde_json::Value) -> Pin<Box<dyn Future<Output = ()> + Send>> + Send + Sync,
>;
pub type CompletionMeta = serde_json::Value;
pub struct CompletionContext {
pub generator: GeneratorInfo,
pub meta: CompletionMeta,
callback: AsyncCostCallback,
}
impl CompletionContext {
pub fn new(
generator: GeneratorInfo,
meta: CompletionMeta,
callback: AsyncCostCallback,
app_url: impl Into<String>,
app_title: impl Into<String>,
) -> Self {
let generator = generator.with_app_attribution(app_url, app_title);
Self {
generator,
meta,
callback,
}
}
pub fn is_byok(&self) -> bool {
self.meta
.get("isByok")
.and_then(|v| v.as_bool())
.unwrap_or(false)
}
pub async fn report_cost(&self, cost_info: CostInfo) {
let fut = (self.callback)(cost_info, self.meta.clone());
fut.await;
}
pub(crate) async fn cost_for_response(
&self,
response: &crate::provider::CompletionResponse,
) -> CostInfo {
cost_for_response(&self.generator, response).await
}
}
pub struct TrackedStream {
inner: crate::provider::StreamingCompletion,
callback: AsyncCostCallback,
meta: CompletionMeta,
generator: GeneratorInfo,
cost_reported: bool,
rejected: bool,
}
impl TrackedStream {
pub(crate) fn new(
inner: crate::provider::StreamingCompletion,
ctx: &CompletionContext,
) -> Self {
Self {
inner,
callback: ctx.callback.clone(),
meta: ctx.meta.clone(),
generator: ctx.generator.clone(),
cost_reported: false,
rejected: false,
}
}
pub async fn next_chunk(
&mut self,
) -> Option<crate::error::Result<crate::provider::StreamChunk>> {
self.inner.next_chunk().await
}
pub async fn collect(&mut self) -> crate::error::Result<crate::provider::CompletionResponse> {
while let Some(result) = self.inner.next_chunk().await {
result?;
}
Ok(self.inner.to_response())
}
pub async fn report_cost(&mut self, response: &crate::provider::CompletionResponse) {
if self.cost_reported {
tracing::warn!("report_cost called more than once; ignoring the repeat");
return;
}
let cost_info = cost_for_response(&self.generator, response).await;
(self.callback)(cost_info, self.meta.clone()).await;
self.cost_reported = true;
}
pub async fn cancel(mut self) {
self.inner.drain_buffered();
if self.inner.errored() {
tracing::warn!(
"TrackedStream for {} cancelled after a transport error; no cost booked.",
self.inner.id()
);
self.cost_reported = true; return;
}
let response = self.inner.to_response();
let cost_info = cost_for_response(&self.generator, &response).await;
(self.callback)(cost_info, self.meta.clone()).await;
self.cost_reported = true; }
pub fn reject(mut self) {
self.rejected = true;
}
pub fn is_finished(&self) -> bool {
self.inner.is_finished()
}
pub fn accumulated(&self) -> &str {
self.inner.accumulated()
}
}
impl Drop for TrackedStream {
fn drop(&mut self) {
if self.cost_reported || self.rejected {
return;
}
self.inner.drain_buffered();
if self.inner.errored() {
tracing::warn!(
"TrackedStream for {} ended in a transport error; no cost booked (failed generation).",
self.inner.id()
);
return;
}
let response = self.inner.to_response();
let callback = self.callback.clone();
let meta = self.meta.clone();
let generator = self.generator.clone();
let Ok(handle) = tokio::runtime::Handle::try_current() else {
tracing::error!(
"TrackedStream for {} dropped un-reported outside a tokio runtime: cost CANNOT be settled and is LOST. Use cancel().await or report_cost().await.",
response.id
);
return;
};
tracing::debug!(
"TrackedStream for {} dropped without report_cost()/cancel()/reject(); settling cost on a detached task",
response.id
);
handle.spawn(async move {
let mut guard = LostCostGuard::new(response.id.clone());
let cost_info = cost_for_response(&generator, &response).await;
(callback)(cost_info, meta).await;
guard.settled = true;
});
}
}
struct LostCostGuard {
response_id: String,
settled: bool,
}
impl LostCostGuard {
fn new(response_id: String) -> Self {
Self {
response_id,
settled: false,
}
}
}
impl Drop for LostCostGuard {
fn drop(&mut self) {
if !self.settled {
tracing::error!(
"Cost settle task for {} was cancelled before booking: cost is LOST (likely runtime shutdown). Use cancel().await or report_cost().await for a guarantee.",
self.response_id
);
}
}
}
pub(crate) async fn cost_for_response(
generator: &GeneratorInfo,
response: &crate::provider::CompletionResponse,
) -> CostInfo {
let price = generator.token_price.as_ref();
let outcome = match &response.usage {
Some(usage) => generator.provider.cost_of(usage.clone(), price),
None => {
let ctx = crate::provider::PostStreamCtx {
client: reqwest_client(),
generation_id: &response.id,
auth: &generator.auth,
price,
};
generator.provider.resolve_post_stream(ctx).await
}
};
outcome.into_cost_info(response.model.clone(), response.id.clone())
}
fn reqwest_client() -> &'static reqwest::Client {
static CLIENT: std::sync::OnceLock<reqwest::Client> = std::sync::OnceLock::new();
CLIENT.get_or_init(reqwest::Client::new)
}
impl std::fmt::Debug for CompletionContext {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("CompletionContext")
.field("generator", &self.generator.name)
.field("model", &self.generator.model)
.field("meta", &self.meta)
.field("is_byok", &self.is_byok())
.finish()
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::error::MiniLLMError;
use crate::provider::{CostResolution, StreamChunk, StreamingCompletion, Usage};
use std::sync::Mutex;
type CaptureLog = Arc<Mutex<Vec<(CostInfo, serde_json::Value)>>>;
fn capturing_context(meta: serde_json::Value) -> (CompletionContext, CaptureLog) {
let log: CaptureLog = Arc::new(Mutex::new(Vec::new()));
let sink = log.clone();
let callback: AsyncCostCallback = Arc::new(move |cost, meta| {
let sink = sink.clone();
Box::pin(async move {
sink.lock().unwrap().push((cost, meta));
})
});
let generator = GeneratorInfo::new("Test", "https://example.test/v1", "test-model")
.with_provider(std::sync::Arc::new(crate::provider::OpenRouterProvider));
let ctx = CompletionContext::new(generator, meta, callback, "https://app", "App");
(ctx, log)
}
#[tokio::test]
async fn report_cost_passes_cost_and_meta_through() {
let (ctx, log) = capturing_context(serde_json::json!({"userId": "u1"}));
let cost = CostInfo {
cost: 0.001,
model: "test-model".into(),
response_id: "gen-1".into(),
..Default::default()
};
ctx.report_cost(cost).await;
let captured = log.lock().unwrap();
assert_eq!(captured.len(), 1);
assert!((captured[0].0.cost - 0.001).abs() < 1e-9);
assert_eq!(captured[0].1["userId"], "u1");
}
#[test]
fn is_byok_reads_metadata() {
let (byok, _) = capturing_context(serde_json::json!({"isByok": true}));
assert!(byok.is_byok());
let (not_byok, _) = capturing_context(serde_json::json!({}));
assert!(!not_byok.is_byok());
}
#[tokio::test]
async fn collect_then_report_uses_typed_usage_and_sums_byok() {
let (ctx, log) = capturing_context(serde_json::json!({}));
let (stream, tx) = StreamingCompletion::from_channel("test-model", "gen-1", true);
let mut tracked = TrackedStream::new(stream, &ctx);
tx.send(Ok(StreamChunk::content("hi"))).await.unwrap();
tx.send(Ok(StreamChunk {
finish_reason: Some("stop".into()),
usage: Some(Usage {
cost: Some(0.001),
upstream_inference_cost: Some(0.009),
uncached_input_tokens: 5,
completion_tokens: 2,
..Default::default()
}),
..Default::default()
}))
.await
.unwrap();
drop(tx);
let resp = tracked.collect().await.unwrap();
assert_eq!(resp.content, "hi");
assert!(log.lock().unwrap().is_empty(), "collect must not book cost");
tracked.report_cost(&resp).await;
let captured = log.lock().unwrap();
assert_eq!(captured.len(), 1);
let cost = &captured[0].0;
assert!((cost.cost - 0.010).abs() < 1e-9, "cost was {}", cost.cost);
assert_eq!(cost.total_tokens, 7);
assert_eq!(cost.resolution, CostResolution::Resolved);
}
#[tokio::test]
async fn report_cost_marks_unknown_when_no_usage_and_no_id() {
let (ctx, log) = capturing_context(serde_json::json!({}));
let (stream, tx) = StreamingCompletion::from_channel("test-model", "", true);
let mut tracked = TrackedStream::new(stream, &ctx);
tx.send(Ok(StreamChunk::content("hi"))).await.unwrap();
drop(tx);
let resp = tracked.collect().await.unwrap();
tracked.report_cost(&resp).await;
let captured = log.lock().unwrap();
assert_eq!(captured.len(), 1);
assert_eq!(captured[0].0.resolution, CostResolution::Unknown);
assert_eq!(captured[0].0.cost, 0.0);
}
async fn drained_stream(ctx: &CompletionContext) -> TrackedStream {
let (stream, tx) = StreamingCompletion::from_channel("test-model", "gen-1", true);
let mut tracked = TrackedStream::new(stream, ctx);
tx.send(Ok(StreamChunk::content("hi"))).await.unwrap();
tx.send(Ok(StreamChunk {
finish_reason: Some("stop".into()),
usage: Some(Usage {
cost: Some(0.5),
..Default::default()
}),
..Default::default()
}))
.await
.unwrap();
drop(tx);
let _ = tracked.collect().await.unwrap();
tracked
}
#[tokio::test]
async fn explicit_reject_books_nothing() {
let (ctx, log) = capturing_context(serde_json::json!({}));
drained_stream(&ctx).await.reject();
for _ in 0..3 {
tokio::task::yield_now().await;
}
assert!(
log.lock().unwrap().is_empty(),
"an explicitly rejected stream must not book cost"
);
}
#[tokio::test]
async fn drained_then_dropped_without_report_still_books_cost() {
let (ctx, log) = capturing_context(serde_json::json!({}));
{
let _tracked = drained_stream(&ctx).await; }
for _ in 0..3 {
tokio::task::yield_now().await;
}
let captured = log.lock().unwrap();
assert_eq!(captured.len(), 1, "a forgotten report must still book cost");
assert!((captured[0].0.cost - 0.5).abs() < 1e-9);
}
#[tokio::test]
async fn genuine_cancellation_books_cost() {
let (ctx, log) = capturing_context(serde_json::json!({}));
let (stream, _tx) = StreamingCompletion::from_channel("test-model", "", true);
let tracked = TrackedStream::new(stream, &ctx);
assert!(!tracked.is_finished(), "precondition: not collected");
drop(tracked);
for _ in 0..3 {
tokio::task::yield_now().await;
}
let captured = log.lock().unwrap();
assert_eq!(captured.len(), 1, "genuine cancel must book cost");
assert_eq!(captured[0].0.resolution, CostResolution::Unknown);
}
#[tokio::test]
async fn explicit_cancel_settles_cost_synchronously_and_suppresses_drop() {
let (ctx, log) = capturing_context(serde_json::json!({}));
let (mut stream_holder, tx) =
StreamingCompletion::from_channel("test-model", "gen-1", true);
tx.send(Ok(StreamChunk::content("partial"))).await.unwrap();
tx.send(Ok(StreamChunk {
usage: Some(Usage {
cost: Some(0.02),
..Default::default()
}),
..Default::default()
}))
.await
.unwrap();
let _ = stream_holder.next_chunk().await;
let _ = stream_holder.next_chunk().await;
let tracked = TrackedStream::new(stream_holder, &ctx);
tracked.cancel().await;
let captured = log.lock().unwrap();
assert_eq!(captured.len(), 1, "cancel reports exactly once");
assert_eq!(captured[0].0.resolution, CostResolution::Resolved);
assert!((captured[0].0.cost - 0.02).abs() < 1e-9);
}
#[tokio::test]
async fn report_cost_then_drop_books_exactly_once() {
let (ctx, log) = capturing_context(serde_json::json!({}));
{
let mut tracked = drained_stream(&ctx).await;
let resp = tracked.inner.to_response();
tracked.report_cost(&resp).await;
} for _ in 0..3 {
tokio::task::yield_now().await;
}
assert_eq!(
log.lock().unwrap().len(),
1,
"report_cost then drop must book exactly once"
);
}
#[tokio::test]
async fn errored_stream_books_nothing_on_drop() {
let (ctx, log) = capturing_context(serde_json::json!({}));
let (stream, tx) = StreamingCompletion::from_channel("test-model", "gen-1", true);
let mut tracked = TrackedStream::new(stream, &ctx);
tx.send(Ok(StreamChunk::content("partial"))).await.unwrap();
tx.send(Err(MiniLLMError::Timeout)).await.unwrap();
let err = tracked.collect().await;
assert!(err.is_err(), "stream surfaces the transport error");
drop(tracked);
for _ in 0..3 {
tokio::task::yield_now().await;
}
assert!(
log.lock().unwrap().is_empty(),
"a failed stream must not book cost"
);
}
#[tokio::test]
async fn buffered_error_left_undrained_then_dropped_books_nothing() {
let (ctx, log) = capturing_context(serde_json::json!({}));
let (stream, tx) = StreamingCompletion::from_channel("test-model", "", true);
let tracked = TrackedStream::new(stream, &ctx);
tx.send(Ok(StreamChunk::content("partial"))).await.unwrap();
tx.send(Err(MiniLLMError::Stream("in-band provider error".into())))
.await
.unwrap();
drop(tracked);
for _ in 0..3 {
tokio::task::yield_now().await;
}
assert!(
log.lock().unwrap().is_empty(),
"a buffered terminal error must make Drop book nothing"
);
}
#[tokio::test]
async fn anthropic_split_usage_books_correct_tokens_end_to_end() {
let log: CaptureLog = Arc::new(Mutex::new(Vec::new()));
let sink = log.clone();
let callback: AsyncCostCallback = Arc::new(move |cost, meta| {
let sink = sink.clone();
Box::pin(async move {
sink.lock().unwrap().push((cost, meta));
})
});
let generator = GeneratorInfo::new("Test", "https://example.test", "claude-haiku-4-5")
.with_provider(std::sync::Arc::new(crate::provider::AnthropicProvider))
.with_token_price(crate::provider::TokenPrice::new(1.0, 5.0));
let ctx = CompletionContext::new(generator, serde_json::json!({}), callback, "u", "a");
let (stream, tx) = StreamingCompletion::from_channel("claude-haiku-4-5", "msg_1", true);
let mut tracked = TrackedStream::new(stream, &ctx);
tx.send(Ok(StreamChunk {
id: Some("msg_1".into()),
usage: Some(Usage {
uncached_input_tokens: 1_000_000,
..Default::default()
}),
..Default::default()
}))
.await
.unwrap();
tx.send(Ok(StreamChunk::content("hi"))).await.unwrap();
tx.send(Ok(StreamChunk {
finish_reason: Some("end_turn".into()),
usage: Some(Usage {
completion_tokens: 1_000_000,
..Default::default()
}),
..Default::default()
}))
.await
.unwrap();
drop(tx);
let resp = tracked.collect().await.unwrap();
tracked.report_cost(&resp).await;
let captured = log.lock().unwrap();
assert_eq!(captured.len(), 1);
let cost = &captured[0].0;
assert_eq!(
cost.prompt_tokens, 1_000_000,
"input merged from message_start"
);
assert_eq!(
cost.completion_tokens, 1_000_000,
"output from message_delta"
);
assert_eq!(cost.resolution, CostResolution::Resolved);
assert!((cost.cost - 6.0).abs() < 1e-9, "got {}", cost.cost);
}
}