use std::collections::HashMap;
use std::time::Duration;
use reqwest::header::{HeaderMap, HeaderValue, AUTHORIZATION, CONTENT_TYPE, USER_AGENT};
use crate::chat::ChatService;
use crate::completion::CompletionService;
use crate::embedding::EmbeddingService;
use crate::error::{ApiError, Error};
use crate::model::ModelService;
pub const VERSION: &str = "0.1.1";
pub const DEFAULT_BASE_URL: &str = "https://api.openmodex.com/v1";
const DEFAULT_TIMEOUT: Duration = Duration::from_secs(30);
const DEFAULT_MAX_RETRIES: u32 = 2;
#[derive(Debug, Clone)]
pub struct OpenModex {
pub(crate) http: reqwest::Client,
pub(crate) base_url: String,
pub(crate) api_key: String,
pub(crate) max_retries: u32,
pub(crate) fallback_models: Vec<String>,
pub(crate) default_model: Option<String>,
pub(crate) default_headers: HashMap<String, String>,
}
impl OpenModex {
pub fn new(api_key: impl Into<String>) -> Result<Self, Error> {
ClientBuilder::new().api_key(api_key).build()
}
pub fn from_env() -> Result<Self, Error> {
ClientBuilder::new().build()
}
pub fn builder() -> ClientBuilder {
ClientBuilder::new()
}
pub fn api_key(&self) -> &str {
&self.api_key
}
pub fn base_url(&self) -> &str {
&self.base_url
}
pub fn chat(&self) -> ChatService<'_> {
ChatService::new(self)
}
pub fn completions(&self) -> CompletionService<'_> {
CompletionService::new(self)
}
pub fn embeddings(&self) -> EmbeddingService<'_> {
EmbeddingService::new(self)
}
pub fn models(&self) -> ModelService<'_> {
ModelService::new(self)
}
pub(crate) async fn post<T: serde::de::DeserializeOwned>(
&self,
path: &str,
body: &impl serde::Serialize,
) -> Result<T, Error> {
let url = format!("{}{}", self.base_url, path);
let body_bytes = serde_json::to_vec(body)?;
let mut last_err: Option<Error> = None;
for attempt in 0..=self.max_retries {
if attempt > 0 {
let backoff = Duration::from_millis(500 * 2u64.pow(attempt - 1));
tokio::time::sleep(backoff).await;
}
let mut req = self
.http
.post(&url)
.header(CONTENT_TYPE, "application/json")
.header("Accept", "application/json")
.body(body_bytes.clone());
for (k, v) in &self.default_headers {
req = req.header(k.as_str(), v.as_str());
}
let resp = match req.send().await {
Ok(r) => r,
Err(e) => {
last_err = Some(Error::Http(e));
continue;
}
};
let status = resp.status().as_u16();
if status == 429 || status == 500 || status == 502 || status == 503 || status == 504 {
let api_err = parse_error_response(resp).await;
last_err = Some(api_err);
continue;
}
if status >= 400 {
return Err(parse_error_response(resp).await);
}
let result: T = resp.json().await.map_err(Error::Http)?;
return Ok(result);
}
Err(last_err.unwrap_or(Error::Timeout))
}
pub(crate) async fn get<T: serde::de::DeserializeOwned>(
&self,
path: &str,
) -> Result<T, Error> {
let url = format!("{}{}", self.base_url, path);
let mut last_err: Option<Error> = None;
for attempt in 0..=self.max_retries {
if attempt > 0 {
let backoff = Duration::from_millis(500 * 2u64.pow(attempt - 1));
tokio::time::sleep(backoff).await;
}
let mut req = self
.http
.get(&url)
.header("Accept", "application/json");
for (k, v) in &self.default_headers {
req = req.header(k.as_str(), v.as_str());
}
let resp = match req.send().await {
Ok(r) => r,
Err(e) => {
last_err = Some(Error::Http(e));
continue;
}
};
let status = resp.status().as_u16();
if status == 429 || status == 500 || status == 502 || status == 503 || status == 504 {
let api_err = parse_error_response(resp).await;
last_err = Some(api_err);
continue;
}
if status >= 400 {
return Err(parse_error_response(resp).await);
}
let result: T = resp.json().await.map_err(Error::Http)?;
return Ok(result);
}
Err(last_err.unwrap_or(Error::Timeout))
}
pub(crate) async fn post_stream(
&self,
path: &str,
body: &impl serde::Serialize,
) -> Result<reqwest::Response, Error> {
let url = format!("{}{}", self.base_url, path);
let body_bytes = serde_json::to_vec(body)?;
let mut last_err: Option<Error> = None;
for attempt in 0..=self.max_retries {
if attempt > 0 {
let backoff = Duration::from_millis(500 * 2u64.pow(attempt - 1));
tokio::time::sleep(backoff).await;
}
let mut req = self
.http
.post(&url)
.header(CONTENT_TYPE, "application/json")
.header("Accept", "text/event-stream")
.body(body_bytes.clone());
for (k, v) in &self.default_headers {
req = req.header(k.as_str(), v.as_str());
}
let resp = match req.send().await {
Ok(r) => r,
Err(e) => {
last_err = Some(Error::Http(e));
continue;
}
};
let status = resp.status().as_u16();
if status == 429 || status == 500 || status == 502 || status == 503 || status == 504 {
let api_err = parse_error_response(resp).await;
last_err = Some(api_err);
continue;
}
if status >= 400 {
return Err(parse_error_response(resp).await);
}
return Ok(resp);
}
Err(last_err.unwrap_or(Error::Timeout))
}
pub(crate) async fn with_fallback<T, F, Fut>(
&self,
primary_model: &str,
mut f: F,
) -> Result<T, Error>
where
F: FnMut(String) -> Fut,
Fut: std::future::Future<Output = Result<T, Error>>,
{
match f(primary_model.to_string()).await {
Ok(v) => return Ok(v),
Err(e) => {
if !e.should_fallback() || self.fallback_models.is_empty() {
return Err(e);
}
let mut _last_err = e;
for model in &self.fallback_models {
if model == primary_model {
continue;
}
match f(model.clone()).await {
Ok(v) => return Ok(v),
Err(e) => {
if !e.should_fallback() {
return Err(e);
}
_last_err = e;
}
}
}
Err(Error::AllFallbacksFailed)
}
}
}
}
#[derive(Debug, Clone)]
pub struct ClientBuilder {
api_key: Option<String>,
base_url: Option<String>,
timeout: Option<Duration>,
max_retries: Option<u32>,
default_headers: HashMap<String, String>,
fallback_models: Vec<String>,
default_model: Option<String>,
}
impl ClientBuilder {
pub fn new() -> Self {
Self {
api_key: None,
base_url: None,
timeout: None,
max_retries: None,
default_headers: HashMap::new(),
fallback_models: Vec::new(),
default_model: None,
}
}
pub fn api_key(mut self, key: impl Into<String>) -> Self {
let key = key.into();
if !key.is_empty() {
self.api_key = Some(key);
}
self
}
pub fn base_url(mut self, url: impl Into<String>) -> Self {
self.base_url = Some(url.into());
self
}
pub fn timeout(mut self, timeout: Duration) -> Self {
self.timeout = Some(timeout);
self
}
pub fn max_retries(mut self, n: u32) -> Self {
self.max_retries = Some(n);
self
}
pub fn default_header(mut self, key: impl Into<String>, value: impl Into<String>) -> Self {
self.default_headers.insert(key.into(), value.into());
self
}
pub fn default_headers(mut self, headers: HashMap<String, String>) -> Self {
self.default_headers.extend(headers);
self
}
pub fn fallback_models(mut self, models: impl IntoIterator<Item = impl Into<String>>) -> Self {
self.fallback_models = models.into_iter().map(|m| m.into()).collect();
self
}
pub fn default_model(mut self, model: impl Into<String>) -> Self {
self.default_model = Some(model.into());
self
}
pub fn build(self) -> Result<OpenModex, Error> {
let api_key = self
.api_key
.or_else(|| std::env::var("OPENMODEX_API_KEY").ok())
.ok_or(Error::MissingApiKey)?;
let base_url = self
.base_url
.unwrap_or_else(|| DEFAULT_BASE_URL.to_string());
let timeout = self.timeout.unwrap_or(DEFAULT_TIMEOUT);
let max_retries = self.max_retries.unwrap_or(DEFAULT_MAX_RETRIES);
let mut default_headers = HeaderMap::new();
default_headers.insert(
AUTHORIZATION,
HeaderValue::from_str(&format!("Bearer {api_key}"))
.expect("invalid API key characters"),
);
default_headers.insert(
USER_AGENT,
HeaderValue::from_static(concat!("openmodex-rust/", "0.1.1")),
);
let http = reqwest::Client::builder()
.timeout(timeout)
.default_headers(default_headers)
.build()
.map_err(Error::Http)?;
Ok(OpenModex {
http,
base_url,
api_key,
max_retries,
fallback_models: self.fallback_models,
default_model: self.default_model,
default_headers: self.default_headers,
})
}
}
impl Default for ClientBuilder {
fn default() -> Self {
Self::new()
}
}
async fn parse_error_response(resp: reqwest::Response) -> Error {
let status = resp.status().as_u16();
let body = resp.text().await.unwrap_or_default();
#[derive(serde::Deserialize)]
struct ErrorEnvelope {
error: ErrorBody,
}
#[derive(serde::Deserialize)]
struct ErrorBody {
#[serde(default)]
code: Option<String>,
#[serde(default, rename = "type")]
error_type: Option<String>,
#[serde(default)]
message: String,
}
if let Ok(envelope) = serde_json::from_str::<ErrorEnvelope>(&body) {
Error::Api(ApiError {
status_code: status,
code: envelope.error.code,
error_type: envelope.error.error_type,
message: envelope.error.message,
})
} else {
Error::Api(ApiError {
status_code: status,
code: None,
error_type: None,
message: body,
})
}
}