use async_trait::async_trait;
use crate::client::LlmClient;
use crate::error::LlmError;
use crate::stream::ChatStream;
use crate::traits::{ChatCapability, ModelListingCapability, ProviderCapabilities};
use crate::types::*;
use super::api::GroqModels;
use super::chat::GroqChatCapability;
use super::config::GroqConfig;
use crate::retry_api::RetryOptions;
pub struct GroqClient {
config: GroqConfig,
http_client: reqwest::Client,
chat_capability: GroqChatCapability,
models_capability: GroqModels,
tracing_config: Option<crate::tracing::TracingConfig>,
_tracing_guard: Option<tracing_appender::non_blocking::WorkerGuard>,
retry_options: Option<RetryOptions>,
}
impl Clone for GroqClient {
fn clone(&self) -> Self {
Self {
config: self.config.clone(),
http_client: self.http_client.clone(),
chat_capability: self.chat_capability.clone(),
models_capability: self.models_capability.clone(),
tracing_config: self.tracing_config.clone(),
_tracing_guard: None, retry_options: self.retry_options.clone(),
}
}
}
impl std::fmt::Debug for GroqClient {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("GroqClient")
.field("provider_name", &"groq")
.field("model", &self.config.common_params.model)
.field("base_url", &self.config.base_url)
.field("temperature", &self.config.common_params.temperature)
.field("max_tokens", &self.config.common_params.max_tokens)
.field("top_p", &self.config.common_params.top_p)
.field("seed", &self.config.common_params.seed)
.field("has_tracing", &self.tracing_config.is_some())
.finish()
}
}
impl GroqClient {
pub fn new(config: GroqConfig, http_client: reqwest::Client) -> Self {
let chat_capability = GroqChatCapability::new(
config.api_key.clone(),
config.base_url.clone(),
http_client.clone(),
config.http_config.clone(),
config.common_params.clone(),
);
let models_capability = GroqModels::new(
config.api_key.clone(),
config.base_url.clone(),
http_client.clone(),
config.http_config.clone(),
);
Self {
config,
http_client,
chat_capability,
models_capability,
tracing_config: None,
_tracing_guard: None,
retry_options: None,
}
}
pub fn config(&self) -> &GroqConfig {
&self.config
}
pub fn http_client(&self) -> &reqwest::Client {
&self.http_client
}
pub fn chat_capability(&self) -> &GroqChatCapability {
&self.chat_capability
}
}
#[async_trait]
impl LlmClient for GroqClient {
fn provider_name(&self) -> &'static str {
"groq"
}
fn supported_models(&self) -> Vec<String> {
GroqConfig::supported_models()
.iter()
.map(|&s| s.to_string())
.collect()
}
fn capabilities(&self) -> ProviderCapabilities {
ProviderCapabilities::new()
.with_chat()
.with_streaming()
.with_tools()
}
fn as_any(&self) -> &dyn std::any::Any {
self
}
fn clone_box(&self) -> Box<dyn LlmClient> {
Box::new(self.clone())
}
}
#[async_trait]
impl ChatCapability for GroqClient {
async fn chat_with_tools(
&self,
messages: Vec<ChatMessage>,
tools: Option<Vec<Tool>>,
) -> Result<ChatResponse, LlmError> {
if let Some(opts) = &self.retry_options {
crate::retry_api::retry_with(
|| {
let m = messages.clone();
let t = tools.clone();
async move { self.chat_capability.chat_with_tools(m, t).await }
},
opts.clone(),
)
.await
} else {
self.chat_capability.chat_with_tools(messages, tools).await
}
}
async fn chat_stream(
&self,
messages: Vec<ChatMessage>,
tools: Option<Vec<Tool>>,
) -> Result<ChatStream, LlmError> {
self.chat_capability.chat_stream(messages, tools).await
}
}
impl GroqClient {
pub(crate) fn set_tracing_guard(
&mut self,
guard: Option<tracing_appender::non_blocking::WorkerGuard>,
) {
self._tracing_guard = guard;
}
pub(crate) fn set_tracing_config(&mut self, config: Option<crate::tracing::TracingConfig>) {
self.tracing_config = config;
}
pub fn set_retry_options(&mut self, options: Option<RetryOptions>) {
self.retry_options = options;
}
}
#[async_trait]
impl ModelListingCapability for GroqClient {
async fn list_models(&self) -> Result<Vec<ModelInfo>, LlmError> {
self.models_capability.list_models().await
}
async fn get_model(&self, model_id: String) -> Result<ModelInfo, LlmError> {
self.models_capability.get_model(model_id).await
}
}