use crate::{
AppState,
auth::Principal,
handlers::{
RequestOutcome, emit_usage, emit_usage_error, error_response, error_status,
record_duration, record_tokens, with_timeout,
},
};
use axum::{
Extension, Json,
extract::{Path, State},
http::StatusCode,
response::{IntoResponse, Response},
};
use bytes::Bytes;
use crabllm_core::{ApiError, Provider, RequestContext, Storage};
use std::time::Instant;
const ENDPOINT: &str = "gemini.generateContent";
pub async fn generate_content<S, P>(
State(state): State<AppState<S, P>>,
Extension(principal): Extension<Principal>,
Path(model_action): Path<String>,
raw_body: Bytes,
) -> Response
where
S: Storage + 'static,
P: Provider + 'static,
{
let Some((model_raw, action)) = model_action.split_once(':') else {
return (
StatusCode::BAD_REQUEST,
Json(ApiError::new(
format!("invalid path '{model_action}', expected '<model>:<action>'"),
"invalid_request_error",
)),
)
.into_response();
};
let is_stream = match action {
"generateContent" => false,
"streamGenerateContent" => true,
other => {
return (
StatusCode::NOT_FOUND,
Json(ApiError::new(
format!("unknown Gemini action '{other}'"),
"invalid_request_error",
)),
)
.into_response();
}
};
let registry = state.registry();
let model = registry.resolve(model_raw).to_string();
let deployments = match registry.dispatch_list(&model) {
Some(list) => list,
None => {
return (
StatusCode::NOT_FOUND,
Json(ApiError::new(
format!("model '{model}' not found"),
"invalid_request_error",
)),
)
.into_response();
}
};
let provider_name = registry
.provider_name(&model)
.unwrap_or_default()
.to_string();
let ctx = RequestContext {
request_id: uuid::Uuid::new_v4().to_string(),
model: model.clone(),
provider: provider_name,
principal: principal.0,
is_stream,
started_at: Instant::now(),
};
if is_stream {
return stream_path(&state, ctx, &model, &deployments, raw_body).await;
}
unary_path(&state, ctx, &model, &deployments, raw_body).await
}
async fn unary_path<S, P>(
state: &AppState<S, P>,
ctx: RequestContext,
model: &str,
deployments: &[&crabllm_provider::Deployment<P>],
raw_body: Bytes,
) -> Response
where
S: Storage + 'static,
P: Provider + 'static,
{
let mut last_err = None;
for deployment in deployments {
if !deployment.provider.is_gemini_compat() {
continue;
}
match with_timeout(
deployment.timeout,
deployment
.provider
.gemini_generate_content_raw(model, raw_body.clone()),
)
.await
{
Ok(resp_bytes) => {
let usage = crabllm_core::Usage::from(resp_bytes.as_ref());
if usage.prompt_tokens() > 0 || usage.completion_tokens() > 0 {
record_tokens(&ctx, usage.prompt_tokens(), usage.completion_tokens());
}
record_duration(&ctx, "2xx");
emit_usage(state, &ctx, ENDPOINT, RequestOutcome::ok(usage));
return (
[(axum::http::header::CONTENT_TYPE, "application/json")],
resp_bytes,
)
.into_response();
}
Err(e) => {
if !e.is_transient() {
record_duration(&ctx, error_status(&e));
emit_usage_error(state, &ctx, ENDPOINT, &e);
return error_response(e);
}
last_err = Some(e);
}
}
}
let e = last_err.unwrap_or_else(|| {
crabllm_core::Error::Routing("no compatible providers available".into())
});
record_duration(&ctx, error_status(&e));
emit_usage_error(state, &ctx, ENDPOINT, &e);
error_response(e)
}
async fn stream_path<S, P>(
state: &AppState<S, P>,
ctx: RequestContext,
model: &str,
deployments: &[&crabllm_provider::Deployment<P>],
raw_body: Bytes,
) -> Response
where
S: Storage + 'static,
P: Provider + 'static,
{
let mut last_err = None;
for deployment in deployments {
if !deployment.provider.is_gemini_compat() {
continue;
}
match with_timeout(
deployment.timeout,
deployment
.provider
.gemini_generate_content_stream_raw(model, raw_body.clone()),
)
.await
{
Ok(byte_stream) => {
record_duration(&ctx, "2xx");
let body = axum::body::Body::from_stream(byte_stream);
return axum::http::Response::builder()
.status(StatusCode::OK)
.header(axum::http::header::CONTENT_TYPE, "text/event-stream")
.header("cache-control", "no-cache")
.body(body)
.unwrap();
}
Err(e) => {
if !e.is_transient() {
record_duration(&ctx, error_status(&e));
emit_usage_error(state, &ctx, ENDPOINT, &e);
return error_response(e);
}
last_err = Some(e);
continue;
}
}
}
let e = last_err.unwrap_or_else(|| {
crabllm_core::Error::Routing("no compatible providers available".into())
});
record_duration(&ctx, error_status(&e));
emit_usage_error(state, &ctx, ENDPOINT, &e);
error_response(e)
}