Skip to main content

panini_lang_engine/
extractor.rs

1use panini_core::component::{AnalysisComponent, ExtractionResult};
2use panini_core::traits::LinguisticDefinition;
3use rig::completion::{CompletionModel, CompletionRequestBuilder};
4use rig::message::Message;
5use std::time::Duration;
6
7use crate::composer::{compose_prompt, compose_schema};
8use crate::llm_utils::clean_llm_json;
9use crate::prompts::{ExtractionRequest, ExtractorPrompts};
10
11// ─── Error types ──────────────────────────────────────────────────────────────
12
13/// Detailed reason for an extraction failure.
14#[derive(Debug, thiserror::Error)]
15pub enum ExtractionFailureReason {
16    /// LLM output could not be parsed as JSON.
17    #[error("Invalid JSON syntax: {0}")]
18    JsonSyntax(String),
19
20    /// JSON output did not match the required schema.
21    #[error("Schema validation failed: {0}")]
22    Schema(String),
23
24    /// A specific component failed its internal validation.
25    #[error("Validation failed for component '{key}': {message}")]
26    ComponentValidation {
27        key: &'static str,
28        message: String,
29    },
30
31    /// A specific component failed its internal post-processing.
32    #[error("Post-processing failed for component '{key}': {message}")]
33    ComponentPostProcess {
34        key: &'static str,
35        message: String,
36    },
37}
38
39/// Error returned when feature extraction parsing fails, carrying the raw LLM output and structured reason.
40#[derive(Debug, thiserror::Error)]
41#[error("{reason}")]
42pub struct ExtractionParseError {
43    pub raw_response: String,
44    pub reason: ExtractionFailureReason,
45}
46
47/// Typed error enum for the extraction pipeline.
48#[derive(Debug, thiserror::Error)]
49pub enum ExtractionError {
50    /// LLM provider errors (rig-core completion failures, network, auth, etc.)
51    #[error("LLM completion failed: {0}")]
52    Llm(#[from] rig::completion::request::CompletionError),
53
54    /// JSON serialization/deserialization errors (schema conversion, response parsing)
55    #[error("JSON error: {0}")]
56    Json(#[from] serde_json::Error),
57
58    /// Prompt composition errors (missing placeholders, I/O, etc.)
59    #[error("prompt composition failed: {0}")]
60    PromptComposition(#[from] crate::prompts::PromptBuilderError),
61
62    /// LLM returned no text content in its response
63    #[error("LLM returned no text content")]
64    EmptyResponse,
65
66    /// Schema validation or component validation/parse failure — carries the raw
67    /// LLM output so callers can retry with `PreviousAttempt`
68    #[error("{0}")]
69    Parse(#[from] ExtractionParseError),
70
71    /// Failed to map raw `ExtractionResult` into a typed consumer struct
72    /// (used by `#[derive(PaniniResult)]` generated code)
73    #[error("failed to map extracted components to result struct")]
74    ResultMapping(#[from] panini_core::component::ExtractionResultError),
75}
76
77// ─── Extraction options ───────────────────────────────────────────────────────
78
79/// Previous failed attempt context for LLM self-correction retry.
80struct PreviousAttempt {
81    pub raw_response: String,
82    pub error: String,
83}
84
85/// Configuration for the retry mechanism
86#[derive(Clone, Debug)]
87pub struct RetryConfig {
88    pub max_retries: usize,
89    pub initial_backoff_secs: u64,
90}
91
92impl Default for RetryConfig {
93    fn default() -> Self {
94        Self {
95            max_retries: 2,
96            initial_backoff_secs: 1,
97        }
98    }
99}
100
101/// Bundles extraction parameters
102#[derive(Clone)]
103pub struct ExtractionOptions<'a> {
104    pub temperature: f32,
105    pub max_tokens: u32,
106    pub extractor_prompts: &'a ExtractorPrompts,
107    pub retry: RetryConfig,
108    pub timeout: Duration,
109}
110
111impl<'a> ExtractionOptions<'a> {
112    #[must_use]
113    pub fn new(extractor_prompts: &'a ExtractorPrompts) -> Self {
114        Self {
115            temperature: 0.2,
116            max_tokens: 4096,
117            extractor_prompts,
118            retry: RetryConfig::default(),
119            timeout: Duration::from_secs(30),
120        }
121    }
122}
123
124// ─── Composable entry point ───────────────────────────────────────────────────
125
126/// Extracts features using composable `AnalysisComponent`s.
127///
128/// This is the entry-point that supports selecting which analyses to include.
129/// It includes an automatic self-correction loop (Retry) in case of validation errors.
130///
131/// # Errors
132/// Returns an extraction error if the LLM completion fails, or JSON parsing
133/// / validation fails after all retry attempts are exhausted.
134pub async fn extract_with_components<L, M>(
135    language: &L,
136    model: &M,
137    request: &ExtractionRequest,
138    components: &[&dyn AnalysisComponent<L>],
139    options: ExtractionOptions<'_>,
140) -> Result<ExtractionResult, ExtractionError>
141where
142    L: LinguisticDefinition + Send + Sync,
143    M: CompletionModel,
144{
145    let mut prev_attempt: Option<PreviousAttempt> = None;
146    let mut backoff = backoff::ExponentialBackoffBuilder::new()
147        .with_initial_interval(Duration::from_secs(options.retry.initial_backoff_secs))
148        .with_multiplier(2.0)
149        .with_max_elapsed_time(Some(options.timeout))
150        .build();
151
152    loop {
153        let result = perform_single_shot_extraction(
154            language,
155            model,
156            request,
157            components,
158            &options,
159            prev_attempt.as_ref(),
160        )
161        .await;
162
163        match result {
164            Ok(res) => return Ok(res),
165            Err(e) => {
166                // Only retry on parsing/validation errors
167                if let ExtractionError::Parse(pe) = &e
168                    && let Some(wait) = backoff::backoff::Backoff::next_backoff(&mut backoff)
169                {
170                    let err_msg = pe.reason.to_string();
171                    tracing::warn!(
172                        ?wait,
173                        error = %err_msg,
174                        "Extraction validation failed, retrying with self-correction..."
175                    );
176                    prev_attempt = Some(PreviousAttempt {
177                        raw_response: pe.raw_response.clone(),
178                        error: err_msg,
179                    });
180                    tokio::time::sleep(wait).await;
181                    continue;
182                }
183                return Err(e);
184            }
185        }
186    }
187}
188
189/// Internal function to perform a single extraction attempt.
190async fn perform_single_shot_extraction<L, M>(
191    language: &L,
192    model: &M,
193    request: &ExtractionRequest,
194    components: &[&dyn AnalysisComponent<L>],
195    options: &ExtractionOptions<'_>,
196    previous_attempt: Option<&PreviousAttempt>,
197) -> Result<ExtractionResult, ExtractionError>
198where
199    L: LinguisticDefinition + Send + Sync,
200    M: CompletionModel,
201{
202    // 1. Filter to compatible components
203    let compatible: Vec<&dyn AnalysisComponent<L>> = components
204        .iter()
205        .filter(|c| c.is_compatible(language))
206        .copied()
207        .collect();
208
209    let requested_keys: Vec<&'static str> = compatible.iter().map(|c| c.schema_key()).collect();
210
211    // 2. Compose schema
212    let schema_value = compose_schema(language, &compatible);
213    let rig_schema: schemars::Schema = serde_json::from_value(schema_value.clone())?;
214
215    // 3. Compose prompt
216    let system_prompt = compose_prompt(language, request, options.extractor_prompts, &compatible)?;
217
218    let user_message = format!(
219        "Extract features from this card:\n{}\n\nTARGET WORDS: {:?}",
220        request.content, request.targets
221    );
222
223    // 4. Build LLM request
224    let mut builder: CompletionRequestBuilder<M> = model
225        .completion_request(user_message.as_str())
226        .preamble(system_prompt)
227        .temperature(f64::from(options.temperature))
228        .max_tokens(u64::from(options.max_tokens))
229        .output_schema(rig_schema);
230
231    if let Some(prev) = previous_attempt {
232        builder = builder
233            .message(Message::assistant(&prev.raw_response))
234            .message(Message::user(format!(
235                "Your output is not conform to what I'm expecting. \
236                 Please look at the error and correct yourself: {}",
237                prev.error
238            )));
239    }
240
241    let completion_response = builder.send().await?;
242
243    let raw_text = completion_response
244        .choice
245        .into_iter()
246        .find_map(|c| {
247            if let rig::completion::message::AssistantContent::Text(t) = c {
248                Some(t.text)
249            } else {
250                None
251            }
252        })
253        .ok_or(ExtractionError::EmptyResponse)?;
254
255    // 5. Chain pre_process from each component
256    let cleaned = clean_llm_json(&raw_text);
257    let mut processed = cleaned.to_string();
258    for comp in &compatible {
259        processed = comp.pre_process(&processed);
260    }
261
262    // 6. Parse JSON
263    let mut json_value: serde_json::Value = match serde_json::from_str(&processed) {
264        Ok(v) => v,
265        Err(e) => {
266            let err_msg = format!("{e}");
267            tracing::warn!(error = %err_msg, "Failed to parse JSON syntax");
268            return Err(ExtractionParseError {
269                raw_response: processed,
270                reason: ExtractionFailureReason::JsonSyntax(err_msg),
271            }
272            .into());
273        }
274    };
275
276    // 7. Validate composed schema
277    if let Ok(validator) = jsonschema::validator_for(&schema_value) {
278        let schema_errors: Vec<_> = validator.iter_errors(&json_value).collect();
279        if !schema_errors.is_empty() {
280            let mut err_msgs = Vec::new();
281            for err in schema_errors {
282                err_msgs.push(format!("- Path: {}: {}", err.instance_path(), err));
283            }
284            let err_msg = err_msgs.join("\n");
285            tracing::warn!(error = %err_msg, "Schema validation failed — retrying");
286            return Err(ExtractionParseError {
287                raw_response: processed,
288                reason: ExtractionFailureReason::Schema(err_msg),
289            }
290            .into());
291        }
292    }
293
294    // 8. Per-component validate + post_process
295    for comp in &compatible {
296        let key = comp.schema_key();
297        if let Some(section) = json_value.get(key) {
298            comp.validate(language, section)
299                .map_err(|e| ExtractionParseError {
300                    raw_response: processed.clone(),
301                    reason: ExtractionFailureReason::ComponentValidation { key, message: e },
302                })?;
303        }
304    }
305
306    for comp in &compatible {
307        let key = comp.schema_key();
308        if let Some(section) = json_value.get_mut(key) {
309            comp.post_process(language, section)
310                .map_err(|e| ExtractionParseError {
311                    raw_response: processed.clone(),
312                    reason: ExtractionFailureReason::ComponentPostProcess { key, message: e },
313                })?;
314        }
315    }
316
317    // 9. Return ExtractionResult
318    Ok(ExtractionResult::new(json_value, requested_keys))
319}