use super::{TokenUsage, UsageCounters, shared};
use crate::error::KumoError;
use async_trait::async_trait;
use rig::client::{CompletionClient, Nothing};
use rig::completion::CompletionModel;
use rig::providers::ollama;
use serde_json::Value;
use std::sync::Arc;
pub struct OllamaClient {
inner: ollama::Client,
model: String,
system_prompt: Option<String>,
prompt_template: Option<String>,
strip_scripts: bool,
max_tokens: u64,
usage: Arc<UsageCounters>,
}
impl OllamaClient {
pub fn new() -> Self {
let inner = ollama::Client::builder()
.api_key(Nothing)
.build()
.expect("failed to build Ollama client");
Self {
inner,
model: "llama3.2".into(),
system_prompt: None,
prompt_template: None,
strip_scripts: false,
max_tokens: 4096,
usage: UsageCounters::new(),
}
}
pub fn with_url(base_url: impl AsRef<str>) -> Self {
let inner = ollama::Client::builder()
.api_key(Nothing)
.base_url(base_url.as_ref())
.build()
.expect("failed to build Ollama client");
Self {
inner,
model: "llama3.2".into(),
system_prompt: None,
prompt_template: None,
strip_scripts: false,
max_tokens: 4096,
usage: UsageCounters::new(),
}
}
pub fn model(mut self, model: impl Into<String>) -> Self {
self.model = model.into();
self
}
pub fn system_prompt(mut self, prompt: impl Into<String>) -> Self {
self.system_prompt = Some(prompt.into());
self
}
pub fn prompt_template(mut self, template: impl Into<String>) -> Self {
self.prompt_template = Some(template.into());
self
}
pub fn strip_scripts_and_styles(mut self, yes: bool) -> Self {
self.strip_scripts = yes;
self
}
pub fn max_tokens(mut self, n: u64) -> Self {
self.max_tokens = n;
self
}
pub fn total_usage(&self) -> TokenUsage {
self.usage.snapshot()
}
}
impl Default for OllamaClient {
fn default() -> Self {
Self::new()
}
}
#[async_trait]
impl super::LlmClient for OllamaClient {
async fn extract_json(
&self,
schema: &Value,
html: &str,
) -> Result<(Value, TokenUsage), KumoError> {
let args = shared::build_extract_args(
schema,
html,
self.system_prompt.as_deref(),
self.prompt_template.as_deref(),
self.strip_scripts,
);
let model = self.inner.completion_model(&self.model);
let request = model
.completion_request(args.user_content)
.preamble(args.system)
.tool(args.tool)
.max_tokens(self.max_tokens)
.build();
let resp = model
.completion(request)
.await
.map_err(|e| shared::llm_err(format!("Ollama error — {e}")))?;
let usage = TokenUsage::from_rig(&resp.usage);
let value = shared::extract_tool_input(resp.choice, "extract")?;
self.usage.add(&usage);
Ok((value, usage))
}
}