1pub mod flow;
31pub mod messages;
32pub mod question;
33pub mod state;
34
35use serde::{Deserialize, Serialize};
36
37pub use question::{Answer, QAPair, Question, QuestionStep, SessionResult};
38pub use state::{ClassificationState, PartialComponent};
39
40use crate::error::{HsPredictError, Result};
41use crate::session::flow::{
42 choice_index_to_intended_use, choice_index_to_organic_inorganic,
43 choice_index_to_physical_form, multi_choice_indices_to_functional_groups, next_question,
44};
45use crate::types::{Language, MixtureComponent, PhysicalForm, ProductDescription, SubstanceIdentifier};
46
47#[derive(Debug, Clone, Serialize, Deserialize)]
52pub struct ClassificationSession {
53 state: ClassificationState,
55 history: Vec<QAPair>,
57 current_question: Option<Question>,
59 current_step: Option<QuestionStep>,
61 language: Language,
63}
64
65impl ClassificationSession {
66 pub fn new() -> Self {
68 Self {
69 state: ClassificationState::default(),
70 history: Vec::new(),
71 current_question: None,
72 current_step: None,
73 language: Language::En,
74 }
75 }
76
77 pub fn new_ja() -> Self {
79 Self::new().with_language(Language::Ja)
80 }
81
82 pub fn with_language(mut self, language: Language) -> Self {
86 self.language = language;
87 self
88 }
89
90 pub fn start(&mut self) -> Question {
100 let (q, step) = next_question(&self.state, self.language)
101 .unwrap_or_else(|| {
102 let prompt = "Enter a product identifier (CAS number, SMILES, or IUPAC name):"
104 .to_string();
105 (
106 Question::Text { prompt, example: Some("1310-73-2".to_string()) },
107 QuestionStep::Identifier,
108 )
109 });
110 self.current_question = Some(q.clone());
111 self.current_step = Some(step);
112 q
113 }
114
115 pub fn answer(&mut self, answer: Answer) -> Result<SessionResult> {
128 let question = self
129 .current_question
130 .clone()
131 .ok_or(HsPredictError::NoActiveQuestion)?;
132
133 self.validate_and_apply(&question, &answer)?;
135
136 self.history.push(QAPair {
138 question: question.clone(),
139 answer,
140 });
141
142 self.try_resolve_smiles();
144
145 match next_question(&self.state, self.language) {
147 Some((q, step)) => {
148 self.current_question = Some(q.clone());
149 self.current_step = Some(step);
150 Ok(SessionResult::NeedMoreInfo { next_question: q })
151 }
152 None => {
153 self.state.is_complete = true;
154 self.current_question = None;
155 self.current_step = None;
156 if self.state.confidence_estimate() < 0.25 {
157 Ok(SessionResult::RequiresLlm)
158 } else {
159 Ok(SessionResult::Ready)
160 }
161 }
162 }
163 }
164
165 pub fn to_product_description(&self) -> ProductDescription {
169 let mixture_components = if self.state.is_mixture == Some(true) {
170 Some(
171 self.state
172 .components
173 .iter()
174 .map(|c| MixtureComponent {
175 substance: c.identifier.clone(),
176 weight_fraction_pct: c.weight_fraction_pct,
177 volume_fraction_pct: None,
178 is_solvent: c.is_solvent,
179 })
180 .collect(),
181 )
182 } else {
183 None
184 };
185
186 ProductDescription {
187 identifier: self.state.identifier.clone(),
188 physical_form: self.state.physical_form.clone(),
189 purity_pct: self.state.purity_pct,
190 purity_type: None,
191 mixture_components,
192 intended_use: self.state.intended_use.clone(),
193 additional_context: None,
194 }
195 }
196
197 pub fn state(&self) -> &ClassificationState {
199 &self.state
200 }
201
202 pub fn history(&self) -> &[QAPair] {
204 &self.history
205 }
206
207 pub fn question_count(&self) -> usize {
209 self.history.len()
210 }
211
212 pub fn is_complete(&self) -> bool {
214 self.state.is_complete
215 }
216
217 pub fn language(&self) -> Language {
219 self.language
220 }
221
222 pub fn current_step(&self) -> Option<QuestionStep> {
224 self.current_step
225 }
226
227 fn validate_and_apply(&mut self, question: &Question, answer: &Answer) -> Result<()> {
230 match (question, answer) {
231 (Question::Text { .. }, Answer::Text(text)) => {
233 self.apply_identifier_input(text);
234 }
235 (Question::Text { .. }, Answer::Skip) => {
236 if !self.state.has_identifier()
237 && self.current_step != Some(QuestionStep::ComponentIdentifier)
238 {
239 return Err(HsPredictError::MissingIdentifier);
240 }
241 }
242
243 (Question::YesNo { .. }, Answer::YesNo(val)) => {
245 self.apply_yes_no(*val);
246 }
247
248 (Question::Number { min, max, .. }, Answer::Number(val)) => {
250 if *val < *min || *val > *max {
251 return Err(HsPredictError::NumberOutOfRange {
252 value: *val,
253 min: *min,
254 max: *max,
255 });
256 }
257 self.apply_number(*val);
258 }
259
260 (Question::Choice { options, .. }, Answer::Choice(idx)) => {
262 if *idx >= options.len() {
263 return Err(HsPredictError::InvalidChoiceIndex {
264 index: *idx,
265 max: options.len() - 1,
266 });
267 }
268 self.apply_choice(*idx);
269 }
270
271 (Question::MultiChoice { options, .. }, Answer::MultiChoice(indices)) => {
273 for &idx in indices {
274 if idx >= options.len() {
275 return Err(HsPredictError::InvalidChoiceIndex {
276 index: idx,
277 max: options.len() - 1,
278 });
279 }
280 }
281 self.apply_multi_choice(indices);
282 }
283
284 _ => {
286 return Err(HsPredictError::AnswerTypeMismatch {
287 expected: question_kind_name(question),
288 got: answer.kind_name(),
289 });
290 }
291 }
292 Ok(())
293 }
294
295 fn apply_identifier_input(&mut self, input: &str) {
297 let input = input.trim();
298
299 let in_mixture = self.state.is_mixture == Some(true)
300 && self.state.current_component_index < self.state.component_count.unwrap_or(0);
301
302 if in_mixture {
303 let idx = self.state.current_component_index;
304 while self.state.components.len() <= idx {
305 self.state.components.push(PartialComponent::default());
306 }
307 self.state.components[idx].identifier = parse_identifier(input);
308 } else {
309 self.state.identifier = parse_identifier(input);
310 }
311 }
312
313 fn apply_yes_no(&mut self, val: bool) {
314 if let Some(QuestionStep::IsMixture) = self.current_step {
315 self.state.is_mixture = Some(val);
316 }
317 }
318
319 fn apply_number(&mut self, val: f64) {
320 match self.current_step {
321 Some(QuestionStep::ComponentCount) => {
322 self.state.component_count = Some(val as usize);
323 }
324 Some(QuestionStep::ComponentFraction) => {
325 let idx = self.state.current_component_index;
326 if idx < self.state.components.len() {
327 self.state.components[idx].weight_fraction_pct =
328 if val > 0.0 { Some(val) } else { None };
329 self.state.current_component_index += 1;
330 }
331 }
332 Some(QuestionStep::SolutionConcentration) => {
333 if let Some(PhysicalForm::Solution { concentration_pct_ww, .. }) =
334 &mut self.state.physical_form
335 {
336 *concentration_pct_ww = if val > 0.0 { Some(val) } else { None };
337 }
338 }
339 _ => {}
340 }
341 }
342
343 fn apply_choice(&mut self, idx: usize) {
344 match self.current_step {
345 Some(QuestionStep::PhysicalForm) => {
346 self.state.physical_form = Some(choice_index_to_physical_form(idx));
347 }
348 Some(QuestionStep::IntendedUse) => {
349 self.state.intended_use = Some(choice_index_to_intended_use(idx));
350 }
351 Some(QuestionStep::OrganicInorganic) => {
352 self.state.organic_inorganic = Some(choice_index_to_organic_inorganic(idx));
353 }
354 _ => {}
355 }
356 }
357
358 fn apply_multi_choice(&mut self, indices: &[usize]) {
359 self.state.detected_functional_groups = multi_choice_indices_to_functional_groups(indices);
360 }
361
362 fn try_resolve_smiles(&mut self) {
364 if self.state.identifier.smiles.is_none() {
365 if let Some(ref iupac) = self.state.identifier.iupac_name.clone() {
366 if let Some(smiles) = resolve_iupac_to_smiles(iupac) {
367 self.state.identifier.smiles = Some(smiles);
368 }
369 }
370 }
371 for comp in &mut self.state.components {
372 if comp.identifier.smiles.is_none() {
373 if let Some(ref iupac) = comp.identifier.iupac_name.clone() {
374 if let Some(smiles) = resolve_iupac_to_smiles(iupac) {
375 comp.identifier.smiles = Some(smiles);
376 }
377 }
378 }
379 }
380 }
381}
382
383impl Default for ClassificationSession {
384 fn default() -> Self {
385 Self::new()
386 }
387}
388
389fn parse_identifier(input: &str) -> SubstanceIdentifier {
392 let s = input.trim();
393
394 if is_cas_format(s) {
395 return SubstanceIdentifier::from_cas(s);
396 }
397 if is_inchi_key_format(s) {
398 return SubstanceIdentifier {
399 inchi_key: Some(s.to_string()),
400 ..Default::default()
401 };
402 }
403 if s.starts_with("InChI=") {
404 return SubstanceIdentifier {
405 inchi: Some(s.to_string()),
406 ..Default::default()
407 };
408 }
409 if !s.contains(' ')
410 && s.chars()
411 .any(|c| matches!(c, '(' | ')' | '=' | '#' | '[' | ']' | '+' | '-'))
412 {
413 return SubstanceIdentifier::from_smiles(s);
414 }
415 SubstanceIdentifier::from_iupac_name(s)
416}
417
418fn is_cas_format(s: &str) -> bool {
419 let parts: Vec<&str> = s.split('-').collect();
420 parts.len() == 3
421 && parts[0].len() >= 2
422 && parts[0].chars().all(|c| c.is_ascii_digit())
423 && parts[1].len() == 2
424 && parts[1].chars().all(|c| c.is_ascii_digit())
425 && parts[2].len() == 1
426 && parts[2].chars().all(|c| c.is_ascii_digit())
427}
428
429fn is_inchi_key_format(s: &str) -> bool {
430 let parts: Vec<&str> = s.split('-').collect();
431 parts.len() == 3
432 && parts[0].len() == 14
433 && parts[1].len() == 10
434 && parts[2].len() == 1
435 && s.chars().all(|c| c.is_ascii_uppercase() || c == '-')
436}
437
438fn resolve_iupac_to_smiles(iupac_name: &str) -> Option<String> {
439 chem_name_resolver::resolve(iupac_name)
440 .ok()
441 .map(|r| r.smiles)
442}
443
444fn question_kind_name(q: &Question) -> &'static str {
445 match q {
446 Question::Text { .. } => "text",
447 Question::Choice { .. } => "choice",
448 Question::YesNo { .. } => "yes_no",
449 Question::Number { .. } => "number",
450 Question::MultiChoice { .. } => "multi_choice",
451 }
452}
453
454#[cfg(test)]
457mod tests {
458 use super::*;
459 use crate::types::{IntendedUse, OrganicInorganic, PhysicalForm};
460
461 fn next_q(result: SessionResult) -> Question {
463 match result {
464 SessionResult::NeedMoreInfo { next_question } => next_question,
465 other => panic!("expected NeedMoreInfo, got {:?}", std::mem::discriminant(&other)),
466 }
467 }
468
469 #[test]
472 fn session_starts_with_identifier_question() {
473 let mut session = ClassificationSession::new();
474 let q = session.start();
475 assert!(matches!(q, Question::Text { .. }));
476 assert_eq!(session.current_step(), Some(QuestionStep::Identifier));
477 }
478
479 #[test]
480 fn session_pure_cas_inorganic_full_flow() {
481 let mut session = ClassificationSession::new();
483 session.start();
484
485 let r = session.answer(Answer::Text("1310-73-2".to_string())).unwrap();
487 assert!(matches!(next_q(r), Question::YesNo { .. }));
488 assert_eq!(session.current_step(), Some(QuestionStep::IsMixture));
489
490 let r = session.answer(Answer::YesNo(false)).unwrap();
492 assert!(matches!(next_q(r), Question::Choice { .. }));
493 assert_eq!(session.current_step(), Some(QuestionStep::PhysicalForm));
494
495 let r = session.answer(Answer::Choice(0)).unwrap();
497 assert!(matches!(next_q(r), Question::Choice { .. }));
498 assert_eq!(session.current_step(), Some(QuestionStep::IntendedUse));
499
500 let r = session.answer(Answer::Choice(0)).unwrap();
502 assert!(matches!(next_q(r), Question::Choice { .. }));
503 assert_eq!(session.current_step(), Some(QuestionStep::OrganicInorganic));
504
505 let r = session.answer(Answer::Choice(1)).unwrap();
507 assert!(matches!(r, SessionResult::Ready));
508
509 let product = session.to_product_description();
511 assert_eq!(product.identifier.cas.as_deref(), Some("1310-73-2"));
512 assert!(matches!(product.physical_form, Some(PhysicalForm::Solid)));
513 assert_eq!(product.intended_use, Some(IntendedUse::Industrial));
514 assert_eq!(session.question_count(), 5);
515 assert!(session.is_complete());
516 }
517
518 #[test]
519 fn session_smiles_input_skips_organic_inorganic_question() {
520 let mut session = ClassificationSession::new();
522 session.start();
523
524 let r = session.answer(Answer::Text("[Na+].[OH-]".to_string())).unwrap();
526 assert!(matches!(next_q(r), Question::YesNo { .. }));
527
528 let r = session.answer(Answer::YesNo(false)).unwrap();
530 assert!(matches!(next_q(r), Question::Choice { .. })); let r = session.answer(Answer::Choice(3)).unwrap();
534 assert!(matches!(next_q(r), Question::Choice { .. })); let r = session.answer(Answer::Choice(0)).unwrap();
538 assert!(matches!(r, SessionResult::Ready));
540 assert_eq!(session.question_count(), 4);
541
542 let product = session.to_product_description();
543 assert!(product.identifier.smiles.is_some());
544 }
545
546 #[test]
547 fn session_organic_cas_asks_functional_groups() {
548 let mut session = ClassificationSession::new();
550 session.start();
551
552 session.answer(Answer::Text("108-88-3".to_string())).unwrap(); session.answer(Answer::YesNo(false)).unwrap(); session.answer(Answer::Choice(0)).unwrap(); session.answer(Answer::Choice(0)).unwrap(); let r = session.answer(Answer::Choice(0)).unwrap(); let q = next_q(r);
560 assert!(matches!(q, Question::MultiChoice { .. }));
561 assert_eq!(session.current_step(), Some(QuestionStep::FunctionalGroups));
562
563 let r = session.answer(Answer::MultiChoice(vec![10])).unwrap();
565 assert!(matches!(r, SessionResult::Ready));
566
567 let state = session.state();
568 assert_eq!(state.organic_inorganic, Some(OrganicInorganic::Organic));
569 assert!(state.detected_functional_groups.contains(&"aromatic".to_string()));
570 }
571
572 #[test]
573 fn session_solution_asks_concentration() {
574 let mut session = ClassificationSession::new();
575 session.start();
576
577 session.answer(Answer::Text("7647-01-0".to_string())).unwrap(); session.answer(Answer::YesNo(false)).unwrap();
579
580 let r = session.answer(Answer::Choice(4)).unwrap();
582 let q = next_q(r);
583 assert!(matches!(q, Question::Number { .. }));
584 assert_eq!(session.current_step(), Some(QuestionStep::SolutionConcentration));
585
586 let r = session.answer(Answer::Number(35.0)).unwrap();
588 assert!(matches!(next_q(r), Question::Choice { .. })); session.answer(Answer::Choice(0)).unwrap(); let r = session.answer(Answer::Choice(1)).unwrap(); assert!(matches!(r, SessionResult::Ready));
593
594 let product = session.to_product_description();
595 assert_eq!(
596 product.physical_form,
597 Some(PhysicalForm::Solution {
598 solvent: None,
599 concentration_pct_ww: Some(35.0),
600 })
601 );
602 }
603
604 #[test]
607 fn session_mixture_two_components() {
608 let mut session = ClassificationSession::new();
609 session.start();
610
611 session.answer(Answer::Text("7664-93-9".to_string())).unwrap(); session.answer(Answer::YesNo(true)).unwrap();
616 assert_eq!(session.current_step(), Some(QuestionStep::ComponentCount));
617
618 session.answer(Answer::Number(2.0)).unwrap();
620 assert_eq!(session.current_step(), Some(QuestionStep::ComponentIdentifier));
621
622 session.answer(Answer::Text("7664-93-9".to_string())).unwrap();
624 assert_eq!(session.current_step(), Some(QuestionStep::ComponentFraction));
625
626 session.answer(Answer::Number(70.0)).unwrap();
628 assert_eq!(session.current_step(), Some(QuestionStep::ComponentIdentifier));
629
630 session.answer(Answer::Text("7732-18-5".to_string())).unwrap(); assert_eq!(session.current_step(), Some(QuestionStep::ComponentFraction));
633
634 let r = session.answer(Answer::Number(30.0)).unwrap();
636 assert!(matches!(r, SessionResult::Ready | SessionResult::RequiresLlm));
637
638 let product = session.to_product_description();
639 let comps = product.mixture_components.unwrap();
640 assert_eq!(comps.len(), 2);
641 assert_eq!(comps[0].substance.cas.as_deref(), Some("7664-93-9"));
642 assert_eq!(comps[0].weight_fraction_pct, Some(70.0));
643 assert_eq!(comps[1].substance.cas.as_deref(), Some("7732-18-5"));
644 assert_eq!(comps[1].weight_fraction_pct, Some(30.0));
645 }
646
647 #[test]
650 fn error_no_active_question_before_start() {
651 let mut session = ClassificationSession::new();
652 let err = session.answer(Answer::Text("1310-73-2".to_string())).unwrap_err();
653 assert!(matches!(err, HsPredictError::NoActiveQuestion));
654 }
655
656 #[test]
657 fn error_answer_type_mismatch() {
658 let mut session = ClassificationSession::new();
659 session.start(); let err = session.answer(Answer::YesNo(true)).unwrap_err();
661 assert!(matches!(err, HsPredictError::AnswerTypeMismatch { .. }));
662 }
663
664 #[test]
665 fn error_choice_index_out_of_range() {
666 let mut session = ClassificationSession::new();
667 session.start();
668 session.answer(Answer::Text("1310-73-2".to_string())).unwrap();
669 session.answer(Answer::YesNo(false)).unwrap(); let err = session.answer(Answer::Choice(99)).unwrap_err();
671 assert!(matches!(err, HsPredictError::InvalidChoiceIndex { .. }));
672 }
673
674 #[test]
675 fn error_number_out_of_range() {
676 let mut session = ClassificationSession::new();
677 session.start();
678 session.answer(Answer::Text("1310-73-2".to_string())).unwrap();
679 session.answer(Answer::YesNo(true)).unwrap(); let err = session.answer(Answer::Number(1.0)).unwrap_err(); assert!(matches!(err, HsPredictError::NumberOutOfRange { .. }));
682 }
683
684 #[test]
687 fn japanese_session_prompts_are_in_japanese() {
688 let mut session = ClassificationSession::new_ja();
689 let q = session.start();
690 assert!(q.prompt().chars().any(|c| c as u32 > 0x7F));
692 }
693
694 #[test]
695 fn japanese_session_completes_same_as_english() {
696 let mut session = ClassificationSession::new_ja();
698 session.start();
699
700 session.answer(Answer::Text("1310-73-2".to_string())).unwrap();
701 session.answer(Answer::YesNo(false)).unwrap();
702 session.answer(Answer::Choice(0)).unwrap(); session.answer(Answer::Choice(0)).unwrap(); let r = session.answer(Answer::Choice(1)).unwrap(); assert!(matches!(r, SessionResult::Ready));
707 let product = session.to_product_description();
708 assert_eq!(product.identifier.cas.as_deref(), Some("1310-73-2"));
709 }
710
711 #[test]
714 fn session_serializes_and_deserializes() {
715 let mut session = ClassificationSession::new();
716 session.start();
717 session.answer(Answer::Text("1310-73-2".to_string())).unwrap();
718
719 let json = serde_json::to_string(&session).unwrap();
720 let restored: ClassificationSession = serde_json::from_str(&json).unwrap();
721
722 assert_eq!(
723 restored.state().identifier.cas.as_deref(),
724 Some("1310-73-2")
725 );
726 assert_eq!(restored.language(), Language::En);
727 assert_eq!(restored.current_step(), Some(QuestionStep::IsMixture));
728 }
729}