pub mod chat;
pub mod models;
pub mod text_to_image;
pub mod text_to_speech;
use base64::prelude::BASE64_STANDARD;
use base64::Engine;
use reqwest::header::HeaderMap;
use reqwest::header::HeaderValue;
use serde::Deserialize;
use serde::Serialize;
use std::collections::HashMap;
use std::error::Error;
use std::fs::File;
use std::io::Read;
use std::str::FromStr;
pub(crate) fn request_headers(key: &Key) -> Result<HeaderMap, Box<dyn Error + Send + Sync>> {
let mut headers = HeaderMap::new();
headers.insert(
"Authorization",
HeaderValue::from_str(&format!("Bearer {}", key.key))?,
);
headers.insert("Content-Type", HeaderValue::from_str("application/json")?);
Ok(headers)
}
pub(crate) fn openai_base_url(provider: &Provider) -> String {
match provider {
Provider::Google => format!("{}/v1beta/openai", provider.domain()),
Provider::Groq => format!("{}/openai/v1", provider.domain()),
Provider::Hyperbolic => format!("{}/v1", provider.domain()),
Provider::Mistral => format!("{}/v1", provider.domain()),
Provider::OpenAI => format!("{}/v1", provider.domain()),
Provider::OpenAICompatible(domain) => domain.clone(),
Provider::SambaNova => format!("{}/v1", provider.domain()),
Provider::TogetherAI => format!("{}/v1", provider.domain()),
_ => format!("{}/v1/openai", provider.domain()),
}
}
#[allow(rustdoc::bare_urls)]
#[derive(Clone, Debug, Serialize, PartialEq)]
pub enum Provider {
Amazon,
Azure,
Cerebras,
DeepInfra,
ElevenLabs,
Fireworks,
FriendliAI,
Google,
Groq,
Hyperbolic,
Mistral,
Nebius,
Novita,
OpenAI,
OpenAICompatible(String),
SambaNova,
TogetherAI,
}
impl std::fmt::Display for Provider {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{:?}", self)
}
}
impl Provider {
pub fn domain(&self) -> String {
match self {
Provider::Amazon => "https://api.amazon.com",
Provider::Azure => "https://api.azure.com",
Provider::Cerebras => "https://api.cerebras.ai",
Provider::DeepInfra => "https://api.deepinfra.com",
Provider::ElevenLabs => "https://api.elevenlabs.io",
Provider::Fireworks => "https://api.fireworks.ai",
Provider::FriendliAI => "https://api.friendli.ai",
Provider::Google => "https://generativelanguage.googleapis.com",
Provider::Groq => "https://api.groq.com",
Provider::Hyperbolic => "https://api.hyperbolic.xyz",
Provider::Mistral => "https://api.mistral.ai",
Provider::Nebius => "https://api.nebi.us",
Provider::Novita => "https://api.novita.ai",
Provider::OpenAI => "https://api.openai.com",
Provider::OpenAICompatible(base_url) => base_url,
Provider::SambaNova => "https://api.sambanova.ai",
Provider::TogetherAI => "https://api.together.xyz",
}
.to_string()
}
pub fn key_name(&self) -> String {
match self {
Provider::OpenAICompatible(_) => "OPENAI_COMPATIBLE_KEY".to_string(),
_ => self.to_string().to_uppercase() + "_KEY",
}
}
}
impl FromStr for Provider {
type Err = Box<dyn Error + Send + Sync>;
fn from_str(s: &str) -> Result<Self, Self::Err> {
let s = s.to_lowercase();
if s.starts_with("openai-compatible(") {
let s = s.strip_prefix("openai-compatible(").unwrap();
let s = s.strip_suffix(")").unwrap();
let mut domain = s.to_string();
if !domain.starts_with("https") {
if domain.contains("localhost") {
domain = format!("http://{}", domain);
} else {
domain = format!("https://{}", domain);
}
}
return Ok(Provider::OpenAICompatible(domain));
}
match s.as_str() {
"amazon" => Ok(Provider::Amazon),
"azure" => Ok(Provider::Azure),
"cerebras" => Ok(Provider::Cerebras),
"deepinfra" => Ok(Provider::DeepInfra),
"elevenlabs" => Ok(Provider::ElevenLabs),
"fireworks" => Ok(Provider::Fireworks),
"friendliai" => Ok(Provider::FriendliAI),
"google" => Ok(Provider::Google),
"groq" => Ok(Provider::Groq),
"hyperbolic" => Ok(Provider::Hyperbolic),
"mistral" => Ok(Provider::Mistral),
"nebi" => Ok(Provider::Nebius),
"novita" => Ok(Provider::Novita),
"openai" => Ok(Provider::OpenAI),
"sambanova" => Ok(Provider::SambaNova),
"togetherai" => Ok(Provider::TogetherAI),
_ => Err(format!("Unsupported provider: {s}.").into()),
}
}
}
#[derive(Clone, Debug, Deserialize)]
pub enum SubContent {
TextContent { text: String },
ImageUrlContent { image_url: String },
}
impl SubContent {
pub fn new(r#type: &str, text: &str) -> Self {
match r#type {
"text" => Self::TextContent {
text: text.to_string(),
},
"image_url" => Self::ImageUrlContent {
image_url: text.to_string(),
},
_ => panic!("Invalid subcontent type: {}", r#type),
}
}
}
impl Serialize for SubContent {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: serde::Serializer,
{
match self {
SubContent::TextContent { text } => serializer.serialize_str(text),
SubContent::ImageUrlContent { image_url } => {
let json = serde_json::json!({
"type": "image_url",
"image_url": {
"url": image_url
}
});
json.serialize(serializer)
}
}
}
}
#[derive(Clone, Debug)]
pub enum Content {
Text(String),
Collection(Vec<SubContent>),
}
impl std::fmt::Display for Content {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Content::Text(text) => write!(f, "{}", text),
Content::Collection(items) => {
write!(f, "{items:?}")
}
}
}
}
impl Serialize for Content {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: serde::Serializer,
{
match self {
Content::Text(text) => serializer.serialize_str(text),
Content::Collection(items) => items.serialize(serializer),
}
}
}
impl<'de> Deserialize<'de> for Content {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: serde::Deserializer<'de>,
{
let value = serde_json::Value::deserialize(deserializer)?;
if let serde_json::Value::String(text) = value {
Ok(Content::Text(text))
} else if let serde_json::Value::Array(items) = value {
let subcontent = items
.into_iter()
.map(SubContent::deserialize)
.collect::<Result<Vec<_>, _>>()
.unwrap();
Ok(Content::Collection(subcontent))
} else {
Err(serde::de::Error::custom("Invalid content format"))
}
}
}
#[derive(Clone, Debug, Deserialize, Serialize)]
pub struct Message {
pub role: String,
pub content: Content,
}
impl Message {
pub fn from_str(role: &str, text: &str) -> Self {
Self {
role: role.to_string(),
content: Content::Text(text.to_string()),
}
}
pub fn from_image_url(role: &str, image_url: &str) -> Self {
Self {
role: role.to_string(),
content: Content::Collection(vec![SubContent::ImageUrlContent {
image_url: image_url.to_string(),
}]),
}
}
pub fn from_image_bytes(role: &str, image_type: &str, image: &[u8]) -> Self {
let base64 = BASE64_STANDARD.encode(image);
let image_url = format!("data:image/{image_type};base64,{base64}");
Self::from_image_url(role, &image_url)
}
}
#[derive(Clone, Debug)]
pub struct Key {
pub provider: Provider,
pub key: String,
}
#[derive(Clone, Debug)]
pub struct Keys {
pub keys: Vec<Key>,
}
impl Keys {
pub fn for_provider(&self, provider: &Provider) -> Option<Key> {
fn finder(provider: &Provider, key: &Key) -> bool {
match provider {
Provider::OpenAICompatible(_) => {
matches!(&key.provider, Provider::OpenAICompatible(_))
}
_ => key.provider == *provider,
}
}
self.keys.iter().find(|key| finder(provider, key)).cloned()
}
}
fn load_env_file(path: &str) -> HashMap<String, String> {
let mut env_content = String::new();
if let Ok(mut file) = File::open(path) {
file.read_to_string(&mut env_content)
.expect("Failed to read .env file");
}
env_content
.lines()
.filter_map(|line| {
let mut parts = line.split('=');
if let (Some(key), Some(value)) = (parts.next(), parts.next()) {
Some((key.to_string(), value.to_string()))
} else {
None
}
})
.collect()
}
pub fn load_keys(path: &str) -> Keys {
let env_map = load_env_file(path);
let mut keys = vec![];
let providers = [
Provider::Amazon,
Provider::Azure,
Provider::Cerebras,
Provider::DeepInfra,
Provider::ElevenLabs,
Provider::Fireworks,
Provider::FriendliAI,
Provider::Google,
Provider::Groq,
Provider::Hyperbolic,
Provider::Mistral,
Provider::Nebius,
Provider::Novita,
Provider::OpenAI,
Provider::OpenAICompatible("".to_string()),
Provider::SambaNova,
Provider::TogetherAI,
];
for provider in providers {
if let Ok(key_value) = std::env::var(provider.key_name()) {
keys.push(Key {
provider: provider.clone(),
key: key_value,
});
} else if let Some(key_value) = env_map.get(&provider.key_name()) {
keys.push(Key {
provider: provider.clone(),
key: key_value.to_string(),
});
}
}
Keys { keys }
}