use super::{TokenUsage, UsageCounters, shared};
use crate::error::KumoError;
use async_trait::async_trait;
use rig::client::CompletionClient;
use rig::completion::CompletionModel;
use rig::providers::gemini;
use serde_json::Value;
use std::sync::Arc;
pub mod models {
pub use crate::llm::models::gemini::*;
}
pub struct GeminiClient {
inner: gemini::Client,
model: String,
system_prompt: Option<String>,
prompt_template: Option<String>,
strip_scripts: bool,
max_tokens: u64,
usage: Arc<UsageCounters>,
}
impl GeminiClient {
pub fn new(api_key: impl Into<String>) -> Self {
let key = api_key.into();
Self {
inner: gemini::Client::new(key).expect("failed to build Gemini client"),
model: models::GEMINI_2_5_FLASH.into(),
system_prompt: None,
prompt_template: None,
strip_scripts: false,
max_tokens: 4096,
usage: UsageCounters::new(),
}
}
pub fn model(mut self, model: impl Into<String>) -> Self {
self.model = model.into();
self
}
pub fn system_prompt(mut self, prompt: impl Into<String>) -> Self {
self.system_prompt = Some(prompt.into());
self
}
pub fn prompt_template(mut self, template: impl Into<String>) -> Self {
self.prompt_template = Some(template.into());
self
}
pub fn strip_scripts_and_styles(mut self, yes: bool) -> Self {
self.strip_scripts = yes;
self
}
pub fn max_tokens(mut self, n: u64) -> Self {
self.max_tokens = n;
self
}
pub fn total_usage(&self) -> TokenUsage {
self.usage.snapshot()
}
}
#[async_trait]
impl super::LlmClient for GeminiClient {
async fn extract_json(
&self,
schema: &Value,
html: &str,
) -> Result<(Value, TokenUsage), KumoError> {
let args = shared::build_extract_args(
schema,
html,
self.system_prompt.as_deref(),
self.prompt_template.as_deref(),
self.strip_scripts,
);
let model = self.inner.completion_model(&self.model);
let request = model
.completion_request(args.user_content)
.preamble(args.system)
.tool(args.tool)
.max_tokens(self.max_tokens)
.build();
let resp = model
.completion(request)
.await
.map_err(|e| shared::llm_err(format!("Gemini API error — {e}")))?;
let usage = TokenUsage::from_rig(&resp.usage);
let value = shared::extract_tool_input(resp.choice, "extract")?;
self.usage.add(&usage);
Ok((value, usage))
}
}