hs-predict 0.3.0

HS code prediction for chemical products — Akinator-style interactive classification with rule-based and LLM hybrid engine
Documentation
//! Main classification pipeline.
//!
//! Runs classification in priority order:
//! 1. User-provided CAS → HS mappings (confidence = 1.0)
//! 2. Embedded static rule table (CAS + shape + purity)
//! 3. *(placeholder)* SMILES-based rule engine — v0.3
//! 4. *(placeholder)* LLM API — v0.4

use std::collections::HashMap;

use crate::error::{HsPredictError, Result};
use crate::rules::jp_table::{find_jp_rule, JP_TARIFF_YEAR};
use crate::rules::matcher::find_best_rule;
use crate::types::{HsPrediction, PhysicalForm, ProductDescription, PredictionSource, RecommendedAction};

/// Configuration for the classification pipeline.
#[derive(Debug, Clone)]
pub struct PipelineConfig {
    /// Confidence threshold above which a result is returned directly
    /// without asking for LLM confirmation.
    pub confidence_threshold_direct: f32,

    /// Confidence threshold below which LLM is required.
    /// Between `confidence_threshold_llm_required` and `confidence_threshold_direct`
    /// the result is returned with `RecommendedAction::VerifyWithLlm`.
    pub confidence_threshold_llm_required: f32,
}

impl Default for PipelineConfig {
    fn default() -> Self {
        Self {
            confidence_threshold_direct: 0.85,
            confidence_threshold_llm_required: 0.50,
        }
    }
}

/// Main HS code classification pipeline.
///
/// # Example — direct (sync)
/// ```rust,no_run
/// use hs_predict::pipeline::HsPipeline;
/// use hs_predict::types::{ProductDescription, SubstanceIdentifier, PhysicalForm};
///
/// let pipeline = HsPipeline::new();
///
/// let product = ProductDescription {
///     identifier: SubstanceIdentifier::from_cas("1310-73-2"),
///     physical_form: Some(PhysicalForm::Solid),
///     purity_pct: None,
///     purity_type: None,
///     mixture_components: None,
///     intended_use: None,
///     additional_context: None,
/// };
///
/// let prediction = pipeline.classify(&product).unwrap();
/// assert_eq!(&prediction.hs_code, "281511");
/// ```
///
/// # Example — with PubChem enrichment (async, `pubchem` feature)
/// ```rust,no_run
/// # #[cfg(feature = "pubchem")]
/// # async fn example() -> hs_predict::Result<()> {
/// use hs_predict::pipeline::HsPipeline;
/// use hs_predict::pubchem::PubChemClient;
/// use hs_predict::types::{ProductDescription, SubstanceIdentifier, PhysicalForm};
///
/// let pipeline = HsPipeline::new().with_pubchem(PubChemClient::new());
///
/// let mut product = ProductDescription {
///     identifier: SubstanceIdentifier::from_cas("1310-73-2"),
///     physical_form: Some(PhysicalForm::Solid),
///     purity_pct: None,
///     purity_type: None,
///     mixture_components: None,
///     intended_use: None,
///     additional_context: None,
/// };
///
/// pipeline.enrich(&mut product).await?;   // fills SMILES, InChI, IUPAC name …
/// let prediction = pipeline.classify(&product)?;
/// println!("{}", prediction.display());   // "28.15.11"
/// # Ok(())
/// # }
/// ```
#[derive(Debug, Default)]
pub struct HsPipeline {
    /// User-supplied CAS → HS code overrides. Highest priority.
    user_mappings: HashMap<String, String>,

    config: PipelineConfig,

    /// PubChem client for identifier enrichment (v0.2, `pubchem` feature).
    #[cfg(feature = "pubchem")]
    pubchem: Option<std::sync::Arc<crate::pubchem::PubChemClient>>,
}

impl HsPipeline {
    /// Create a pipeline with default configuration.
    pub fn new() -> Self {
        Self::default()
    }

    /// Add a user-provided CAS → HS code mapping.
    ///
    /// These mappings override the embedded rule table with `confidence = 1.0`.
    pub fn with_mapping(mut self, cas: impl Into<String>, hs_code: impl Into<String>) -> Self {
        self.user_mappings.insert(cas.into(), hs_code.into());
        self
    }

    /// Override the default pipeline configuration.
    pub fn with_config(mut self, config: PipelineConfig) -> Self {
        self.config = config;
        self
    }

    /// Attach a [`PubChemClient`](crate::pubchem::PubChemClient) to enable
    /// automatic identifier enrichment before classification.
    ///
    /// Requires the **`pubchem`** Cargo feature.
    #[cfg(feature = "pubchem")]
    pub fn with_pubchem(mut self, client: crate::pubchem::PubChemClient) -> Self {
        self.pubchem = Some(std::sync::Arc::new(client));
        self
    }

    /// Enrich a [`ProductDescription`] with PubChem data.
    ///
    /// Fills in any missing fields of the main identifier and each mixture
    /// component's identifier (SMILES, InChI, InChIKey, IUPAC name, CID).
    ///
    /// This is a **best-effort** operation:
    /// - "Not found" and "no usable identifier" results are silently ignored.
    /// - Network / parse errors **are** propagated.
    /// - If no PubChem client is configured, returns `Ok(())` immediately.
    ///
    /// Requires the **`pubchem`** Cargo feature.
    #[cfg(feature = "pubchem")]
    pub async fn enrich(&self, product: &mut ProductDescription) -> Result<()> {
        let Some(ref client) = self.pubchem else {
            return Ok(());
        };

        client.enrich(&mut product.identifier).await?;

        if let Some(ref mut comps) = product.mixture_components {
            for comp in comps.iter_mut() {
                client.enrich(&mut comp.substance).await?;
            }
        }

        Ok(())
    }

    /// Classify a product and return an HS code prediction.
    ///
    /// Priority order:
    /// 1. User-provided mapping
    /// 2. Embedded static rule table
    /// 3. (v0.3) SMILES rule engine
    /// 4. (v0.4) LLM fallback
    pub fn classify(&self, product: &ProductDescription) -> Result<HsPrediction> {
        // ── Priority 1: User-provided mappings ────────────────────────
        if let Some(ref cas) = product.identifier.cas {
            if let Some(hs_code) = self.user_mappings.get(cas.as_str()) {
                let jp = find_jp_rule(hs_code);
                return Ok(HsPrediction {
                    hs_code: hs_code.clone(),
                    heading_description: String::new(),
                    confidence: 1.0,
                    source: PredictionSource::UserMapping,
                    notes: vec!["From user-provided mapping".to_string()],
                    alternatives: vec![],
                    recommended_action: RecommendedAction::Accept,
                    jp_tariff_code: jp.map(|r| r.jp_code.to_string()),
                    jp_tariff_year: jp.map(|_| JP_TARIFF_YEAR),
                });
            }
        }

        // ── Priority 2: Embedded static rule table ────────────────────
        if let Some(ref cas) = product.identifier.cas {
            if let Some(rule) = find_best_rule(
                cas,
                product.physical_form.as_ref(),
                product.purity_pct,
            ) {
                let action = self.recommended_action(rule.confidence);
                let jp = find_jp_rule(rule.hs_code);
                return Ok(HsPrediction {
                    hs_code: rule.hs_code.to_string(),
                    heading_description: rule.heading_description.to_string(),
                    confidence: rule.confidence,
                    source: PredictionSource::EmbeddedRule {
                        rule_id: format!("{}:{}", rule.cas, rule.hs_code),
                    },
                    notes: self.build_notes(product),
                    alternatives: vec![],
                    recommended_action: action,
                    jp_tariff_code: jp.map(|r| r.jp_code.to_string()),
                    jp_tariff_year: jp.map(|_| JP_TARIFF_YEAR),
                });
            }
        }

        // ── Priority 3: SMILES-based rule engine ─────────────────────────
        if let Some(ref smiles) = product.identifier.smiles {
            if let Some(classification) = crate::smiles::classify_smiles(smiles) {
                let hint = &classification.heading_hint;
                // Only emit a result when we have at least a 4-digit heading
                // and confidence meets the LLM-required threshold.
                if let Some(heading) = hint.heading {
                    if hint.confidence >= self.config.confidence_threshold_llm_required {
                        // Pad to 6 digits with "00" sub-heading (best guess)
                        let hs_code = format!("{:04}00", heading);
                        let jp = find_jp_rule(&hs_code);
                        let action = self.recommended_action(hint.confidence);

                        let mut notes = self.build_notes(product);
                        notes.push(
                            "Heading is derived from SMILES functional-group analysis. \
                             Sub-heading (last two digits) is a placeholder — \
                             verify the exact 6-digit code with the product specification."
                                .to_string(),
                        );

                        let matched_rules: Vec<String> = classification
                            .functional_groups
                            .iter()
                            .map(|g| g.label().to_string())
                            .collect();

                        return Ok(HsPrediction {
                            hs_code,
                            heading_description: hint.rationale.to_string(),
                            confidence: hint.confidence,
                            source: PredictionSource::RuleEngine { matched_rules },
                            notes,
                            alternatives: vec![],
                            recommended_action: action,
                            jp_tariff_code: jp.map(|r| r.jp_code.to_string()),
                            jp_tariff_year: jp.map(|_| JP_TARIFF_YEAR),
                        });
                    }
                }
            }
        }

        // ── Priority 4: LLM fallback (v0.4 placeholder) ───────────────
        // TODO: implement LLM client in v0.4.

        // No rule matched — return low-confidence placeholder
        Err(HsPredictError::LowConfidenceNoLlm {
            confidence: 0.0,
            threshold: self.config.confidence_threshold_llm_required,
        })
    }

    // ─── Private helpers ──────────────────────────────────────────────

    fn recommended_action(&self, confidence: f32) -> RecommendedAction {
        if confidence >= self.config.confidence_threshold_direct {
            RecommendedAction::Accept
        } else if confidence >= self.config.confidence_threshold_llm_required {
            RecommendedAction::VerifyWithLlm
        } else {
            RecommendedAction::ExpertReview
        }
    }

    /// Build supplementary notes about shape / purity caveats.
    fn build_notes(&self, product: &ProductDescription) -> Vec<String> {
        let mut notes = Vec::new();

        match &product.physical_form {
            None | Some(PhysicalForm::Unknown) => {
                notes.push(
                    "Physical form not specified — the HS subheading may differ \
                     (e.g. solid vs. solution).".to_string(),
                );
            }
            Some(PhysicalForm::Solution { concentration_pct_ww: None, .. }) => {
                notes.push(
                    "Solution concentration not specified — subheading may differ \
                     (e.g. fuming vs. standard grade).".to_string(),
                );
            }
            _ => {}
        }

        if product.purity_pct.is_none() {
            notes.push(
                "Purity not specified — some headings require a minimum purity threshold."
                    .to_string(),
            );
        }

        notes
    }
}