Skip to main content

hs_predict/llm/
mock.rs

1//! Mock LLM classifier for unit testing.
2//!
3//! Enabled by the **`mock`** Cargo feature (which implies `llm`).
4//!
5//! [`MockLlmClassifier`] is a deterministic stub that derives an HS code
6//! directly from the SMILES pre-analysis embedded in the prompt, or falls back
7//! to a configurable default.  It never makes a network call.
8//!
9//! # Example
10//! ```rust
11//! # #[cfg(all(feature = "llm", feature = "mock"))]
12//! # {
13//! use hs_predict::llm::{MockLlmClassifier, LlmClassifier, LlmPrompt};
14//! use hs_predict::llm::PromptBuilder;
15//! use hs_predict::types::{ProductDescription, SubstanceIdentifier, PhysicalForm};
16//!
17//! # tokio_test::block_on(async {
18//! let product = ProductDescription {
19//!     identifier: SubstanceIdentifier {
20//!         cas: Some("64-19-7".to_string()),
21//!         smiles: Some("CC(O)=O".to_string()),
22//!         ..Default::default()
23//!     },
24//!     physical_form: Some(PhysicalForm::Liquid),
25//!     purity_pct: Some(99.5),
26//!     purity_type: None,
27//!     mixture_components: None,
28//!     intended_use: None,
29//!     additional_context: None,
30//! };
31//!
32//! let prompt = PromptBuilder::new().build(&product);
33//! let mock = MockLlmClassifier::new();
34//! let response = mock.classify(&prompt).await.unwrap();
35//! assert_eq!(response.hs_code.len(), 6);
36//! # });
37//! # }
38//! ```
39
40#[cfg(feature = "mock")]
41pub use inner::MockLlmClassifier;
42
43#[cfg(feature = "mock")]
44mod inner {
45    use futures::future::BoxFuture;
46    use crate::llm::{LlmAlternative, LlmClassifier, LlmPrompt, LlmResponse};
47
48    /// Deterministic mock LLM classifier for unit tests.
49    ///
50    /// Resolution order:
51    /// 1. If the prompt contains a SMILES analysis with a 4-digit heading hint,
52    ///    return that heading padded to 6 digits (`XXXX00`) with confidence from
53    ///    the SMILES engine.
54    /// 2. Otherwise return the configured `default_hs_code` (default: `"999999"`).
55    ///
56    /// The mock never makes a network call and is fully `Send + Sync`.
57    #[derive(Debug, Clone)]
58    pub struct MockLlmClassifier {
59        /// HS code returned when no SMILES heading hint is available.
60        pub default_hs_code: String,
61        /// Confidence for the default code.
62        pub default_confidence: f32,
63    }
64
65    impl Default for MockLlmClassifier {
66        fn default() -> Self {
67            Self {
68                default_hs_code: "999999".to_string(),
69                default_confidence: 0.50,
70            }
71        }
72    }
73
74    impl MockLlmClassifier {
75        /// Create a mock with the default fallback code `"999999"`.
76        pub fn new() -> Self {
77            Self::default()
78        }
79
80        /// Create a mock that always returns the specified HS code.
81        pub fn with_default(hs_code: impl Into<String>, confidence: f32) -> Self {
82            Self {
83                default_hs_code: hs_code.into(),
84                default_confidence: confidence,
85            }
86        }
87    }
88
89    impl LlmClassifier for MockLlmClassifier {
90        fn classify<'a>(
91            &'a self,
92            prompt: &'a LlmPrompt,
93        ) -> BoxFuture<'a, crate::Result<LlmResponse>> {
94            Box::pin(async move {
95                // Derive answer from SMILES analysis if available
96                if let Some(ref analysis) = prompt.smiles_analysis {
97                    let hint = &analysis.heading_hint;
98                    if let Some(heading) = hint.heading {
99                        let hs_code = format!("{:04}00", heading);
100                        return Ok(LlmResponse {
101                            hs_code,
102                            confidence: hint.confidence,
103                            rationale: format!(
104                                "Mock: derived from SMILES analysis ({}). \
105                                 Sub-heading is a placeholder.",
106                                hint.rationale
107                            ),
108                            alternatives: vec![],
109                        });
110                    }
111                    // Have analysis but no heading — use chapter
112                    let hs_code = format!("{:02}0000", hint.chapter);
113                    return Ok(LlmResponse {
114                        hs_code,
115                        confidence: hint.confidence * 0.8,
116                        rationale: format!(
117                            "Mock: chapter-level hint only (Ch.{:02}, {}).",
118                            hint.chapter, hint.rationale
119                        ),
120                        alternatives: vec![],
121                    });
122                }
123
124                // No analysis — return default
125                Ok(LlmResponse {
126                    hs_code: self.default_hs_code.clone(),
127                    confidence: self.default_confidence,
128                    rationale: "Mock classifier — no SMILES analysis available.".to_string(),
129                    alternatives: vec![LlmAlternative {
130                        hs_code: "000000".to_string(),
131                        confidence: 0.0,
132                        reason: "Placeholder alternative from mock.".to_string(),
133                    }],
134                })
135            })
136        }
137    }
138}
139
140// ─────────────────────────────────────────────────────────────────────────────
141// Tests
142// ─────────────────────────────────────────────────────────────────────────────
143
144#[cfg(all(test, feature = "mock"))]
145mod tests {
146    use super::MockLlmClassifier;
147    use crate::llm::{LlmClassifier, PromptBuilder};
148    use crate::types::{ProductDescription, SubstanceIdentifier, PhysicalForm};
149
150    fn acetic_acid_product() -> ProductDescription {
151        ProductDescription {
152            identifier: SubstanceIdentifier {
153                cas: Some("64-19-7".to_string()),
154                smiles: Some("CC(O)=O".to_string()),
155                iupac_name: None,
156                inchi: None,
157                inchi_key: None,
158                cid: None,
159            },
160            physical_form: Some(PhysicalForm::Liquid),
161            purity_pct: Some(99.5),
162            purity_type: None,
163            mixture_components: None,
164            intended_use: None,
165            additional_context: None,
166        }
167    }
168
169    #[tokio::test]
170    async fn mock_smiles_based_returns_6_digits() {
171        let product = acetic_acid_product();
172        let prompt = PromptBuilder::new().build(&product);
173        let mock = MockLlmClassifier::new();
174        let resp = mock.classify(&prompt).await.unwrap();
175        assert_eq!(resp.hs_code.len(), 6, "hs_code must be 6 digits");
176        assert!(resp.hs_code.chars().all(|c| c.is_ascii_digit()));
177    }
178
179    #[tokio::test]
180    async fn mock_smiles_based_derives_chapter_29() {
181        // Acetic acid → carboxylic acid → heading 2915 → "291500"
182        let product = acetic_acid_product();
183        let prompt = PromptBuilder::new().build(&product);
184        let mock = MockLlmClassifier::new();
185        let resp = mock.classify(&prompt).await.unwrap();
186        assert!(
187            resp.hs_code.starts_with("29"),
188            "acetic acid should be Chapter 29, got {}",
189            resp.hs_code
190        );
191    }
192
193    #[tokio::test]
194    async fn mock_no_smiles_returns_default() {
195        let product = ProductDescription {
196            identifier: SubstanceIdentifier::from_cas("64-19-7"),
197            physical_form: None,
198            purity_pct: None,
199            purity_type: None,
200            mixture_components: None,
201            intended_use: None,
202            additional_context: None,
203        };
204        let prompt = PromptBuilder::new().build(&product);
205        let mock = MockLlmClassifier::new();
206        let resp = mock.classify(&prompt).await.unwrap();
207        assert_eq!(resp.hs_code, "999999");
208    }
209
210    #[tokio::test]
211    async fn mock_custom_default_returned_when_no_smiles() {
212        let product = ProductDescription {
213            identifier: SubstanceIdentifier::from_cas("64-19-7"),
214            physical_form: None,
215            purity_pct: None,
216            purity_type: None,
217            mixture_components: None,
218            intended_use: None,
219            additional_context: None,
220        };
221        let prompt = PromptBuilder::new().build(&product);
222        let mock = MockLlmClassifier::with_default("291511", 0.85);
223        let resp = mock.classify(&prompt).await.unwrap();
224        assert_eq!(resp.hs_code, "291511");
225        assert!((resp.confidence - 0.85).abs() < 0.001);
226    }
227
228    #[tokio::test]
229    async fn mock_confidence_nonzero_with_smiles() {
230        let product = acetic_acid_product();
231        let prompt = PromptBuilder::new().build(&product);
232        let mock = MockLlmClassifier::new();
233        let resp = mock.classify(&prompt).await.unwrap();
234        assert!(resp.confidence > 0.0);
235    }
236}