use async_trait::async_trait;
use futures::stream::BoxStream;
use tracing::debug;
use crate::error::{LlmError, Result};
use crate::model_config::{
ModelCapabilities, ModelCard, ModelType, ProviderConfig, ProviderType as ConfigProviderType,
};
use crate::providers::openai_compatible::OpenAICompatibleProvider;
use crate::traits::{
ChatMessage, CompletionOptions, EmbeddingProvider, LLMProvider, LLMResponse, StreamChunk,
};
#[allow(dead_code)]
const HF_BASE_URL_TEMPLATE: &str = "https://api-inference.huggingface.co/models";
const HF_ROUTER_URL: &str = "https://router.huggingface.co/hf-inference/v1";
const HF_DEFAULT_MODEL: &str = "meta-llama/Meta-Llama-3.1-70B-Instruct";
const HF_PROVIDER_NAME: &str = "huggingface";
const HF_MODELS: &[(&str, &str, usize)] = &[
(
"meta-llama/Meta-Llama-3.1-70B-Instruct",
"Llama 3.1 70B Instruct",
128000,
),
(
"meta-llama/Meta-Llama-3.1-8B-Instruct",
"Llama 3.1 8B Instruct",
128000,
),
(
"meta-llama/Meta-Llama-3-8B-Instruct",
"Llama 3 8B Instruct",
8192,
),
(
"meta-llama/Meta-Llama-3-70B-Instruct",
"Llama 3 70B Instruct",
8192,
),
(
"mistralai/Mistral-7B-Instruct-v0.3",
"Mistral 7B Instruct v0.3",
32000,
),
(
"mistralai/Mixtral-8x7B-Instruct-v0.1",
"Mixtral 8x7B Instruct",
32000,
),
("Qwen/Qwen2.5-72B-Instruct", "Qwen 2.5 72B Instruct", 128000),
("Qwen/Qwen2.5-7B-Instruct", "Qwen 2.5 7B Instruct", 128000),
(
"Qwen/Qwen2.5-Coder-32B-Instruct",
"Qwen 2.5 Coder 32B",
128000,
),
(
"microsoft/Phi-3-medium-4k-instruct",
"Phi-3 Medium 4K",
4096,
),
("microsoft/Phi-3-mini-4k-instruct", "Phi-3 Mini 4K", 4096),
("google/gemma-7b-it", "Gemma 7B IT", 8192),
("google/gemma-2b-it", "Gemma 2B IT", 8192),
(
"deepseek-ai/DeepSeek-Coder-V2-Instruct",
"DeepSeek Coder V2",
128000,
),
];
#[derive(Debug)]
pub struct HuggingFaceProvider {
inner: OpenAICompatibleProvider,
model: String,
}
impl HuggingFaceProvider {
pub fn from_env() -> Result<Self> {
let api_key = std::env::var("HF_TOKEN")
.or_else(|_| std::env::var("HUGGINGFACE_TOKEN"))
.map_err(|_| {
LlmError::ConfigError(
"HF_TOKEN or HUGGINGFACE_TOKEN environment variable not set. \
Get your token from https://huggingface.co/settings/tokens"
.to_string(),
)
})?;
if api_key.is_empty() {
return Err(LlmError::ConfigError(
"HF_TOKEN is empty. Please set a valid token.".to_string(),
));
}
let model = std::env::var("HF_MODEL").unwrap_or_else(|_| HF_DEFAULT_MODEL.to_string());
let base_url = std::env::var("HF_BASE_URL").ok();
Self::new(api_key, model, base_url)
}
pub fn new(api_key: String, model: String, base_url: Option<String>) -> Result<Self> {
std::env::set_var("HF_TOKEN", &api_key);
let config = Self::build_config(&model, base_url.as_deref());
let inner = OpenAICompatibleProvider::from_config(config)?;
debug!(
provider = HF_PROVIDER_NAME,
model = %model,
"Created HuggingFace provider"
);
Ok(Self { inner, model })
}
pub fn with_model(mut self, model: &str) -> Self {
self.model = model.to_string();
self.inner = self.inner.with_model(model);
self
}
#[allow(dead_code)]
fn model_url(_model: &str) -> String {
HF_ROUTER_URL.to_string()
}
fn build_config(model: &str, base_url: Option<&str>) -> ProviderConfig {
let models: Vec<ModelCard> = HF_MODELS
.iter()
.map(|(name, display, context)| ModelCard {
name: name.to_string(),
display_name: display.to_string(),
model_type: ModelType::Llm,
capabilities: ModelCapabilities {
context_length: *context,
supports_function_calling: true, supports_json_mode: true,
supports_streaming: true,
supports_system_message: true,
supports_vision: false, ..Default::default()
},
..Default::default()
})
.collect();
let effective_base_url = base_url
.map(|s| s.to_string())
.unwrap_or_else(|| HF_ROUTER_URL.to_string());
ProviderConfig {
name: HF_PROVIDER_NAME.to_string(),
display_name: "HuggingFace Hub".to_string(),
provider_type: ConfigProviderType::OpenAICompatible,
api_key_env: Some("HF_TOKEN".to_string()),
base_url: Some(effective_base_url),
base_url_env: Some("HF_BASE_URL".to_string()),
default_llm_model: Some(model.to_string()),
default_embedding_model: None,
models,
headers: std::collections::HashMap::new(),
enabled: true,
..Default::default()
}
}
pub fn context_length(model: &str) -> usize {
HF_MODELS
.iter()
.find(|(name, _, _)| *name == model)
.map(|(_, _, ctx)| *ctx)
.unwrap_or(8192) }
pub fn available_models() -> Vec<(&'static str, &'static str, usize)> {
HF_MODELS.to_vec()
}
pub fn is_hf_token(token: &str) -> bool {
token.starts_with("hf_")
}
}
#[async_trait]
impl LLMProvider for HuggingFaceProvider {
fn name(&self) -> &str {
HF_PROVIDER_NAME
}
fn model(&self) -> &str {
&self.model
}
fn max_context_length(&self) -> usize {
Self::context_length(&self.model)
}
async fn complete(&self, prompt: &str) -> Result<LLMResponse> {
self.inner.complete(prompt).await
}
async fn complete_with_options(
&self,
prompt: &str,
options: &CompletionOptions,
) -> Result<LLMResponse> {
self.inner.complete_with_options(prompt, options).await
}
async fn chat(
&self,
messages: &[ChatMessage],
options: Option<&CompletionOptions>,
) -> Result<LLMResponse> {
self.inner.chat(messages, options).await
}
async fn chat_with_tools(
&self,
messages: &[ChatMessage],
tools: &[crate::traits::ToolDefinition],
tool_choice: Option<crate::traits::ToolChoice>,
options: Option<&CompletionOptions>,
) -> Result<LLMResponse> {
self.inner
.chat_with_tools(messages, tools, tool_choice, options)
.await
}
async fn stream(&self, prompt: &str) -> Result<BoxStream<'static, Result<String>>> {
self.inner.stream(prompt).await
}
fn supports_streaming(&self) -> bool {
true
}
fn supports_function_calling(&self) -> bool {
self.inner.supports_function_calling()
}
fn supports_tool_streaming(&self) -> bool {
self.inner.supports_tool_streaming()
}
async fn chat_with_tools_stream(
&self,
messages: &[ChatMessage],
tools: &[crate::traits::ToolDefinition],
tool_choice: Option<crate::traits::ToolChoice>,
options: Option<&CompletionOptions>,
) -> Result<BoxStream<'static, Result<StreamChunk>>> {
self.inner
.chat_with_tools_stream(messages, tools, tool_choice, options)
.await
}
}
#[async_trait]
impl EmbeddingProvider for HuggingFaceProvider {
fn name(&self) -> &str {
HF_PROVIDER_NAME
}
fn model(&self) -> &str {
"none"
}
fn dimension(&self) -> usize {
0 }
fn max_tokens(&self) -> usize {
0
}
async fn embed(&self, _texts: &[String]) -> Result<Vec<Vec<f32>>> {
Err(LlmError::ConfigError(
"HuggingFace embeddings require a separate provider configuration. \
Use the HuggingFace Inference API directly for embeddings."
.to_string(),
))
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_context_length_known_model() {
assert_eq!(
HuggingFaceProvider::context_length("meta-llama/Meta-Llama-3.1-70B-Instruct"),
128000
);
assert_eq!(
HuggingFaceProvider::context_length("mistralai/Mistral-7B-Instruct-v0.3"),
32000
);
assert_eq!(
HuggingFaceProvider::context_length("microsoft/Phi-3-medium-4k-instruct"),
4096
);
}
#[test]
fn test_context_length_unknown_model() {
assert_eq!(HuggingFaceProvider::context_length("unknown/model"), 8192);
}
#[test]
fn test_available_models() {
let models = HuggingFaceProvider::available_models();
assert!(!models.is_empty());
assert!(models
.iter()
.any(|(name, _, _)| *name == "meta-llama/Meta-Llama-3.1-70B-Instruct"));
assert!(models
.iter()
.any(|(name, _, _)| *name == "Qwen/Qwen2.5-72B-Instruct"));
}
#[test]
fn test_build_config() {
let config =
HuggingFaceProvider::build_config("meta-llama/Meta-Llama-3.1-70B-Instruct", None);
assert_eq!(config.name, "huggingface");
assert_eq!(
config.base_url,
Some("https://router.huggingface.co/hf-inference/v1".to_string())
);
assert_eq!(
config.default_llm_model,
Some("meta-llama/Meta-Llama-3.1-70B-Instruct".to_string())
);
}
#[test]
fn test_build_config_custom_url() {
let config = HuggingFaceProvider::build_config(
"mistralai/Mistral-7B-Instruct-v0.3",
Some("https://custom.api"),
);
assert_eq!(config.base_url, Some("https://custom.api".to_string()));
}
#[test]
fn test_is_hf_token() {
assert!(HuggingFaceProvider::is_hf_token("hf_xxxxx"));
assert!(HuggingFaceProvider::is_hf_token("hf_abc123"));
assert!(!HuggingFaceProvider::is_hf_token("sk-xxxxx"));
assert!(!HuggingFaceProvider::is_hf_token("xxxxx"));
}
#[test]
fn test_model_url() {
let url = HuggingFaceProvider::model_url("meta-llama/Meta-Llama-3.1-70B-Instruct");
assert_eq!(url, "https://router.huggingface.co/hf-inference/v1");
}
#[test]
fn test_context_length_llama_models() {
assert_eq!(
HuggingFaceProvider::context_length("meta-llama/Meta-Llama-3.1-70B-Instruct"),
128000
);
assert_eq!(
HuggingFaceProvider::context_length("meta-llama/Meta-Llama-3.1-8B-Instruct"),
128000
);
assert_eq!(
HuggingFaceProvider::context_length("meta-llama/Meta-Llama-3-8B-Instruct"),
8192
);
assert_eq!(
HuggingFaceProvider::context_length("meta-llama/Meta-Llama-3-70B-Instruct"),
8192
);
}
#[test]
fn test_context_length_mistral_models() {
assert_eq!(
HuggingFaceProvider::context_length("mistralai/Mistral-7B-Instruct-v0.3"),
32000
);
assert_eq!(
HuggingFaceProvider::context_length("mistralai/Mixtral-8x7B-Instruct-v0.1"),
32000
);
}
#[test]
fn test_context_length_qwen_models() {
assert_eq!(
HuggingFaceProvider::context_length("Qwen/Qwen2.5-72B-Instruct"),
128000
);
assert_eq!(
HuggingFaceProvider::context_length("Qwen/Qwen2.5-7B-Instruct"),
128000
);
assert_eq!(
HuggingFaceProvider::context_length("Qwen/Qwen2.5-Coder-32B-Instruct"),
128000
);
}
#[test]
fn test_context_length_phi_models() {
assert_eq!(
HuggingFaceProvider::context_length("microsoft/Phi-3-medium-4k-instruct"),
4096
);
assert_eq!(
HuggingFaceProvider::context_length("microsoft/Phi-3-mini-4k-instruct"),
4096
);
}
#[test]
fn test_context_length_gemma_and_deepseek() {
assert_eq!(
HuggingFaceProvider::context_length("google/gemma-7b-it"),
8192
);
assert_eq!(
HuggingFaceProvider::context_length("google/gemma-2b-it"),
8192
);
assert_eq!(
HuggingFaceProvider::context_length("deepseek-ai/DeepSeek-Coder-V2-Instruct"),
128000
);
}
#[test]
fn test_available_models_contains_all_families() {
let models = HuggingFaceProvider::available_models();
assert!(models
.iter()
.any(|(name, _, _)| name.contains("meta-llama")));
assert!(models.iter().any(|(name, _, _)| name.contains("mistralai")));
assert!(models.iter().any(|(name, _, _)| name.contains("Qwen")));
assert!(models.iter().any(|(name, _, _)| name.contains("microsoft")));
assert!(models.iter().any(|(name, _, _)| name.contains("google")));
assert!(models.iter().any(|(name, _, _)| name.contains("deepseek")));
}
#[test]
fn test_available_models_has_positive_context() {
let models = HuggingFaceProvider::available_models();
for (name, _desc, context_len) in models {
assert!(
context_len > 0,
"Model {} should have positive context length",
name
);
}
}
#[test]
fn test_constants() {
assert_eq!(HF_DEFAULT_MODEL, "meta-llama/Meta-Llama-3.1-70B-Instruct");
assert_eq!(HF_PROVIDER_NAME, "huggingface");
assert_eq!(
HF_ROUTER_URL,
"https://router.huggingface.co/hf-inference/v1"
);
}
#[test]
fn test_from_env_missing_token() {
std::env::remove_var("HF_TOKEN");
std::env::remove_var("HUGGINGFACE_TOKEN");
std::env::remove_var("HF_MODEL");
std::env::remove_var("HF_BASE_URL");
let result = HuggingFaceProvider::from_env();
assert!(result.is_err());
let err = result.unwrap_err();
assert!(
err.to_string().contains("HF_TOKEN") || err.to_string().contains("HUGGINGFACE_TOKEN")
);
}
#[test]
fn test_build_config_has_models() {
let config =
HuggingFaceProvider::build_config("meta-llama/Meta-Llama-3.1-70B-Instruct", None);
assert!(!config.models.is_empty());
assert!(config
.models
.iter()
.any(|m| m.name == "meta-llama/Meta-Llama-3.1-70B-Instruct"));
}
#[test]
fn test_build_config_api_key_env() {
let config =
HuggingFaceProvider::build_config("meta-llama/Meta-Llama-3.1-70B-Instruct", None);
assert_eq!(config.api_key_env, Some("HF_TOKEN".to_string()));
}
#[test]
fn test_is_hf_token_edge_cases() {
assert!(HuggingFaceProvider::is_hf_token("hf_a"));
assert!(HuggingFaceProvider::is_hf_token(
"hf_verylongtokenstring123"
));
assert!(HuggingFaceProvider::is_hf_token("hf_"));
assert!(!HuggingFaceProvider::is_hf_token("sk_xxxxx"));
assert!(!HuggingFaceProvider::is_hf_token("api_key"));
assert!(!HuggingFaceProvider::is_hf_token("HF_token"));
assert!(!HuggingFaceProvider::is_hf_token(""));
assert!(!HuggingFaceProvider::is_hf_token("h"));
assert!(!HuggingFaceProvider::is_hf_token("hf"));
}
#[test]
fn test_model_url_always_returns_router() {
let models = vec![
"meta-llama/Meta-Llama-3.1-70B-Instruct",
"mistralai/Mistral-7B-Instruct-v0.3",
"Qwen/Qwen2.5-72B-Instruct",
"unknown/model",
];
for model in models {
assert_eq!(
HuggingFaceProvider::model_url(model),
"https://router.huggingface.co/hf-inference/v1"
);
}
}
}