use std::time::Duration;
use panini_core::component::{AnalysisComponent, ExtractionResult};
use panini_core::traits::LinguisticDefinition;
use rig::completion::{CompletionModel, CompletionRequestBuilder};
use rig::message::Message;
use crate::composer::{compose_schema, compose_prompt};
use crate::llm_utils::clean_llm_json;
use crate::prompts::{ExtractorPrompts, ExtractionRequest};
#[derive(Debug)]
pub struct ExtractionParseError {
pub raw_response: String,
pub error_message: String,
}
impl std::fmt::Display for ExtractionParseError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{}", self.error_message)
}
}
impl std::error::Error for ExtractionParseError {}
#[derive(Debug, thiserror::Error)]
pub enum ExtractionError {
#[error("LLM completion failed: {0}")]
Llm(#[from] rig::completion::request::CompletionError),
#[error("JSON error: {0}")]
Json(#[from] serde_json::Error),
#[error("prompt composition failed: {0}")]
PromptComposition(#[from] crate::prompts::PromptBuilderError),
#[error("LLM returned no text content")]
EmptyResponse,
#[error("{0}")]
Parse(#[from] ExtractionParseError),
#[error("failed to map extracted components to result struct")]
ResultMapping(#[from] panini_core::component::ExtractionResultError),
}
struct PreviousAttempt {
pub raw_response: String,
pub error: String,
}
#[derive(Clone, Debug)]
pub struct RetryConfig {
pub max_retries: usize,
pub initial_backoff_secs: u64,
}
impl Default for RetryConfig {
fn default() -> Self {
Self {
max_retries: 2,
initial_backoff_secs: 1,
}
}
}
#[derive(Clone)]
pub struct ExtractionOptions<'a> {
pub temperature: f32,
pub max_tokens: u32,
pub extractor_prompts: &'a ExtractorPrompts,
pub retry: RetryConfig,
pub timeout: Duration,
}
impl<'a> ExtractionOptions<'a> {
#[must_use]
pub fn new(extractor_prompts: &'a ExtractorPrompts) -> Self {
Self {
temperature: 0.2,
max_tokens: 4096,
extractor_prompts,
retry: RetryConfig::default(),
timeout: Duration::from_secs(30),
}
}
}
pub async fn extract_with_components<L, M>(
language: &L,
model: &M,
request: &ExtractionRequest,
components: &[&dyn AnalysisComponent<L>],
options: ExtractionOptions<'_>,
) -> Result<ExtractionResult, ExtractionError>
where
L: LinguisticDefinition + Send + Sync,
M: CompletionModel,
{
let mut prev_attempt: Option<PreviousAttempt> = None;
let mut backoff = backoff::ExponentialBackoffBuilder::new()
.with_initial_interval(Duration::from_secs(options.retry.initial_backoff_secs))
.with_multiplier(2.0)
.with_max_elapsed_time(Some(options.timeout))
.build();
loop {
let result = perform_single_shot_extraction(
language,
model,
request,
components,
&options,
prev_attempt.as_ref(),
)
.await;
match result {
Ok(res) => return Ok(res),
Err(e) => {
if let ExtractionError::Parse(pe) = &e
&& let Some(wait) = backoff::backoff::Backoff::next_backoff(&mut backoff) {
tracing::warn!(
?wait,
error = %pe.error_message,
"Extraction validation failed, retrying with self-correction..."
);
prev_attempt = Some(PreviousAttempt {
raw_response: pe.raw_response.clone(),
error: pe.error_message.clone(),
});
tokio::time::sleep(wait).await;
continue;
}
return Err(e);
}
}
}
}
async fn perform_single_shot_extraction<L, M>(
language: &L,
model: &M,
request: &ExtractionRequest,
components: &[&dyn AnalysisComponent<L>],
options: &ExtractionOptions<'_>,
previous_attempt: Option<&PreviousAttempt>,
) -> Result<ExtractionResult, ExtractionError>
where
L: LinguisticDefinition + Send + Sync,
M: CompletionModel,
{
let compatible: Vec<&dyn AnalysisComponent<L>> = components
.iter()
.filter(|c| c.is_compatible(language))
.copied()
.collect();
let requested_keys: Vec<&'static str> = compatible.iter().map(|c| c.schema_key()).collect();
let schema_value = compose_schema(language, &compatible);
let rig_schema: schemars::Schema = serde_json::from_value(schema_value.clone())?;
let system_prompt = compose_prompt(language, request, options.extractor_prompts, &compatible)?;
let user_message = format!(
"Extract features from this card:\n{}\n\nTARGET WORDS: {:?}",
request.content, request.targets
);
let mut builder: CompletionRequestBuilder<M> = model
.completion_request(user_message.as_str())
.preamble(system_prompt)
.temperature(f64::from(options.temperature))
.max_tokens(u64::from(options.max_tokens))
.output_schema(rig_schema);
if let Some(prev) = previous_attempt {
builder = builder
.message(Message::assistant(&prev.raw_response))
.message(Message::user(format!(
"Your output is not conform to what I'm expecting. \
Please look at the error and correct yourself: {}",
prev.error
)));
}
let completion_response = builder.send().await?;
let raw_text = completion_response
.choice
.into_iter()
.find_map(|c| {
if let rig::completion::message::AssistantContent::Text(t) = c {
Some(t.text)
} else {
None
}
})
.ok_or(ExtractionError::EmptyResponse)?;
let cleaned = clean_llm_json(&raw_text);
let mut processed = cleaned.to_string();
for comp in &compatible {
processed = comp.pre_process(&processed);
}
let mut json_value: serde_json::Value = match serde_json::from_str(&processed) {
Ok(v) => v,
Err(e) => {
let err_msg = format!("Invalid JSON syntax: {e}");
tracing::warn!(error = %err_msg, "Failed to parse JSON syntax");
return Err(ExtractionParseError {
raw_response: processed,
error_message: err_msg,
}
.into());
}
};
if let Ok(validator) = jsonschema::validator_for(&schema_value) {
let schema_errors: Vec<_> = validator.iter_errors(&json_value).collect();
if !schema_errors.is_empty() {
let mut err_msgs = Vec::new();
for err in schema_errors {
err_msgs.push(format!("- Path: {}: {}", err.instance_path(), err));
}
let err_msg = format!(
"Schema validation failed with {} errors:\n{}",
err_msgs.len(),
err_msgs.join("\n")
);
tracing::warn!(error = %err_msg, "Schema validation failed — retrying");
return Err(ExtractionParseError {
raw_response: processed,
error_message: err_msg,
}
.into());
}
}
for comp in &compatible {
let key = comp.schema_key();
if let Some(section) = json_value.get(key) {
comp.validate(language, section).map_err(|e| {
ExtractionParseError {
raw_response: processed.clone(),
error_message: format!("Validation failed for component '{key}': {e}"),
}
})?;
}
}
for comp in &compatible {
let key = comp.schema_key();
if let Some(section) = json_value.get_mut(key) {
comp.post_process(language, section).map_err(|e| {
ExtractionParseError {
raw_response: processed.clone(),
error_message: format!("Post-processing failed for component '{key}': {e}"),
}
})?;
}
}
Ok(ExtractionResult::new(json_value, requested_keys))
}