use std::sync::Arc;
use std::sync::atomic::{AtomicU64, Ordering};
use std::time::Duration;
use axum::Json;
use axum::extract::State;
use axum::http::HeaderMap;
use axum::response::sse::Event;
use axum::response::{IntoResponse, Response, Sse};
use reqwest::Client;
use serde_json::{Value, json};
use tokio::sync::mpsc;
use tokio_stream::StreamExt;
use tokio_stream::wrappers::ReceiverStream;
use crate::auth::RequestAuth;
use crate::config::Config;
use crate::converter::{AnthropicToOpenAiConverter, OpenAiToAnthropicConverter};
use crate::error::{ProxyError, Result};
use crate::provider::LlmProviderBackend;
pub struct AppState {
pub config: Config,
pub request_auth: RequestAuth,
pub http_client: Client,
pub openai_to_anthropic: OpenAiToAnthropicConverter,
pub anthropic_to_openai: AnthropicToOpenAiConverter,
pub metrics: AppMetrics,
}
#[derive(Debug, Default)]
pub struct AppMetrics {
pub total_requests: AtomicU64,
pub quota_errors: AtomicU64,
pub retry_attempts: AtomicU64,
pub successful_requests: AtomicU64,
pub failed_requests: AtomicU64,
}
struct StreamChunkParams<'a> {
chunk: &'a bytes::Bytes,
buffer: &'a mut String,
state: &'a Arc<AppState>,
model: &'a str,
current_tool_call: &'a mut Option<crate::converter::anthropic_to_openai::StreamingToolCall>,
has_tool_calls: &'a mut bool,
stop_reason_from_delta: &'a mut Option<String>,
tx: &'a mpsc::Sender<Result<Event>>,
}
const HTTP_CLIENT_TIMEOUT_SECS: u64 = 300;
const STREAMING_CHANNEL_BUFFER: usize = 100;
const CONTENT_TYPE_JSON: &str = "application/json";
const AUTHORIZATION_HEADER: &str = "Authorization";
const BASE_RETRY_DELAY_SECS: u64 = 1;
const MIN_BUFFER_SIZE: usize = 50;
impl AppState {
pub async fn new(config: Config) -> Result<Self> {
let request_auth = match &config.llm_provider {
Some(provider) => RequestAuth::from_strategy(provider.auth_strategy()).await?,
None => return Err(ProxyError::Config("LLM provider not configured".to_string())),
};
let http_client = Self::create_http_client()?;
let openai_to_anthropic = OpenAiToAnthropicConverter::new(config.server.log_level);
let anthropic_to_openai = AnthropicToOpenAiConverter::new(config.server.log_level);
let metrics = AppMetrics::default();
Ok(Self {
config,
request_auth,
http_client,
openai_to_anthropic,
anthropic_to_openai,
metrics,
})
}
fn create_http_client() -> Result<Client> {
Client::builder()
.timeout(Duration::from_secs(HTTP_CLIENT_TIMEOUT_SECS))
.build()
.map_err(|e| ProxyError::Http(format!("Failed to create HTTP client: {}", e)))
}
}
pub async fn chat_completions(
State(state): State<Arc<AppState>>,
headers: HeaderMap,
Json(request): Json<Value>,
) -> axum::response::Response {
state.metrics.total_requests.fetch_add(1, Ordering::Relaxed);
match process_chat_completion(state.clone(), request, &headers).await {
Ok(response) => {
state.metrics.successful_requests.fetch_add(1, Ordering::Relaxed);
response
}
Err(e) => {
state.metrics.failed_requests.fetch_add(1, Ordering::Relaxed);
create_error_response(&e)
}
}
}
async fn process_chat_completion(
state: Arc<AppState>,
mut request: Value,
headers: &HeaderMap,
) -> Result<axum::response::Response> {
if let Some(user_agent) = headers.get("user-agent") {
if let Ok(ua_str) = user_agent.to_str() {
tracing::debug!("Client User-Agent: {}", ua_str);
}
}
let is_goose_client = detect_goose_client(headers);
if is_goose_client {
tracing::debug!("Using goose-compatible mode (non-streaming SSE)");
let openai_request = parse_openai_request(request)?;
log_incoming_request(&state, &openai_request);
return handle_goose_request(state, openai_request).await;
}
let (should_force_non_streaming, should_use_buffered_streaming) =
determine_streaming_behavior(&state.config, headers);
if should_force_non_streaming {
if let Some(obj) = request.as_object_mut() {
obj.insert("stream".to_string(), serde_json::Value::Bool(false));
}
tracing::debug!("Using non-streaming mode");
} else if should_use_buffered_streaming {
tracing::debug!("Using buffered streaming mode");
} else {
tracing::debug!("Using standard streaming mode");
}
let openai_request = parse_openai_request(request)?;
log_incoming_request(&state, &openai_request);
let anthropic_request = convert_to_anthropic(state.clone(), openai_request)?;
let auth_header = get_authorization_header(state.clone()).await?;
let vertex_response =
make_vertex_request_with_retry(state.clone(), &anthropic_request, &auth_header).await?;
if anthropic_request.stream {
if should_use_buffered_streaming {
handle_buffered_streaming_response(vertex_response, state).await
} else {
handle_streaming_response(vertex_response, state).await
}
} else {
handle_non_streaming_response(vertex_response, state).await
}
}
fn parse_openai_request(
request: Value,
) -> Result<crate::converter::openai_to_anthropic::OpenAiRequest> {
serde_json::from_value(request)
.map_err(|e| ProxyError::Conversion(format!("Invalid request format: {}", e)))
}
fn log_incoming_request(
state: &Arc<AppState>,
request: &crate::converter::openai_to_anthropic::OpenAiRequest,
) {
state.openai_to_anthropic.debug("=== Incoming OpenAI Request ===");
state.openai_to_anthropic.debug(&format!("Model: {:?}", request.model));
state.openai_to_anthropic.debug(&format!("Stream: {:?}", request.stream));
state.openai_to_anthropic.debug(&format!("Messages: {}", request.messages.len()));
if let Some(ref tools) = request.tools {
state.openai_to_anthropic.debug(&format!("Tools provided: {}", tools.len()));
let tool_names: Vec<String> = tools.iter().map(|t| t.function.name.clone()).collect();
state.openai_to_anthropic.debug(&format!("Tool names: {}", tool_names.join(", ")));
}
}
fn convert_to_anthropic(
state: Arc<AppState>,
request: crate::converter::openai_to_anthropic::OpenAiRequest,
) -> Result<crate::converter::openai_to_anthropic::AnthropicRequest> {
state.openai_to_anthropic.convert(request)
}
async fn get_authorization_header(state: Arc<AppState>) -> Result<String> {
state.request_auth.authorization_header_value().await
}
async fn make_vertex_request_with_retry(
state: Arc<AppState>,
anthropic_request: &crate::converter::openai_to_anthropic::AnthropicRequest,
auth_header: &str,
) -> Result<reqwest::Response> {
if !state.config.server.enable_retries {
return make_vertex_request(state, anthropic_request, auth_header).await;
}
let mut attempts = 0;
loop {
attempts += 1;
let response = make_vertex_request(state.clone(), anthropic_request, auth_header).await;
match response {
Ok(resp) => return Ok(resp),
Err(ProxyError::Http(msg)) if attempts < state.config.server.max_retry_attempts => {
if msg.contains("Rate limit") || msg.contains("Quota exceeded") {
state.metrics.quota_errors.fetch_add(1, Ordering::Relaxed);
state.metrics.retry_attempts.fetch_add(1, Ordering::Relaxed);
let delay_secs = BASE_RETRY_DELAY_SECS * 2_u64.pow(attempts - 1);
tracing::warn!(
"Quota exceeded, retrying in {} seconds (attempt {}/{}) - Total quota errors: {}, \
Total retries: {}",
delay_secs,
attempts,
state.config.server.max_retry_attempts,
state.metrics.quota_errors.load(Ordering::Relaxed),
state.metrics.retry_attempts.load(Ordering::Relaxed)
);
tokio::time::sleep(tokio::time::Duration::from_secs(delay_secs)).await;
continue;
}
return Err(ProxyError::Http(msg));
}
Err(e) => return Err(e),
}
}
}
async fn make_vertex_request(
state: Arc<AppState>,
anthropic_request: &crate::converter::openai_to_anthropic::AnthropicRequest,
auth_header: &str,
) -> Result<reqwest::Response> {
let url = state.config.build_predict_url(anthropic_request.stream);
let response = state
.http_client
.post(&url)
.header(AUTHORIZATION_HEADER, auth_header)
.header("Content-Type", CONTENT_TYPE_JSON)
.json(anthropic_request)
.send()
.await
.map_err(ProxyError::Request)?;
validate_vertex_response(response).await
}
async fn validate_vertex_response(response: reqwest::Response) -> Result<reqwest::Response> {
if !response.status().is_success() {
let status = response.status();
let error_text = response.text().await.unwrap_or_else(|_| "Unknown error".to_string());
tracing::error!("Vertex AI error: {}", error_text);
let client_error = match status.as_u16() {
429 => {
if error_text.contains("Quota exceeded") {
tracing::error!(
"Quota exceeded for Vertex AI. Consider requesting quota increase: https://cloud.google.com/vertex-ai/docs/generative-ai/quotas-genai"
);
ProxyError::Http(
"Rate limit exceeded. Please try again later or contact support for quota increase."
.to_string(),
)
} else {
ProxyError::Http("Too many requests. Please try again later.".to_string())
}
}
400 => {
if error_text.contains("tools: Input should be a valid list") {
ProxyError::Conversion("Invalid tools configuration in request.".to_string())
} else {
ProxyError::Http("Bad request format.".to_string())
}
}
401 => ProxyError::Auth(
"Authentication failed. Please check your API credentials.".to_string(),
),
403 => ProxyError::Auth("Access forbidden. Please check your permissions.".to_string()),
404 => ProxyError::Http("Model or endpoint not found.".to_string()),
500..=599 => ProxyError::Http(
"Vertex AI service is temporarily unavailable. Please try again later.".to_string(),
),
_ => ProxyError::Http(format!("Vertex AI returned error ({}): {}", status, error_text)),
};
return Err(client_error);
}
Ok(response)
}
async fn handle_non_streaming_response(
response: reqwest::Response,
state: Arc<AppState>,
) -> Result<Response> {
state.anthropic_to_openai.debug("=== Non-streaming response ===");
let anthropic_response: crate::converter::anthropic_to_openai::AnthropicResponse =
response.json().await.map_err(ProxyError::Request)?;
log_anthropic_response(&state, &anthropic_response);
let openai_response =
state.anthropic_to_openai.convert(anthropic_response, state.config.llm_model());
log_openai_response(&state, &openai_response);
Ok(Json(openai_response).into_response())
}
fn log_anthropic_response(
state: &Arc<AppState>,
response: &crate::converter::anthropic_to_openai::AnthropicResponse,
) {
state
.anthropic_to_openai
.debug(&format!("Anthropic response stop_reason: {:?}", response.stop_reason));
let tool_calls_count = response
.content
.iter()
.filter(|c| {
matches!(
c,
crate::converter::anthropic_to_openai::AnthropicContentBlock::ToolUse { .. }
)
})
.count();
if tool_calls_count > 0 {
state
.anthropic_to_openai
.debug(&format!("Anthropic response contains {} tool call(s)", tool_calls_count));
}
}
fn log_openai_response(
state: &Arc<AppState>,
response: &crate::converter::anthropic_to_openai::OpenAiResponse,
) {
state.anthropic_to_openai.debug("=== Outgoing OpenAI Response ===");
state
.anthropic_to_openai
.debug(&format!("Finish reason: {}", response.choices[0].finish_reason));
if let Some(ref tool_calls) = response.choices[0].message.tool_calls {
state.anthropic_to_openai.debug(&format!("Tool calls in response: {}", tool_calls.len()));
}
}
async fn handle_streaming_response(
response: reqwest::Response,
state: Arc<AppState>,
) -> Result<Response> {
state.anthropic_to_openai.debug("=== Streaming response ===");
let (tx, rx) = mpsc::channel::<Result<Event>>(STREAMING_CHANNEL_BUFFER);
let state_clone = state.clone();
let model = state.config.llm_model().to_string();
tokio::spawn(async move {
process_streaming_events(response, state_clone, model, tx).await;
});
Ok(Sse::new(ReceiverStream::new(rx)).into_response())
}
async fn process_streaming_events(
response: reqwest::Response,
state: Arc<AppState>,
model: String,
tx: mpsc::Sender<Result<Event>>,
) {
let mut stream = response.bytes_stream();
let mut current_tool_call: Option<crate::converter::anthropic_to_openai::StreamingToolCall> =
None;
let mut has_tool_calls = false;
let mut stop_reason_from_delta: Option<String> = None;
let mut buffer = String::new();
while let Some(chunk_result) = stream.next().await {
match chunk_result {
Ok(chunk) => {
let params = StreamChunkParams {
chunk: &chunk,
buffer: &mut buffer,
state: &state,
model: &model,
current_tool_call: &mut current_tool_call,
has_tool_calls: &mut has_tool_calls,
stop_reason_from_delta: &mut stop_reason_from_delta,
tx: &tx,
};
if let Err(e) = process_stream_chunk(params).await {
tracing::error!("Stream processing error: {}", e);
break;
}
}
Err(e) => {
tracing::error!("Stream chunk error: {}", e);
break;
}
}
}
send_stream_done(&tx).await;
}
fn determine_streaming_behavior(
config: &crate::config::Config,
headers: &HeaderMap,
) -> (bool, bool) {
use crate::config::StreamingMode;
match config.streaming.mode {
StreamingMode::Never => (true, false),
StreamingMode::Standard => (false, true),
StreamingMode::Buffered => (false, true),
StreamingMode::Always => (false, true),
StreamingMode::Auto => {
let should_force_non_streaming = detect_problematic_client(headers);
let should_use_buffered_streaming =
!should_force_non_streaming && detect_buffered_streaming_client(headers);
(should_force_non_streaming, should_use_buffered_streaming)
}
}
}
fn detect_goose_client(headers: &HeaderMap) -> bool {
if let Some(org) = headers.get("openai-organization") {
if let Ok(org_str) = org.to_str() {
if org_str.to_lowercase().contains("basebox") {
return true;
}
}
}
if let Some(project) = headers.get("openai-project") {
if let Ok(project_str) = project.to_str() {
if project_str.to_lowercase().contains("gui") {
return true;
}
}
}
false
}
fn detect_problematic_client(headers: &HeaderMap) -> bool {
if let Some(user_agent) = headers.get("user-agent") {
if let Ok(user_agent_str) = user_agent.to_str() {
let ua = user_agent_str.to_lowercase();
if ua.contains("goose")
|| ua.contains("curl")
|| ua.contains("wget")
|| ua.contains("httpie")
|| ua.contains("python-requests")
{
return true;
}
if ua.contains("postman") || ua.contains("insomnia") || ua.contains("thunderclient") {
return true;
}
}
}
if let Some(accept) = headers.get("accept") {
if let Ok(accept_str) = accept.to_str() {
if !accept_str.contains("text/event-stream") && !accept_str.contains("*/*") {
return true;
}
}
}
false
}
fn detect_buffered_streaming_client(headers: &HeaderMap) -> bool {
if let Some(user_agent) = headers.get("user-agent") {
if let Ok(user_agent_str) = user_agent.to_str() {
let ua = user_agent_str.to_lowercase();
if ua.contains("chrome")
|| ua.contains("firefox")
|| ua.contains("safari")
|| ua.contains("edge")
|| ua.contains("vscode")
|| ua.contains("visual studio code")
|| ua.contains("intellij")
|| ua.contains("rustrover")
|| ua.contains("jetbrains")
|| ua.contains("pycharm")
|| ua.contains("clion")
|| ua.contains("webstorm")
|| ua.contains("phpstorm")
{
return true;
}
}
}
false
}
async fn handle_buffered_streaming_response(
response: reqwest::Response,
state: Arc<AppState>,
) -> Result<Response> {
state.anthropic_to_openai.debug("=== Buffered streaming response ===");
let (tx, rx) = mpsc::channel::<Result<Event>>(STREAMING_CHANNEL_BUFFER);
let state_clone = state.clone();
let model = state.config.llm_model().to_string();
tokio::spawn(async move {
process_buffered_streaming_events(response, state_clone, model, tx).await;
});
Ok(Sse::new(ReceiverStream::new(rx)).into_response())
}
async fn process_buffered_streaming_events(
response: reqwest::Response,
state: Arc<AppState>,
model: String,
tx: mpsc::Sender<Result<Event>>,
) {
let mut stream = response.bytes_stream();
let mut current_tool_call: Option<crate::converter::anthropic_to_openai::StreamingToolCall> =
None;
let mut has_tool_calls = false;
let mut stop_reason_from_delta: Option<String> = None;
let mut buffer = String::new();
let mut text_accumulator = String::new();
while let Some(chunk_result) = stream.next().await {
match chunk_result {
Ok(chunk) => {
if let Err(e) = process_buffered_stream_chunk(
&chunk,
&mut buffer,
&state,
&model,
&mut current_tool_call,
&mut has_tool_calls,
&mut stop_reason_from_delta,
&mut text_accumulator,
&tx,
)
.await
{
tracing::error!("Buffered stream processing error: {}", e);
break;
}
}
Err(e) => {
tracing::error!("Stream chunk error: {}", e);
break;
}
}
}
if !text_accumulator.is_empty() {
send_buffered_text(&text_accumulator, &model, &state, &tx).await;
}
send_stream_done(&tx).await;
}
async fn process_buffered_stream_chunk(
chunk: &bytes::Bytes,
buffer: &mut String,
state: &Arc<AppState>,
model: &str,
current_tool_call: &mut Option<crate::converter::anthropic_to_openai::StreamingToolCall>,
has_tool_calls: &mut bool,
stop_reason_from_delta: &mut Option<String>,
text_accumulator: &mut String,
tx: &mpsc::Sender<Result<Event>>,
) -> Result<()> {
let chunk_str = String::from_utf8_lossy(chunk);
let new_content = format!("{}{}", buffer, chunk_str);
let (lines_to_process, new_buffer) = split_sse_lines(&new_content);
*buffer = new_buffer;
for line in lines_to_process {
if let Some(data) = extract_sse_data(line) {
if data == "[DONE]" {
if !text_accumulator.is_empty() {
send_buffered_text(text_accumulator, model, state, tx).await;
text_accumulator.clear();
}
send_sse_event(tx, "[DONE]").await;
continue;
}
process_buffered_sse_event(
data,
state,
model,
current_tool_call,
has_tool_calls,
stop_reason_from_delta,
text_accumulator,
tx,
)
.await;
}
}
Ok(())
}
async fn process_buffered_sse_event(
data: &str,
state: &Arc<AppState>,
model: &str,
current_tool_call: &mut Option<crate::converter::anthropic_to_openai::StreamingToolCall>,
has_tool_calls: &mut bool,
stop_reason_from_delta: &mut Option<String>,
text_accumulator: &mut String,
tx: &mpsc::Sender<Result<Event>>,
) {
match serde_json::from_str::<crate::converter::anthropic_to_openai::AnthropicStreamEvent>(data)
{
Ok(event) => {
if let Some(chunk) = state.anthropic_to_openai.convert_stream_event(
&event,
model,
current_tool_call,
has_tool_calls,
stop_reason_from_delta,
) {
if let Some(content) =
chunk.choices.get(0).and_then(|choice| choice.delta.content.as_ref())
{
text_accumulator.push_str(content);
if text_accumulator.len() >= MIN_BUFFER_SIZE
|| content.contains('.')
|| content.contains('!')
|| content.contains('?')
|| content.contains('\n')
{
send_buffered_text(text_accumulator, model, state, tx).await;
text_accumulator.clear();
}
} else {
if !text_accumulator.is_empty() {
send_buffered_text(text_accumulator, model, state, tx).await;
text_accumulator.clear();
}
match serde_json::to_string(&chunk) {
Ok(json) => {
send_sse_event(tx, &json).await;
}
Err(e) => {
tracing::error!("Failed to serialize chunk: {}", e);
}
}
}
}
}
Err(e) => {
tracing::error!("Failed to parse stream event: {} - data: {}", e, data);
}
}
}
async fn send_buffered_text(
text: &str,
model: &str,
state: &Arc<AppState>,
tx: &mpsc::Sender<Result<Event>>,
) {
if let Some(chunk) = state.anthropic_to_openai.create_text_chunk(text, model) {
match serde_json::to_string(&chunk) {
Ok(json) => {
send_sse_event(tx, &json).await;
}
Err(e) => {
tracing::error!("Failed to serialize buffered text chunk: {}", e);
}
}
}
}
async fn handle_goose_request(
state: Arc<AppState>,
openai_request: crate::converter::openai_to_anthropic::OpenAiRequest,
) -> Result<axum::response::Response> {
let anthropic_request = state.openai_to_anthropic.convert(openai_request)?;
let auth_header = get_authorization_header(state.clone()).await?;
let mut anthropic_request_non_streaming = anthropic_request;
anthropic_request_non_streaming.stream = false;
let vertex_response = make_vertex_request_with_retry(
state.clone(),
&anthropic_request_non_streaming,
&auth_header,
)
.await?;
let anthropic_response: crate::converter::anthropic_to_openai::AnthropicResponse =
vertex_response.json().await.map_err(ProxyError::Request)?;
let openai_response =
state.anthropic_to_openai.convert(anthropic_response, state.config.llm_model());
let (tx, rx) = mpsc::channel::<Result<Event>>(STREAMING_CHANNEL_BUFFER);
tokio::spawn(async move {
if let Some(choice) = openai_response.choices.first() {
if let Some(content) = &choice.message.content {
let chunk = crate::converter::anthropic_to_openai::OpenAiStreamChunk {
id: openai_response.id.clone(),
object: "chat.completion.chunk".to_string(),
created: openai_response.created,
model: openai_response.model.clone(),
choices: vec![crate::converter::anthropic_to_openai::OpenAiStreamChoice {
index: 0,
delta: crate::converter::anthropic_to_openai::OpenAiStreamDelta {
content: Some(content.clone()),
tool_calls: None,
},
finish_reason: None,
}],
};
if let Ok(json) = serde_json::to_string(&chunk) {
let _ = tx.send(Ok(Event::default().data(json))).await;
}
}
if let Some(tool_calls) = &choice.message.tool_calls {
for (index, tool_call) in tool_calls.iter().enumerate() {
let tool_chunk = crate::converter::anthropic_to_openai::OpenAiStreamChunk {
id: openai_response.id.clone(),
object: "chat.completion.chunk".to_string(),
created: openai_response.created,
model: openai_response.model.clone(),
choices: vec![crate::converter::anthropic_to_openai::OpenAiStreamChoice {
index: 0,
delta: crate::converter::anthropic_to_openai::OpenAiStreamDelta {
content: None,
tool_calls: Some(vec![
crate::converter::anthropic_to_openai::OpenAiStreamToolCall {
index: index as u32,
id: Some(tool_call.id.clone()),
call_type: Some(tool_call.call_type.clone()),
function: Some(
crate::converter::anthropic_to_openai::OpenAiStreamFunctionCall {
name: Some(tool_call.function.name.clone()),
arguments: Some(tool_call.function.arguments.clone()),
},
),
},
]),
},
finish_reason: None,
}],
};
if let Ok(json) = serde_json::to_string(&tool_chunk) {
let _ = tx.send(Ok(Event::default().data(json))).await;
}
}
}
let finish_chunk = crate::converter::anthropic_to_openai::OpenAiStreamChunk {
id: openai_response.id,
object: "chat.completion.chunk".to_string(),
created: openai_response.created,
model: openai_response.model,
choices: vec![crate::converter::anthropic_to_openai::OpenAiStreamChoice {
index: 0,
delta: crate::converter::anthropic_to_openai::OpenAiStreamDelta {
content: None,
tool_calls: None,
},
finish_reason: Some(choice.finish_reason.clone()),
}],
};
if let Ok(json) = serde_json::to_string(&finish_chunk) {
let _ = tx.send(Ok(Event::default().data(json))).await;
}
}
let _ = tx.send(Ok(Event::default().data("[DONE]"))).await;
});
Ok(Sse::new(ReceiverStream::new(rx)).into_response())
}
async fn process_stream_chunk(params: StreamChunkParams<'_>) -> Result<()> {
let chunk_str = String::from_utf8_lossy(params.chunk);
let new_content = format!("{}{}", params.buffer, chunk_str);
let (lines_to_process, new_buffer) = split_sse_lines(&new_content);
*params.buffer = new_buffer;
for line in lines_to_process {
if let Some(data) = extract_sse_data(line) {
if data == "[DONE]" {
send_sse_event(params.tx, "[DONE]").await;
continue;
}
process_sse_event(
data,
params.state,
params.model,
params.current_tool_call,
params.has_tool_calls,
params.stop_reason_from_delta,
params.tx,
)
.await;
}
}
Ok(())
}
fn split_sse_lines(content: &str) -> (Vec<&str>, String) {
let mut lines_to_process = Vec::new();
let mut new_buffer = String::new();
let ends_with_newline = content.ends_with('\n');
let all_lines: Vec<&str> = content.lines().collect();
let line_count = all_lines.len();
for (i, line) in all_lines.into_iter().enumerate() {
let is_last = i == line_count - 1;
if is_last && !ends_with_newline {
new_buffer = line.to_string();
} else {
lines_to_process.push(line);
}
}
(lines_to_process, new_buffer)
}
fn extract_sse_data(line: &str) -> Option<&str> {
line.strip_prefix("data: ")
}
async fn process_sse_event(
data: &str,
state: &Arc<AppState>,
model: &str,
current_tool_call: &mut Option<crate::converter::anthropic_to_openai::StreamingToolCall>,
has_tool_calls: &mut bool,
stop_reason_from_delta: &mut Option<String>,
tx: &mpsc::Sender<Result<Event>>,
) {
match serde_json::from_str::<crate::converter::anthropic_to_openai::AnthropicStreamEvent>(data)
{
Ok(event) => {
if let Some(chunk) = state.anthropic_to_openai.convert_stream_event(
&event,
model,
current_tool_call,
has_tool_calls,
stop_reason_from_delta,
) {
match serde_json::to_string(&chunk) {
Ok(json) => {
send_sse_event(tx, &json).await;
}
Err(e) => {
tracing::error!("Failed to serialize chunk: {}", e);
}
}
}
}
Err(e) => {
tracing::error!("Failed to parse stream event: {} - data: {}", e, data);
}
}
}
async fn send_sse_event(tx: &mpsc::Sender<Result<Event>>, data: &str) {
let _ = tx.send(Ok(Event::default().data(data))).await;
}
async fn send_stream_done(tx: &mpsc::Sender<Result<Event>>) {
let _ = tx.send(Ok(Event::default().data("[DONE]"))).await;
}
fn create_error_response(error: &ProxyError) -> axum::response::Response {
let (status_code, error_type) = match error {
ProxyError::Config(_) | ProxyError::Conversion(_) => {
(axum::http::StatusCode::BAD_REQUEST, "invalid_request_error")
}
ProxyError::Auth(_) => (axum::http::StatusCode::UNAUTHORIZED, "authentication_error"),
ProxyError::Http(msg) if msg.contains("Rate limit") || msg.contains("Quota exceeded") => {
(axum::http::StatusCode::TOO_MANY_REQUESTS, "rate_limit_error")
}
ProxyError::Http(msg) if msg.contains("temporarily unavailable") => {
(axum::http::StatusCode::SERVICE_UNAVAILABLE, "service_unavailable")
}
_ => (axum::http::StatusCode::INTERNAL_SERVER_ERROR, "internal_error"),
};
let error_response = json!({
"error": {
"message": error.to_string(),
"type": error_type,
"code": status_code.as_u16()
}
});
(status_code, Json(error_response)).into_response()
}
pub async fn models(State(state): State<Arc<AppState>>) -> Json<Value> {
Json(json!({
"object": "list",
"data": [{
"id": state.config.llm_model(),
"object": "model",
"created": chrono::Utc::now().timestamp_millis(),
"owned_by": "anthropic"
}]
}))
}
pub async fn health(State(state): State<Arc<AppState>>) -> Json<Value> {
let total_requests = state.metrics.total_requests.load(Ordering::Relaxed);
let quota_errors = state.metrics.quota_errors.load(Ordering::Relaxed);
let retry_attempts = state.metrics.retry_attempts.load(Ordering::Relaxed);
let successful_requests = state.metrics.successful_requests.load(Ordering::Relaxed);
let failed_requests = state.metrics.failed_requests.load(Ordering::Relaxed);
Json(json!({
"status": "ok",
"metrics": {
"total_requests": total_requests,
"successful_requests": successful_requests,
"failed_requests": failed_requests,
"quota_errors": quota_errors,
"retry_attempts": retry_attempts,
"success_rate": if total_requests > 0 {
(successful_requests as f64 / total_requests as f64 * 100.0).round()
} else {
100.0
}
}
}))
}
#[cfg(test)]
mod tests {
use axum::http::HeaderValue;
use super::*;
use crate::provider::{AuthStrategy, LlmProviderConfig, VertexProvider};
#[test]
fn test_detect_buffered_streaming_client_rustrover() {
let mut headers = HeaderMap::new();
headers.insert(
"user-agent",
HeaderValue::from_static("RustRover/2024.1 Build #RR-241.14494.158"),
);
assert!(detect_buffered_streaming_client(&headers));
}
#[test]
fn test_detect_buffered_streaming_client_intellij() {
let mut headers = HeaderMap::new();
headers.insert("user-agent", HeaderValue::from_static("IntelliJ IDEA/2024.1"));
assert!(detect_buffered_streaming_client(&headers));
}
#[test]
fn test_detect_problematic_client_goose() {
let mut headers = HeaderMap::new();
headers.insert("user-agent", HeaderValue::from_static("goose/1.0.0"));
assert!(detect_problematic_client(&headers));
}
#[test]
fn test_detect_problematic_client_curl() {
let mut headers = HeaderMap::new();
headers.insert("user-agent", HeaderValue::from_static("curl/7.68.0"));
assert!(detect_problematic_client(&headers));
}
#[test]
fn test_detect_problematic_client_no_sse_accept() {
let mut headers = HeaderMap::new();
headers.insert("user-agent", HeaderValue::from_static("CustomClient/1.0"));
headers.insert("accept", HeaderValue::from_static("application/json"));
assert!(detect_problematic_client(&headers));
}
#[test]
fn test_detect_buffered_streaming_client_chrome() {
let mut headers = HeaderMap::new();
headers.insert(
"user-agent",
HeaderValue::from_static(
"Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) \
Chrome/91.0.4472.124 Safari/537.36",
),
);
assert!(detect_buffered_streaming_client(&headers));
}
#[test]
fn test_detect_buffered_streaming_client_vscode() {
let mut headers = HeaderMap::new();
headers.insert("user-agent", HeaderValue::from_static("Visual Studio Code 1.85.0"));
assert!(detect_buffered_streaming_client(&headers));
}
#[test]
fn test_normal_client_not_problematic() {
let mut headers = HeaderMap::new();
headers.insert("user-agent", HeaderValue::from_static("OpenAI-Client/1.0"));
headers.insert("accept", HeaderValue::from_static("text/event-stream, application/json"));
assert!(!detect_problematic_client(&headers));
assert!(!detect_buffered_streaming_client(&headers));
}
#[test]
fn test_determine_streaming_behavior_auto_mode() {
use crate::config::{
AuthConfig, Config, LogLevel, ServerConfig, ServiceAccountKey, StreamingConfig,
StreamingMode,
};
let service_account_key = ServiceAccountKey {
account_type: "service_account".to_string(),
project_id: "test".to_string(),
private_key_id: "test".to_string(),
private_key: "test".to_string(),
client_email: "test".to_string(),
client_id: "test".to_string(),
auth_uri: "test".to_string(),
token_uri: "test".to_string(),
auth_provider_x509_cert_url: "test".to_string(),
client_x509_cert_url: "test".to_string(),
universe_domain: None,
};
let vertex = VertexProvider {
predict_resource_url: "https://test.example.com/v1/test-model".to_string(),
display_model: "test".to_string(),
auth: AuthStrategy::GcpOAuth2(service_account_key),
};
let config = Config {
server: ServerConfig {
port: 3000,
log_level: LogLevel::Info,
enable_retries: true,
max_retry_attempts: 3,
},
auth: AuthConfig::default(),
streaming: StreamingConfig {
mode: StreamingMode::Auto,
buffer_size: 65536,
chunk_timeout_ms: 5000,
},
vertex: None,
llm_provider: Some(LlmProviderConfig::Vertex(vertex)),
};
let mut headers = HeaderMap::new();
headers.insert("user-agent", HeaderValue::from_static("goose/1.0.0"));
let (force_non_streaming, use_buffered) = determine_streaming_behavior(&config, &headers);
assert!(force_non_streaming);
assert!(!use_buffered);
let mut headers = HeaderMap::new();
headers.insert("user-agent", HeaderValue::from_static("Mozilla/5.0 Chrome/91.0"));
headers.insert("accept", HeaderValue::from_static("text/event-stream"));
let (force_non_streaming, use_buffered) = determine_streaming_behavior(&config, &headers);
assert!(!force_non_streaming);
assert!(use_buffered);
let mut headers = HeaderMap::new();
headers.insert("user-agent", HeaderValue::from_static("curl/7.68.0"));
let (force_non_streaming, use_buffered) = determine_streaming_behavior(&config, &headers);
assert!(force_non_streaming);
assert!(!use_buffered);
}
#[test]
fn test_determine_streaming_behavior_non_streaming_mode() {
use crate::config::{
AuthConfig, Config, LogLevel, ServerConfig, ServiceAccountKey, StreamingConfig,
StreamingMode,
};
let service_account_key = ServiceAccountKey {
account_type: "service_account".to_string(),
project_id: "test".to_string(),
private_key_id: "test".to_string(),
private_key: "test".to_string(),
client_email: "test".to_string(),
client_id: "test".to_string(),
auth_uri: "test".to_string(),
token_uri: "test".to_string(),
auth_provider_x509_cert_url: "test".to_string(),
client_x509_cert_url: "test".to_string(),
universe_domain: None,
};
let vertex = VertexProvider {
predict_resource_url: "https://test.example.com/v1/test-model".to_string(),
display_model: "test".to_string(),
auth: AuthStrategy::GcpOAuth2(service_account_key),
};
let config = Config {
server: ServerConfig {
port: 3000,
log_level: LogLevel::Info,
enable_retries: true,
max_retry_attempts: 3,
},
auth: AuthConfig::default(),
streaming: StreamingConfig {
mode: StreamingMode::Never,
buffer_size: 65536,
chunk_timeout_ms: 5000,
},
vertex: None,
llm_provider: Some(LlmProviderConfig::Vertex(vertex)),
};
let headers = HeaderMap::new();
let (force_non_streaming, use_buffered) = determine_streaming_behavior(&config, &headers);
assert!(force_non_streaming);
assert!(!use_buffered);
}
}