1use crate::io::molblock::{self, SdfFormat};
2use crate::io::sdf::{SdfCoordinateMode, read_sdf_from_str_with_coordinate_mode};
3use crate::{Molecule, PreparedDrawMolecule, SmilesWriteParams};
4use indicatif::{ProgressBar, ProgressDrawTarget, ProgressStyle};
5use rayon::prelude::*;
6use std::collections::HashSet;
7use std::fs::{self, File};
8use std::io::Write;
9use std::path::{Component, Path, PathBuf};
10use thiserror::Error;
11
12#[derive(Debug, Copy, Clone, PartialEq, Eq)]
13pub enum BatchErrorMode {
14 Raise,
15 Keep,
16 Skip,
17}
18
19#[derive(Debug, Clone, PartialEq, Eq)]
20pub struct BatchRecordError {
21 pub index: usize,
22 pub input: Option<String>,
23 pub stage: String,
24 pub error_type: String,
25 pub message: String,
26}
27
28impl BatchRecordError {
29 #[must_use]
30 pub fn new(
31 index: usize,
32 input: Option<String>,
33 stage: impl Into<String>,
34 error_type: impl Into<String>,
35 message: impl Into<String>,
36 ) -> Self {
37 Self {
38 index,
39 input,
40 stage: stage.into(),
41 error_type: error_type.into(),
42 message: message.into(),
43 }
44 }
45}
46
47#[derive(Debug, Clone, PartialEq)]
48pub enum BatchRecord {
49 Valid(Molecule),
50 Invalid(BatchRecordError),
51}
52
53#[derive(Debug, Clone, PartialEq, Default)]
54pub struct MoleculeBatch {
55 pub records: Vec<BatchRecord>,
56}
57
58#[derive(Debug, Clone, PartialEq, Eq)]
59pub struct BatchExportReport {
60 pub total: usize,
61 pub success: usize,
62 pub failed: usize,
63 pub errors: Vec<BatchRecordError>,
64}
65
66#[derive(Debug, Clone, PartialEq, Eq, Error)]
67#[error("batch validation failed with {} error(s)", .errors.len())]
68pub struct BatchValidationError {
69 pub errors: Vec<BatchRecordError>,
70}
71
72pub type BatchProgress<'a> = Option<&'a (dyn Fn() + Sync)>;
74
75pub struct BatchProgressBar {
77 inner: ProgressBar,
78}
79
80impl BatchProgressBar {
81 #[must_use]
83 pub fn new(total: usize, message: impl Into<String>) -> Self {
84 let progress_bar = ProgressBar::with_draw_target(
85 Some(total as u64),
86 ProgressDrawTarget::term_like_with_hz(Box::new(console::Term::buffered_stderr()), 20),
87 );
88 let style = ProgressStyle::with_template(
89 "{spinner:.green} {msg} [{elapsed_precise}] [{wide_bar:.cyan/blue}] {pos}/{len}",
90 )
91 .unwrap_or_else(|_| ProgressStyle::default_bar());
92 progress_bar.set_style(style);
93 progress_bar.set_message(message.into());
94 Self {
95 inner: progress_bar,
96 }
97 }
98
99 #[must_use]
101 pub fn callback(&self) -> Box<dyn Fn() + Sync + '_> {
102 Box::new(|| self.inner.inc(1))
103 }
104
105 pub fn finish(&self) {
107 self.inner.finish();
108 eprintln!();
109 }
110}
111
112#[must_use]
114pub fn batch_progress_bar(
115 enabled: bool,
116 total: usize,
117 message: impl Into<String>,
118) -> Option<BatchProgressBar> {
119 if !enabled || total == 0 {
120 return None;
121 }
122 Some(BatchProgressBar::new(total, message))
123}
124
125fn tick_progress(progress: BatchProgress<'_>) {
126 if let Some(progress) = progress {
127 progress();
128 }
129}
130
131impl MoleculeBatch {
132 #[must_use]
133 pub fn new(records: Vec<BatchRecord>) -> Self {
134 Self { records }
135 }
136
137 pub fn from_smiles_list(
138 smiles: &[String],
139 errors: BatchErrorMode,
140 ) -> Result<Self, BatchValidationError> {
141 Self::from_smiles_list_with_sanitize(smiles, true, errors)
142 }
143
144 pub fn from_smiles_list_with_sanitize(
145 smiles: &[String],
146 sanitize: bool,
147 errors: BatchErrorMode,
148 ) -> Result<Self, BatchValidationError> {
149 let records: Vec<BatchRecord> = smiles
150 .par_iter()
151 .enumerate()
152 .filter_map(|(index, smiles)| {
153 match Molecule::from_smiles_with_sanitize(smiles, sanitize) {
154 Ok(molecule) => Some(BatchRecord::Valid(molecule)),
155 Err(error) => {
156 let record_error = BatchRecordError::new(
157 index,
158 Some(smiles.clone()),
159 "parse_smiles",
160 "SmilesParseError",
161 error.to_string(),
162 );
163 match errors {
164 BatchErrorMode::Raise | BatchErrorMode::Keep => {
165 Some(BatchRecord::Invalid(record_error))
166 }
167 BatchErrorMode::Skip => None,
168 }
169 }
170 }
171 })
172 .collect();
173 Self::from_records_with_mode(records, errors)
174 }
175
176 fn from_sdf_record_strings(
177 records: &[String],
178 coordinate_mode: SdfCoordinateMode,
179 errors: BatchErrorMode,
180 ) -> Result<Self, BatchValidationError> {
181 let batch_records: Vec<BatchRecord> = records
182 .par_iter()
183 .enumerate()
184 .filter_map(|(index, sdf)| {
185 match read_sdf_from_str_with_coordinate_mode(sdf, coordinate_mode) {
186 Ok(record) => Some(BatchRecord::Valid(record.molecule)),
187 Err(error) => {
188 let record_error = BatchRecordError::new(
189 index,
190 None,
191 "read_sdf",
192 "SdfReadError",
193 error.to_string(),
194 );
195 match errors {
196 BatchErrorMode::Raise | BatchErrorMode::Keep => {
197 Some(BatchRecord::Invalid(record_error))
198 }
199 BatchErrorMode::Skip => None,
200 }
201 }
202 }
203 })
204 .collect();
205 Self::from_records_with_mode(batch_records, errors)
206 }
207
208 pub fn read_sdf_records_from_str(
209 sdf_text: &str,
210 coordinate_mode: SdfCoordinateMode,
211 errors: BatchErrorMode,
212 ) -> Result<Self, BatchValidationError> {
213 let records = split_sdf_record_strings(sdf_text);
214 Self::from_sdf_record_strings(&records, coordinate_mode, errors)
215 }
216
217 fn from_records_with_mode(
218 records: Vec<BatchRecord>,
219 errors: BatchErrorMode,
220 ) -> Result<Self, BatchValidationError> {
221 if matches!(errors, BatchErrorMode::Raise) {
222 let collected = records
223 .iter()
224 .filter_map(|record| match record {
225 BatchRecord::Valid(_) => None,
226 BatchRecord::Invalid(error) => Some(error.clone()),
227 })
228 .collect::<Vec<_>>();
229 if !collected.is_empty() {
230 return Err(BatchValidationError { errors: collected });
231 }
232 }
233 Ok(Self { records })
234 }
235
236 #[must_use]
237 pub fn len(&self) -> usize {
238 self.records.len()
239 }
240
241 #[must_use]
242 pub fn is_empty(&self) -> bool {
243 self.records.is_empty()
244 }
245
246 #[must_use]
247 pub fn valid_mask(&self) -> Vec<bool> {
248 self.records
249 .iter()
250 .map(|record| matches!(record, BatchRecord::Valid(_)))
251 .collect()
252 }
253
254 #[must_use]
255 pub fn errors(&self) -> Vec<BatchRecordError> {
256 self.records
257 .iter()
258 .filter_map(|record| match record {
259 BatchRecord::Valid(_) => None,
260 BatchRecord::Invalid(error) => Some(error.clone()),
261 })
262 .collect()
263 }
264
265 #[must_use]
266 pub fn valid_count(&self) -> usize {
267 self.records
268 .iter()
269 .filter(|record| matches!(record, BatchRecord::Valid(_)))
270 .count()
271 }
272
273 #[must_use]
274 pub fn invalid_count(&self) -> usize {
275 self.records.len() - self.valid_count()
276 }
277
278 #[must_use]
279 pub fn filter_valid(&self) -> Self {
280 let records = self
281 .records
282 .iter()
283 .filter_map(|record| match record {
284 BatchRecord::Valid(molecule) => Some(BatchRecord::Valid(molecule.clone())),
285 BatchRecord::Invalid(_) => None,
286 })
287 .collect();
288 Self { records }
289 }
290
291 pub fn add_hydrogens(&self, errors: BatchErrorMode) -> Result<Self, BatchValidationError> {
292 self.add_hydrogens_with_progress(errors, None)
293 }
294
295 pub fn add_hydrogens_with_progress(
296 &self,
297 errors: BatchErrorMode,
298 progress: BatchProgress<'_>,
299 ) -> Result<Self, BatchValidationError> {
300 self.transform_valid(
301 "add_hydrogens",
302 "AddHydrogensError",
303 errors,
304 |molecule| {
305 molecule
306 .with_hydrogens()
307 .map_err(|error| format!("{error:?}"))
308 },
309 progress,
310 )
311 }
312
313 pub fn remove_hydrogens(&self, errors: BatchErrorMode) -> Result<Self, BatchValidationError> {
314 self.remove_hydrogens_with_progress(errors, None)
315 }
316
317 pub fn remove_hydrogens_with_progress(
318 &self,
319 errors: BatchErrorMode,
320 progress: BatchProgress<'_>,
321 ) -> Result<Self, BatchValidationError> {
322 self.transform_valid(
323 "remove_hydrogens",
324 "RemoveHydrogensError",
325 errors,
326 |molecule| {
327 molecule
328 .without_hydrogens()
329 .map_err(|error| format!("{error:?}"))
330 },
331 progress,
332 )
333 }
334
335 pub fn sanitize(&self, errors: BatchErrorMode) -> Result<Self, BatchValidationError> {
336 self.sanitize_with_progress(errors, None)
337 }
338
339 pub fn sanitize_with_progress(
340 &self,
341 errors: BatchErrorMode,
342 progress: BatchProgress<'_>,
343 ) -> Result<Self, BatchValidationError> {
344 self.transform_valid(
345 "sanitize",
346 "SanitizeError",
347 errors,
348 |molecule| molecule.sanitize().map_err(|error| error.to_string()),
349 progress,
350 )
351 }
352
353 pub fn kekulize(&self, errors: BatchErrorMode) -> Result<Self, BatchValidationError> {
354 self.kekulize_with_sanitize(false, errors)
355 }
356
357 pub fn kekulize_with_sanitize(
358 &self,
359 sanitize: bool,
360 errors: BatchErrorMode,
361 ) -> Result<Self, BatchValidationError> {
362 self.kekulize_with_sanitize_and_progress(sanitize, errors, None)
363 }
364
365 pub fn kekulize_with_sanitize_and_progress(
366 &self,
367 sanitize: bool,
368 errors: BatchErrorMode,
369 progress: BatchProgress<'_>,
370 ) -> Result<Self, BatchValidationError> {
371 self.transform_valid(
372 "kekulize",
373 "KekulizeError",
374 errors,
375 |molecule| {
376 molecule
377 .with_kekulized_bonds(sanitize)
378 .map_err(|error| format!("{error:?}"))
379 },
380 progress,
381 )
382 }
383
384 pub fn compute_2d_coords(&self, errors: BatchErrorMode) -> Result<Self, BatchValidationError> {
385 self.compute_2d_coords_with_progress(errors, None)
386 }
387
388 pub fn compute_2d_coords_with_progress(
389 &self,
390 errors: BatchErrorMode,
391 progress: BatchProgress<'_>,
392 ) -> Result<Self, BatchValidationError> {
393 self.transform_valid(
394 "compute_2d_coords",
395 "CoordinateGenerationError",
396 errors,
397 |molecule| molecule.with_2d_coords().map_err(|error| error.to_string()),
398 progress,
399 )
400 }
401
402 pub fn to_smiles_list(
403 &self,
404 isomeric_smiles: bool,
405 ) -> Result<Vec<Option<String>>, BatchValidationError> {
406 self.collect_optional_values(
407 "to_smiles",
408 "SmilesWriteError",
409 |molecule| {
410 molecule
411 .to_smiles(isomeric_smiles)
412 .map_err(|error| error.to_string())
413 },
414 None,
415 )
416 }
417
418 pub fn to_smiles_list_with_params(
419 &self,
420 params: &SmilesWriteParams,
421 ) -> Result<Vec<Option<String>>, BatchValidationError> {
422 self.to_smiles_list_with_params_and_progress(params, None)
423 }
424
425 pub fn to_smiles_list_with_params_and_progress(
426 &self,
427 params: &SmilesWriteParams,
428 progress: BatchProgress<'_>,
429 ) -> Result<Vec<Option<String>>, BatchValidationError> {
430 self.collect_optional_values(
431 "to_smiles",
432 "SmilesWriteError",
433 |molecule| {
434 molecule
435 .to_smiles_with_params(params)
436 .map_err(|error| error.to_string())
437 },
438 progress,
439 )
440 }
441
442 pub fn dg_bounds_matrix_list(
443 &self,
444 ) -> Result<Vec<Option<Vec<Vec<f64>>>>, BatchValidationError> {
445 self.dg_bounds_matrix_list_with_progress(None)
446 }
447
448 pub fn dg_bounds_matrix_list_with_progress(
449 &self,
450 progress: BatchProgress<'_>,
451 ) -> Result<Vec<Option<Vec<Vec<f64>>>>, BatchValidationError> {
452 self.collect_optional_values(
453 "dg_bounds_matrix",
454 "DistanceGeometryError",
455 |molecule| {
456 molecule
457 .dg_bounds_matrix()
458 .map_err(|error| error.to_string())
459 },
460 progress,
461 )
462 }
463
464 pub fn morgan_fingerprint_list(
465 &self,
466 params: &crate::MorganFingerprintParams,
467 ) -> Result<Vec<Option<crate::Fingerprint>>, BatchValidationError> {
468 self.morgan_fingerprint_list_with_progress(params, None)
469 }
470
471 pub fn morgan_fingerprint_list_with_progress(
472 &self,
473 params: &crate::MorganFingerprintParams,
474 progress: BatchProgress<'_>,
475 ) -> Result<Vec<Option<crate::Fingerprint>>, BatchValidationError> {
476 self.collect_optional_values(
477 "morgan_fingerprint",
478 "FingerprintError",
479 |molecule| {
480 molecule
481 .morgan_fingerprint(params)
482 .map_err(|error| error.to_string())
483 },
484 progress,
485 )
486 }
487
488 pub fn morgan_fingerprint_with_output_list(
489 &self,
490 params: &crate::MorganFingerprintParams,
491 ) -> Result<Vec<Option<crate::MorganFingerprintOutput>>, BatchValidationError> {
492 self.morgan_fingerprint_with_output_list_with_progress(params, None)
493 }
494
495 pub fn morgan_fingerprint_with_output_list_with_progress(
496 &self,
497 params: &crate::MorganFingerprintParams,
498 progress: BatchProgress<'_>,
499 ) -> Result<Vec<Option<crate::MorganFingerprintOutput>>, BatchValidationError> {
500 self.collect_optional_values(
501 "morgan_fingerprint",
502 "FingerprintError",
503 |molecule| {
504 molecule
505 .morgan_fingerprint_with_output(params)
506 .map_err(|error| error.to_string())
507 },
508 progress,
509 )
510 }
511
512 pub fn to_svg_list(
513 &self,
514 width: u32,
515 height: u32,
516 ) -> Result<Vec<Option<String>>, BatchValidationError> {
517 self.to_svg_list_with_progress(width, height, None)
518 }
519
520 pub fn to_svg_list_with_progress(
521 &self,
522 width: u32,
523 height: u32,
524 progress: BatchProgress<'_>,
525 ) -> Result<Vec<Option<String>>, BatchValidationError> {
526 self.collect_optional_values(
527 "to_svg",
528 "SvgDrawError",
529 |molecule| {
530 molecule
531 .to_svg(width, height)
532 .map_err(|error| error.to_string())
533 },
534 progress,
535 )
536 }
537
538 pub fn prepare_for_drawing_parity_list(
539 &self,
540 ) -> Result<Vec<Option<PreparedDrawMolecule>>, BatchValidationError> {
541 self.collect_optional_values(
542 "prepare_for_drawing_parity",
543 "PreparedDrawError",
544 |molecule| {
545 molecule
546 .prepare_for_drawing_parity()
547 .map_err(|error| error.to_string())
548 },
549 None,
550 )
551 }
552
553 fn transform_valid<F>(
554 &self,
555 stage: &'static str,
556 error_type: &'static str,
557 errors: BatchErrorMode,
558 transform: F,
559 progress: BatchProgress<'_>,
560 ) -> Result<Self, BatchValidationError>
561 where
562 F: Fn(&Molecule) -> Result<Molecule, String> + Sync,
563 {
564 let records: Vec<BatchRecord> = self
565 .records
566 .par_iter()
567 .enumerate()
568 .filter_map(|(index, record)| {
569 let out = match record {
570 BatchRecord::Valid(molecule) => match transform(molecule) {
571 Ok(molecule) => Some(BatchRecord::Valid(molecule)),
572 Err(message) => {
573 let error =
574 BatchRecordError::new(index, None, stage, error_type, message);
575 match errors {
576 BatchErrorMode::Raise | BatchErrorMode::Keep => {
577 Some(BatchRecord::Invalid(error))
578 }
579 BatchErrorMode::Skip => None,
580 }
581 }
582 },
583 BatchRecord::Invalid(error) => match errors {
584 BatchErrorMode::Raise | BatchErrorMode::Keep => {
585 Some(BatchRecord::Invalid(error.clone()))
586 }
587 BatchErrorMode::Skip => None,
588 },
589 };
590 tick_progress(progress);
591 out
592 })
593 .collect();
594 Self::from_records_with_mode(records, errors)
595 }
596
597 fn collect_optional_values<T, F>(
598 &self,
599 stage: &'static str,
600 error_type: &'static str,
601 collect: F,
602 progress: BatchProgress<'_>,
603 ) -> Result<Vec<Option<T>>, BatchValidationError>
604 where
605 T: Send,
606 F: Fn(&Molecule) -> Result<T, String> + Sync,
607 {
608 let pairs: Vec<(Option<T>, Option<BatchRecordError>)> = self
609 .records
610 .par_iter()
611 .enumerate()
612 .map(|(index, record)| {
613 let out = match record {
614 BatchRecord::Valid(molecule) => match collect(molecule) {
615 Ok(value) => (Some(value), None),
616 Err(message) => (
617 None,
618 Some(BatchRecordError::new(
619 index, None, stage, error_type, message,
620 )),
621 ),
622 },
623 BatchRecord::Invalid(_) => (None, None),
624 };
625 tick_progress(progress);
626 out
627 })
628 .collect();
629 let mut values = Vec::with_capacity(pairs.len());
630 let mut errors = Vec::new();
631 for (value, error) in pairs {
632 values.push(value);
633 if let Some(error) = error {
634 errors.push(error);
635 }
636 }
637 if errors.is_empty() {
638 Ok(values)
639 } else {
640 Err(BatchValidationError { errors })
641 }
642 }
643
644 pub fn write_images(
645 &self,
646 out_dir: &Path,
647 format: &str,
648 width: u32,
649 height: u32,
650 errors: BatchErrorMode,
651 filenames: Option<&[Option<String>]>,
652 ) -> Result<BatchExportReport, BatchValidationError> {
653 self.write_images_with_progress(out_dir, format, width, height, errors, filenames, None)
654 }
655
656 pub fn write_images_with_progress(
657 &self,
658 out_dir: &Path,
659 format: &str,
660 width: u32,
661 height: u32,
662 errors: BatchErrorMode,
663 filenames: Option<&[Option<String>]>,
664 progress: BatchProgress<'_>,
665 ) -> Result<BatchExportReport, BatchValidationError> {
666 let format = format.to_ascii_lowercase();
667 let paths = output_paths(out_dir, self.records.len(), &format, filenames, "to_images")?;
668 fs::create_dir_all(out_dir).map_err(|error| BatchValidationError {
669 errors: vec![BatchRecordError::new(
670 0,
671 Some(out_dir.display().to_string()),
672 "to_images",
673 "IoError",
674 format!("create output directory failed: {error}"),
675 )],
676 })?;
677 let outcomes: Vec<Result<(), BatchRecordError>> = self
678 .records
679 .par_iter()
680 .enumerate()
681 .filter_map(|(index, record)| {
682 let out = match record {
683 BatchRecord::Valid(molecule) => Some(
684 write_one_image(molecule, &paths[index], &format, width, height).map_err(
685 |message| {
686 BatchRecordError::new(
687 index,
688 Some(paths[index].display().to_string()),
689 "to_images",
690 "ImageExportError",
691 message,
692 )
693 },
694 ),
695 ),
696 BatchRecord::Invalid(error) => match errors {
697 BatchErrorMode::Raise | BatchErrorMode::Keep => Some(Err(error.clone())),
698 BatchErrorMode::Skip => None,
699 },
700 };
701 tick_progress(progress);
702 out
703 })
704 .collect();
705 let mut success = 0usize;
706 let mut report_errors = Vec::new();
707 for outcome in outcomes {
708 match outcome {
709 Ok(()) => success += 1,
710 Err(error) => report_errors.push(error),
711 }
712 }
713 if matches!(errors, BatchErrorMode::Raise) && !report_errors.is_empty() {
714 return Err(BatchValidationError {
715 errors: report_errors,
716 });
717 }
718 Ok(BatchExportReport {
719 total: self.records.len(),
720 success,
721 failed: report_errors.len(),
722 errors: report_errors,
723 })
724 }
725
726 pub fn write_sdf(
727 &self,
728 path: &Path,
729 format: SdfFormat,
730 errors: BatchErrorMode,
731 ) -> Result<BatchExportReport, BatchValidationError> {
732 self.write_sdf_with_progress(path, format, errors, None)
733 }
734
735 pub fn write_sdf_with_progress(
736 &self,
737 path: &Path,
738 format: SdfFormat,
739 errors: BatchErrorMode,
740 progress: BatchProgress<'_>,
741 ) -> Result<BatchExportReport, BatchValidationError> {
742 let outcomes: Vec<Result<String, BatchRecordError>> = self
743 .records
744 .par_iter()
745 .enumerate()
746 .filter_map(|(index, record)| {
747 let out = match record {
748 BatchRecord::Valid(molecule) => Some(
749 molecule_to_sdf_record_string(molecule, format).map_err(|error| {
750 BatchRecordError::new(index, None, "to_sdf", "SdfWriteError", error)
751 }),
752 ),
753 BatchRecord::Invalid(error) => match errors {
754 BatchErrorMode::Raise | BatchErrorMode::Keep => Some(Err(error.clone())),
755 BatchErrorMode::Skip => None,
756 },
757 };
758 tick_progress(progress);
759 out
760 })
761 .collect();
762 let mut blocks = Vec::new();
763 let mut report_errors = Vec::new();
764 for outcome in outcomes {
765 match outcome {
766 Ok(block) => blocks.push(block),
767 Err(error) => report_errors.push(error),
768 }
769 }
770 if matches!(errors, BatchErrorMode::Raise) && !report_errors.is_empty() {
771 return Err(BatchValidationError {
772 errors: report_errors,
773 });
774 }
775 let mut file = File::create(path).map_err(|error| BatchValidationError {
776 errors: vec![BatchRecordError::new(
777 0,
778 Some(path.display().to_string()),
779 "to_sdf",
780 "IoError",
781 format!("create SDF failed: {error}"),
782 )],
783 })?;
784 let success = blocks.len();
785 for block in blocks {
786 file.write_all(block.as_bytes())
787 .map_err(|error| BatchValidationError {
788 errors: vec![BatchRecordError::new(
789 0,
790 Some(path.display().to_string()),
791 "to_sdf",
792 "IoError",
793 format!("write SDF failed: {error}"),
794 )],
795 })?;
796 }
797 Ok(BatchExportReport {
798 total: self.records.len(),
799 success,
800 failed: report_errors.len(),
801 errors: report_errors,
802 })
803 }
804
805 pub fn write_sdf_files(
806 &self,
807 out_dir: &Path,
808 format: SdfFormat,
809 errors: BatchErrorMode,
810 filenames: Option<&[Option<String>]>,
811 ) -> Result<BatchExportReport, BatchValidationError> {
812 self.write_sdf_files_with_progress(out_dir, format, errors, filenames, None)
813 }
814
815 pub fn write_sdf_files_with_progress(
816 &self,
817 out_dir: &Path,
818 format: SdfFormat,
819 errors: BatchErrorMode,
820 filenames: Option<&[Option<String>]>,
821 progress: BatchProgress<'_>,
822 ) -> Result<BatchExportReport, BatchValidationError> {
823 let paths = output_paths(
824 out_dir,
825 self.records.len(),
826 "sdf",
827 filenames,
828 "to_sdf_files",
829 )?;
830 fs::create_dir_all(out_dir).map_err(|error| BatchValidationError {
831 errors: vec![BatchRecordError::new(
832 0,
833 Some(out_dir.display().to_string()),
834 "to_sdf_files",
835 "IoError",
836 format!("create output directory failed: {error}"),
837 )],
838 })?;
839 let outcomes: Vec<Result<(), BatchRecordError>> = self
840 .records
841 .par_iter()
842 .enumerate()
843 .filter_map(|(index, record)| {
844 let out = match record {
845 BatchRecord::Valid(molecule) => Some(
846 write_one_sdf_file(molecule, &paths[index], format).map_err(|message| {
847 BatchRecordError::new(
848 index,
849 Some(paths[index].display().to_string()),
850 "to_sdf_files",
851 "SdfWriteError",
852 message,
853 )
854 }),
855 ),
856 BatchRecord::Invalid(error) => match errors {
857 BatchErrorMode::Raise | BatchErrorMode::Keep => Some(Err(error.clone())),
858 BatchErrorMode::Skip => None,
859 },
860 };
861 tick_progress(progress);
862 out
863 })
864 .collect();
865 let mut success = 0usize;
866 let mut report_errors = Vec::new();
867 for outcome in outcomes {
868 match outcome {
869 Ok(()) => success += 1,
870 Err(error) => report_errors.push(error),
871 }
872 }
873 if matches!(errors, BatchErrorMode::Raise) && !report_errors.is_empty() {
874 return Err(BatchValidationError {
875 errors: report_errors,
876 });
877 }
878 Ok(BatchExportReport {
879 total: self.records.len(),
880 success,
881 failed: report_errors.len(),
882 errors: report_errors,
883 })
884 }
885}
886
887fn write_one_image(
888 molecule: &Molecule,
889 path: &Path,
890 format: &str,
891 width: u32,
892 height: u32,
893) -> Result<(), String> {
894 match format {
895 "svg" => {
896 let svg = molecule
897 .to_svg(width, height)
898 .map_err(|error| error.to_string())?;
899 fs::write(path, svg).map_err(|error| error.to_string())
900 }
901 "png" => {
902 let png = molecule
903 .to_png(width, height)
904 .map_err(|error| error.to_string())?;
905 fs::write(path, png).map_err(|error| error.to_string())
906 }
907 other => Err(format!(
908 "unsupported image format '{other}', expected 'png' or 'svg'"
909 )),
910 }
911}
912
913fn write_one_sdf_file(molecule: &Molecule, path: &Path, format: SdfFormat) -> Result<(), String> {
914 let block = molecule_to_sdf_record_string(molecule, format)?;
915 fs::write(path, block).map_err(|error| error.to_string())
916}
917
918fn molecule_to_sdf_record_string(molecule: &Molecule, format: SdfFormat) -> Result<String, String> {
919 if molecule.coords_2d().is_some() {
920 molblock::mol_to_2d_sdf_record(molecule, format).map_err(|error| error.to_string())
921 } else if molecule.coords_3d().is_some() {
922 molblock::mol_to_3d_sdf_record(molecule, format).map_err(|error| error.to_string())
923 } else {
924 Err(
925 "SDF writing requires coordinates; call with_2d_coords() or read a molecule with 2D/3D coordinates before writing SDF"
926 .to_owned(),
927 )
928 }
929}
930
931fn output_paths(
932 out_dir: &Path,
933 total: usize,
934 extension: &str,
935 filenames: Option<&[Option<String>]>,
936 stage: &'static str,
937) -> Result<Vec<PathBuf>, BatchValidationError> {
938 if let Some(filenames) = filenames
939 && filenames.len() != total
940 {
941 return Err(BatchValidationError {
942 errors: vec![BatchRecordError::new(
943 0,
944 None,
945 stage,
946 "FilenameError",
947 format!(
948 "filenames length {} must match batch length {total}",
949 filenames.len()
950 ),
951 )],
952 });
953 }
954
955 let mut seen = HashSet::new();
956 let mut paths = Vec::with_capacity(total);
957 for index in 0..total {
958 let filename = match filenames.and_then(|items| items[index].as_deref()) {
959 Some(raw) => normalize_output_filename(raw, extension).map_err(|message| {
960 BatchValidationError {
961 errors: vec![BatchRecordError::new(
962 index,
963 Some(raw.to_string()),
964 stage,
965 "FilenameError",
966 message,
967 )],
968 }
969 })?,
970 None => format!("{index:06}.{extension}"),
971 };
972 if !seen.insert(filename.clone()) {
973 return Err(BatchValidationError {
974 errors: vec![BatchRecordError::new(
975 index,
976 Some(filename),
977 stage,
978 "FilenameError",
979 "duplicate output filename",
980 )],
981 });
982 }
983 paths.push(out_dir.join(filename));
984 }
985 Ok(paths)
986}
987
988fn normalize_output_filename(raw: &str, extension: &str) -> Result<String, String> {
989 let trimmed = raw.trim();
990 if trimmed.is_empty() {
991 return Err("filename must not be empty".to_string());
992 }
993 let path = Path::new(trimmed);
994 if path.is_absolute() {
995 return Err("filename must be relative to the output directory".to_string());
996 }
997 let components = path.components().collect::<Vec<_>>();
998 if components.len() != 1 || !matches!(components[0], Component::Normal(_)) {
999 return Err("filename must not include path separators or '..'".to_string());
1000 }
1001 let Some(file_name) = path.file_name().and_then(|value| value.to_str()) else {
1002 return Err("filename must be valid UTF-8".to_string());
1003 };
1004 match path.extension().and_then(|value| value.to_str()) {
1005 Some(actual) if actual.eq_ignore_ascii_case(extension) => Ok(file_name.to_string()),
1006 Some(actual) => Err(format!(
1007 "filename extension '.{actual}' does not match expected '.{extension}'"
1008 )),
1009 None => Ok(format!("{file_name}.{extension}")),
1010 }
1011}
1012
1013fn split_sdf_record_strings(sdf_text: &str) -> Vec<String> {
1014 let mut records = Vec::new();
1015 let mut current = String::new();
1016 let mut after_record = false;
1017 let lines = sdf_text.lines().collect::<Vec<_>>();
1018
1019 for (line_idx, line) in lines.iter().enumerate() {
1020 if current.is_empty()
1021 && after_record
1022 && line.trim().is_empty()
1023 && lines
1024 .get(line_idx + 1)
1025 .is_some_and(|next| next.trim().is_empty())
1026 {
1027 after_record = false;
1028 continue;
1029 }
1030 after_record = false;
1031 current.push_str(line);
1032 current.push('\n');
1033 if line.trim_end() == "$$$$" {
1034 records.push(std::mem::take(&mut current));
1035 after_record = true;
1036 }
1037 }
1038
1039 if !current.trim().is_empty() {
1040 records.push(current);
1041 }
1042
1043 records
1044}
1045
1046#[cfg(test)]
1047mod tests {
1048 use super::{BatchErrorMode, BatchRecord, MoleculeBatch};
1049 use crate::Molecule;
1050 use crate::io::molblock::SdfFormat;
1051 use std::fs;
1052 use std::path::PathBuf;
1053
1054 fn unique_temp_dir(name: &str) -> PathBuf {
1055 std::env::temp_dir().join(format!(
1056 "cosmolkit-batch-{name}-{}-{:?}",
1057 std::process::id(),
1058 std::thread::current().id()
1059 ))
1060 }
1061
1062 #[test]
1063 fn from_smiles_list_keeps_invalid_records_in_order() {
1064 let smiles = vec!["CCO".to_string(), "C1CC".to_string(), "CC".to_string()];
1065 let batch = MoleculeBatch::from_smiles_list(&smiles, BatchErrorMode::Keep)
1066 .expect("keep mode should not raise");
1067
1068 assert_eq!(batch.len(), 3);
1069 assert_eq!(batch.valid_mask(), vec![true, false, true]);
1070 assert_eq!(batch.errors()[0].index, 1);
1071 assert_eq!(
1072 batch
1073 .filter_valid()
1074 .to_smiles_list(true)
1075 .expect("valid records should serialize to SMILES"),
1076 vec![Some("CCO".to_string()), Some("CC".to_string())]
1077 );
1078 }
1079
1080 #[test]
1081 fn from_smiles_list_raise_aggregates_errors() {
1082 let smiles = vec!["C1CC".to_string(), "N1".to_string()];
1083 let error = MoleculeBatch::from_smiles_list(&smiles, BatchErrorMode::Raise)
1084 .expect_err("raise mode should aggregate invalid inputs");
1085
1086 assert_eq!(error.errors.len(), 2);
1087 assert_eq!(error.errors[0].stage, "parse_smiles");
1088 assert_eq!(error.errors[1].index, 1);
1089 }
1090
1091 #[test]
1092 fn transforms_skip_invalid_records_when_requested() {
1093 let smiles = vec!["CCO".to_string(), "C1CC".to_string()];
1094 let batch = MoleculeBatch::from_smiles_list(&smiles, BatchErrorMode::Keep)
1095 .expect("keep mode should not raise");
1096 let prepared = batch
1097 .add_hydrogens(BatchErrorMode::Skip)
1098 .expect("skip mode should drop invalid records")
1099 .compute_2d_coords(BatchErrorMode::Skip)
1100 .expect("2D coords should compute for valid record");
1101
1102 assert_eq!(prepared.len(), 1);
1103 assert_eq!(prepared.valid_count(), 1);
1104 assert_eq!(prepared.invalid_count(), 0);
1105 }
1106
1107 #[test]
1108 fn sanitize_transform_applies_pipeline_to_valid_records() {
1109 let raw = Molecule::from_smiles_with_sanitize("CN(=O)=O", false)
1110 .expect("unsanitized nitro SMILES should parse");
1111 let batch = MoleculeBatch::new(vec![BatchRecord::Valid(raw)]);
1112
1113 let sanitized = batch
1114 .sanitize(BatchErrorMode::Raise)
1115 .expect("sanitize should transform valid records");
1116
1117 let BatchRecord::Valid(mol) = &sanitized.records[0] else {
1118 panic!("record should remain valid");
1119 };
1120 assert_eq!(
1121 mol.atoms()
1122 .iter()
1123 .map(|atom| atom.formal_charge)
1124 .collect::<Vec<_>>(),
1125 vec![0, 1, -1, 0]
1126 );
1127 }
1128
1129 #[test]
1130 fn write_images_uses_custom_filenames_and_rejects_unsafe_names() {
1131 let smiles = vec!["CCO".to_string(), "C1CC".to_string(), "CC".to_string()];
1132 let batch = MoleculeBatch::from_smiles_list(&smiles, BatchErrorMode::Keep)
1133 .expect("keep mode should not raise");
1134 let out_dir = unique_temp_dir("images");
1135 let filenames = vec![
1136 Some("ethanol".to_string()),
1137 Some("invalid.svg".to_string()),
1138 None,
1139 ];
1140
1141 let report = batch
1142 .write_images(
1143 &out_dir,
1144 "svg",
1145 120,
1146 100,
1147 BatchErrorMode::Skip,
1148 Some(&filenames),
1149 )
1150 .expect("custom image filenames should write");
1151
1152 assert_eq!(report.success, 2);
1153 assert!(out_dir.join("ethanol.svg").exists());
1154 assert!(out_dir.join("000002.svg").exists());
1155 let _ = fs::remove_dir_all(&out_dir);
1156
1157 let bad = vec![Some("../escape".to_string()), None, None];
1158 let error = batch
1159 .write_images(
1160 &unique_temp_dir("bad-images"),
1161 "svg",
1162 120,
1163 100,
1164 BatchErrorMode::Skip,
1165 Some(&bad),
1166 )
1167 .expect_err("unsafe filename should be rejected");
1168 assert_eq!(error.errors[0].error_type, "FilenameError");
1169 }
1170
1171 #[test]
1172 fn write_sdf_files_uses_custom_filenames() {
1173 let smiles = vec!["CCO".to_string(), "CC".to_string()];
1174 let batch = MoleculeBatch::from_smiles_list(&smiles, BatchErrorMode::Raise)
1175 .expect("SMILES should parse")
1176 .compute_2d_coords(BatchErrorMode::Raise)
1177 .expect("2D coords should compute");
1178 let out_dir = unique_temp_dir("sdf-files");
1179 let filenames = vec![Some("ethanol".to_string()), Some("ethane.sdf".to_string())];
1180
1181 let report = batch
1182 .write_sdf_files(
1183 &out_dir,
1184 SdfFormat::V2000,
1185 BatchErrorMode::Raise,
1186 Some(&filenames),
1187 )
1188 .expect("custom SDF filenames should write");
1189
1190 assert_eq!(report.success, 2);
1191 assert!(out_dir.join("ethanol.sdf").exists());
1192 assert!(out_dir.join("ethane.sdf").exists());
1193 let _ = fs::remove_dir_all(&out_dir);
1194 }
1195}