agentkit_provider_groq/
lib.rs1use agentkit_adapter_completions::{
31 CompletionsAdapter, CompletionsError, CompletionsProvider, CompletionsSession, CompletionsTurn,
32};
33use agentkit_loop::{LoopError, ModelAdapter, SessionConfig};
34use async_trait::async_trait;
35use serde::Serialize;
36use thiserror::Error;
37
38const DEFAULT_ENDPOINT: &str = "https://api.groq.com/openai/v1/chat/completions";
39
40#[derive(Clone, Debug)]
55pub struct GroqConfig {
56 pub api_key: String,
58 pub model: String,
60 pub base_url: String,
62 pub temperature: Option<f32>,
64 pub max_completion_tokens: Option<u32>,
66 pub top_p: Option<f32>,
68 pub parallel_tool_calls: Option<bool>,
72 pub streaming: bool,
74}
75
76impl GroqConfig {
77 pub fn new(api_key: impl Into<String>, model: impl Into<String>) -> Self {
79 Self {
80 api_key: api_key.into(),
81 model: model.into(),
82 base_url: DEFAULT_ENDPOINT.into(),
83 temperature: None,
84 max_completion_tokens: None,
85 top_p: None,
86 parallel_tool_calls: None,
87 streaming: true,
88 }
89 }
90
91 pub fn with_base_url(mut self, url: impl Into<String>) -> Self {
93 self.base_url = url.into();
94 self
95 }
96
97 pub fn with_temperature(mut self, v: f32) -> Self {
99 self.temperature = Some(v);
100 self
101 }
102
103 pub fn with_max_completion_tokens(mut self, v: u32) -> Self {
105 self.max_completion_tokens = Some(v);
106 self
107 }
108
109 pub fn with_top_p(mut self, v: f32) -> Self {
111 self.top_p = Some(v);
112 self
113 }
114
115 pub fn with_parallel_tool_calls(mut self, flag: bool) -> Self {
117 self.parallel_tool_calls = Some(flag);
118 self
119 }
120
121 pub fn with_streaming(mut self, flag: bool) -> Self {
123 self.streaming = flag;
124 self
125 }
126
127 pub fn from_env() -> Result<Self, GroqError> {
135 let api_key =
136 std::env::var("GROQ_API_KEY").map_err(|_| GroqError::MissingEnv("GROQ_API_KEY"))?;
137 let model = std::env::var("GROQ_MODEL").unwrap_or_else(|_| "llama-3.1-8b-instant".into());
138
139 let mut config = Self::new(api_key, model);
140
141 if let Ok(url) = std::env::var("GROQ_BASE_URL") {
142 config = config.with_base_url(url);
143 }
144
145 Ok(config)
146 }
147}
148
149#[derive(Clone, Debug, Serialize)]
151pub struct GroqRequestConfig {
152 pub model: String,
153 #[serde(skip_serializing_if = "Option::is_none")]
154 pub temperature: Option<f32>,
155 #[serde(skip_serializing_if = "Option::is_none")]
156 pub max_completion_tokens: Option<u32>,
157 #[serde(skip_serializing_if = "Option::is_none")]
158 pub top_p: Option<f32>,
159 #[serde(skip_serializing_if = "Option::is_none")]
160 pub parallel_tool_calls: Option<bool>,
161}
162
163#[derive(Clone, Debug)]
165pub struct GroqProvider {
166 api_key: String,
167 base_url: String,
168 streaming: bool,
169 request_config: GroqRequestConfig,
170}
171
172impl From<GroqConfig> for GroqProvider {
173 fn from(config: GroqConfig) -> Self {
174 Self {
175 api_key: config.api_key,
176 base_url: config.base_url,
177 streaming: config.streaming,
178 request_config: GroqRequestConfig {
179 model: config.model,
180 temperature: config.temperature,
181 max_completion_tokens: config.max_completion_tokens,
182 top_p: config.top_p,
183 parallel_tool_calls: config.parallel_tool_calls,
184 },
185 }
186 }
187}
188
189impl CompletionsProvider for GroqProvider {
190 type Config = GroqRequestConfig;
191
192 fn provider_name(&self) -> &str {
193 "Groq"
194 }
195 fn endpoint_url(&self) -> &str {
196 &self.base_url
197 }
198 fn config(&self) -> &GroqRequestConfig {
199 &self.request_config
200 }
201
202 fn preprocess_request(
203 &self,
204 builder: agentkit_http::HttpRequestBuilder,
205 ) -> agentkit_http::HttpRequestBuilder {
206 builder.bearer_auth(&self.api_key).header(
207 "User-Agent",
208 concat!("agentkit-provider-groq/", env!("CARGO_PKG_VERSION")),
209 )
210 }
211
212 fn streaming(&self) -> bool {
213 self.streaming
214 }
215}
216
217#[derive(Clone)]
235pub struct GroqAdapter(CompletionsAdapter<GroqProvider>);
236
237pub type GroqSession = CompletionsSession<GroqProvider>;
239
240pub type GroqTurn = CompletionsTurn;
242
243impl GroqAdapter {
244 pub fn new(config: GroqConfig) -> Result<Self, GroqError> {
246 let provider = GroqProvider::from(config);
247 Ok(Self(CompletionsAdapter::new(provider)?))
248 }
249}
250
251#[async_trait]
252impl ModelAdapter for GroqAdapter {
253 type Session = GroqSession;
254
255 async fn start_session(&self, config: SessionConfig) -> Result<Self::Session, LoopError> {
256 self.0.start_session(config).await
257 }
258}
259
260#[derive(Debug, Error)]
262pub enum GroqError {
263 #[error("missing environment variable {0}")]
265 MissingEnv(&'static str),
266
267 #[error(transparent)]
269 Completions(#[from] CompletionsError),
270}