hs_predict/pipeline.rs
1//! Main classification pipeline.
2//!
3//! Runs classification in priority order:
4//! 1. User-provided CAS → HS mappings (confidence = 1.0)
5//! 2. Embedded static rule table (CAS + shape + purity)
6//! 3. SMILES-based rule engine (v0.3)
7//! 4. LLM fallback via [`LlmClassifier`] trait hook (v0.4, `llm` feature)
8
9use std::collections::HashMap;
10#[cfg(feature = "llm")]
11use std::sync::Arc;
12
13use crate::error::{HsPredictError, Result};
14use crate::rules::jp_table::{find_jp_rule, JP_TARIFF_YEAR};
15use crate::rules::matcher::find_best_rule;
16use crate::types::{
17 GrayZone, HsPrediction, OrganicInorganic, PhysicalForm, ProductDescription,
18 PredictionSource, RecommendedAction,
19};
20
21/// Configuration for the classification pipeline.
22#[derive(Debug, Clone)]
23pub struct PipelineConfig {
24 /// Confidence threshold above which a result is returned directly
25 /// without asking for LLM confirmation.
26 pub confidence_threshold_direct: f32,
27
28 /// Confidence threshold below which LLM is required.
29 /// Between `confidence_threshold_llm_required` and `confidence_threshold_direct`
30 /// the result is returned with `RecommendedAction::VerifyWithLlm`.
31 pub confidence_threshold_llm_required: f32,
32}
33
34impl Default for PipelineConfig {
35 fn default() -> Self {
36 Self {
37 confidence_threshold_direct: 0.85,
38 confidence_threshold_llm_required: 0.50,
39 }
40 }
41}
42
43/// Main HS code classification pipeline.
44///
45/// # Example — direct (sync)
46/// ```rust,no_run
47/// use hs_predict::pipeline::HsPipeline;
48/// use hs_predict::types::{ProductDescription, SubstanceIdentifier, PhysicalForm};
49///
50/// let pipeline = HsPipeline::new();
51///
52/// let product = ProductDescription {
53/// identifier: SubstanceIdentifier::from_cas("1310-73-2"),
54/// physical_form: Some(PhysicalForm::Solid),
55/// purity_pct: None,
56/// purity_type: None,
57/// mixture_components: None,
58/// intended_use: None,
59/// additional_context: None,
60/// };
61///
62/// let prediction = pipeline.classify(&product).unwrap();
63/// assert_eq!(&prediction.hs_code, "281511");
64/// ```
65///
66/// # Example — with PubChem enrichment (async, `pubchem` feature)
67/// ```rust,no_run
68/// # #[cfg(feature = "pubchem")]
69/// # async fn example() -> hs_predict::Result<()> {
70/// use hs_predict::pipeline::HsPipeline;
71/// use hs_predict::pubchem::PubChemClient;
72/// use hs_predict::types::{ProductDescription, SubstanceIdentifier, PhysicalForm};
73///
74/// let pipeline = HsPipeline::new().with_pubchem(PubChemClient::new());
75///
76/// let mut product = ProductDescription {
77/// identifier: SubstanceIdentifier::from_cas("1310-73-2"),
78/// physical_form: Some(PhysicalForm::Solid),
79/// purity_pct: None,
80/// purity_type: None,
81/// mixture_components: None,
82/// intended_use: None,
83/// additional_context: None,
84/// };
85///
86/// pipeline.enrich(&mut product).await?; // fills SMILES, InChI, IUPAC name …
87/// let prediction = pipeline.classify(&product)?;
88/// println!("{}", prediction.display()); // "28.15.11"
89/// # Ok(())
90/// # }
91/// ```
92///
93/// # Example — with LLM fallback (async, `llm` feature)
94/// ```rust,no_run
95/// # #[cfg(feature = "llm")]
96/// # async fn example() -> hs_predict::Result<()> {
97/// use hs_predict::pipeline::HsPipeline;
98/// use hs_predict::llm::{LlmClassifier, LlmPrompt, LlmResponse};
99/// use futures::future::BoxFuture;
100///
101/// struct MyClient;
102/// impl LlmClassifier for MyClient {
103/// fn classify<'a>(&'a self, prompt: &'a LlmPrompt) -> BoxFuture<'a, hs_predict::Result<LlmResponse>> {
104/// Box::pin(async move { todo!() })
105/// }
106/// }
107///
108/// let pipeline = HsPipeline::new().with_llm(MyClient);
109/// use hs_predict::types::{ProductDescription, SubstanceIdentifier};
110/// let product = ProductDescription {
111/// identifier: SubstanceIdentifier::from_cas("12-34-5"),
112/// physical_form: None, purity_pct: None, purity_type: None,
113/// mixture_components: None, intended_use: None, additional_context: None,
114/// };
115/// let prediction = pipeline.classify_with_llm(&product).await?;
116/// println!("{}", prediction.display());
117/// # Ok(())
118/// # }
119/// ```
120#[derive(Default)]
121pub struct HsPipeline {
122 /// User-supplied CAS → HS code overrides. Highest priority.
123 user_mappings: HashMap<String, String>,
124
125 config: PipelineConfig,
126
127 /// PubChem client for identifier enrichment (v0.2, `pubchem` feature).
128 #[cfg(feature = "pubchem")]
129 pubchem: Option<std::sync::Arc<crate::pubchem::PubChemClient>>,
130
131 /// LLM classifier hook (v0.4, `llm` feature).
132 #[cfg(feature = "llm")]
133 llm: Option<Arc<dyn crate::llm::LlmClassifier>>,
134}
135
136impl std::fmt::Debug for HsPipeline {
137 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
138 let mut s = f.debug_struct("HsPipeline");
139 s.field("user_mappings", &self.user_mappings);
140 s.field("config", &self.config);
141 #[cfg(feature = "pubchem")]
142 s.field("pubchem", &self.pubchem.as_ref().map(|_| "<PubChemClient>"));
143 #[cfg(feature = "llm")]
144 s.field("llm", &self.llm.as_ref().map(|_| "<dyn LlmClassifier>"));
145 s.finish()
146 }
147}
148
149impl HsPipeline {
150 /// Create a pipeline with default configuration.
151 pub fn new() -> Self {
152 Self::default()
153 }
154
155 /// Add a user-provided CAS → HS code mapping.
156 ///
157 /// These mappings override the embedded rule table with `confidence = 1.0`.
158 ///
159 /// The `hs_code` must be exactly 6 ASCII digits (e.g. `"281511"`).
160 /// If the code does not satisfy this constraint the mapping is silently
161 /// ignored and the pipeline is returned unchanged.
162 pub fn with_mapping(mut self, cas: impl Into<String>, hs_code: impl Into<String>) -> Self {
163 let hs_code = hs_code.into();
164 let valid = hs_code.len() == 6 && hs_code.chars().all(|c| c.is_ascii_digit());
165 if valid {
166 self.user_mappings.insert(cas.into(), hs_code);
167 }
168 self
169 }
170
171 /// Override the default pipeline configuration.
172 pub fn with_config(mut self, config: PipelineConfig) -> Self {
173 self.config = config;
174 self
175 }
176
177 /// Attach an [`LlmClassifier`](crate::llm::LlmClassifier) implementation to
178 /// enable the LLM fallback (Priority 4).
179 ///
180 /// The LLM is called by [`classify_with_llm`](Self::classify_with_llm) when
181 /// the rule-based pipeline returns a result with
182 /// `recommended_action != Accept`, or returns
183 /// [`LowConfidenceNoLlm`](crate::HsPredictError::LowConfidenceNoLlm).
184 ///
185 /// Requires the **`llm`** Cargo feature.
186 #[cfg(feature = "llm")]
187 pub fn with_llm(mut self, client: impl crate::llm::LlmClassifier + 'static) -> Self {
188 self.llm = Some(Arc::new(client));
189 self
190 }
191
192 /// Attach a [`PubChemClient`](crate::pubchem::PubChemClient) to enable
193 /// automatic identifier enrichment before classification.
194 ///
195 /// Requires the **`pubchem`** Cargo feature.
196 #[cfg(feature = "pubchem")]
197 pub fn with_pubchem(mut self, client: crate::pubchem::PubChemClient) -> Self {
198 self.pubchem = Some(std::sync::Arc::new(client));
199 self
200 }
201
202 /// Enrich a [`ProductDescription`] with PubChem data.
203 ///
204 /// Fills in any missing fields of the main identifier and each mixture
205 /// component's identifier (SMILES, InChI, InChIKey, IUPAC name, CID).
206 ///
207 /// This is a **best-effort** operation:
208 /// - "Not found" and "no usable identifier" results are silently ignored.
209 /// - Network / parse errors **are** propagated.
210 /// - If no PubChem client is configured, returns `Ok(())` immediately.
211 ///
212 /// Requires the **`pubchem`** Cargo feature.
213 #[cfg(feature = "pubchem")]
214 pub async fn enrich(&self, product: &mut ProductDescription) -> Result<()> {
215 let Some(ref client) = self.pubchem else {
216 return Ok(());
217 };
218
219 client.enrich(&mut product.identifier).await?;
220
221 if let Some(ref mut comps) = product.mixture_components {
222 for comp in comps.iter_mut() {
223 client.enrich(&mut comp.substance).await?;
224 }
225 }
226
227 Ok(())
228 }
229
230 /// Classify a product and return an HS code prediction.
231 ///
232 /// Priority order:
233 /// 0. Mixture branch (v0.5) — GRI 3a/3b/3c via [`crate::mixture`]
234 /// 1. User-provided mapping
235 /// 2. Embedded static rule table
236 /// 3. (v0.3) SMILES rule engine
237 /// 4. (v0.4) LLM fallback
238 pub fn classify(&self, product: &ProductDescription) -> Result<HsPrediction> {
239 // ── Priority 0: Mixture branch (v0.5) ────────────────────────────
240 if product.is_mixture() {
241 return crate::mixture::classify_mixture(product, |comp| self.classify(comp));
242 }
243
244 // ── Priority 1: User-provided mappings ────────────────────────
245 if let Some(ref cas) = product.identifier.cas {
246 if let Some(hs_code) = self.user_mappings.get(cas.as_str()) {
247 let jp = find_jp_rule(hs_code);
248 return Ok(HsPrediction {
249 hs_code: hs_code.clone(),
250 heading_description: String::new(),
251 confidence: 1.0,
252 source: PredictionSource::UserMapping,
253 notes: vec!["From user-provided mapping".to_string()],
254 alternatives: vec![],
255 recommended_action: RecommendedAction::Accept,
256 gray_zone: None,
257 jp_tariff_code: jp.map(|r| r.jp_code.to_string()),
258 jp_tariff_year: jp.map(|_| JP_TARIFF_YEAR),
259 });
260 }
261 }
262
263 // ── Priority 2: Embedded static rule table ────────────────────
264 if let Some(ref cas) = product.identifier.cas {
265 if let Some(rule) = find_best_rule(
266 cas,
267 product.physical_form.as_ref(),
268 product.purity_pct,
269 ) {
270 let gray_zone = self.detect_gray_zone(product, rule.hs_code, None);
271 let action = self.recommended_action_with_gz(rule.confidence, gray_zone.as_ref());
272 let jp = find_jp_rule(rule.hs_code);
273 return Ok(HsPrediction {
274 hs_code: rule.hs_code.to_string(),
275 heading_description: rule.heading_description.to_string(),
276 confidence: rule.confidence,
277 source: PredictionSource::EmbeddedRule {
278 rule_id: format!("{}:{}", rule.cas, rule.hs_code),
279 },
280 notes: self.build_notes(product),
281 alternatives: vec![],
282 recommended_action: action,
283 gray_zone,
284 jp_tariff_code: jp.map(|r| r.jp_code.to_string()),
285 jp_tariff_year: jp.map(|_| JP_TARIFF_YEAR),
286 });
287 }
288 }
289
290 // ── Priority 3: SMILES-based rule engine ─────────────────────────
291 if let Some(ref smiles) = product.identifier.smiles {
292 if let Some(classification) = crate::smiles::classify_smiles(smiles) {
293 let hint = &classification.heading_hint;
294
295 // Prefer the 6-digit subheading when the structural engine
296 // resolved it; otherwise pad the 4-digit heading with "00".
297 let maybe_code: Option<(String, bool)> = hint
298 .subheading
299 .as_ref()
300 .map(|sub| (sub.clone(), true))
301 .or_else(|| {
302 hint.heading
303 .map(|heading| (format!("{:04}00", heading), false))
304 });
305
306 if let Some((hs_code, is_6digit)) = maybe_code {
307 if hint.confidence >= self.config.confidence_threshold_llm_required {
308 let jp = find_jp_rule(&hs_code);
309
310 let gray_zone = self.detect_gray_zone(
311 product,
312 &hs_code,
313 Some(&classification.organic_class),
314 );
315 let action =
316 self.recommended_action_with_gz(hint.confidence, gray_zone.as_ref());
317
318 let mut notes = self.build_notes(product);
319 if is_6digit {
320 notes.push(
321 "6-digit subheading resolved from SMILES structural analysis \
322 (carbon count, ring type, functional group). \
323 Verify with product specification before declaration."
324 .to_string(),
325 );
326 } else {
327 notes.push(
328 "Heading derived from SMILES functional-group analysis. \
329 Sub-heading (last two digits) is a placeholder — \
330 verify the exact 6-digit code with the product specification."
331 .to_string(),
332 );
333 }
334
335 let matched_rules: Vec<String> = classification
336 .functional_groups
337 .iter()
338 .map(|g| g.label().to_string())
339 .collect();
340
341 return Ok(HsPrediction {
342 hs_code,
343 heading_description: hint.rationale.to_string(),
344 confidence: hint.confidence,
345 source: PredictionSource::RuleEngine { matched_rules },
346 notes,
347 alternatives: vec![],
348 recommended_action: action,
349 gray_zone,
350 jp_tariff_code: jp.map(|r| r.jp_code.to_string()),
351 jp_tariff_year: jp.map(|_| JP_TARIFF_YEAR),
352 });
353 }
354 }
355 }
356 }
357
358 // ── Priority 4: LLM fallback ─────────────────────────────────
359 // (async path — use classify_with_llm for LLM support)
360 Err(HsPredictError::LowConfidenceNoLlm {
361 confidence: 0.0,
362 threshold: self.config.confidence_threshold_llm_required,
363 })
364 }
365
366 /// Classify a batch of products concurrently.
367 ///
368 /// Returns one `Result<HsPrediction>` per input, in the same order.
369 /// Uses synchronous [`classify`](Self::classify) internally — for LLM-backed
370 /// batch classification see `classify_batch_with_llm` (future work).
371 pub fn classify_batch(&self, products: &[ProductDescription]) -> Vec<Result<HsPrediction>> {
372 products.iter().map(|p| self.classify(p)).collect()
373 }
374
375 /// Classify a batch of products using the async LLM path.
376 ///
377 /// Each product is classified via [`classify_with_llm`](Self::classify_with_llm).
378 /// All requests are issued concurrently.
379 ///
380 /// Requires the **`llm`** Cargo feature.
381 #[cfg(feature = "llm")]
382 pub async fn classify_batch_with_llm(
383 &self,
384 products: &[ProductDescription],
385 ) -> Vec<Result<HsPrediction>> {
386 use futures::future::join_all;
387 let futures: Vec<_> = products.iter().map(|p| self.classify_with_llm(p)).collect();
388 join_all(futures).await
389 }
390
391 /// Classify a product, falling back to the configured LLM when the
392 /// rule-based pipeline returns a low-confidence or uncertain result.
393 ///
394 /// # Priority order (same as [`classify`](Self::classify) + LLM)
395 ///
396 /// 1. User-provided mapping → `Accept` → return immediately.
397 /// 2. Embedded static rule table → `Accept` → return immediately.
398 /// 3. SMILES rule engine → `Accept` → return immediately.
399 /// 4. Any result with `recommended_action != Accept`, or
400 /// `LowConfidenceNoLlm` → forward to LLM.
401 ///
402 /// If no LLM client has been configured via [`with_llm`](Self::with_llm),
403 /// returns [`HsPredictError::LlmNotConfigured`].
404 ///
405 /// # Validation
406 /// The LLM's `hs_code` must be exactly 6 ASCII digits; otherwise
407 /// [`HsPredictError::ValidationFailed`] is returned.
408 ///
409 /// # Chapter consistency
410 /// If the LLM chapter differs from the SMILES engine's chapter hint, a
411 /// warning note is appended — this is **not** a hard error.
412 ///
413 /// Requires the **`llm`** Cargo feature.
414 #[cfg(feature = "llm")]
415 pub async fn classify_with_llm(
416 &self,
417 product: &ProductDescription,
418 ) -> Result<HsPrediction> {
419 use crate::llm::PromptBuilder;
420 use crate::types::AlternativePrediction;
421
422 // First try the synchronous rule-based pipeline.
423 let needs_llm = match self.classify(product) {
424 Ok(pred) if pred.recommended_action == RecommendedAction::Accept => {
425 return Ok(pred);
426 }
427 Ok(_pred) => true, // low-confidence result → try LLM
428 Err(HsPredictError::LowConfidenceNoLlm { .. }) => true,
429 Err(e) => return Err(e),
430 };
431
432 debug_assert!(needs_llm);
433
434 // Require a configured LLM client.
435 let llm = self
436 .llm
437 .as_ref()
438 .ok_or(HsPredictError::LlmNotConfigured)?;
439
440 // Build prompt and call the LLM.
441 let prompt = PromptBuilder::new().build(product);
442 let resp = llm.classify(&prompt).await?;
443
444 // Validate: must be exactly 6 ASCII digits.
445 if resp.hs_code.len() != 6 || !resp.hs_code.chars().all(|c| c.is_ascii_digit()) {
446 return Err(HsPredictError::ValidationFailed { code: resp.hs_code });
447 }
448
449 // Chapter consistency check (warning only).
450 let mut notes = self.build_notes(product);
451 if let Some(ref analysis) = prompt.smiles_analysis {
452 let llm_chapter = &resp.hs_code[..2];
453 let expected_chapter = format!("{:02}", analysis.heading_hint.chapter);
454 if llm_chapter != expected_chapter {
455 notes.push(format!(
456 "Chapter mismatch: LLM returned Chapter {} but SMILES engine \
457 suggested Chapter {}. Verify with Chapter Notes.",
458 llm_chapter, expected_chapter
459 ));
460 }
461 }
462
463 notes.push(format!("LLM rationale: {}", resp.rationale));
464
465 let jp = find_jp_rule(&resp.hs_code);
466 let action = self.recommended_action(resp.confidence);
467
468 // Only include alternatives whose hs_code passes the same 6-digit
469 // format check applied to the primary result.
470 let alternatives = resp
471 .alternatives
472 .into_iter()
473 .filter(|a| a.hs_code.len() == 6 && a.hs_code.chars().all(|c| c.is_ascii_digit()))
474 .map(|a| AlternativePrediction {
475 hs_code: a.hs_code,
476 confidence: a.confidence,
477 reason: a.reason,
478 })
479 .collect();
480
481 Ok(HsPrediction {
482 hs_code: resp.hs_code,
483 heading_description: String::new(),
484 confidence: resp.confidence,
485 source: PredictionSource::LlmApi { model: String::new() },
486 notes,
487 alternatives,
488 recommended_action: action,
489 gray_zone: None, // LLM response does not carry gray-zone information
490 jp_tariff_code: jp.map(|r| r.jp_code.to_string()),
491 jp_tariff_year: jp.map(|_| JP_TARIFF_YEAR),
492 })
493 }
494
495 // ─── Private helpers ──────────────────────────────────────────────
496
497 fn recommended_action(&self, confidence: f32) -> RecommendedAction {
498 if confidence >= self.config.confidence_threshold_direct {
499 RecommendedAction::Accept
500 } else if confidence >= self.config.confidence_threshold_llm_required {
501 RecommendedAction::VerifyWithLlm
502 } else {
503 RecommendedAction::ExpertReview
504 }
505 }
506
507 /// Like `recommended_action` but upgrades to `PriorConsultation` when a
508 /// gray zone is present and the confidence does not reach the "direct" threshold.
509 fn recommended_action_with_gz(
510 &self,
511 confidence: f32,
512 gray_zone: Option<&GrayZone>,
513 ) -> RecommendedAction {
514 let base = self.recommended_action(confidence);
515 if gray_zone.is_some() && base != RecommendedAction::Accept {
516 // Gray zone identified → recommend an advance ruling (事前教示)
517 RecommendedAction::PriorConsultation
518 } else {
519 base
520 }
521 }
522
523 /// Detect whether a prediction falls in a well-known gray zone.
524 ///
525 /// When `organic_class` is `Some`, the supplied classification is used
526 /// (e.g. when the SMILES engine has already analysed the structure);
527 /// otherwise the classification is re-derived from
528 /// `product.identifier.smiles` when available.
529 fn detect_gray_zone(
530 &self,
531 product: &ProductDescription,
532 hs_code: &str,
533 organic_class: Option<&OrganicInorganic>,
534 ) -> Option<GrayZone> {
535 let chapter = &hs_code[..2];
536
537 // Chapter 28 / 29 boundary: organometallic or borderline compound
538 if chapter == "28" && self.is_organometallic(product, organic_class) {
539 return Some(GrayZone::Chapter28vs29);
540 }
541
542 // Chapter 29 result but product is used industrially → Ch.29 vs Ch.38
543 if chapter == "29" {
544 use crate::types::IntendedUse;
545 if let Some(IntendedUse::Industrial) = &product.intended_use {
546 return Some(GrayZone::Chapter29vs38);
547 }
548 }
549
550 None
551 }
552
553 /// Whether the product is an organometallic compound — either via the
554 /// pre-computed `organic_class` (preferred) or by re-deriving from SMILES.
555 fn is_organometallic(
556 &self,
557 product: &ProductDescription,
558 organic_class: Option<&OrganicInorganic>,
559 ) -> bool {
560 match organic_class {
561 Some(oc) => matches!(oc, OrganicInorganic::Organometallic),
562 None => product.identifier.smiles.as_deref().is_some_and(|s| {
563 matches!(
564 crate::smiles::detector::classify_organic(s),
565 OrganicInorganic::Organometallic,
566 )
567 }),
568 }
569 }
570
571 /// Build supplementary notes about shape / purity caveats.
572 fn build_notes(&self, product: &ProductDescription) -> Vec<String> {
573 let mut notes = Vec::new();
574
575 match &product.physical_form {
576 None | Some(PhysicalForm::Unknown) => {
577 notes.push(
578 "Physical form not specified — the HS subheading may differ \
579 (e.g. solid vs. solution).".to_string(),
580 );
581 }
582 Some(PhysicalForm::Solution { concentration_pct_ww: None, .. }) => {
583 notes.push(
584 "Solution concentration not specified — subheading may differ \
585 (e.g. fuming vs. standard grade).".to_string(),
586 );
587 }
588 _ => {}
589 }
590
591 if product.purity_pct.is_none() {
592 notes.push(
593 "Purity not specified — some headings require a minimum purity threshold."
594 .to_string(),
595 );
596 }
597
598 notes
599 }
600}
601
602// ─────────────────────────────────────────────────────────────────────────────
603// Pipeline integration tests
604// ─────────────────────────────────────────────────────────────────────────────
605
606#[cfg(all(test, feature = "mock"))]
607mod tests {
608 use super::*;
609 use crate::llm::MockLlmClassifier;
610 use crate::types::{SubstanceIdentifier};
611
612 /// A product with no static rule and a SMILES → triggers LLM path.
613 fn unknown_organic() -> ProductDescription {
614 // Ethyl propanoate (ester): SMILES engine gives heading 2915 at conf 0.55
615 // (VerifyWithLlm) because esters don't have a 6-digit structural decision
616 // tree yet. No CAS → Priority 2 miss. Suitable for testing LLM paths.
617 ProductDescription {
618 identifier: SubstanceIdentifier {
619 cas: None,
620 smiles: Some("CCC(=O)OCC".to_string()),
621 iupac_name: None,
622 inchi: None,
623 inchi_key: None,
624 cid: None,
625 },
626 physical_form: None,
627 purity_pct: None,
628 purity_type: None,
629 mixture_components: None,
630 intended_use: None,
631 additional_context: None,
632 }
633 }
634
635 #[tokio::test]
636 async fn classify_with_llm_mock_returns_6_digit_code() {
637 let pipeline = HsPipeline::new().with_llm(MockLlmClassifier::new());
638 let product = unknown_organic();
639 let pred = pipeline.classify_with_llm(&product).await.unwrap();
640 assert_eq!(pred.hs_code.len(), 6);
641 assert!(pred.hs_code.chars().all(|c| c.is_ascii_digit()));
642 }
643
644 #[tokio::test]
645 async fn classify_with_llm_mock_chapter_29_for_smiles_acid() {
646 let pipeline = HsPipeline::new().with_llm(MockLlmClassifier::new());
647 let product = unknown_organic();
648 let pred = pipeline.classify_with_llm(&product).await.unwrap();
649 assert!(
650 pred.hs_code.starts_with("29"),
651 "acetic acid SMILES should yield Chapter 29, got {}",
652 pred.hs_code
653 );
654 }
655
656 #[tokio::test]
657 async fn classify_with_llm_no_client_returns_error() {
658 let pipeline = HsPipeline::new(); // no LLM attached
659 let product = unknown_organic();
660 let err = pipeline.classify_with_llm(&product).await.unwrap_err();
661 assert!(
662 matches!(err, HsPredictError::LlmNotConfigured),
663 "expected LlmNotConfigured, got {:?}",
664 err
665 );
666 }
667
668 #[tokio::test]
669 async fn classify_with_llm_skips_llm_for_high_confidence_rule() {
670 // NaOH solid → static rule, confidence = 1.0 → should NOT call LLM
671 let pipeline = HsPipeline::new()
672 .with_llm(MockLlmClassifier::with_default("999999", 0.1));
673 let product = ProductDescription {
674 identifier: SubstanceIdentifier::from_cas("1310-73-2"),
675 physical_form: Some(crate::types::PhysicalForm::Solid),
676 purity_pct: None,
677 purity_type: None,
678 mixture_components: None,
679 intended_use: None,
680 additional_context: None,
681 };
682 let pred = pipeline.classify_with_llm(&product).await.unwrap();
683 // Should be the static rule result, not the mock's "999999"
684 assert_eq!(pred.hs_code, "281511", "static rule should win over LLM");
685 }
686
687 #[tokio::test]
688 async fn classify_with_llm_invalid_code_returns_validation_error() {
689 // Mock returning an invalid code
690 struct BadMock;
691 impl crate::llm::LlmClassifier for BadMock {
692 fn classify<'a>(
693 &'a self,
694 _prompt: &'a crate::llm::LlmPrompt,
695 ) -> futures::future::BoxFuture<'a, crate::Result<crate::llm::LlmResponse>> {
696 Box::pin(async {
697 Ok(crate::llm::LlmResponse {
698 hs_code: "BAD!!".to_string(),
699 confidence: 0.5,
700 rationale: "bad".to_string(),
701 alternatives: vec![],
702 })
703 })
704 }
705 }
706 let pipeline = HsPipeline::new().with_llm(BadMock);
707 let product = unknown_organic();
708 let err = pipeline.classify_with_llm(&product).await.unwrap_err();
709 assert!(
710 matches!(err, HsPredictError::ValidationFailed { .. }),
711 "expected ValidationFailed, got {:?}",
712 err
713 );
714 }
715}