1use panini_core::component::{AnalysisComponent, ExtractionResult};
2use panini_core::traits::LinguisticDefinition;
3use rig::completion::{CompletionModel, CompletionRequestBuilder};
4use rig::message::Message;
5use std::time::Duration;
6
7use crate::composer::{compose_prompt, compose_schema};
8use crate::llm_utils::clean_llm_json;
9use crate::prompts::{ExtractionRequest, ExtractorPrompts};
10
11#[derive(Debug, thiserror::Error)]
15pub enum ExtractionFailureReason {
16 #[error("Invalid JSON syntax: {0}")]
18 JsonSyntax(String),
19
20 #[error("Schema validation failed: {0}")]
22 Schema(String),
23
24 #[error("Validation failed for component '{key}': {message}")]
26 ComponentValidation {
27 key: &'static str,
28 message: String,
29 },
30
31 #[error("Post-processing failed for component '{key}': {message}")]
33 ComponentPostProcess {
34 key: &'static str,
35 message: String,
36 },
37}
38
39#[derive(Debug, thiserror::Error)]
41#[error("{reason}")]
42pub struct ExtractionParseError {
43 pub raw_response: String,
44 pub reason: ExtractionFailureReason,
45}
46
47#[derive(Debug, thiserror::Error)]
49pub enum ExtractionError {
50 #[error("LLM completion failed: {0}")]
52 Llm(#[from] rig::completion::request::CompletionError),
53
54 #[error("JSON error: {0}")]
56 Json(#[from] serde_json::Error),
57
58 #[error("prompt composition failed: {0}")]
60 PromptComposition(#[from] crate::prompts::PromptBuilderError),
61
62 #[error("LLM returned no text content")]
64 EmptyResponse,
65
66 #[error("{0}")]
69 Parse(#[from] ExtractionParseError),
70
71 #[error("failed to map extracted components to result struct")]
74 ResultMapping(#[from] panini_core::component::ExtractionResultError),
75}
76
77struct PreviousAttempt {
81 pub raw_response: String,
82 pub error: String,
83}
84
85#[derive(Clone, Debug)]
87pub struct RetryConfig {
88 pub max_retries: usize,
89 pub initial_backoff_secs: u64,
90}
91
92impl Default for RetryConfig {
93 fn default() -> Self {
94 Self {
95 max_retries: 2,
96 initial_backoff_secs: 1,
97 }
98 }
99}
100
101#[derive(Clone)]
103pub struct ExtractionOptions<'a> {
104 pub temperature: f32,
105 pub max_tokens: u32,
106 pub extractor_prompts: &'a ExtractorPrompts,
107 pub retry: RetryConfig,
108 pub timeout: Duration,
109}
110
111impl<'a> ExtractionOptions<'a> {
112 #[must_use]
113 pub fn new(extractor_prompts: &'a ExtractorPrompts) -> Self {
114 Self {
115 temperature: 0.2,
116 max_tokens: 4096,
117 extractor_prompts,
118 retry: RetryConfig::default(),
119 timeout: Duration::from_secs(30),
120 }
121 }
122}
123
124pub async fn extract_with_components<L, M>(
135 language: &L,
136 model: &M,
137 request: &ExtractionRequest,
138 components: &[&dyn AnalysisComponent<L>],
139 options: ExtractionOptions<'_>,
140) -> Result<ExtractionResult, ExtractionError>
141where
142 L: LinguisticDefinition + Send + Sync,
143 M: CompletionModel,
144{
145 let mut prev_attempt: Option<PreviousAttempt> = None;
146 let mut backoff = backoff::ExponentialBackoffBuilder::new()
147 .with_initial_interval(Duration::from_secs(options.retry.initial_backoff_secs))
148 .with_multiplier(2.0)
149 .with_max_elapsed_time(Some(options.timeout))
150 .build();
151
152 loop {
153 let result = perform_single_shot_extraction(
154 language,
155 model,
156 request,
157 components,
158 &options,
159 prev_attempt.as_ref(),
160 )
161 .await;
162
163 match result {
164 Ok(res) => return Ok(res),
165 Err(e) => {
166 if let ExtractionError::Parse(pe) = &e
168 && let Some(wait) = backoff::backoff::Backoff::next_backoff(&mut backoff)
169 {
170 let err_msg = pe.reason.to_string();
171 tracing::warn!(
172 ?wait,
173 error = %err_msg,
174 "Extraction validation failed, retrying with self-correction..."
175 );
176 prev_attempt = Some(PreviousAttempt {
177 raw_response: pe.raw_response.clone(),
178 error: err_msg,
179 });
180 tokio::time::sleep(wait).await;
181 continue;
182 }
183 return Err(e);
184 }
185 }
186 }
187}
188
189async fn perform_single_shot_extraction<L, M>(
191 language: &L,
192 model: &M,
193 request: &ExtractionRequest,
194 components: &[&dyn AnalysisComponent<L>],
195 options: &ExtractionOptions<'_>,
196 previous_attempt: Option<&PreviousAttempt>,
197) -> Result<ExtractionResult, ExtractionError>
198where
199 L: LinguisticDefinition + Send + Sync,
200 M: CompletionModel,
201{
202 let compatible: Vec<&dyn AnalysisComponent<L>> = components
204 .iter()
205 .filter(|c| c.is_compatible(language))
206 .copied()
207 .collect();
208
209 let requested_keys: Vec<&'static str> = compatible.iter().map(|c| c.schema_key()).collect();
210
211 let schema_value = compose_schema(language, &compatible);
213 let rig_schema: schemars::Schema = serde_json::from_value(schema_value.clone())?;
214
215 let system_prompt = compose_prompt(language, request, options.extractor_prompts, &compatible)?;
217
218 let user_message = format!(
219 "Extract features from this card:\n{}\n\nTARGET WORDS: {:?}",
220 request.content, request.targets
221 );
222
223 let mut builder: CompletionRequestBuilder<M> = model
225 .completion_request(user_message.as_str())
226 .preamble(system_prompt)
227 .temperature(f64::from(options.temperature))
228 .max_tokens(u64::from(options.max_tokens))
229 .output_schema(rig_schema);
230
231 if let Some(prev) = previous_attempt {
232 builder = builder
233 .message(Message::assistant(&prev.raw_response))
234 .message(Message::user(format!(
235 "Your output is not conform to what I'm expecting. \
236 Please look at the error and correct yourself: {}",
237 prev.error
238 )));
239 }
240
241 let completion_response = builder.send().await?;
242
243 let raw_text = completion_response
244 .choice
245 .into_iter()
246 .find_map(|c| {
247 if let rig::completion::message::AssistantContent::Text(t) = c {
248 Some(t.text)
249 } else {
250 None
251 }
252 })
253 .ok_or(ExtractionError::EmptyResponse)?;
254
255 let cleaned = clean_llm_json(&raw_text);
257 let mut processed = cleaned.to_string();
258 for comp in &compatible {
259 processed = comp.pre_process(&processed);
260 }
261
262 let mut json_value: serde_json::Value = match serde_json::from_str(&processed) {
264 Ok(v) => v,
265 Err(e) => {
266 let err_msg = format!("{e}");
267 tracing::warn!(error = %err_msg, "Failed to parse JSON syntax");
268 return Err(ExtractionParseError {
269 raw_response: processed,
270 reason: ExtractionFailureReason::JsonSyntax(err_msg),
271 }
272 .into());
273 }
274 };
275
276 if let Ok(validator) = jsonschema::validator_for(&schema_value) {
278 let schema_errors: Vec<_> = validator.iter_errors(&json_value).collect();
279 if !schema_errors.is_empty() {
280 let mut err_msgs = Vec::new();
281 for err in schema_errors {
282 err_msgs.push(format!("- Path: {}: {}", err.instance_path(), err));
283 }
284 let err_msg = err_msgs.join("\n");
285 tracing::warn!(error = %err_msg, "Schema validation failed — retrying");
286 return Err(ExtractionParseError {
287 raw_response: processed,
288 reason: ExtractionFailureReason::Schema(err_msg),
289 }
290 .into());
291 }
292 }
293
294 for comp in &compatible {
296 let key = comp.schema_key();
297 if let Some(section) = json_value.get(key) {
298 comp.validate(language, section)
299 .map_err(|e| ExtractionParseError {
300 raw_response: processed.clone(),
301 reason: ExtractionFailureReason::ComponentValidation { key, message: e },
302 })?;
303 }
304 }
305
306 for comp in &compatible {
307 let key = comp.schema_key();
308 if let Some(section) = json_value.get_mut(key) {
309 comp.post_process(language, section)
310 .map_err(|e| ExtractionParseError {
311 raw_response: processed.clone(),
312 reason: ExtractionFailureReason::ComponentPostProcess { key, message: e },
313 })?;
314 }
315 }
316
317 Ok(ExtractionResult::new(json_value, requested_keys))
319}