Skip to main content

agent_io/llm/
groq.rs

1//! Groq Chat Model implementation
2
3use async_trait::async_trait;
4
5use crate::llm::openai_compatible::ChatOpenAICompatible;
6use crate::llm::{
7    BaseChatModel, ChatCompletion, ChatStream, LlmError, Message, ToolChoice, ToolDefinition,
8};
9
10const GROQ_URL: &str = "https://api.groq.com/openai/v1";
11
12/// Groq Chat Model
13///
14/// Fast inference using Groq's LPU.
15///
16/// # Example
17/// ```ignore
18/// use agent_io::llm::ChatGroq;
19///
20/// let llm = ChatGroq::new("llama-3.3-70b-versatile")?;
21/// ```
22pub struct ChatGroq {
23    inner: ChatOpenAICompatible,
24}
25
26impl ChatGroq {
27    /// Create a new Groq chat model
28    pub fn new(model: impl Into<String>) -> Result<Self, LlmError> {
29        Self::builder().model(model).build()
30    }
31
32    /// Create a builder for configuration
33    pub fn builder() -> ChatGroqBuilder {
34        ChatGroqBuilder::default()
35    }
36}
37
38#[derive(Default)]
39pub struct ChatGroqBuilder {
40    model: Option<String>,
41    api_key: Option<String>,
42    temperature: Option<f32>,
43    max_tokens: Option<u64>,
44}
45
46impl ChatGroqBuilder {
47    pub fn model(mut self, model: impl Into<String>) -> Self {
48        self.model = Some(model.into());
49        self
50    }
51
52    pub fn api_key(mut self, key: impl Into<String>) -> Self {
53        self.api_key = Some(key.into());
54        self
55    }
56
57    pub fn temperature(mut self, temp: f32) -> Self {
58        self.temperature = Some(temp);
59        self
60    }
61
62    pub fn max_tokens(mut self, tokens: u64) -> Self {
63        self.max_tokens = Some(tokens);
64        self
65    }
66
67    pub fn build(self) -> Result<ChatGroq, LlmError> {
68        let model = self
69            .model
70            .ok_or_else(|| LlmError::Config("model is required".into()))?;
71
72        let api_key = self
73            .api_key
74            .or_else(|| std::env::var("GROQ_API_KEY").ok())
75            .ok_or_else(|| LlmError::Config("GROQ_API_KEY not set".into()))?;
76
77        let inner = ChatOpenAICompatible::builder()
78            .model(&model)
79            .base_url(GROQ_URL)
80            .provider("groq")
81            .api_key(Some(api_key))
82            .temperature(self.temperature.unwrap_or(0.2))
83            .max_completion_tokens(self.max_tokens)
84            .build()?;
85
86        Ok(ChatGroq { inner })
87    }
88}
89
90#[async_trait]
91impl BaseChatModel for ChatGroq {
92    fn model(&self) -> &str {
93        self.inner.model()
94    }
95
96    fn provider(&self) -> &str {
97        "groq"
98    }
99
100    fn context_window(&self) -> Option<u64> {
101        Some(128_000)
102    }
103
104    async fn invoke(
105        &self,
106        messages: Vec<Message>,
107        tools: Option<Vec<ToolDefinition>>,
108        tool_choice: Option<ToolChoice>,
109    ) -> Result<ChatCompletion, LlmError> {
110        self.inner.invoke(messages, tools, tool_choice).await
111    }
112
113    async fn invoke_stream(
114        &self,
115        messages: Vec<Message>,
116        tools: Option<Vec<ToolDefinition>>,
117        tool_choice: Option<ToolChoice>,
118    ) -> Result<ChatStream, LlmError> {
119        self.inner.invoke_stream(messages, tools, tool_choice).await
120    }
121}