Skip to main content

agentkit_provider_groq/
lib.rs

1//! Groq model adapter for the agentkit agent loop.
2//!
3//! This crate provides [`GroqAdapter`] and [`GroqConfig`] for connecting
4//! the agent loop to the [Groq](https://groq.com) chat completions API,
5//! which serves open-source models on custom LPU hardware.
6//! It is built on the generic [`agentkit_adapter_completions`] crate.
7//!
8//! # Quick start
9//!
10//! ```rust,ignore
11//! use agentkit_loop::{Agent, SessionConfig};
12//! use agentkit_provider_groq::{GroqAdapter, GroqConfig};
13//!
14//! #[tokio::main]
15//! async fn main() -> Result<(), Box<dyn std::error::Error>> {
16//!     let config = GroqConfig::from_env()?;
17//!     let adapter = GroqAdapter::new(config)?;
18//!
19//!     let agent = Agent::builder()
20//!         .model(adapter)
21//!         .build()?;
22//!
23//!     let mut driver = agent
24//!         .start(SessionConfig::new("demo"))
25//!         .await?;
26//!     Ok(())
27//! }
28//! ```
29
30use agentkit_adapter_completions::{
31    CompletionsAdapter, CompletionsError, CompletionsProvider, CompletionsSession, CompletionsTurn,
32};
33use agentkit_loop::{LoopError, ModelAdapter, SessionConfig};
34use async_trait::async_trait;
35use serde::Serialize;
36use thiserror::Error;
37
38const DEFAULT_ENDPOINT: &str = "https://api.groq.com/openai/v1/chat/completions";
39
40/// Configuration for connecting to the Groq API.
41///
42/// Build one with [`GroqConfig::new`] for explicit values, or
43/// [`GroqConfig::from_env`] to read from environment variables.
44///
45/// # Example
46///
47/// ```rust,no_run
48/// use agentkit_provider_groq::GroqConfig;
49///
50/// let config = GroqConfig::new("gsk_...", "llama-3.3-70b-versatile")
51///     .with_temperature(0.0)
52///     .with_max_completion_tokens(4096);
53/// ```
54#[derive(Clone, Debug)]
55pub struct GroqConfig {
56    /// Groq API key (starts with `gsk_`).
57    pub api_key: String,
58    /// Model identifier, e.g. `"llama-3.3-70b-versatile"` or `"llama-3.1-8b-instant"`.
59    pub model: String,
60    /// Chat completions endpoint URL. Defaults to the Groq production URL.
61    pub base_url: String,
62    /// Sampling temperature (0.0 = deterministic, higher = more creative).
63    pub temperature: Option<f32>,
64    /// Maximum number of completion tokens the model may generate.
65    pub max_completion_tokens: Option<u32>,
66    /// Nucleus sampling parameter.
67    pub top_p: Option<f32>,
68    /// Whether the model is allowed to emit multiple tool calls in a
69    /// single turn. Omitted from the request when `None` so Groq's
70    /// per-model default applies.
71    pub parallel_tool_calls: Option<bool>,
72    /// Request SSE streaming responses. Defaults to `true`.
73    pub streaming: bool,
74}
75
76impl GroqConfig {
77    /// Creates a new configuration with the given API key and model identifier.
78    pub fn new(api_key: impl Into<String>, model: impl Into<String>) -> Self {
79        Self {
80            api_key: api_key.into(),
81            model: model.into(),
82            base_url: DEFAULT_ENDPOINT.into(),
83            temperature: None,
84            max_completion_tokens: None,
85            top_p: None,
86            parallel_tool_calls: None,
87            streaming: true,
88        }
89    }
90
91    /// Overrides the default chat completions endpoint URL.
92    pub fn with_base_url(mut self, url: impl Into<String>) -> Self {
93        self.base_url = url.into();
94        self
95    }
96
97    /// Sets the sampling temperature (0.0 for deterministic output).
98    pub fn with_temperature(mut self, v: f32) -> Self {
99        self.temperature = Some(v);
100        self
101    }
102
103    /// Sets the maximum number of tokens the model may generate per turn.
104    pub fn with_max_completion_tokens(mut self, v: u32) -> Self {
105        self.max_completion_tokens = Some(v);
106        self
107    }
108
109    /// Sets the nucleus sampling parameter.
110    pub fn with_top_p(mut self, v: f32) -> Self {
111        self.top_p = Some(v);
112        self
113    }
114
115    /// Sets whether the model may emit multiple tool calls in a single turn.
116    pub fn with_parallel_tool_calls(mut self, flag: bool) -> Self {
117        self.parallel_tool_calls = Some(flag);
118        self
119    }
120
121    /// Toggles SSE streaming of model responses. Default: true.
122    pub fn with_streaming(mut self, flag: bool) -> Self {
123        self.streaming = flag;
124        self
125    }
126
127    /// Builds a configuration from environment variables.
128    ///
129    /// | Variable | Required | Default |
130    /// |---|---|---|
131    /// | `GROQ_API_KEY` | yes | -- |
132    /// | `GROQ_MODEL` | no | `llama-3.1-8b-instant` |
133    /// | `GROQ_BASE_URL` | no | `https://api.groq.com/openai/v1/chat/completions` |
134    pub fn from_env() -> Result<Self, GroqError> {
135        let api_key =
136            std::env::var("GROQ_API_KEY").map_err(|_| GroqError::MissingEnv("GROQ_API_KEY"))?;
137        let model = std::env::var("GROQ_MODEL").unwrap_or_else(|_| "llama-3.1-8b-instant".into());
138
139        let mut config = Self::new(api_key, model);
140
141        if let Ok(url) = std::env::var("GROQ_BASE_URL") {
142            config = config.with_base_url(url);
143        }
144
145        Ok(config)
146    }
147}
148
149/// Request parameters serialized into the Groq request body.
150#[derive(Clone, Debug, Serialize)]
151pub struct GroqRequestConfig {
152    pub model: String,
153    #[serde(skip_serializing_if = "Option::is_none")]
154    pub temperature: Option<f32>,
155    #[serde(skip_serializing_if = "Option::is_none")]
156    pub max_completion_tokens: Option<u32>,
157    #[serde(skip_serializing_if = "Option::is_none")]
158    pub top_p: Option<f32>,
159    #[serde(skip_serializing_if = "Option::is_none")]
160    pub parallel_tool_calls: Option<bool>,
161}
162
163/// The Groq provider, implementing [`CompletionsProvider`].
164#[derive(Clone, Debug)]
165pub struct GroqProvider {
166    api_key: String,
167    base_url: String,
168    streaming: bool,
169    request_config: GroqRequestConfig,
170}
171
172impl From<GroqConfig> for GroqProvider {
173    fn from(config: GroqConfig) -> Self {
174        Self {
175            api_key: config.api_key,
176            base_url: config.base_url,
177            streaming: config.streaming,
178            request_config: GroqRequestConfig {
179                model: config.model,
180                temperature: config.temperature,
181                max_completion_tokens: config.max_completion_tokens,
182                top_p: config.top_p,
183                parallel_tool_calls: config.parallel_tool_calls,
184            },
185        }
186    }
187}
188
189impl CompletionsProvider for GroqProvider {
190    type Config = GroqRequestConfig;
191
192    fn provider_name(&self) -> &str {
193        "Groq"
194    }
195    fn endpoint_url(&self) -> &str {
196        &self.base_url
197    }
198    fn config(&self) -> &GroqRequestConfig {
199        &self.request_config
200    }
201
202    fn preprocess_request(
203        &self,
204        builder: agentkit_http::HttpRequestBuilder,
205    ) -> agentkit_http::HttpRequestBuilder {
206        builder.bearer_auth(&self.api_key).header(
207            "User-Agent",
208            concat!("agentkit-provider-groq/", env!("CARGO_PKG_VERSION")),
209        )
210    }
211
212    fn streaming(&self) -> bool {
213        self.streaming
214    }
215}
216
217/// Model adapter that connects the agentkit agent loop to Groq.
218///
219/// # Example
220///
221/// ```rust,no_run
222/// use agentkit_loop::Agent;
223/// use agentkit_provider_groq::{GroqAdapter, GroqConfig};
224///
225/// # fn main() -> Result<(), Box<dyn std::error::Error>> {
226/// let adapter = GroqAdapter::new(GroqConfig::from_env()?)?;
227///
228/// let agent = Agent::builder()
229///     .model(adapter)
230///     .build()?;
231/// # Ok(())
232/// # }
233/// ```
234#[derive(Clone)]
235pub struct GroqAdapter(CompletionsAdapter<GroqProvider>);
236
237/// An active session with the Groq API.
238pub type GroqSession = CompletionsSession<GroqProvider>;
239
240/// A completed turn from the Groq API.
241pub type GroqTurn = CompletionsTurn;
242
243impl GroqAdapter {
244    /// Creates a new adapter from the given configuration.
245    pub fn new(config: GroqConfig) -> Result<Self, GroqError> {
246        let provider = GroqProvider::from(config);
247        Ok(Self(CompletionsAdapter::new(provider)?))
248    }
249}
250
251#[async_trait]
252impl ModelAdapter for GroqAdapter {
253    type Session = GroqSession;
254
255    async fn start_session(&self, config: SessionConfig) -> Result<Self::Session, LoopError> {
256        self.0.start_session(config).await
257    }
258}
259
260/// Errors produced by the Groq adapter.
261#[derive(Debug, Error)]
262pub enum GroqError {
263    /// A required environment variable is not set.
264    #[error("missing environment variable {0}")]
265    MissingEnv(&'static str),
266
267    /// An error from the generic completions adapter.
268    #[error(transparent)]
269    Completions(#[from] CompletionsError),
270}