use std::fmt::Display;
use super::completion::CompletionModel;
use crate::agent::AgentBuilder;
use crate::providers::huggingface::transcription::TranscriptionModel;
const HUGGINGFACE_API_BASE_URL: &str = "https://router.huggingface.co/";
#[derive(Debug, Clone, PartialEq, Default)]
pub enum SubProvider {
#[default]
HFInference,
Together,
SambaNova,
Fireworks,
Hyperbolic,
Nebius,
Novita,
Custom(String),
}
impl SubProvider {
pub fn completion_endpoint(&self, model: &str) -> String {
match self {
SubProvider::HFInference => format!("/{}/v1/chat/completions", model),
_ => "/v1/chat/completions".to_string(),
}
}
pub fn transcription_endpoint(&self, model: &str) -> String {
match self {
SubProvider::HFInference => format!("hf-inference/models/{}", model),
_ => panic!("transcription endpoint is not supported yet for {}", self),
}
}
pub fn model_identifier(&self, model: &str) -> String {
match self {
SubProvider::Fireworks => format!("accounts/fireworks/models/{}", model),
_ => model.to_string(),
}
}
}
impl From<&str> for SubProvider {
fn from(s: &str) -> Self {
SubProvider::Custom(s.to_string())
}
}
impl From<String> for SubProvider {
fn from(value: String) -> Self {
SubProvider::Custom(value)
}
}
impl Display for SubProvider {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
let route = match self {
SubProvider::HFInference => "hf-inference/models".to_string(),
SubProvider::Together => "together".to_string(),
SubProvider::SambaNova => "sambanova".to_string(),
SubProvider::Fireworks => "fireworks-ai".to_string(),
SubProvider::Hyperbolic => "hyperbolic".to_string(),
SubProvider::Nebius => "nebius".to_string(),
SubProvider::Novita => "novita".to_string(),
SubProvider::Custom(route) => route.clone(),
};
write!(f, "{}", route)
}
}
pub struct ClientBuilder {
api_key: String,
base_url: String,
sub_provider: SubProvider,
}
impl ClientBuilder {
pub fn new(api_key: &str) -> Self {
Self {
api_key: api_key.to_string(),
base_url: HUGGINGFACE_API_BASE_URL.to_string(),
sub_provider: SubProvider::default(),
}
}
pub fn base_url(mut self, base_url: &str) -> Self {
self.base_url = base_url.to_string();
self
}
pub fn sub_provider(mut self, provider: impl Into<SubProvider>) -> Self {
self.sub_provider = provider.into();
self
}
pub fn build(self) -> Client {
let route = self.sub_provider.to_string();
let base_url = format!("{}/{}", self.base_url, route).replace("//", "/");
Client::from_url(self.api_key.as_str(), base_url.as_str(), self.sub_provider)
}
}
#[derive(Clone)]
pub struct Client {
base_url: String,
http_client: reqwest::Client,
pub(crate) sub_provider: SubProvider,
}
impl Client {
pub fn new(api_key: &str) -> Self {
Self::from_url(api_key, HUGGINGFACE_API_BASE_URL, SubProvider::HFInference)
}
pub fn from_url(api_key: &str, base_url: &str, sub_provider: SubProvider) -> Self {
let http_client = reqwest::Client::builder()
.default_headers({
let mut headers = reqwest::header::HeaderMap::new();
headers.insert(
"Authorization",
format!("Bearer {api_key}")
.parse()
.expect("Failed to parse API key"),
);
headers.insert(
"Content-Type",
"application/json"
.parse()
.expect("Failed to parse Content-Type"),
);
headers
})
.build()
.expect("Failed to build HTTP client");
Self {
base_url: base_url.to_owned(),
http_client,
sub_provider,
}
}
pub fn from_env() -> Self {
let api_key = std::env::var("HUGGINGFACE_API_KEY").expect("HUGGINGFACE_API_KEY is not set");
Self::new(&api_key)
}
pub(crate) fn post(&self, path: &str) -> reqwest::RequestBuilder {
let url = format!("{}/{}", self.base_url, path).replace("//", "/");
self.http_client.post(url)
}
pub fn completion_model(&self, model: &str) -> CompletionModel {
CompletionModel::new(self.clone(), model)
}
pub fn transcription_model(&self, model: &str) -> TranscriptionModel {
TranscriptionModel::new(self.clone(), model)
}
pub fn agent(&self, model: &str) -> AgentBuilder<CompletionModel> {
AgentBuilder::new(self.completion_model(model))
}
}