use crate::core::types::{
GenerateOptions, GenerateResult, Prompt, ProviderSettings, SpeechOptions, SpeechResult,
StreamPart, TranscriptionOptions, TranscriptionResult,
};
use crate::openai::speech::OpenAISpeechModel;
use crate::openai::transcription::OpenAITranscriptionModel;
use crate::openai::OpenAIModel;
use async_trait::async_trait;
use futures::stream::BoxStream;
use reqwest::Client;
const GROQ_BASE_URL: &str = "https://api.groq.com/openai/v1";
pub struct GroqCloudModel {
pub inner: OpenAIModel,
}
impl GroqCloudModel {
#[must_use]
pub fn new(api_key: String) -> Self {
Self {
inner: OpenAIModel {
api_key,
base_url: GROQ_BASE_URL.to_string(),
client: Client::new(),
},
}
}
}
#[async_trait]
impl crate::core::LanguageModel for GroqCloudModel {
#[tracing::instrument(skip(self, prompt), fields(model = options.model_id))]
async fn generate(
&self,
prompt: Prompt,
options: GenerateOptions,
) -> crate::core::Result<GenerateResult> {
self.inner.generate(prompt, options).await
}
async fn generate_stream(
&self,
prompt: Prompt,
options: GenerateOptions,
) -> crate::core::Result<BoxStream<'static, StreamPart>> {
self.inner.generate_stream(prompt, options).await
}
}
pub struct GroqCloudTranscriptionModel {
pub inner: OpenAITranscriptionModel,
}
impl GroqCloudTranscriptionModel {
#[must_use]
pub fn new(api_key: String) -> Self {
Self {
inner: OpenAITranscriptionModel {
api_key,
base_url: GROQ_BASE_URL.to_string(),
client: Client::new(),
},
}
}
}
#[async_trait]
impl crate::core::TranscriptionModel for GroqCloudTranscriptionModel {
async fn transcribe(
&self,
options: TranscriptionOptions,
) -> crate::core::Result<TranscriptionResult> {
self.inner.transcribe(options).await
}
}
pub struct GroqCloudSpeechModel {
pub inner: OpenAISpeechModel,
}
impl GroqCloudSpeechModel {
#[must_use]
pub fn new(api_key: String) -> Self {
Self {
inner: OpenAISpeechModel {
api_key,
base_url: GROQ_BASE_URL.to_string(),
client: Client::new(),
},
}
}
}
#[async_trait]
impl crate::core::SpeechModel for GroqCloudSpeechModel {
async fn synthesize(&self, options: SpeechOptions) -> crate::core::Result<SpeechResult> {
self.inner.synthesize(options).await
}
}
pub struct GroqCloudProvider {
settings: ProviderSettings,
}
impl GroqCloudProvider {
#[must_use]
pub fn chat(&self, _model_id: &str) -> GroqCloudModel {
let api_key = self.get_api_key();
GroqCloudModel::new(api_key)
}
#[must_use]
pub fn language_model(&self, model_id: &str) -> GroqCloudModel {
self.chat(model_id)
}
#[must_use]
pub fn transcription(&self, _model_id: &str) -> GroqCloudTranscriptionModel {
let api_key = self.get_api_key();
GroqCloudTranscriptionModel::new(api_key)
}
#[must_use]
pub fn speech(&self, _model_id: &str) -> GroqCloudSpeechModel {
let api_key = self.get_api_key();
GroqCloudSpeechModel::new(api_key)
}
fn get_api_key(&self) -> String {
self.settings
.api_key
.clone()
.unwrap_or_else(|| std::env::var("GROQ_API_KEY").unwrap_or_default())
}
}
#[must_use]
pub fn create_groqcloud(settings: ProviderSettings) -> GroqCloudProvider {
GroqCloudProvider { settings }
}
impl crate::core::registry::Provider for GroqCloudProvider {
fn language_model(&self, model_id: &str) -> Option<Box<dyn crate::core::LanguageModel>> {
Some(Box::new(self.chat(model_id)))
}
fn transcription_model(
&self,
model_id: &str,
) -> Option<Box<dyn crate::core::TranscriptionModel>> {
Some(Box::new(self.transcription(model_id)))
}
fn speech_model(&self, model_id: &str) -> Option<Box<dyn crate::core::SpeechModel>> {
Some(Box::new(self.speech(model_id)))
}
}