1pub mod flow;
31pub mod messages;
32pub mod question;
33pub mod state;
34
35use serde::{Deserialize, Serialize};
36
37pub use question::{Answer, QAPair, Question, QuestionStep, SessionResult};
38pub use state::{ClassificationState, PartialComponent};
39
40use crate::error::{HsPredictError, Result};
41use crate::session::flow::{
42 choice_index_to_intended_use, choice_index_to_organic_inorganic,
43 choice_index_to_physical_form, multi_choice_indices_to_functional_groups, next_question,
44};
45use crate::types::{Language, MixtureComponent, PhysicalForm, ProductDescription, SubstanceIdentifier};
46
47#[derive(Debug, Clone, Serialize, Deserialize)]
52pub struct ClassificationSession {
53 state: ClassificationState,
55 history: Vec<QAPair>,
57 current_question: Option<Question>,
59 current_step: Option<QuestionStep>,
61 language: Language,
63}
64
65impl ClassificationSession {
66 pub fn new() -> Self {
68 Self {
69 state: ClassificationState::default(),
70 history: Vec::new(),
71 current_question: None,
72 current_step: None,
73 language: Language::En,
74 }
75 }
76
77 pub fn new_ja() -> Self {
79 Self::new().with_language(Language::Ja)
80 }
81
82 pub fn with_language(mut self, language: Language) -> Self {
86 self.language = language;
87 self
88 }
89
90 pub fn start(&mut self) -> Question {
94 let (q, step) = next_question(&self.state, self.language)
95 .expect("new session should always have a first question");
96 self.current_question = Some(q.clone());
97 self.current_step = Some(step);
98 q
99 }
100
101 pub fn answer(&mut self, answer: Answer) -> Result<SessionResult> {
114 let question = self
115 .current_question
116 .clone()
117 .ok_or(HsPredictError::NoActiveQuestion)?;
118
119 self.validate_and_apply(&question, &answer)?;
121
122 self.history.push(QAPair {
124 question: question.clone(),
125 answer,
126 });
127
128 self.try_resolve_smiles();
130
131 match next_question(&self.state, self.language) {
133 Some((q, step)) => {
134 self.current_question = Some(q.clone());
135 self.current_step = Some(step);
136 Ok(SessionResult::NeedMoreInfo { next_question: q })
137 }
138 None => {
139 self.state.is_complete = true;
140 self.current_question = None;
141 self.current_step = None;
142 if self.state.confidence_estimate() < 0.25 {
143 Ok(SessionResult::RequiresLlm)
144 } else {
145 Ok(SessionResult::Ready)
146 }
147 }
148 }
149 }
150
151 pub fn to_product_description(&self) -> ProductDescription {
155 let mixture_components = if self.state.is_mixture == Some(true) {
156 Some(
157 self.state
158 .components
159 .iter()
160 .map(|c| MixtureComponent {
161 substance: c.identifier.clone(),
162 weight_fraction_pct: c.weight_fraction_pct,
163 volume_fraction_pct: None,
164 is_solvent: c.is_solvent,
165 })
166 .collect(),
167 )
168 } else {
169 None
170 };
171
172 ProductDescription {
173 identifier: self.state.identifier.clone(),
174 physical_form: self.state.physical_form.clone(),
175 purity_pct: self.state.purity_pct,
176 purity_type: None,
177 mixture_components,
178 intended_use: self.state.intended_use.clone(),
179 additional_context: None,
180 }
181 }
182
183 pub fn state(&self) -> &ClassificationState {
185 &self.state
186 }
187
188 pub fn history(&self) -> &[QAPair] {
190 &self.history
191 }
192
193 pub fn question_count(&self) -> usize {
195 self.history.len()
196 }
197
198 pub fn is_complete(&self) -> bool {
200 self.state.is_complete
201 }
202
203 pub fn language(&self) -> Language {
205 self.language
206 }
207
208 pub fn current_step(&self) -> Option<QuestionStep> {
210 self.current_step
211 }
212
213 fn validate_and_apply(&mut self, question: &Question, answer: &Answer) -> Result<()> {
216 match (question, answer) {
217 (Question::Text { .. }, Answer::Text(text)) => {
219 self.apply_identifier_input(text);
220 }
221 (Question::Text { .. }, Answer::Skip) => {
222 if !self.state.has_identifier()
223 && self.current_step != Some(QuestionStep::ComponentIdentifier)
224 {
225 return Err(HsPredictError::MissingIdentifier);
226 }
227 }
228
229 (Question::YesNo { .. }, Answer::YesNo(val)) => {
231 self.apply_yes_no(*val);
232 }
233
234 (Question::Number { min, max, .. }, Answer::Number(val)) => {
236 if *val < *min || *val > *max {
237 return Err(HsPredictError::NumberOutOfRange {
238 value: *val,
239 min: *min,
240 max: *max,
241 });
242 }
243 self.apply_number(*val);
244 }
245
246 (Question::Choice { options, .. }, Answer::Choice(idx)) => {
248 if *idx >= options.len() {
249 return Err(HsPredictError::InvalidChoiceIndex {
250 index: *idx,
251 max: options.len() - 1,
252 });
253 }
254 self.apply_choice(*idx);
255 }
256
257 (Question::MultiChoice { options, .. }, Answer::MultiChoice(indices)) => {
259 for &idx in indices {
260 if idx >= options.len() {
261 return Err(HsPredictError::InvalidChoiceIndex {
262 index: idx,
263 max: options.len() - 1,
264 });
265 }
266 }
267 self.apply_multi_choice(indices);
268 }
269
270 _ => {
272 return Err(HsPredictError::AnswerTypeMismatch {
273 expected: question_kind_name(question),
274 got: answer.kind_name(),
275 });
276 }
277 }
278 Ok(())
279 }
280
281 fn apply_identifier_input(&mut self, input: &str) {
283 let input = input.trim();
284
285 let in_mixture = self.state.is_mixture == Some(true)
286 && self.state.current_component_index < self.state.component_count.unwrap_or(0);
287
288 if in_mixture {
289 let idx = self.state.current_component_index;
290 while self.state.components.len() <= idx {
291 self.state.components.push(PartialComponent::default());
292 }
293 self.state.components[idx].identifier = parse_identifier(input);
294 } else {
295 self.state.identifier = parse_identifier(input);
296 }
297 }
298
299 fn apply_yes_no(&mut self, val: bool) {
300 match self.current_step {
301 Some(QuestionStep::IsMixture) => {
302 self.state.is_mixture = Some(val);
303 }
304 _ => {}
305 }
306 }
307
308 fn apply_number(&mut self, val: f64) {
309 match self.current_step {
310 Some(QuestionStep::ComponentCount) => {
311 self.state.component_count = Some(val as usize);
312 }
313 Some(QuestionStep::ComponentFraction) => {
314 let idx = self.state.current_component_index;
315 if idx < self.state.components.len() {
316 self.state.components[idx].weight_fraction_pct =
317 if val > 0.0 { Some(val) } else { None };
318 self.state.current_component_index += 1;
319 }
320 }
321 Some(QuestionStep::SolutionConcentration) => {
322 if let Some(PhysicalForm::Solution { concentration_pct_ww, .. }) =
323 &mut self.state.physical_form
324 {
325 *concentration_pct_ww = if val > 0.0 { Some(val) } else { None };
326 }
327 }
328 _ => {}
329 }
330 }
331
332 fn apply_choice(&mut self, idx: usize) {
333 match self.current_step {
334 Some(QuestionStep::PhysicalForm) => {
335 self.state.physical_form = Some(choice_index_to_physical_form(idx));
336 }
337 Some(QuestionStep::IntendedUse) => {
338 self.state.intended_use = Some(choice_index_to_intended_use(idx));
339 }
340 Some(QuestionStep::OrganicInorganic) => {
341 self.state.organic_inorganic = Some(choice_index_to_organic_inorganic(idx));
342 }
343 _ => {}
344 }
345 }
346
347 fn apply_multi_choice(&mut self, indices: &[usize]) {
348 self.state.detected_functional_groups = multi_choice_indices_to_functional_groups(indices);
349 }
350
351 fn try_resolve_smiles(&mut self) {
353 if self.state.identifier.smiles.is_none() {
354 if let Some(ref iupac) = self.state.identifier.iupac_name.clone() {
355 if let Some(smiles) = resolve_iupac_to_smiles(iupac) {
356 self.state.identifier.smiles = Some(smiles);
357 }
358 }
359 }
360 for comp in &mut self.state.components {
361 if comp.identifier.smiles.is_none() {
362 if let Some(ref iupac) = comp.identifier.iupac_name.clone() {
363 if let Some(smiles) = resolve_iupac_to_smiles(iupac) {
364 comp.identifier.smiles = Some(smiles);
365 }
366 }
367 }
368 }
369 }
370}
371
372impl Default for ClassificationSession {
373 fn default() -> Self {
374 Self::new()
375 }
376}
377
378fn parse_identifier(input: &str) -> SubstanceIdentifier {
381 let s = input.trim();
382
383 if is_cas_format(s) {
384 return SubstanceIdentifier::from_cas(s);
385 }
386 if is_inchi_key_format(s) {
387 return SubstanceIdentifier {
388 inchi_key: Some(s.to_string()),
389 ..Default::default()
390 };
391 }
392 if s.starts_with("InChI=") {
393 return SubstanceIdentifier {
394 inchi: Some(s.to_string()),
395 ..Default::default()
396 };
397 }
398 if !s.contains(' ')
399 && s.chars()
400 .any(|c| matches!(c, '(' | ')' | '=' | '#' | '[' | ']' | '+' | '-'))
401 {
402 return SubstanceIdentifier::from_smiles(s);
403 }
404 SubstanceIdentifier::from_iupac_name(s)
405}
406
407fn is_cas_format(s: &str) -> bool {
408 let parts: Vec<&str> = s.split('-').collect();
409 parts.len() == 3
410 && parts[0].len() >= 2
411 && parts[0].chars().all(|c| c.is_ascii_digit())
412 && parts[1].len() == 2
413 && parts[1].chars().all(|c| c.is_ascii_digit())
414 && parts[2].len() == 1
415 && parts[2].chars().all(|c| c.is_ascii_digit())
416}
417
418fn is_inchi_key_format(s: &str) -> bool {
419 let parts: Vec<&str> = s.split('-').collect();
420 parts.len() == 3
421 && parts[0].len() == 14
422 && parts[1].len() == 10
423 && parts[2].len() == 1
424 && s.chars().all(|c| c.is_ascii_uppercase() || c == '-')
425}
426
427fn resolve_iupac_to_smiles(iupac_name: &str) -> Option<String> {
428 chem_name_resolver::resolve(iupac_name)
429 .ok()
430 .map(|r| r.smiles)
431}
432
433fn question_kind_name(q: &Question) -> &'static str {
434 match q {
435 Question::Text { .. } => "text",
436 Question::Choice { .. } => "choice",
437 Question::YesNo { .. } => "yes_no",
438 Question::Number { .. } => "number",
439 Question::MultiChoice { .. } => "multi_choice",
440 }
441}
442
443#[cfg(test)]
446mod tests {
447 use super::*;
448 use crate::types::{IntendedUse, OrganicInorganic, PhysicalForm};
449
450 fn next_q(result: SessionResult) -> Question {
452 match result {
453 SessionResult::NeedMoreInfo { next_question } => next_question,
454 other => panic!("expected NeedMoreInfo, got {:?}", std::mem::discriminant(&other)),
455 }
456 }
457
458 #[test]
461 fn session_starts_with_identifier_question() {
462 let mut session = ClassificationSession::new();
463 let q = session.start();
464 assert!(matches!(q, Question::Text { .. }));
465 assert_eq!(session.current_step(), Some(QuestionStep::Identifier));
466 }
467
468 #[test]
469 fn session_pure_cas_inorganic_full_flow() {
470 let mut session = ClassificationSession::new();
472 session.start();
473
474 let r = session.answer(Answer::Text("1310-73-2".to_string())).unwrap();
476 assert!(matches!(next_q(r), Question::YesNo { .. }));
477 assert_eq!(session.current_step(), Some(QuestionStep::IsMixture));
478
479 let r = session.answer(Answer::YesNo(false)).unwrap();
481 assert!(matches!(next_q(r), Question::Choice { .. }));
482 assert_eq!(session.current_step(), Some(QuestionStep::PhysicalForm));
483
484 let r = session.answer(Answer::Choice(0)).unwrap();
486 assert!(matches!(next_q(r), Question::Choice { .. }));
487 assert_eq!(session.current_step(), Some(QuestionStep::IntendedUse));
488
489 let r = session.answer(Answer::Choice(0)).unwrap();
491 assert!(matches!(next_q(r), Question::Choice { .. }));
492 assert_eq!(session.current_step(), Some(QuestionStep::OrganicInorganic));
493
494 let r = session.answer(Answer::Choice(1)).unwrap();
496 assert!(matches!(r, SessionResult::Ready));
497
498 let product = session.to_product_description();
500 assert_eq!(product.identifier.cas.as_deref(), Some("1310-73-2"));
501 assert!(matches!(product.physical_form, Some(PhysicalForm::Solid)));
502 assert_eq!(product.intended_use, Some(IntendedUse::Industrial));
503 assert_eq!(session.question_count(), 5);
504 assert!(session.is_complete());
505 }
506
507 #[test]
508 fn session_smiles_input_skips_organic_inorganic_question() {
509 let mut session = ClassificationSession::new();
511 session.start();
512
513 let r = session.answer(Answer::Text("[Na+].[OH-]".to_string())).unwrap();
515 assert!(matches!(next_q(r), Question::YesNo { .. }));
516
517 let r = session.answer(Answer::YesNo(false)).unwrap();
519 assert!(matches!(next_q(r), Question::Choice { .. })); let r = session.answer(Answer::Choice(3)).unwrap();
523 assert!(matches!(next_q(r), Question::Choice { .. })); let r = session.answer(Answer::Choice(0)).unwrap();
527 assert!(matches!(r, SessionResult::Ready));
529 assert_eq!(session.question_count(), 4);
530
531 let product = session.to_product_description();
532 assert!(product.identifier.smiles.is_some());
533 }
534
535 #[test]
536 fn session_organic_cas_asks_functional_groups() {
537 let mut session = ClassificationSession::new();
539 session.start();
540
541 session.answer(Answer::Text("108-88-3".to_string())).unwrap(); session.answer(Answer::YesNo(false)).unwrap(); session.answer(Answer::Choice(0)).unwrap(); session.answer(Answer::Choice(0)).unwrap(); let r = session.answer(Answer::Choice(0)).unwrap(); let q = next_q(r);
549 assert!(matches!(q, Question::MultiChoice { .. }));
550 assert_eq!(session.current_step(), Some(QuestionStep::FunctionalGroups));
551
552 let r = session.answer(Answer::MultiChoice(vec![10])).unwrap();
554 assert!(matches!(r, SessionResult::Ready));
555
556 let state = session.state();
557 assert_eq!(state.organic_inorganic, Some(OrganicInorganic::Organic));
558 assert!(state.detected_functional_groups.contains(&"aromatic".to_string()));
559 }
560
561 #[test]
562 fn session_solution_asks_concentration() {
563 let mut session = ClassificationSession::new();
564 session.start();
565
566 session.answer(Answer::Text("7647-01-0".to_string())).unwrap(); session.answer(Answer::YesNo(false)).unwrap();
568
569 let r = session.answer(Answer::Choice(4)).unwrap();
571 let q = next_q(r);
572 assert!(matches!(q, Question::Number { .. }));
573 assert_eq!(session.current_step(), Some(QuestionStep::SolutionConcentration));
574
575 let r = session.answer(Answer::Number(35.0)).unwrap();
577 assert!(matches!(next_q(r), Question::Choice { .. })); session.answer(Answer::Choice(0)).unwrap(); let r = session.answer(Answer::Choice(1)).unwrap(); assert!(matches!(r, SessionResult::Ready));
582
583 let product = session.to_product_description();
584 assert_eq!(
585 product.physical_form,
586 Some(PhysicalForm::Solution {
587 solvent: None,
588 concentration_pct_ww: Some(35.0),
589 })
590 );
591 }
592
593 #[test]
596 fn session_mixture_two_components() {
597 let mut session = ClassificationSession::new();
598 session.start();
599
600 session.answer(Answer::Text("7664-93-9".to_string())).unwrap(); session.answer(Answer::YesNo(true)).unwrap();
605 assert_eq!(session.current_step(), Some(QuestionStep::ComponentCount));
606
607 session.answer(Answer::Number(2.0)).unwrap();
609 assert_eq!(session.current_step(), Some(QuestionStep::ComponentIdentifier));
610
611 session.answer(Answer::Text("7664-93-9".to_string())).unwrap();
613 assert_eq!(session.current_step(), Some(QuestionStep::ComponentFraction));
614
615 session.answer(Answer::Number(70.0)).unwrap();
617 assert_eq!(session.current_step(), Some(QuestionStep::ComponentIdentifier));
618
619 session.answer(Answer::Text("7732-18-5".to_string())).unwrap(); assert_eq!(session.current_step(), Some(QuestionStep::ComponentFraction));
622
623 let r = session.answer(Answer::Number(30.0)).unwrap();
625 assert!(matches!(r, SessionResult::Ready | SessionResult::RequiresLlm));
626
627 let product = session.to_product_description();
628 let comps = product.mixture_components.unwrap();
629 assert_eq!(comps.len(), 2);
630 assert_eq!(comps[0].substance.cas.as_deref(), Some("7664-93-9"));
631 assert_eq!(comps[0].weight_fraction_pct, Some(70.0));
632 assert_eq!(comps[1].substance.cas.as_deref(), Some("7732-18-5"));
633 assert_eq!(comps[1].weight_fraction_pct, Some(30.0));
634 }
635
636 #[test]
639 fn error_no_active_question_before_start() {
640 let mut session = ClassificationSession::new();
641 let err = session.answer(Answer::Text("1310-73-2".to_string())).unwrap_err();
642 assert!(matches!(err, HsPredictError::NoActiveQuestion));
643 }
644
645 #[test]
646 fn error_answer_type_mismatch() {
647 let mut session = ClassificationSession::new();
648 session.start(); let err = session.answer(Answer::YesNo(true)).unwrap_err();
650 assert!(matches!(err, HsPredictError::AnswerTypeMismatch { .. }));
651 }
652
653 #[test]
654 fn error_choice_index_out_of_range() {
655 let mut session = ClassificationSession::new();
656 session.start();
657 session.answer(Answer::Text("1310-73-2".to_string())).unwrap();
658 session.answer(Answer::YesNo(false)).unwrap(); let err = session.answer(Answer::Choice(99)).unwrap_err();
660 assert!(matches!(err, HsPredictError::InvalidChoiceIndex { .. }));
661 }
662
663 #[test]
664 fn error_number_out_of_range() {
665 let mut session = ClassificationSession::new();
666 session.start();
667 session.answer(Answer::Text("1310-73-2".to_string())).unwrap();
668 session.answer(Answer::YesNo(true)).unwrap(); let err = session.answer(Answer::Number(1.0)).unwrap_err(); assert!(matches!(err, HsPredictError::NumberOutOfRange { .. }));
671 }
672
673 #[test]
676 fn japanese_session_prompts_are_in_japanese() {
677 let mut session = ClassificationSession::new_ja();
678 let q = session.start();
679 assert!(q.prompt().chars().any(|c| c as u32 > 0x7F));
681 }
682
683 #[test]
684 fn japanese_session_completes_same_as_english() {
685 let mut session = ClassificationSession::new_ja();
687 session.start();
688
689 session.answer(Answer::Text("1310-73-2".to_string())).unwrap();
690 session.answer(Answer::YesNo(false)).unwrap();
691 session.answer(Answer::Choice(0)).unwrap(); session.answer(Answer::Choice(0)).unwrap(); let r = session.answer(Answer::Choice(1)).unwrap(); assert!(matches!(r, SessionResult::Ready));
696 let product = session.to_product_description();
697 assert_eq!(product.identifier.cas.as_deref(), Some("1310-73-2"));
698 }
699
700 #[test]
703 fn session_serializes_and_deserializes() {
704 let mut session = ClassificationSession::new();
705 session.start();
706 session.answer(Answer::Text("1310-73-2".to_string())).unwrap();
707
708 let json = serde_json::to_string(&session).unwrap();
709 let restored: ClassificationSession = serde_json::from_str(&json).unwrap();
710
711 assert_eq!(
712 restored.state().identifier.cas.as_deref(),
713 Some("1310-73-2")
714 );
715 assert_eq!(restored.language(), Language::En);
716 assert_eq!(restored.current_step(), Some(QuestionStep::IsMixture));
717 }
718}