use reqwest::header::{HeaderMap, HeaderValue, AUTHORIZATION, CONTENT_TYPE};
use reqwest::Client;
use reqwest_eventsource::EventSource;
use serde::Deserialize;
use crate::config::Config;
use crate::error::{OtariError, Result};
use crate::types::{
Batch, BatchResult, ChatCompletion, CompletionParams, CompletionStream, CreateBatchParams,
ListBatchesOptions, ModerationParams, ModerationResponse, RerankParams, RerankResponse,
};
mod models;
use models::request::GatewayRequest;
use models::response::GatewayResponse;
use models::stream::GatewayStream;
const OTARI_HEADER_NAME: &str = "Otari-Key";
const OTARI_PLATFORM_TOKEN_ENV: &str = "OTARI_PLATFORM_TOKEN";
const OTARI_API_BASE_ENV: &str = "OTARI_API_BASE";
pub struct Otari {
client: Client,
api_base: String,
platform_mode: bool,
}
impl Otari {
pub fn from_config(config: Config) -> Result<Self> {
let api_base = config
.api_base
.or_else(|| std::env::var(OTARI_API_BASE_ENV).ok())
.ok_or_else(|| {
OtariError::provider_error(format!(
"api_base is required (set via config or {OTARI_API_BASE_ENV} env var)"
))
})?
.trim_end_matches('/')
.to_string();
let platform_token_env = std::env::var(OTARI_PLATFORM_TOKEN_ENV).ok();
let explicit_platform_token = config.extra.get("platform_token").cloned();
let explicit_platform_mode = config.extra.get("platform_mode").map(|v| v == "true");
let (platform_mode, headers) = resolve_auth(
config.api_key,
explicit_platform_token,
explicit_platform_mode,
platform_token_env,
)?;
let client = Client::builder()
.default_headers(headers)
.build()
.map_err(|e| OtariError::provider_error(format!("Failed to build HTTP client: {e}")))?;
Ok(Self {
client,
api_base,
platform_mode,
})
}
pub fn is_platform_mode(&self) -> bool {
self.platform_mode
}
pub async fn completion(&self, params: CompletionParams) -> Result<ChatCompletion> {
let body: GatewayRequest = params.try_into()?;
let response = self
.client
.post(format!("{}/v1/chat/completions", self.api_base))
.json(&body)
.send()
.await?;
let status = response.status().as_u16();
if status != 200 {
return Err(convert_error(response).await);
}
Ok(response.json::<GatewayResponse>().await?.into())
}
#[allow(clippy::unused_async)]
pub async fn completion_stream(&self, params: CompletionParams) -> Result<CompletionStream> {
let model = params.model_id.clone();
let body = TryInto::<GatewayRequest>::try_into(params)?.stream();
let request = self
.client
.post(format!("{}/v1/chat/completions", self.api_base))
.json(&body);
let es = EventSource::new(request).map_err(|e| OtariError::Streaming {
provider: "otari".into(),
message: e.to_string().into(),
})?;
GatewayStream::new(es, model).try_into()
}
pub async fn rerank(&self, params: RerankParams) -> Result<RerankResponse> {
let body = models::rerank::GatewayRerankRequest::from(params);
let response = self
.client
.post(format!("{}/v1/rerank", self.api_base))
.json(&body)
.send()
.await
.map_err(OtariError::from)?;
let status = response.status().as_u16();
if status != 200 {
return Err(convert_error(response).await);
}
response
.json::<RerankResponse>()
.await
.map_err(OtariError::from)
}
pub async fn create_batch(&self, params: CreateBatchParams) -> Result<Batch> {
let url = format!("{}/v1/batches", self.api_base);
let response = self.client.post(&url).json(¶ms).send().await?;
if response.status().as_u16() != 200 {
return Err(convert_batch_error(response, "/v1/batches").await);
}
Ok(response.json::<Batch>().await?)
}
pub async fn retrieve_batch(&self, batch_id: &str, provider: &str) -> Result<Batch> {
let url = format!("{}/v1/batches/{}", self.api_base, batch_id);
let response = self
.client
.get(&url)
.query(&[("provider", provider)])
.send()
.await?;
let path = format!("/v1/batches/{batch_id}");
if response.status().as_u16() != 200 {
return Err(convert_batch_error(response, &path).await);
}
Ok(response.json::<Batch>().await?)
}
pub async fn cancel_batch(&self, batch_id: &str, provider: &str) -> Result<Batch> {
let url = format!("{}/v1/batches/{}/cancel", self.api_base, batch_id);
let response = self
.client
.post(&url)
.query(&[("provider", provider)])
.send()
.await?;
let path = format!("/v1/batches/{batch_id}/cancel");
if response.status().as_u16() != 200 {
return Err(convert_batch_error(response, &path).await);
}
Ok(response.json::<Batch>().await?)
}
pub async fn list_batches(
&self,
provider: &str,
options: ListBatchesOptions,
) -> Result<Vec<Batch>> {
let url = format!("{}/v1/batches", self.api_base);
let mut query: Vec<(&str, String)> = vec![("provider", provider.to_string())];
if let Some(after) = &options.after {
query.push(("after", after.clone()));
}
if let Some(limit) = options.limit {
query.push(("limit", limit.to_string()));
}
let response = self.client.get(&url).query(&query).send().await?;
if response.status().as_u16() != 200 {
return Err(convert_batch_error(response, "/v1/batches").await);
}
#[derive(Deserialize)]
struct ListResponse {
data: Vec<Batch>,
}
let list_resp: ListResponse = response.json().await?;
Ok(list_resp.data)
}
pub async fn retrieve_batch_results(
&self,
batch_id: &str,
provider: &str,
) -> Result<BatchResult> {
let url = format!("{}/v1/batches/{}/results", self.api_base, batch_id);
let response = self
.client
.get(&url)
.query(&[("provider", provider)])
.send()
.await?;
let path = format!("/v1/batches/{batch_id}/results");
if response.status().as_u16() != 200 {
return Err(convert_batch_error(response, &path).await);
}
Ok(response.json::<BatchResult>().await?)
}
pub async fn moderation(&self, params: ModerationParams) -> Result<ModerationResponse> {
let mut url = format!("{}/v1/moderations", self.api_base);
if params.include_raw {
url.push_str("?include_raw=true");
}
let body = serde_json::to_value(¶ms)?;
let response = self.client.post(&url).json(&body).send().await?;
if !response.status().is_success() {
return Err(convert_error(response).await);
}
let body_bytes = response.bytes().await?;
serde_json::from_slice::<ModerationResponse>(&body_bytes).map_err(OtariError::from)
}
}
fn resolve_auth(
api_key: Option<String>,
platform_token: Option<String>,
platform_mode: Option<bool>,
platform_token_env: Option<String>,
) -> Result<(bool, HeaderMap)> {
let mut headers = HeaderMap::new();
headers.insert(CONTENT_TYPE, HeaderValue::from_static("application/json"));
if platform_mode == Some(true) {
let token = platform_token
.or(api_key)
.or(platform_token_env)
.ok_or_else(|| OtariError::MissingApiKey {
provider: "otari".into(),
env_var: OTARI_PLATFORM_TOKEN_ENV.into(),
})?;
let val = format!("Bearer {token}");
headers.insert(
AUTHORIZATION,
HeaderValue::from_str(&val)
.map_err(|e| OtariError::provider_error(format!("Invalid platform token: {e}")))?,
);
return Ok((true, headers));
}
if platform_mode.is_none() && api_key.is_none() {
if let Some(token) = platform_token.or(platform_token_env) {
let val = format!("Bearer {token}");
headers.insert(
AUTHORIZATION,
HeaderValue::from_str(&val).map_err(|e| {
OtariError::provider_error(format!("Invalid platform token: {e}"))
})?,
);
return Ok((true, headers));
}
}
let key = api_key
.or_else(|| std::env::var("OTARI_API_KEY").ok())
.unwrap_or_default();
if !key.is_empty() {
let val = format!("Bearer {key}");
headers.insert(
OTARI_HEADER_NAME,
HeaderValue::from_str(&val)
.map_err(|e| OtariError::provider_error(format!("Invalid API key: {e}")))?,
);
}
Ok((false, headers))
}
async fn convert_error(response: reqwest::Response) -> OtariError {
let status = response.status().as_u16();
let correlation_id = response
.headers()
.get("x-correlation-id")
.and_then(|v| v.to_str().ok())
.map(String::from);
let retry_after = response
.headers()
.get("retry-after")
.and_then(|v| v.to_str().ok())
.map(String::from);
let body = response.text().await.unwrap_or_default();
let message = extract_error_message(&body).unwrap_or_else(|| {
if body.is_empty() {
format!("HTTP {status}")
} else {
body.clone()
}
});
let detail = match &correlation_id {
Some(cid) => format!("{message} (correlation_id={cid})"),
None => message,
};
if status == 400 && detail.contains("does not support") && detail.contains("moderation") {
let provider = parse_unsupported_provider(&detail).unwrap_or_else(|| "unknown".to_string());
let operation = if detail.contains("multimodal") {
"multimodal_moderation"
} else {
"moderation"
};
return OtariError::unsupported_dynamic(provider, operation);
}
let detail_with_retry = match &retry_after {
Some(ra) => format!("{detail} (retry_after={ra})"),
None => detail,
};
match status {
401 | 403 => OtariError::authentication(detail_with_retry),
402 => OtariError::provider_error(format!("Insufficient funds: {detail_with_retry}")),
404 => OtariError::model_not_found(detail_with_retry),
429 => OtariError::rate_limit(detail_with_retry),
502 => OtariError::provider_error(format!("Upstream provider error: {detail_with_retry}")),
504 => OtariError::provider_error(format!("Gateway timeout: {detail_with_retry}")),
_ => OtariError::provider_error(format!("HTTP {status}: {detail_with_retry}")),
}
}
fn extract_error_message(body: &str) -> Option<String> {
let val: serde_json::Value = serde_json::from_str(body).ok()?;
if let Some(err) = val.get("error") {
if let Some(msg) = err.get("message").and_then(|m| m.as_str()) {
return Some(msg.to_string());
}
if let Some(s) = err.as_str() {
return Some(s.to_string());
}
}
if let Some(detail) = val.get("detail").and_then(|d| d.as_str()) {
return Some(detail.to_string());
}
None
}
fn parse_unsupported_provider(detail: &str) -> Option<String> {
let after = detail.strip_prefix("Provider ")?;
let before_does = after.split(" does not").next()?;
if before_does.is_empty() {
None
} else {
Some(before_does.to_string())
}
}
async fn convert_batch_error(response: reqwest::Response, path: &str) -> OtariError {
let status = response.status().as_u16();
if status == 409 || (status == 404 && path.contains("/v1/batches")) {
let correlation_id = response
.headers()
.get("x-correlation-id")
.and_then(|v| v.to_str().ok())
.map(String::from);
let body = response.text().await.unwrap_or_default();
let message = extract_error_message(&body).unwrap_or_else(|| {
if body.is_empty() {
format!("HTTP {status}")
} else {
body.clone()
}
});
let detail = match &correlation_id {
Some(cid) => format!("{message} (correlation_id={cid})"),
None => message,
};
return match status {
409 => {
let batch_id = extract_batch_id_from_detail(&detail)
.unwrap_or_default();
let batch_status =
extract_batch_status_from_detail(&detail).unwrap_or("unknown".to_string());
OtariError::BatchNotComplete {
batch_id: batch_id.into(),
status: batch_status.into(),
provider: "otari".into(),
}
}
404 => OtariError::Provider {
message: format!(
"This gateway does not support batch operations. Upgrade your gateway. ({detail})"
)
.into(),
provider: "otari".into(),
},
_ => unreachable!(),
};
}
convert_error(response).await
}
fn extract_batch_id_from_detail(detail: &str) -> Option<String> {
let marker = "atch '";
let start = detail.find(marker)?;
let value_start = start + marker.len();
let rest = &detail[value_start..];
let end = rest.find('\'')?;
let value = &rest[..end];
if value.is_empty() {
None
} else {
Some(value.to_string())
}
}
fn extract_batch_status_from_detail(detail: &str) -> Option<String> {
let marker = "status: ";
let start = detail.find(marker)?;
let value_start = start + marker.len();
let rest = &detail[value_start..];
let end = rest
.find(|c: char| !c.is_alphanumeric() && c != '_')
.unwrap_or(rest.len());
let value = &rest[..end];
if value.is_empty() {
None
} else {
Some(value.to_string())
}
}