use anyhow::{Context, Result};
use regex::Regex;
use rig::{agent::Agent, completion::Prompt};
use schemars::JsonSchema;
use serde::{Deserialize, Serialize};
use serde_json::Value;
use std::sync::LazyLock;
static JSON_CODE_BLOCK_REGEX: LazyLock<Regex> =
LazyLock::new(|| Regex::new(r"```(?:json)?\s*(\{[\s\S]*?\})\s*```").unwrap());
pub struct OllamaExtractorWrapper<T> {
agent: Agent<rig::providers::ollama::CompletionModel>,
max_retries: u32,
base_url: String,
model: String,
_phantom: std::marker::PhantomData<T>,
}
impl<T> OllamaExtractorWrapper<T>
where
T: JsonSchema + Serialize + for<'de> Deserialize<'de>,
{
pub fn with_config(
agent: Agent<rig::providers::ollama::CompletionModel>,
max_retries: u32,
base_url: String,
model: String,
) -> Self {
Self {
agent,
max_retries,
base_url,
model,
_phantom: std::marker::PhantomData,
}
}
pub async fn extract(&self, prompt: &str) -> Result<T> {
let mut last_error = None;
for attempt in 1..=self.max_retries {
let enhanced_prompt = self.build_prompt(prompt, last_error.as_deref());
match self.try_extract_via_rig(&enhanced_prompt, attempt as usize).await {
Ok(result) => return Ok(result),
Err(e) => {
let error_msg = format!("{:#}", e);
if error_msg.contains("ApiResponse") || error_msg.contains("untagged enum") {
match self.try_extract_via_http(&enhanced_prompt, attempt as usize).await {
Ok(result) => return Ok(result),
Err(http_err) => {
last_error = Some(format!("rig: {}, http: {}", error_msg, http_err));
}
}
} else {
last_error = Some(error_msg);
}
if attempt < self.max_retries {
tokio::time::sleep(tokio::time::Duration::from_millis(3000)).await;
}
}
}
}
Err(anyhow::anyhow!(
"Failed after {} attempts. Last error: {}",
self.max_retries,
last_error.unwrap_or_else(|| "Unknown error".to_string())
))
}
async fn try_extract_via_rig(&self, prompt: &str, attempt: usize) -> Result<T> {
let response = self
.agent
.prompt(prompt)
.await
.context("Failed to get response from Ollama via rig")?;
self.parse_and_validate(&response, attempt)
}
async fn try_extract_via_http(&self, prompt: &str, attempt: usize) -> Result<T> {
let client = reqwest::Client::new();
let request_body = serde_json::json!({
"model": self.model,
"prompt": prompt,
"stream": false
});
let response = client
.post(format!("{}/api/generate", self.base_url))
.json(&request_body)
.timeout(std::time::Duration::from_secs(120))
.send()
.await
.context("Failed to send HTTP request to Ollama")?;
if !response.status().is_success() {
anyhow::bail!("Ollama HTTP error: {}", response.status());
}
let json: Value = response
.json()
.await
.context("Failed to parse Ollama HTTP response")?;
let response_text = json
.get("response")
.and_then(|v| v.as_str())
.ok_or_else(|| anyhow::anyhow!("No 'response' field in Ollama HTTP response"))?;
self.parse_and_validate(response_text, attempt)
}
fn parse_and_validate(&self, response: &str, attempt: usize) -> Result<T> {
let parsed = self
.parse_json_response(response, attempt)
.context("Failed to parse JSON from Ollama response")?;
self.validate_json(&parsed)?;
let result: T = serde_json::from_value(parsed.clone()).with_context(|| {
let json_str = serde_json::to_string_pretty(&parsed)
.unwrap_or_else(|_| "invalid".to_string());
format!(
"Failed to deserialize JSON to target type on attempt {}. JSON structure: {}",
attempt, json_str
)
})?;
Ok(result)
}
fn build_prompt(&self, base_prompt: &str, previous_error: Option<&str>) -> String {
let schema = schemars::schema_for!(T);
let schema_json = serde_json::to_string_pretty(&schema)
.unwrap_or_else(|_| "{}".to_string());
let mut prompt = format!(
"{}\n\n**CRITICAL: YOU MUST RETURN VALID JSON**\n\nYou MUST return the result as a valid JSON object that strictly follows this schema:\n\n```json\n{}\n```\n\n",
base_prompt, schema_json
);
prompt.push_str("Requirements:\n");
prompt.push_str("1. Return pure JSON object, do not add any extra text\n");
prompt.push_str("2. All required fields must be present\n");
prompt.push_str("3. Field types must match schema exactly\n");
prompt.push_str("4. Arrays and nested objects must be correctly formatted\n\n");
if let Some(error) = previous_error {
prompt.push_str(&format!(
"**Previous attempt failed with error: {}**\nPlease fix these issues and regenerate.\n\n",
error
));
}
prompt
}
fn parse_json_response(&self, response: &str, attempt: usize) -> Result<Value> {
if let Ok(json) = serde_json::from_str::<Value>(response) {
return Ok(json);
}
if let Some(json_str) = self.extract_from_code_block(response) {
if let Ok(parsed) = serde_json::from_str::<Value>(&json_str) {
return Ok(parsed);
}
}
if let Some(json_str) = self.extract_first_json_object(response) {
if let Ok(parsed) = serde_json::from_str::<Value>(&json_str) {
return Ok(parsed);
}
}
let cleaned = self.clean_response(response);
serde_json::from_str::<Value>(&cleaned).with_context(|| {
let preview = response.chars().take(200).collect::<String>();
format!(
"Failed to parse JSON from Ollama response (attempt {}). Response preview: {}",
attempt, preview
)
})
}
fn extract_from_code_block(&self, text: &str) -> Option<String> {
JSON_CODE_BLOCK_REGEX
.captures(text)
.and_then(|cap| cap.get(1))
.map(|m| m.as_str().to_string())
}
fn extract_first_json_object(&self, text: &str) -> Option<String> {
let start = text.find('{')?;
let mut depth = 0;
let mut end = start;
for (i, c) in text[start..].char_indices() {
match c {
'{' => depth += 1,
'}' => {
depth -= 1;
if depth == 0 {
end = start + i + 1;
break;
}
}
_ => {}
}
}
if depth == 0 && end > start {
Some(text[start..end].to_string())
} else {
None
}
}
fn clean_response(&self, text: &str) -> String {
text.trim()
.trim_start_matches("```json")
.trim_start_matches("```")
.trim_end_matches("```")
.trim()
.to_string()
}
fn validate_json(&self, json: &Value) -> Result<()> {
if !json.is_object() {
anyhow::bail!("Expected JSON object, got: {}", json);
}
Ok(())
}
}