use axum::{
Json,
body::Body,
extract::State,
http::StatusCode,
response::{IntoResponse, Response},
};
use byokey_provider::{make_executor_for_model, parse_qualified_model};
use byokey_translate::{apply_thinking, parse_model_suffix};
use byokey_types::traits::ByteStream;
use byokey_types::{ChatRequest, ProviderId, traits::ProviderResponse};
use futures_util::{StreamExt as _, TryStreamExt as _, stream::try_unfold};
use std::collections::HashSet;
use std::sync::Arc;
use crate::{AppState, UsageRecorder, error::ApiError};
pub(crate) fn tap_stream_usage(
inner: ByteStream,
usage: Arc<UsageRecorder>,
model: String,
provider: String,
) -> ByteStream {
struct State {
inner: ByteStream,
scan_buf: Vec<u8>,
usage: Arc<UsageRecorder>,
model: String,
provider: String,
input_tokens: u64,
output_tokens: u64,
}
Box::pin(try_unfold(
State {
inner,
scan_buf: Vec::new(),
usage,
model,
provider,
input_tokens: 0,
output_tokens: 0,
},
|mut s| async move {
match s.inner.next().await {
Some(Ok(bytes)) => {
s.scan_buf.extend_from_slice(&bytes);
while let Some(nl) = s.scan_buf.iter().position(|&b| b == b'\n') {
let line: Vec<u8> = s.scan_buf.drain(..=nl).collect();
let line = String::from_utf8_lossy(&line);
let line = line.trim();
if let Some(data) = line.strip_prefix("data: ")
&& data != "[DONE]"
&& let Ok(json) = serde_json::from_str::<serde_json::Value>(data)
&& let Some(usage) = json.get("usage")
{
if let Some(v) = usage
.get("prompt_tokens")
.and_then(serde_json::Value::as_u64)
{
s.input_tokens = v;
}
if let Some(v) = usage
.get("completion_tokens")
.and_then(serde_json::Value::as_u64)
{
s.output_tokens = v;
}
}
}
Ok(Some((bytes, s)))
}
Some(Err(e)) => {
s.usage.record_failure(&s.model, &s.provider);
Err(e)
}
None => {
s.usage
.record_success(&s.model, &s.provider, s.input_tokens, s.output_tokens);
Ok(None)
}
}
},
))
}
fn extract_usage_tokens(json: &serde_json::Value) -> (u64, u64) {
let usage = json.get("usage");
let input = usage
.and_then(|u| u.get("prompt_tokens"))
.and_then(serde_json::Value::as_u64)
.unwrap_or(0);
let output = usage
.and_then(|u| u.get("completion_tokens"))
.and_then(serde_json::Value::as_u64)
.unwrap_or(0);
(input, output)
}
pub async fn copilot_chat_completions(
State(state): State<Arc<AppState>>,
Json(request): Json<ChatRequest>,
) -> Result<Response, ApiError> {
chat_completions_inner(state, request, true).await
}
pub async fn chat_completions(
State(state): State<Arc<AppState>>,
Json(request): Json<ChatRequest>,
) -> Result<Response, ApiError> {
chat_completions_inner(state, request, false).await
}
async fn chat_completions_inner(
state: Arc<AppState>,
mut request: ChatRequest,
force_copilot: bool,
) -> Result<Response, ApiError> {
let config = state.config.load();
let mut oauth_providers = HashSet::new();
for p in ProviderId::all() {
if state.auth.is_authenticated(p).await {
oauth_providers.insert(p.clone());
}
}
let resolved_model = config.resolve_alias(&request.model);
let (provider_hint, bare_model) = parse_qualified_model(&resolved_model);
let suffix = parse_model_suffix(bare_model);
let config_fn = |p: &ProviderId| {
let mut pc = config.providers.get(p).cloned().unwrap_or_default();
if force_copilot && *p != ProviderId::Copilot {
pc.backend = Some(ProviderId::Copilot);
}
Some(pc)
};
let executor = make_executor_for_model(
&suffix.model,
config_fn,
&oauth_providers,
provider_hint.as_ref(),
state.auth.clone(),
state.http.clone(),
Some(state.ratelimits.clone()),
)
.map_err(ApiError::from)?;
let provider = byokey_provider::resolve_provider(&suffix.model)
.map_or_else(|| "unknown".to_string(), |p| p.to_string());
tracing::info!(
model = %suffix.model,
provider = %provider,
stream = request.stream,
"chat completion request"
);
request.model.clone_from(&suffix.model);
if let Some(ref thinking) = suffix.thinking {
let provider =
byokey_provider::resolve_provider(&suffix.model).unwrap_or(ProviderId::Claude);
let capability = byokey_provider::thinking_capability(&suffix.model);
let mut body = request.into_body();
body = apply_thinking(body, &provider, thinking, capability);
request = serde_json::from_value(body)
.map_err(|e| ApiError::from(byokey_types::ByokError::Translation(e.to_string())))?;
}
if !config.payload.default.is_empty()
|| !config.payload.r#override.is_empty()
|| !config.payload.filter.is_empty()
{
let mut body = request.into_body();
body = config.apply_payload_rules(body, &suffix.model);
request = serde_json::from_value(body)
.map_err(|e| ApiError::from(byokey_types::ByokError::Translation(e.to_string())))?;
}
let model_name = suffix.model.clone();
match executor.chat_completion(request).await {
Ok(ProviderResponse::Complete(json)) => {
let (input_tok, output_tok) = extract_usage_tokens(&json);
state
.usage
.record_success(&model_name, &provider, input_tok, output_tok);
tracing::debug!(model = %model_name, "chat completion complete");
Ok(Json(json).into_response())
}
Ok(ProviderResponse::Stream(byte_stream)) => {
tracing::debug!(model = %model_name, "streaming chat completion");
let tapped = tap_stream_usage(
byte_stream,
state.usage.clone(),
model_name,
provider.clone(),
);
let mapped = tapped.map_err(|e| std::io::Error::other(e.to_string()));
let body = Body::from_stream(mapped);
Ok(Response::builder()
.status(StatusCode::OK)
.header("content-type", "text/event-stream")
.header("cache-control", "no-cache")
.header("x-accel-buffering", "no")
.body(body)
.expect("valid response"))
}
Err(e) => {
state.usage.record_failure(&model_name, &provider);
Err(ApiError::from(e))
}
}
}