use std::pin::Pin;
use std::task::{Context, Poll};
use std::time::Duration;
use futures::{Stream, StreamExt};
use rig::client::{CompletionClient, ProviderClient};
use rig::completion::{CompletionError, CompletionModel, GetTokenUsage, Usage};
use rig::providers::gemini;
use rig::providers::gemini::completion::gemini_api_types::{
AdditionalParameters, GenerationConfig, ThinkingConfig,
};
use rig::streaming::{StreamedAssistantContent, StreamingCompletionResponse};
const MODEL: &str = "gemini-2.5-flash";
const DISRUPT_AFTER_CHARS: usize = 150;
const READ_TIMEOUT: Duration = Duration::from_secs(8);
#[derive(Clone, Copy, Debug)]
enum Disruption {
None,
ManualKill,
TransportError,
Stall,
}
struct Disrupt<R>
where
R: Clone + Unpin + GetTokenUsage,
{
inner: StreamingCompletionResponse<R>,
mode: Disruption,
after_chars: usize,
seen_chars: usize,
fired: bool,
}
impl<R> Disrupt<R>
where
R: Clone + Unpin + GetTokenUsage,
{
fn new(inner: StreamingCompletionResponse<R>, mode: Disruption, after_chars: usize) -> Self {
let after_chars = match mode {
Disruption::None => usize::MAX,
_ => after_chars,
};
Self {
inner,
mode,
after_chars,
seen_chars: 0,
fired: false,
}
}
}
impl<R> Stream for Disrupt<R>
where
R: Clone + Unpin + GetTokenUsage,
{
type Item = Result<StreamedAssistantContent<R>, CompletionError>;
fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
let this = self.get_mut();
if this.fired && matches!(this.mode, Disruption::Stall) {
return Poll::Pending;
}
if !this.fired && this.seen_chars >= this.after_chars {
this.fired = true;
match this.mode {
Disruption::ManualKill => {
this.inner.cancel();
return Poll::Ready(None);
}
Disruption::TransportError => {
return Poll::Ready(Some(Err(CompletionError::ProviderError(
"injected mid-stream transport drop".to_string(),
))));
}
Disruption::Stall => {
return Poll::Pending;
}
Disruption::None => {}
}
}
match Pin::new(&mut this.inner).poll_next(cx) {
Poll::Ready(Some(Ok(item))) => {
this.seen_chars += visible_len(&item);
Poll::Ready(Some(Ok(item)))
}
other => other,
}
}
}
fn visible_len<R>(item: &StreamedAssistantContent<R>) -> usize {
match item {
StreamedAssistantContent::Text(t) => t.text.chars().count(),
StreamedAssistantContent::ReasoningDelta { reasoning, .. } => reasoning.chars().count(),
StreamedAssistantContent::Reasoning(r) => r.display_text().chars().count(),
_ => 0,
}
}
enum Outcome {
Clean(Usage),
Estimated { usage: Usage, reason: String },
}
struct Report {
label: &'static str,
output_chars: usize,
outcome: Outcome,
}
async fn drain_with_accounting<S, R>(
label: &'static str,
mut stream: S,
http: &reqwest::Client,
api_key: &str,
prompt_text: &str,
) -> anyhow::Result<Report>
where
S: Stream<Item = Result<StreamedAssistantContent<R>, CompletionError>> + Unpin,
R: Clone + Unpin + GetTokenUsage,
{
let mut output = String::new();
let mut authoritative: Option<Usage> = None;
let mut reason: Option<String> = None;
loop {
match tokio::time::timeout(READ_TIMEOUT, stream.next()).await {
Err(_elapsed) => {
reason = Some(format!("stall: no data within {:?}", READ_TIMEOUT));
break;
}
Ok(None) => {
if authoritative.is_none() {
reason = Some("stream closed without authoritative usage".to_string());
}
break;
}
Ok(Some(Err(err))) => {
reason = Some(format!("stream error: {err}"));
break;
}
Ok(Some(Ok(item))) => match item {
StreamedAssistantContent::Text(text) => output.push_str(&text.text),
StreamedAssistantContent::ReasoningDelta { reasoning, .. } => {
output.push_str(&reasoning)
}
StreamedAssistantContent::Reasoning(r) => output.push_str(&r.display_text()),
StreamedAssistantContent::Final(resp) => {
let usage = resp.token_usage();
if usage.has_values() {
authoritative = Some(usage);
}
}
_ => {}
},
}
}
let outcome = match authoritative {
Some(usage) => Outcome::Clean(usage),
None => {
let input_tokens = count_tokens(http, api_key, prompt_text).await?;
let output_tokens = count_tokens(http, api_key, &output).await?;
let mut usage = Usage::new();
usage.input_tokens = input_tokens;
usage.output_tokens = output_tokens;
usage.total_tokens = input_tokens + output_tokens;
Outcome::Estimated {
usage,
reason: reason.unwrap_or_else(|| "unknown disruption".to_string()),
}
}
};
Ok(Report {
label,
output_chars: output.chars().count(),
outcome,
})
}
async fn count_tokens(http: &reqwest::Client, api_key: &str, text: &str) -> anyhow::Result<u64> {
if text.is_empty() {
return Ok(0);
}
let url = format!(
"https://generativelanguage.googleapis.com/v1beta/models/{MODEL}:countTokens?key={api_key}"
);
let body = serde_json::json!({
"contents": [{ "parts": [{ "text": text }] }]
});
let resp = http
.post(url)
.json(&body)
.send()
.await?
.error_for_status()?;
let value: serde_json::Value = resp.json().await?;
let total = value
.get("totalTokens")
.and_then(serde_json::Value::as_u64)
.unwrap_or(0);
Ok(total)
}
fn no_thinking_params() -> anyhow::Result<serde_json::Value> {
let params = AdditionalParameters {
generation_config: Some(GenerationConfig {
thinking_config: Some(ThinkingConfig {
include_thoughts: Some(false),
thinking_budget: Some(0),
thinking_level: None,
}),
..Default::default()
}),
additional_params: None,
};
Ok(serde_json::to_value(¶ms)?)
}
async fn run_scenario(
label: &'static str,
mode: Disruption,
prompt: &str,
http: &reqwest::Client,
api_key: &str,
) -> anyhow::Result<Report> {
let client = gemini::Client::from_env()?;
let model = client.completion_model(MODEL);
let stream = model
.completion_request(prompt)
.temperature(0.7)
.max_tokens(2000)
.additional_params(no_thinking_params()?)
.stream()
.await?;
let disrupted = Disrupt::new(stream, mode, DISRUPT_AFTER_CHARS);
drain_with_accounting(label, disrupted, http, api_key, prompt).await
}
fn print_report(report: &Report) {
println!("\n=== {} ===", report.label);
println!("partial output: {} chars", report.output_chars);
match &report.outcome {
Outcome::Clean(usage) => {
println!("result: CLEAN — authoritative usage from final chunk");
println!(
" input={} output={} reasoning={} total={}",
usage.input_tokens, usage.output_tokens, usage.reasoning_tokens, usage.total_tokens
);
}
Outcome::Estimated { usage, reason } => {
println!("result: ESTIMATED via countTokens (cut: {reason})");
println!(
" input={} output={} total={} (output is a lower bound; hidden thoughts uncounted)",
usage.input_tokens, usage.output_tokens, usage.total_tokens
);
}
}
}
#[tokio::main]
async fn main() -> anyhow::Result<()> {
let api_key = std::env::var("GEMINI_API_KEY")
.map_err(|_| anyhow::anyhow!("GEMINI_API_KEY must be set"))?;
let http = reqwest::Client::new();
let short_prompt = "Reply with a single short sentence greeting.";
let long_prompt = "Write a detailed, multi-paragraph essay (about 600 words) \
on the history and design philosophy of the Rust programming language.";
println!(
"Demonstrating one token-accounting path across every mid-stream disruption.\n\
Model: {MODEL} | disrupt after ~{DISRUPT_AFTER_CHARS} chars | read timeout {READ_TIMEOUT:?}"
);
let scenarios = [
(
"clean completion (baseline)",
Disruption::None,
short_prompt,
),
("manual kill (cancel)", Disruption::ManualKill, long_prompt),
("transport error", Disruption::TransportError, long_prompt),
("stall / half-open", Disruption::Stall, long_prompt),
];
for (label, mode, prompt) in scenarios {
match run_scenario(label, mode, prompt, &http, &api_key).await {
Ok(report) => print_report(&report),
Err(err) => println!("\n=== {label} ===\nFAILED: {err}"),
}
}
println!(
"\nEvery disrupted run produced a token count without ever receiving the \
provider's final usage chunk — keyed only on \"no authoritative usage\", \
so it is agnostic to why the stream stopped."
);
Ok(())
}