use anyhow::{Context, Result};
use async_trait::async_trait;
use reqwest::Client;
use serde::{Deserialize, Serialize};
use super::provider::ModelProvider;
use super::types::*;
pub struct GroqProvider {
client: Client,
api_key: String,
model: String,
base_url: String,
}
impl GroqProvider {
pub fn new(api_key: String, model: Option<String>, base_url: Option<String>) -> Self {
Self {
client: Client::new(),
api_key,
model: model.unwrap_or_else(|| "llama-3.3-70b-versatile".to_string()),
base_url: base_url.unwrap_or_else(|| "https://api.groq.com/openai".to_string()),
}
}
}
#[async_trait]
impl ModelProvider for GroqProvider {
async fn complete(&self, request: CompletionRequest) -> Result<CompletionResponse> {
let url = format!("{}/v1/chat/completions", self.base_url);
let messages = build_openai_messages(&request);
let tools = build_openai_tools(&request);
let mut body = serde_json::json!({
"model": self.model,
"messages": messages,
"max_tokens": request.max_tokens,
});
if !tools.is_empty() {
body["tools"] = serde_json::json!(tools);
}
let response = self
.client
.post(&url)
.header("Authorization", format!("Bearer {}", self.api_key))
.header("Content-Type", "application/json")
.json(&body)
.send()
.await
.context("Failed to send request to Groq API")?;
if !response.status().is_success() {
let status = response.status();
let error_text = response.text().await.unwrap_or_default();
anyhow::bail!("Groq API error ({}): {}", status, error_text);
}
let api_response: OpenAIResponse = response
.json()
.await
.context("Failed to parse Groq API response")?;
parse_openai_response(api_response)
}
fn name(&self) -> &str {
"groq"
}
fn model_id(&self) -> &str {
&self.model
}
fn supports_tools(&self) -> bool {
true
}
}
pub struct GoogleProvider {
client: Client,
api_key: String,
model: String,
base_url: String,
}
impl GoogleProvider {
pub fn new(api_key: String, model: Option<String>, base_url: Option<String>) -> Self {
Self {
client: Client::new(),
api_key,
model: model.unwrap_or_else(|| "gemini-2.5-flash".to_string()),
base_url: base_url
.unwrap_or_else(|| "https://generativelanguage.googleapis.com".to_string()),
}
}
}
#[async_trait]
impl ModelProvider for GoogleProvider {
async fn complete(&self, request: CompletionRequest) -> Result<CompletionResponse> {
let url = format!(
"{}/v1beta/models/{}:generateContent?key={}",
self.base_url, self.model, self.api_key
);
let mut parts: Vec<serde_json::Value> = Vec::new();
let system_instruction = if !request.system.is_empty() {
Some(serde_json::json!({
"parts": [{ "text": request.system }]
}))
} else {
None
};
let contents: Vec<serde_json::Value> = request
.messages
.iter()
.filter_map(|msg| {
let role = match msg.role {
Role::User => "user",
Role::Assistant => "model",
Role::System => return None, };
let parts: Vec<serde_json::Value> = msg
.content
.iter()
.filter_map(|block| match block {
ContentBlock::Text { text } => {
Some(serde_json::json!({ "text": text }))
}
ContentBlock::ToolUse { id, name, input } => {
Some(serde_json::json!({
"functionCall": {
"name": name,
"args": input
}
}))
}
ContentBlock::ToolResult {
tool_use_id: _,
content,
..
} => Some(serde_json::json!({
"functionResponse": {
"name": "tool",
"response": { "result": content }
}
})),
})
.collect();
Some(serde_json::json!({
"role": role,
"parts": parts
}))
})
.collect();
let mut body = serde_json::json!({
"contents": contents,
"generationConfig": {
"maxOutputTokens": request.max_tokens,
}
});
if let Some(si) = system_instruction {
body["systemInstruction"] = si;
}
if !request.tools.is_empty() {
let function_declarations: Vec<serde_json::Value> = request
.tools
.iter()
.map(|t| {
serde_json::json!({
"name": t.name,
"description": t.description,
"parameters": t.input_schema
})
})
.collect();
body["tools"] = serde_json::json!([{
"functionDeclarations": function_declarations
}]);
}
let response = self
.client
.post(&url)
.header("Content-Type", "application/json")
.json(&body)
.send()
.await
.context("Failed to send request to Google AI Studio")?;
if !response.status().is_success() {
let status = response.status();
let error_text = response.text().await.unwrap_or_default();
anyhow::bail!("Google AI Studio error ({}): {}", status, error_text);
}
let api_response: serde_json::Value = response
.json()
.await
.context("Failed to parse Google AI Studio response")?;
let candidates = api_response["candidates"]
.as_array()
.context("No candidates in Google response")?;
let first_candidate = candidates.first().context("Empty candidates array")?;
let resp_parts = first_candidate["content"]["parts"]
.as_array()
.context("No parts in candidate")?;
let mut content_blocks: Vec<ContentBlock> = Vec::new();
let mut has_tool_use = false;
for part in resp_parts {
if let Some(text) = part["text"].as_str() {
content_blocks.push(ContentBlock::Text {
text: text.to_string(),
});
}
if let Some(fc) = part.get("functionCall") {
has_tool_use = true;
let name = fc["name"].as_str().unwrap_or("unknown").to_string();
let args = fc.get("args").cloned().unwrap_or(serde_json::json!({}));
content_blocks.push(ContentBlock::ToolUse {
id: format!("toolu_{}", uuid::Uuid::new_v4()),
name,
input: args,
});
}
}
let stop_reason = if has_tool_use {
StopReason::ToolUse
} else {
match first_candidate["finishReason"].as_str() {
Some("MAX_TOKENS") => StopReason::MaxTokens,
_ => StopReason::EndTurn,
}
};
Ok(CompletionResponse {
content: content_blocks,
stop_reason,
usage: None, })
}
fn name(&self) -> &str {
"google"
}
fn model_id(&self) -> &str {
&self.model
}
fn supports_tools(&self) -> bool {
true
}
}
pub fn build_openai_messages(request: &CompletionRequest) -> Vec<serde_json::Value> {
let mut messages = Vec::new();
if !request.system.is_empty() {
messages.push(serde_json::json!({
"role": "system",
"content": request.system,
}));
}
for msg in &request.messages {
let role = match msg.role {
Role::User => "user",
Role::Assistant => "assistant",
Role::System => "system",
};
let has_tool_use = msg.content.iter().any(|b| matches!(b, ContentBlock::ToolUse { .. }));
let tool_results: Vec<_> = msg
.content
.iter()
.filter_map(|b| match b {
ContentBlock::ToolResult { tool_use_id, content, is_error } => {
Some((tool_use_id, content, is_error))
}
_ => None,
})
.collect();
if has_tool_use {
let text_content: String = msg
.content
.iter()
.filter_map(|b| match b {
ContentBlock::Text { text } => Some(text.as_str()),
_ => None,
})
.collect::<Vec<_>>()
.join("");
let tool_calls: Vec<serde_json::Value> = msg
.content
.iter()
.filter_map(|b| match b {
ContentBlock::ToolUse { id, name, input } => Some(serde_json::json!({
"id": id,
"type": "function",
"function": {
"name": name,
"arguments": input.to_string(),
}
})),
_ => None,
})
.collect();
let mut msg_json = serde_json::json!({
"role": "assistant",
"tool_calls": tool_calls,
});
if !text_content.is_empty() {
msg_json["content"] = serde_json::json!(text_content);
}
messages.push(msg_json);
} else if !tool_results.is_empty() {
for (tool_use_id, content, _is_error) in tool_results {
messages.push(serde_json::json!({
"role": "tool",
"tool_call_id": tool_use_id,
"content": content,
}));
}
} else {
let text: String = msg
.content
.iter()
.filter_map(|b| match b {
ContentBlock::Text { text } => Some(text.as_str()),
_ => None,
})
.collect::<Vec<_>>()
.join("");
messages.push(serde_json::json!({
"role": role,
"content": text,
}));
}
}
messages
}
pub fn build_openai_tools(request: &CompletionRequest) -> Vec<serde_json::Value> {
request
.tools
.iter()
.map(|t| {
serde_json::json!({
"type": "function",
"function": {
"name": t.name,
"description": t.description,
"parameters": t.input_schema,
}
})
})
.collect()
}
pub fn parse_openai_response(response: OpenAIResponse) -> Result<CompletionResponse> {
let choice = response
.choices
.into_iter()
.next()
.context("No choices in API response")?;
let mut content_blocks: Vec<ContentBlock> = Vec::new();
let mut has_tool_calls = false;
if let Some(text) = &choice.message.content {
if !text.is_empty() {
content_blocks.push(ContentBlock::Text { text: text.clone() });
}
}
if let Some(tool_calls) = &choice.message.tool_calls {
has_tool_calls = !tool_calls.is_empty();
for tc in tool_calls {
let input: serde_json::Value =
serde_json::from_str(&tc.function.arguments).unwrap_or(serde_json::json!({}));
content_blocks.push(ContentBlock::ToolUse {
id: tc.id.clone(),
name: tc.function.name.clone(),
input,
});
}
}
let stop_reason = if has_tool_calls {
StopReason::ToolUse
} else {
match choice.finish_reason.as_deref() {
Some("length") => StopReason::MaxTokens,
Some("stop") => StopReason::EndTurn,
Some("tool_calls") => StopReason::ToolUse,
_ => StopReason::EndTurn,
}
};
let usage = response.usage.map(|u| Usage {
input_tokens: u.prompt_tokens,
output_tokens: u.completion_tokens,
});
Ok(CompletionResponse {
content: content_blocks,
stop_reason,
usage,
})
}
#[derive(Deserialize)]
pub struct OpenAIResponse {
pub choices: Vec<OpenAIChoice>,
pub usage: Option<OpenAIUsage>,
}
#[derive(Deserialize)]
pub struct OpenAIChoice {
pub message: OpenAIMessage,
pub finish_reason: Option<String>,
}
#[derive(Deserialize)]
pub struct OpenAIMessage {
pub content: Option<String>,
pub tool_calls: Option<Vec<OpenAIToolCall>>,
}
#[derive(Deserialize)]
pub struct OpenAIToolCall {
pub id: String,
pub function: OpenAIFunction,
}
#[derive(Deserialize)]
pub struct OpenAIFunction {
pub name: String,
pub arguments: String,
}
#[derive(Deserialize)]
pub struct OpenAIUsage {
pub prompt_tokens: u32,
pub completion_tokens: u32,
}