use std::pin::Pin;
use bytes::Bytes;
use futures::Stream;
use crate::error::Result;
use crate::ir::{Capabilities, ModelRequest, ModelResponse, ModelWarning, OutputStrategy};
use crate::rate_limit::RateLimitSnapshot;
use crate::stream::StreamDelta;
pub type BoxByteStream<'a> = Pin<Box<dyn Stream<Item = Result<Bytes>> + Send + 'a>>;
pub type BoxDeltaStream<'a> = Pin<Box<dyn Stream<Item = Result<StreamDelta>> + Send + 'a>>;
#[derive(Clone, Debug)]
pub struct EncodedRequest {
pub method: http::Method,
pub path: String,
pub headers: http::HeaderMap,
pub body: Bytes,
pub streaming: bool,
pub warnings: Vec<ModelWarning>,
}
impl EncodedRequest {
pub fn post_json(path: impl Into<String>, body: Bytes) -> Self {
let mut headers = http::HeaderMap::new();
headers.insert(
http::header::CONTENT_TYPE,
http::HeaderValue::from_static("application/json"),
);
Self {
method: http::Method::POST,
path: path.into(),
headers,
body,
streaming: false,
warnings: Vec::new(),
}
}
#[must_use]
pub const fn into_streaming(mut self) -> Self {
self.streaming = true;
self
}
}
pub trait Codec: Send + Sync + 'static {
fn name(&self) -> &'static str;
fn capabilities(&self, model: &str) -> Capabilities;
fn auto_output_strategy(&self, _model: &str) -> OutputStrategy {
OutputStrategy::Native
}
fn encode(&self, request: &ModelRequest) -> Result<EncodedRequest>;
fn decode(&self, body: &[u8], warnings_in: Vec<ModelWarning>) -> Result<ModelResponse>;
fn extract_rate_limit(&self, _headers: &http::HeaderMap) -> Option<RateLimitSnapshot> {
None
}
fn encode_streaming(&self, request: &ModelRequest) -> Result<EncodedRequest> {
Ok(self.encode(request)?.into_streaming())
}
#[allow(tail_expr_drop_order)]
fn decode_stream<'a>(
&'a self,
bytes: BoxByteStream<'a>,
warnings_in: Vec<ModelWarning>,
) -> BoxDeltaStream<'a> {
Box::pin(async_stream::stream! {
let mut buf: Vec<u8> = Vec::new();
let mut bytes = bytes;
while let Some(chunk) = futures::StreamExt::next(&mut bytes).await {
let chunk = match chunk {
Ok(b) => b,
Err(e) => {
yield Err(e);
return;
}
};
buf.extend_from_slice(&chunk);
}
let response = match self.decode(&buf, warnings_in) {
Ok(r) => r,
Err(e) => {
yield Err(e);
return;
}
};
for delta in deltas_from_response(&response) {
yield Ok(delta);
}
})
}
}
pub fn service_tier_str(tier: crate::ir::ServiceTier) -> &'static str {
match tier {
crate::ir::ServiceTier::Auto => "auto",
crate::ir::ServiceTier::Default => "default",
crate::ir::ServiceTier::Flex => "flex",
crate::ir::ServiceTier::Priority => "priority",
crate::ir::ServiceTier::Scale => "scale",
}
}
pub fn extract_openai_rate_limit(headers: &http::HeaderMap) -> Option<RateLimitSnapshot> {
let mut snapshot = RateLimitSnapshot::default();
let mut populated = false;
let pairs: [(&str, &mut Option<u64>); 2] = [
(
"x-ratelimit-remaining-requests",
&mut snapshot.requests_remaining,
),
(
"x-ratelimit-remaining-tokens",
&mut snapshot.tokens_remaining,
),
];
for (header_name, target) in pairs {
if let Some(v) = headers.get(header_name).and_then(|h| h.to_str().ok())
&& let Ok(parsed) = v.parse::<u64>()
{
*target = Some(parsed);
snapshot.raw.insert(header_name.to_owned(), v.to_owned());
populated = true;
}
}
for header_name in ["x-ratelimit-reset-requests", "x-ratelimit-reset-tokens"] {
if let Some(v) = headers.get(header_name).and_then(|h| h.to_str().ok()) {
snapshot.raw.insert(header_name.to_owned(), v.to_owned());
populated = true;
}
}
populated.then_some(snapshot)
}
fn deltas_from_response(response: &ModelResponse) -> Vec<StreamDelta> {
use crate::ir::ContentPart;
let mut deltas = Vec::new();
deltas.push(StreamDelta::Start {
id: response.id.clone(),
model: response.model.clone(),
provider_echoes: Vec::new(),
});
for part in &response.content {
match part {
ContentPart::Text {
text,
provider_echoes,
..
} => {
deltas.push(StreamDelta::TextDelta {
text: text.clone(),
provider_echoes: provider_echoes.clone(),
});
}
ContentPart::ToolUse {
id,
name,
input,
provider_echoes,
} => {
deltas.push(StreamDelta::ToolUseStart {
id: id.clone(),
name: name.clone(),
provider_echoes: provider_echoes.clone(),
});
deltas.push(StreamDelta::ToolUseInputDelta {
partial_json: input.to_string(),
});
deltas.push(StreamDelta::ToolUseStop);
}
ContentPart::Thinking {
text,
provider_echoes,
..
} => {
deltas.push(StreamDelta::ThinkingDelta {
text: text.clone(),
provider_echoes: provider_echoes.clone(),
});
}
ContentPart::Image { .. }
| ContentPart::Audio { .. }
| ContentPart::Video { .. }
| ContentPart::Document { .. }
| ContentPart::RedactedThinking { .. }
| ContentPart::Citation { .. }
| ContentPart::ToolResult { .. }
| ContentPart::ImageOutput { .. }
| ContentPart::AudioOutput { .. } => {}
}
}
deltas.push(StreamDelta::Usage(response.usage.clone()));
for w in &response.warnings {
deltas.push(StreamDelta::Warning(w.clone()));
}
deltas.push(StreamDelta::Stop {
stop_reason: response.stop_reason.clone(),
});
deltas
}
pub(super) fn parse_response_body(
body: &[u8],
codec_name: &'static str,
) -> Result<serde_json::Value> {
serde_json::from_slice(body).map_err(|e| {
const PEEK_BYTES: usize = 200;
let peek_end = peek_at_char_boundary(body, PEEK_BYTES);
let peek = body.get(..peek_end).map_or_else(String::new, |slice| {
String::from_utf8_lossy(slice).into_owned()
});
let suffix = if body.len() > PEEK_BYTES { "…" } else { "" };
crate::error::Error::provider_network(format!(
"{codec_name} codec failed to decode response: {e}; \
body was {} bytes; first {peek_end} bytes: {peek:?}{suffix} \
— the response did not parse as JSON; the upstream may have \
returned an HTML error page, a truncated body, or a wire \
format the codec does not yet understand",
body.len(),
))
})
}
fn peek_at_char_boundary(body: &[u8], max: usize) -> usize {
let mut cut = max.min(body.len());
while cut > 0
&& body
.get(..cut)
.is_some_and(|slice| std::str::from_utf8(slice).is_err())
{
cut -= 1;
}
cut
}