use std::collections::{HashMap, HashSet};
use bytes::Bytes;
use serde::{Deserialize, Serialize};
use thiserror::Error;
use typed_builder::TypedBuilder;
use validator::Validate;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
#[non_exhaustive]
pub enum ApiFormat {
AnthropicMessages,
OpenaiChat,
OpenaiResponses,
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum ImageSource {
Inline {
media_type: String,
data: Bytes,
},
Url {
url: String,
},
}
#[derive(Debug, Clone)]
#[non_exhaustive]
pub enum ContentBlock {
Text { text: String },
Image { source: ImageSource },
ToolUse {
id: String,
name: String,
input: serde_json::Value,
},
ToolResult {
tool_use_id: String,
content: Vec<ContentBlock>,
},
Thinking {
text: String,
usage: Option<u64>,
},
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
#[non_exhaustive]
pub enum StopReason {
EndTurn,
MaxTokens,
ToolUse,
StopSequence,
ContentFilter,
}
#[derive(Debug, Error)]
pub enum TransformError {
#[error("invalid format: {0}")]
InvalidFormat(String),
#[error("missing required field: {0}")]
MissingRequiredField(String),
#[error("buffer limit exceeded: {0}")]
BufferLimitExceeded(String),
#[error("stream interrupted: {0}")]
StreamInterrupted(String),
#[error("upstream error: {0}")]
UpstreamError(String),
#[error("lossy downgrade: {0}")]
LossyDowngrade(String),
}
impl TransformError {
#[must_use]
pub fn with_source(
self,
source: impl std::error::Error + Send + Sync + 'static,
) -> anyhow::Error {
anyhow::Error::new(self).context(source.to_string())
}
#[must_use]
pub fn sanitized_message(&self) -> String {
match self {
Self::InvalidFormat(_) => "invalid request format".to_string(),
Self::MissingRequiredField(field) => {
format!("missing required field: {field}")
}
Self::BufferLimitExceeded(_) => "request too large".to_string(),
Self::StreamInterrupted(_) => "stream was interrupted".to_string(),
Self::UpstreamError(_) => "upstream provider error".to_string(),
Self::LossyDowngrade(_) => "feature not supported".to_string(),
}
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
#[non_exhaustive]
pub enum StreamDelta {
Text {
text: String,
},
Thinking {
thinking: String,
},
Signature {
signature: String,
},
InputJson {
partial_json: String,
},
}
#[derive(Debug, Clone)]
#[non_exhaustive]
pub enum StreamEvent {
MessageStart {
role: String,
message_id: String,
model: String,
usage: Usage,
},
ContentBlockStart {
index: usize,
content_block: ContentBlock,
},
ContentBlockDelta {
index: usize,
delta: StreamDelta,
},
ContentBlockStop {
index: usize,
},
MessageDelta {
stop_reason: Option<StopReason>,
stop_sequence: Option<String>,
usage: Usage,
},
MessageStop,
Error {
error_type: String,
message: String,
},
}
#[derive(Debug, Clone, Default, PartialEq, Eq, Serialize, Deserialize)]
pub struct Usage {
#[serde(default)]
pub input_tokens: u64,
#[serde(default)]
pub output_tokens: u64,
#[serde(default)]
pub cache_read_input_tokens: u64,
#[serde(default)]
pub cache_creation_input_tokens: u64,
#[serde(default)]
pub cached_tokens: u64,
#[serde(default)]
pub reasoning_tokens: u64,
}
#[derive(Debug, Clone)]
pub struct TransformRequest {
pub headers: HashMap<String, String>,
pub path: String,
pub body: Bytes,
}
#[derive(Debug, Clone)]
pub struct TransformResponse {
pub headers: HashMap<String, String>,
pub path: String,
pub body: Bytes,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
#[non_exhaustive]
pub enum StreamContentBlockKind {
Text,
Thinking,
ToolUse,
}
#[derive(Debug, Default)]
pub struct ResponsesStreamState {
pub sequence_number: u64,
pub created_at: Option<u64>,
pub item_ids: HashMap<usize, String>,
pub call_ids: HashMap<usize, String>,
pub tool_names: HashMap<usize, String>,
pub text_fragments: HashMap<usize, String>,
pub reasoning_fragments: HashMap<usize, String>,
pub function_arguments: HashMap<usize, String>,
pub final_stop_reason: Option<StopReason>,
pub seen_tool_indices: HashSet<usize>,
}
#[derive(Debug, Default, TypedBuilder, Validate)]
pub struct StreamState {
pub started: bool,
pub finished: bool,
pub message_id: Option<String>,
pub model_name: Option<String>,
pub total_buffer_bytes: usize,
pub content_block_index: usize,
pub active_content_block_index: Option<usize>,
pub active_content_block_kind: Option<StreamContentBlockKind>,
pub last_usage: Usage,
pub tool_correlation: HashMap<String, String>,
pub tool_block_indices: HashMap<usize, usize>,
pub content_block_kinds: HashMap<usize, StreamContentBlockKind>,
pub responses: ResponsesStreamState,
}
pub const MAX_JSON_DEPTH: usize = 64;
pub const MAX_MESSAGES_COUNT: usize = 10_000;
pub const MAX_SSE_STREAM_BYTES: usize = 1024 * 1024;
pub fn validate_json_depth(value: &serde_json::Value) -> Result<(), TransformError> {
fn depth(v: &serde_json::Value) -> usize {
match v {
serde_json::Value::Object(map) => 1 + map.values().map(depth).max().unwrap_or(0),
serde_json::Value::Array(arr) => 1 + arr.iter().map(depth).max().unwrap_or(0),
_ => 0,
}
}
let d = depth(value);
if d > MAX_JSON_DEPTH {
Err(TransformError::InvalidFormat(
"JSON nesting depth exceeds maximum allowed".to_string(),
))
} else {
Ok(())
}
}