use super::Provider;
use crate::error::AppError;
use async_trait::async_trait;
use axum::http::HeaderMap;
use tracing::{debug, error};
use axum::{
body::{Body, to_bytes},
http::{HeaderValue, Response},
};
use serde_json::Value;
use uuid::Uuid;
pub struct TogetherProvider {
base_url: String,
}
impl TogetherProvider {
pub fn new() -> Self {
Self {
base_url: "https://api.together.xyz".to_string(),
}
}
}
#[async_trait]
impl Provider for TogetherProvider {
fn base_url(&self) -> String {
self.base_url.clone()
}
fn name(&self) -> &str {
"together"
}
fn process_headers(&self, original_headers: &HeaderMap) -> Result<HeaderMap, AppError> {
debug!("Processing Together request headers");
let mut headers = HeaderMap::new();
headers.insert(
http::header::CONTENT_TYPE,
http::header::HeaderValue::from_static("application/json"),
);
if let Some(auth) = original_headers
.get(http::header::AUTHORIZATION)
.and_then(|h| h.to_str().ok())
{
if !auth.starts_with("Bearer ") {
error!(
"Invalid authorization format for Together request - must start with 'Bearer '"
);
return Err(AppError::InvalidHeader);
}
if auth.len() <= 7 {
error!("Empty Bearer token in Together authorization header");
return Err(AppError::InvalidHeader);
}
debug!("Using provided authorization header for Together");
headers.insert(
http::header::AUTHORIZATION,
http::header::HeaderValue::from_str(auth).map_err(|_| {
error!("Invalid characters in Together authorization header");
AppError::InvalidHeader
})?,
);
} else {
error!("Missing Bearer token in Authorization header for Together request");
return Err(AppError::MissingApiKey);
}
Ok(headers)
}
async fn process_response(&self, response: Response<Body>) -> Result<Response<Body>, AppError> {
let (mut parts, body) = response.into_parts();
let is_streaming = parts.headers.get(http::header::CONTENT_TYPE)
.and_then(|v| v.to_str().ok())
.map_or(false, |ct| ct.contains("text/event-stream"));
if is_streaming {
let has_request_id = parts.headers.get("x-request-id").is_some();
if !has_request_id {
let generated_id = format!("req_{}", Uuid::new_v4().simple());
debug!("Generated request ID for Together streaming response: {}", generated_id);
if let Ok(header_value) = HeaderValue::from_str(&generated_id) {
parts.headers.insert("x-request-id", header_value);
}
}
return Ok(Response::from_parts(parts, body));
}
let bytes = to_bytes(body, usize::MAX).await?;
let has_request_id = parts.headers.get("x-request-id").is_some();
if !has_request_id {
if let Ok(json) = serde_json::from_slice::<Value>(&bytes) {
let body_request_id = json.get("id").and_then(|v| v.as_str()).map(|id| id.to_string());
if let Some(id) = body_request_id {
debug!("Adding Together request ID from body to response headers: {}", id);
if let Ok(header_value) = HeaderValue::from_str(&id) {
parts.headers.insert("x-request-id", header_value);
}
} else {
let generated_id = format!("req_{}", Uuid::new_v4().simple());
debug!("Generated request ID for Together response: {}", generated_id);
if let Ok(header_value) = HeaderValue::from_str(&generated_id) {
parts.headers.insert("x-request-id", header_value);
}
}
return Ok(Response::from_parts(parts, Body::from(bytes)));
}
}
Ok(Response::from_parts(parts, Body::from(bytes)))
}
}