1use async_trait::async_trait;
4
5use crate::llm::openai_compatible::ChatOpenAICompatible;
6use crate::llm::{
7 BaseChatModel, ChatCompletion, ChatStream, LlmError, Message, ToolChoice, ToolDefinition,
8};
9
10const GROQ_URL: &str = "https://api.groq.com/openai/v1";
11
12pub struct ChatGroq {
23 inner: ChatOpenAICompatible,
24}
25
26impl ChatGroq {
27 pub fn new(model: impl Into<String>) -> Result<Self, LlmError> {
29 Self::builder().model(model).build()
30 }
31
32 pub fn builder() -> ChatGroqBuilder {
34 ChatGroqBuilder::default()
35 }
36}
37
38#[derive(Default)]
39pub struct ChatGroqBuilder {
40 model: Option<String>,
41 api_key: Option<String>,
42 temperature: Option<f32>,
43 max_tokens: Option<u64>,
44}
45
46impl ChatGroqBuilder {
47 pub fn model(mut self, model: impl Into<String>) -> Self {
48 self.model = Some(model.into());
49 self
50 }
51
52 pub fn api_key(mut self, key: impl Into<String>) -> Self {
53 self.api_key = Some(key.into());
54 self
55 }
56
57 pub fn temperature(mut self, temp: f32) -> Self {
58 self.temperature = Some(temp);
59 self
60 }
61
62 pub fn max_tokens(mut self, tokens: u64) -> Self {
63 self.max_tokens = Some(tokens);
64 self
65 }
66
67 pub fn build(self) -> Result<ChatGroq, LlmError> {
68 let model = self
69 .model
70 .ok_or_else(|| LlmError::Config("model is required".into()))?;
71
72 let api_key = self
73 .api_key
74 .or_else(|| std::env::var("GROQ_API_KEY").ok())
75 .ok_or_else(|| LlmError::Config("GROQ_API_KEY not set".into()))?;
76
77 let inner = ChatOpenAICompatible::builder()
78 .model(&model)
79 .base_url(GROQ_URL)
80 .provider("groq")
81 .api_key(Some(api_key))
82 .temperature(self.temperature.unwrap_or(0.2))
83 .max_completion_tokens(self.max_tokens)
84 .build()?;
85
86 Ok(ChatGroq { inner })
87 }
88}
89
90#[async_trait]
91impl BaseChatModel for ChatGroq {
92 fn model(&self) -> &str {
93 self.inner.model()
94 }
95
96 fn provider(&self) -> &str {
97 "groq"
98 }
99
100 fn context_window(&self) -> Option<u64> {
101 Some(128_000)
102 }
103
104 async fn invoke(
105 &self,
106 messages: Vec<Message>,
107 tools: Option<Vec<ToolDefinition>>,
108 tool_choice: Option<ToolChoice>,
109 ) -> Result<ChatCompletion, LlmError> {
110 self.inner.invoke(messages, tools, tool_choice).await
111 }
112
113 async fn invoke_stream(
114 &self,
115 messages: Vec<Message>,
116 tools: Option<Vec<ToolDefinition>>,
117 tool_choice: Option<ToolChoice>,
118 ) -> Result<ChatStream, LlmError> {
119 self.inner.invoke_stream(messages, tools, tool_choice).await
120 }
121}