Skip to main content

cosmolkit_core/
batch.rs

1use crate::io::molblock::{self, SdfFormat};
2use crate::io::sdf::{SdfCoordinateMode, read_sdf_from_str_with_coordinate_mode};
3use crate::{Molecule, PreparedDrawMolecule, SmilesWriteParams};
4use indicatif::{ProgressBar, ProgressDrawTarget, ProgressStyle};
5use rayon::prelude::*;
6use std::collections::HashSet;
7use std::fs::{self, File};
8use std::io::Write;
9use std::path::{Component, Path, PathBuf};
10use thiserror::Error;
11
12#[derive(Debug, Copy, Clone, PartialEq, Eq)]
13pub enum BatchErrorMode {
14    Raise,
15    Keep,
16    Skip,
17}
18
19#[derive(Debug, Clone, PartialEq, Eq)]
20pub struct BatchRecordError {
21    pub index: usize,
22    pub input: Option<String>,
23    pub stage: String,
24    pub error_type: String,
25    pub message: String,
26}
27
28impl BatchRecordError {
29    #[must_use]
30    pub fn new(
31        index: usize,
32        input: Option<String>,
33        stage: impl Into<String>,
34        error_type: impl Into<String>,
35        message: impl Into<String>,
36    ) -> Self {
37        Self {
38            index,
39            input,
40            stage: stage.into(),
41            error_type: error_type.into(),
42            message: message.into(),
43        }
44    }
45}
46
47#[derive(Debug, Clone, PartialEq)]
48pub enum BatchRecord {
49    Valid(Molecule),
50    Invalid(BatchRecordError),
51}
52
53#[derive(Debug, Clone, PartialEq, Default)]
54pub struct MoleculeBatch {
55    pub records: Vec<BatchRecord>,
56}
57
58#[derive(Debug, Clone, PartialEq, Eq)]
59pub struct BatchExportReport {
60    pub total: usize,
61    pub success: usize,
62    pub failed: usize,
63    pub errors: Vec<BatchRecordError>,
64}
65
66#[derive(Debug, Clone, PartialEq, Eq, Error)]
67#[error("batch validation failed with {} error(s)", .errors.len())]
68pub struct BatchValidationError {
69    pub errors: Vec<BatchRecordError>,
70}
71
72/// Optional per-record progress callback used by parallel batch operations.
73pub type BatchProgress<'a> = Option<&'a (dyn Fn() + Sync)>;
74
75/// Terminal progress bar for Rust batch workflows.
76pub struct BatchProgressBar {
77    inner: ProgressBar,
78}
79
80impl BatchProgressBar {
81    /// Create a stderr progress bar that adapts to terminal width.
82    #[must_use]
83    pub fn new(total: usize, message: impl Into<String>) -> Self {
84        let progress_bar = ProgressBar::with_draw_target(
85            Some(total as u64),
86            ProgressDrawTarget::term_like_with_hz(Box::new(console::Term::buffered_stderr()), 20),
87        );
88        let style = ProgressStyle::with_template(
89            "{spinner:.green} {msg} [{elapsed_precise}] [{wide_bar:.cyan/blue}] {pos}/{len}",
90        )
91        .unwrap_or_else(|_| ProgressStyle::default_bar());
92        progress_bar.set_style(style);
93        progress_bar.set_message(message.into());
94        Self {
95            inner: progress_bar,
96        }
97    }
98
99    /// Return a callback suitable for `*_with_progress()` batch methods.
100    #[must_use]
101    pub fn callback(&self) -> Box<dyn Fn() + Sync + '_> {
102        Box::new(|| self.inner.inc(1))
103    }
104
105    /// Finish rendering the progress bar and emit a newline.
106    pub fn finish(&self) {
107        self.inner.finish();
108        eprintln!();
109    }
110}
111
112/// Create a progress bar when `enabled` is true and `total` is non-zero.
113#[must_use]
114pub fn batch_progress_bar(
115    enabled: bool,
116    total: usize,
117    message: impl Into<String>,
118) -> Option<BatchProgressBar> {
119    if !enabled || total == 0 {
120        return None;
121    }
122    Some(BatchProgressBar::new(total, message))
123}
124
125fn tick_progress(progress: BatchProgress<'_>) {
126    if let Some(progress) = progress {
127        progress();
128    }
129}
130
131impl MoleculeBatch {
132    #[must_use]
133    pub fn new(records: Vec<BatchRecord>) -> Self {
134        Self { records }
135    }
136
137    pub fn from_smiles_list(
138        smiles: &[String],
139        errors: BatchErrorMode,
140    ) -> Result<Self, BatchValidationError> {
141        Self::from_smiles_list_with_sanitize(smiles, true, errors)
142    }
143
144    pub fn from_smiles_list_with_sanitize(
145        smiles: &[String],
146        sanitize: bool,
147        errors: BatchErrorMode,
148    ) -> Result<Self, BatchValidationError> {
149        let records: Vec<BatchRecord> = smiles
150            .par_iter()
151            .enumerate()
152            .filter_map(|(index, smiles)| {
153                match Molecule::from_smiles_with_sanitize(smiles, sanitize) {
154                    Ok(molecule) => Some(BatchRecord::Valid(molecule)),
155                    Err(error) => {
156                        let record_error = BatchRecordError::new(
157                            index,
158                            Some(smiles.clone()),
159                            "parse_smiles",
160                            "SmilesParseError",
161                            error.to_string(),
162                        );
163                        match errors {
164                            BatchErrorMode::Raise | BatchErrorMode::Keep => {
165                                Some(BatchRecord::Invalid(record_error))
166                            }
167                            BatchErrorMode::Skip => None,
168                        }
169                    }
170                }
171            })
172            .collect();
173        Self::from_records_with_mode(records, errors)
174    }
175
176    fn from_sdf_record_strings(
177        records: &[String],
178        coordinate_mode: SdfCoordinateMode,
179        errors: BatchErrorMode,
180    ) -> Result<Self, BatchValidationError> {
181        let batch_records: Vec<BatchRecord> = records
182            .par_iter()
183            .enumerate()
184            .filter_map(|(index, sdf)| {
185                match read_sdf_from_str_with_coordinate_mode(sdf, coordinate_mode) {
186                    Ok(record) => Some(BatchRecord::Valid(record.molecule)),
187                    Err(error) => {
188                        let record_error = BatchRecordError::new(
189                            index,
190                            None,
191                            "read_sdf",
192                            "SdfReadError",
193                            error.to_string(),
194                        );
195                        match errors {
196                            BatchErrorMode::Raise | BatchErrorMode::Keep => {
197                                Some(BatchRecord::Invalid(record_error))
198                            }
199                            BatchErrorMode::Skip => None,
200                        }
201                    }
202                }
203            })
204            .collect();
205        Self::from_records_with_mode(batch_records, errors)
206    }
207
208    pub fn read_sdf_records_from_str(
209        sdf_text: &str,
210        coordinate_mode: SdfCoordinateMode,
211        errors: BatchErrorMode,
212    ) -> Result<Self, BatchValidationError> {
213        let records = split_sdf_record_strings(sdf_text);
214        Self::from_sdf_record_strings(&records, coordinate_mode, errors)
215    }
216
217    fn from_records_with_mode(
218        records: Vec<BatchRecord>,
219        errors: BatchErrorMode,
220    ) -> Result<Self, BatchValidationError> {
221        if matches!(errors, BatchErrorMode::Raise) {
222            let collected = records
223                .iter()
224                .filter_map(|record| match record {
225                    BatchRecord::Valid(_) => None,
226                    BatchRecord::Invalid(error) => Some(error.clone()),
227                })
228                .collect::<Vec<_>>();
229            if !collected.is_empty() {
230                return Err(BatchValidationError { errors: collected });
231            }
232        }
233        Ok(Self { records })
234    }
235
236    #[must_use]
237    pub fn len(&self) -> usize {
238        self.records.len()
239    }
240
241    #[must_use]
242    pub fn is_empty(&self) -> bool {
243        self.records.is_empty()
244    }
245
246    #[must_use]
247    pub fn valid_mask(&self) -> Vec<bool> {
248        self.records
249            .iter()
250            .map(|record| matches!(record, BatchRecord::Valid(_)))
251            .collect()
252    }
253
254    #[must_use]
255    pub fn errors(&self) -> Vec<BatchRecordError> {
256        self.records
257            .iter()
258            .filter_map(|record| match record {
259                BatchRecord::Valid(_) => None,
260                BatchRecord::Invalid(error) => Some(error.clone()),
261            })
262            .collect()
263    }
264
265    #[must_use]
266    pub fn valid_count(&self) -> usize {
267        self.records
268            .iter()
269            .filter(|record| matches!(record, BatchRecord::Valid(_)))
270            .count()
271    }
272
273    #[must_use]
274    pub fn invalid_count(&self) -> usize {
275        self.records.len() - self.valid_count()
276    }
277
278    #[must_use]
279    pub fn filter_valid(&self) -> Self {
280        let records = self
281            .records
282            .iter()
283            .filter_map(|record| match record {
284                BatchRecord::Valid(molecule) => Some(BatchRecord::Valid(molecule.clone())),
285                BatchRecord::Invalid(_) => None,
286            })
287            .collect();
288        Self { records }
289    }
290
291    pub fn add_hydrogens(&self, errors: BatchErrorMode) -> Result<Self, BatchValidationError> {
292        self.add_hydrogens_with_progress(errors, None)
293    }
294
295    pub fn add_hydrogens_with_progress(
296        &self,
297        errors: BatchErrorMode,
298        progress: BatchProgress<'_>,
299    ) -> Result<Self, BatchValidationError> {
300        self.transform_valid(
301            "add_hydrogens",
302            "AddHydrogensError",
303            errors,
304            |molecule| {
305                molecule
306                    .with_hydrogens()
307                    .map_err(|error| format!("{error:?}"))
308            },
309            progress,
310        )
311    }
312
313    pub fn remove_hydrogens(&self, errors: BatchErrorMode) -> Result<Self, BatchValidationError> {
314        self.remove_hydrogens_with_progress(errors, None)
315    }
316
317    pub fn remove_hydrogens_with_progress(
318        &self,
319        errors: BatchErrorMode,
320        progress: BatchProgress<'_>,
321    ) -> Result<Self, BatchValidationError> {
322        self.transform_valid(
323            "remove_hydrogens",
324            "RemoveHydrogensError",
325            errors,
326            |molecule| {
327                molecule
328                    .without_hydrogens()
329                    .map_err(|error| format!("{error:?}"))
330            },
331            progress,
332        )
333    }
334
335    pub fn sanitize(&self, errors: BatchErrorMode) -> Result<Self, BatchValidationError> {
336        self.sanitize_with_progress(errors, None)
337    }
338
339    pub fn sanitize_with_progress(
340        &self,
341        errors: BatchErrorMode,
342        progress: BatchProgress<'_>,
343    ) -> Result<Self, BatchValidationError> {
344        self.transform_valid(
345            "sanitize",
346            "SanitizeError",
347            errors,
348            |molecule| molecule.sanitize().map_err(|error| error.to_string()),
349            progress,
350        )
351    }
352
353    pub fn kekulize(&self, errors: BatchErrorMode) -> Result<Self, BatchValidationError> {
354        self.kekulize_with_sanitize(false, errors)
355    }
356
357    pub fn kekulize_with_sanitize(
358        &self,
359        sanitize: bool,
360        errors: BatchErrorMode,
361    ) -> Result<Self, BatchValidationError> {
362        self.kekulize_with_sanitize_and_progress(sanitize, errors, None)
363    }
364
365    pub fn kekulize_with_sanitize_and_progress(
366        &self,
367        sanitize: bool,
368        errors: BatchErrorMode,
369        progress: BatchProgress<'_>,
370    ) -> Result<Self, BatchValidationError> {
371        self.transform_valid(
372            "kekulize",
373            "KekulizeError",
374            errors,
375            |molecule| {
376                molecule
377                    .with_kekulized_bonds(sanitize)
378                    .map_err(|error| format!("{error:?}"))
379            },
380            progress,
381        )
382    }
383
384    pub fn compute_2d_coords(&self, errors: BatchErrorMode) -> Result<Self, BatchValidationError> {
385        self.compute_2d_coords_with_progress(errors, None)
386    }
387
388    pub fn compute_2d_coords_with_progress(
389        &self,
390        errors: BatchErrorMode,
391        progress: BatchProgress<'_>,
392    ) -> Result<Self, BatchValidationError> {
393        self.transform_valid(
394            "compute_2d_coords",
395            "CoordinateGenerationError",
396            errors,
397            |molecule| molecule.with_2d_coords().map_err(|error| error.to_string()),
398            progress,
399        )
400    }
401
402    pub fn to_smiles_list(
403        &self,
404        isomeric_smiles: bool,
405    ) -> Result<Vec<Option<String>>, BatchValidationError> {
406        self.collect_optional_values(
407            "to_smiles",
408            "SmilesWriteError",
409            |molecule| {
410                molecule
411                    .to_smiles(isomeric_smiles)
412                    .map_err(|error| error.to_string())
413            },
414            None,
415        )
416    }
417
418    pub fn to_smiles_list_with_params(
419        &self,
420        params: &SmilesWriteParams,
421    ) -> Result<Vec<Option<String>>, BatchValidationError> {
422        self.to_smiles_list_with_params_and_progress(params, None)
423    }
424
425    pub fn to_smiles_list_with_params_and_progress(
426        &self,
427        params: &SmilesWriteParams,
428        progress: BatchProgress<'_>,
429    ) -> Result<Vec<Option<String>>, BatchValidationError> {
430        self.collect_optional_values(
431            "to_smiles",
432            "SmilesWriteError",
433            |molecule| {
434                molecule
435                    .to_smiles_with_params(params)
436                    .map_err(|error| error.to_string())
437            },
438            progress,
439        )
440    }
441
442    pub fn dg_bounds_matrix_list(
443        &self,
444    ) -> Result<Vec<Option<Vec<Vec<f64>>>>, BatchValidationError> {
445        self.dg_bounds_matrix_list_with_progress(None)
446    }
447
448    pub fn dg_bounds_matrix_list_with_progress(
449        &self,
450        progress: BatchProgress<'_>,
451    ) -> Result<Vec<Option<Vec<Vec<f64>>>>, BatchValidationError> {
452        self.collect_optional_values(
453            "dg_bounds_matrix",
454            "DistanceGeometryError",
455            |molecule| {
456                molecule
457                    .dg_bounds_matrix()
458                    .map_err(|error| error.to_string())
459            },
460            progress,
461        )
462    }
463
464    pub fn morgan_fingerprint_list(
465        &self,
466        params: &crate::MorganFingerprintParams,
467    ) -> Result<Vec<Option<crate::Fingerprint>>, BatchValidationError> {
468        self.morgan_fingerprint_list_with_progress(params, None)
469    }
470
471    pub fn morgan_fingerprint_list_with_progress(
472        &self,
473        params: &crate::MorganFingerprintParams,
474        progress: BatchProgress<'_>,
475    ) -> Result<Vec<Option<crate::Fingerprint>>, BatchValidationError> {
476        self.collect_optional_values(
477            "morgan_fingerprint",
478            "FingerprintError",
479            |molecule| {
480                molecule
481                    .morgan_fingerprint(params)
482                    .map_err(|error| error.to_string())
483            },
484            progress,
485        )
486    }
487
488    pub fn morgan_fingerprint_with_output_list(
489        &self,
490        params: &crate::MorganFingerprintParams,
491    ) -> Result<Vec<Option<crate::MorganFingerprintOutput>>, BatchValidationError> {
492        self.morgan_fingerprint_with_output_list_with_progress(params, None)
493    }
494
495    pub fn morgan_fingerprint_with_output_list_with_progress(
496        &self,
497        params: &crate::MorganFingerprintParams,
498        progress: BatchProgress<'_>,
499    ) -> Result<Vec<Option<crate::MorganFingerprintOutput>>, BatchValidationError> {
500        self.collect_optional_values(
501            "morgan_fingerprint",
502            "FingerprintError",
503            |molecule| {
504                molecule
505                    .morgan_fingerprint_with_output(params)
506                    .map_err(|error| error.to_string())
507            },
508            progress,
509        )
510    }
511
512    pub fn to_svg_list(
513        &self,
514        width: u32,
515        height: u32,
516    ) -> Result<Vec<Option<String>>, BatchValidationError> {
517        self.to_svg_list_with_progress(width, height, None)
518    }
519
520    pub fn to_svg_list_with_progress(
521        &self,
522        width: u32,
523        height: u32,
524        progress: BatchProgress<'_>,
525    ) -> Result<Vec<Option<String>>, BatchValidationError> {
526        self.collect_optional_values(
527            "to_svg",
528            "SvgDrawError",
529            |molecule| {
530                molecule
531                    .to_svg(width, height)
532                    .map_err(|error| error.to_string())
533            },
534            progress,
535        )
536    }
537
538    pub fn prepare_for_drawing_parity_list(
539        &self,
540    ) -> Result<Vec<Option<PreparedDrawMolecule>>, BatchValidationError> {
541        self.collect_optional_values(
542            "prepare_for_drawing_parity",
543            "PreparedDrawError",
544            |molecule| {
545                molecule
546                    .prepare_for_drawing_parity()
547                    .map_err(|error| error.to_string())
548            },
549            None,
550        )
551    }
552
553    fn transform_valid<F>(
554        &self,
555        stage: &'static str,
556        error_type: &'static str,
557        errors: BatchErrorMode,
558        transform: F,
559        progress: BatchProgress<'_>,
560    ) -> Result<Self, BatchValidationError>
561    where
562        F: Fn(&Molecule) -> Result<Molecule, String> + Sync,
563    {
564        let records: Vec<BatchRecord> = self
565            .records
566            .par_iter()
567            .enumerate()
568            .filter_map(|(index, record)| {
569                let out = match record {
570                    BatchRecord::Valid(molecule) => match transform(molecule) {
571                        Ok(molecule) => Some(BatchRecord::Valid(molecule)),
572                        Err(message) => {
573                            let error =
574                                BatchRecordError::new(index, None, stage, error_type, message);
575                            match errors {
576                                BatchErrorMode::Raise | BatchErrorMode::Keep => {
577                                    Some(BatchRecord::Invalid(error))
578                                }
579                                BatchErrorMode::Skip => None,
580                            }
581                        }
582                    },
583                    BatchRecord::Invalid(error) => match errors {
584                        BatchErrorMode::Raise | BatchErrorMode::Keep => {
585                            Some(BatchRecord::Invalid(error.clone()))
586                        }
587                        BatchErrorMode::Skip => None,
588                    },
589                };
590                tick_progress(progress);
591                out
592            })
593            .collect();
594        Self::from_records_with_mode(records, errors)
595    }
596
597    fn collect_optional_values<T, F>(
598        &self,
599        stage: &'static str,
600        error_type: &'static str,
601        collect: F,
602        progress: BatchProgress<'_>,
603    ) -> Result<Vec<Option<T>>, BatchValidationError>
604    where
605        T: Send,
606        F: Fn(&Molecule) -> Result<T, String> + Sync,
607    {
608        let pairs: Vec<(Option<T>, Option<BatchRecordError>)> = self
609            .records
610            .par_iter()
611            .enumerate()
612            .map(|(index, record)| {
613                let out = match record {
614                    BatchRecord::Valid(molecule) => match collect(molecule) {
615                        Ok(value) => (Some(value), None),
616                        Err(message) => (
617                            None,
618                            Some(BatchRecordError::new(
619                                index, None, stage, error_type, message,
620                            )),
621                        ),
622                    },
623                    BatchRecord::Invalid(_) => (None, None),
624                };
625                tick_progress(progress);
626                out
627            })
628            .collect();
629        let mut values = Vec::with_capacity(pairs.len());
630        let mut errors = Vec::new();
631        for (value, error) in pairs {
632            values.push(value);
633            if let Some(error) = error {
634                errors.push(error);
635            }
636        }
637        if errors.is_empty() {
638            Ok(values)
639        } else {
640            Err(BatchValidationError { errors })
641        }
642    }
643
644    pub fn write_images(
645        &self,
646        out_dir: &Path,
647        format: &str,
648        width: u32,
649        height: u32,
650        errors: BatchErrorMode,
651        filenames: Option<&[Option<String>]>,
652    ) -> Result<BatchExportReport, BatchValidationError> {
653        self.write_images_with_progress(out_dir, format, width, height, errors, filenames, None)
654    }
655
656    pub fn write_images_with_progress(
657        &self,
658        out_dir: &Path,
659        format: &str,
660        width: u32,
661        height: u32,
662        errors: BatchErrorMode,
663        filenames: Option<&[Option<String>]>,
664        progress: BatchProgress<'_>,
665    ) -> Result<BatchExportReport, BatchValidationError> {
666        let format = format.to_ascii_lowercase();
667        let paths = output_paths(out_dir, self.records.len(), &format, filenames, "to_images")?;
668        fs::create_dir_all(out_dir).map_err(|error| BatchValidationError {
669            errors: vec![BatchRecordError::new(
670                0,
671                Some(out_dir.display().to_string()),
672                "to_images",
673                "IoError",
674                format!("create output directory failed: {error}"),
675            )],
676        })?;
677        let outcomes: Vec<Result<(), BatchRecordError>> = self
678            .records
679            .par_iter()
680            .enumerate()
681            .filter_map(|(index, record)| {
682                let out = match record {
683                    BatchRecord::Valid(molecule) => Some(
684                        write_one_image(molecule, &paths[index], &format, width, height).map_err(
685                            |message| {
686                                BatchRecordError::new(
687                                    index,
688                                    Some(paths[index].display().to_string()),
689                                    "to_images",
690                                    "ImageExportError",
691                                    message,
692                                )
693                            },
694                        ),
695                    ),
696                    BatchRecord::Invalid(error) => match errors {
697                        BatchErrorMode::Raise | BatchErrorMode::Keep => Some(Err(error.clone())),
698                        BatchErrorMode::Skip => None,
699                    },
700                };
701                tick_progress(progress);
702                out
703            })
704            .collect();
705        let mut success = 0usize;
706        let mut report_errors = Vec::new();
707        for outcome in outcomes {
708            match outcome {
709                Ok(()) => success += 1,
710                Err(error) => report_errors.push(error),
711            }
712        }
713        if matches!(errors, BatchErrorMode::Raise) && !report_errors.is_empty() {
714            return Err(BatchValidationError {
715                errors: report_errors,
716            });
717        }
718        Ok(BatchExportReport {
719            total: self.records.len(),
720            success,
721            failed: report_errors.len(),
722            errors: report_errors,
723        })
724    }
725
726    pub fn write_sdf(
727        &self,
728        path: &Path,
729        format: SdfFormat,
730        errors: BatchErrorMode,
731    ) -> Result<BatchExportReport, BatchValidationError> {
732        self.write_sdf_with_progress(path, format, errors, None)
733    }
734
735    pub fn write_sdf_with_progress(
736        &self,
737        path: &Path,
738        format: SdfFormat,
739        errors: BatchErrorMode,
740        progress: BatchProgress<'_>,
741    ) -> Result<BatchExportReport, BatchValidationError> {
742        let outcomes: Vec<Result<String, BatchRecordError>> = self
743            .records
744            .par_iter()
745            .enumerate()
746            .filter_map(|(index, record)| {
747                let out = match record {
748                    BatchRecord::Valid(molecule) => Some(
749                        molecule_to_sdf_record_string(molecule, format).map_err(|error| {
750                            BatchRecordError::new(index, None, "to_sdf", "SdfWriteError", error)
751                        }),
752                    ),
753                    BatchRecord::Invalid(error) => match errors {
754                        BatchErrorMode::Raise | BatchErrorMode::Keep => Some(Err(error.clone())),
755                        BatchErrorMode::Skip => None,
756                    },
757                };
758                tick_progress(progress);
759                out
760            })
761            .collect();
762        let mut blocks = Vec::new();
763        let mut report_errors = Vec::new();
764        for outcome in outcomes {
765            match outcome {
766                Ok(block) => blocks.push(block),
767                Err(error) => report_errors.push(error),
768            }
769        }
770        if matches!(errors, BatchErrorMode::Raise) && !report_errors.is_empty() {
771            return Err(BatchValidationError {
772                errors: report_errors,
773            });
774        }
775        let mut file = File::create(path).map_err(|error| BatchValidationError {
776            errors: vec![BatchRecordError::new(
777                0,
778                Some(path.display().to_string()),
779                "to_sdf",
780                "IoError",
781                format!("create SDF failed: {error}"),
782            )],
783        })?;
784        let success = blocks.len();
785        for block in blocks {
786            file.write_all(block.as_bytes())
787                .map_err(|error| BatchValidationError {
788                    errors: vec![BatchRecordError::new(
789                        0,
790                        Some(path.display().to_string()),
791                        "to_sdf",
792                        "IoError",
793                        format!("write SDF failed: {error}"),
794                    )],
795                })?;
796        }
797        Ok(BatchExportReport {
798            total: self.records.len(),
799            success,
800            failed: report_errors.len(),
801            errors: report_errors,
802        })
803    }
804
805    pub fn write_sdf_files(
806        &self,
807        out_dir: &Path,
808        format: SdfFormat,
809        errors: BatchErrorMode,
810        filenames: Option<&[Option<String>]>,
811    ) -> Result<BatchExportReport, BatchValidationError> {
812        self.write_sdf_files_with_progress(out_dir, format, errors, filenames, None)
813    }
814
815    pub fn write_sdf_files_with_progress(
816        &self,
817        out_dir: &Path,
818        format: SdfFormat,
819        errors: BatchErrorMode,
820        filenames: Option<&[Option<String>]>,
821        progress: BatchProgress<'_>,
822    ) -> Result<BatchExportReport, BatchValidationError> {
823        let paths = output_paths(
824            out_dir,
825            self.records.len(),
826            "sdf",
827            filenames,
828            "to_sdf_files",
829        )?;
830        fs::create_dir_all(out_dir).map_err(|error| BatchValidationError {
831            errors: vec![BatchRecordError::new(
832                0,
833                Some(out_dir.display().to_string()),
834                "to_sdf_files",
835                "IoError",
836                format!("create output directory failed: {error}"),
837            )],
838        })?;
839        let outcomes: Vec<Result<(), BatchRecordError>> = self
840            .records
841            .par_iter()
842            .enumerate()
843            .filter_map(|(index, record)| {
844                let out = match record {
845                    BatchRecord::Valid(molecule) => Some(
846                        write_one_sdf_file(molecule, &paths[index], format).map_err(|message| {
847                            BatchRecordError::new(
848                                index,
849                                Some(paths[index].display().to_string()),
850                                "to_sdf_files",
851                                "SdfWriteError",
852                                message,
853                            )
854                        }),
855                    ),
856                    BatchRecord::Invalid(error) => match errors {
857                        BatchErrorMode::Raise | BatchErrorMode::Keep => Some(Err(error.clone())),
858                        BatchErrorMode::Skip => None,
859                    },
860                };
861                tick_progress(progress);
862                out
863            })
864            .collect();
865        let mut success = 0usize;
866        let mut report_errors = Vec::new();
867        for outcome in outcomes {
868            match outcome {
869                Ok(()) => success += 1,
870                Err(error) => report_errors.push(error),
871            }
872        }
873        if matches!(errors, BatchErrorMode::Raise) && !report_errors.is_empty() {
874            return Err(BatchValidationError {
875                errors: report_errors,
876            });
877        }
878        Ok(BatchExportReport {
879            total: self.records.len(),
880            success,
881            failed: report_errors.len(),
882            errors: report_errors,
883        })
884    }
885}
886
887fn write_one_image(
888    molecule: &Molecule,
889    path: &Path,
890    format: &str,
891    width: u32,
892    height: u32,
893) -> Result<(), String> {
894    match format {
895        "svg" => {
896            let svg = molecule
897                .to_svg(width, height)
898                .map_err(|error| error.to_string())?;
899            fs::write(path, svg).map_err(|error| error.to_string())
900        }
901        "png" => {
902            let png = molecule
903                .to_png(width, height)
904                .map_err(|error| error.to_string())?;
905            fs::write(path, png).map_err(|error| error.to_string())
906        }
907        other => Err(format!(
908            "unsupported image format '{other}', expected 'png' or 'svg'"
909        )),
910    }
911}
912
913fn write_one_sdf_file(molecule: &Molecule, path: &Path, format: SdfFormat) -> Result<(), String> {
914    let block = molecule_to_sdf_record_string(molecule, format)?;
915    fs::write(path, block).map_err(|error| error.to_string())
916}
917
918fn molecule_to_sdf_record_string(molecule: &Molecule, format: SdfFormat) -> Result<String, String> {
919    if molecule.coords_2d().is_some() {
920        molblock::mol_to_2d_sdf_record(molecule, format).map_err(|error| error.to_string())
921    } else if molecule.coords_3d().is_some() {
922        molblock::mol_to_3d_sdf_record(molecule, format).map_err(|error| error.to_string())
923    } else {
924        Err(
925            "SDF writing requires coordinates; call with_2d_coords() or read a molecule with 2D/3D coordinates before writing SDF"
926                .to_owned(),
927        )
928    }
929}
930
931fn output_paths(
932    out_dir: &Path,
933    total: usize,
934    extension: &str,
935    filenames: Option<&[Option<String>]>,
936    stage: &'static str,
937) -> Result<Vec<PathBuf>, BatchValidationError> {
938    if let Some(filenames) = filenames
939        && filenames.len() != total
940    {
941        return Err(BatchValidationError {
942            errors: vec![BatchRecordError::new(
943                0,
944                None,
945                stage,
946                "FilenameError",
947                format!(
948                    "filenames length {} must match batch length {total}",
949                    filenames.len()
950                ),
951            )],
952        });
953    }
954
955    let mut seen = HashSet::new();
956    let mut paths = Vec::with_capacity(total);
957    for index in 0..total {
958        let filename = match filenames.and_then(|items| items[index].as_deref()) {
959            Some(raw) => normalize_output_filename(raw, extension).map_err(|message| {
960                BatchValidationError {
961                    errors: vec![BatchRecordError::new(
962                        index,
963                        Some(raw.to_string()),
964                        stage,
965                        "FilenameError",
966                        message,
967                    )],
968                }
969            })?,
970            None => format!("{index:06}.{extension}"),
971        };
972        if !seen.insert(filename.clone()) {
973            return Err(BatchValidationError {
974                errors: vec![BatchRecordError::new(
975                    index,
976                    Some(filename),
977                    stage,
978                    "FilenameError",
979                    "duplicate output filename",
980                )],
981            });
982        }
983        paths.push(out_dir.join(filename));
984    }
985    Ok(paths)
986}
987
988fn normalize_output_filename(raw: &str, extension: &str) -> Result<String, String> {
989    let trimmed = raw.trim();
990    if trimmed.is_empty() {
991        return Err("filename must not be empty".to_string());
992    }
993    let path = Path::new(trimmed);
994    if path.is_absolute() {
995        return Err("filename must be relative to the output directory".to_string());
996    }
997    let components = path.components().collect::<Vec<_>>();
998    if components.len() != 1 || !matches!(components[0], Component::Normal(_)) {
999        return Err("filename must not include path separators or '..'".to_string());
1000    }
1001    let Some(file_name) = path.file_name().and_then(|value| value.to_str()) else {
1002        return Err("filename must be valid UTF-8".to_string());
1003    };
1004    match path.extension().and_then(|value| value.to_str()) {
1005        Some(actual) if actual.eq_ignore_ascii_case(extension) => Ok(file_name.to_string()),
1006        Some(actual) => Err(format!(
1007            "filename extension '.{actual}' does not match expected '.{extension}'"
1008        )),
1009        None => Ok(format!("{file_name}.{extension}")),
1010    }
1011}
1012
1013fn split_sdf_record_strings(sdf_text: &str) -> Vec<String> {
1014    let mut records = Vec::new();
1015    let mut current = String::new();
1016    let mut after_record = false;
1017    let lines = sdf_text.lines().collect::<Vec<_>>();
1018
1019    for (line_idx, line) in lines.iter().enumerate() {
1020        if current.is_empty()
1021            && after_record
1022            && line.trim().is_empty()
1023            && lines
1024                .get(line_idx + 1)
1025                .is_some_and(|next| next.trim().is_empty())
1026        {
1027            after_record = false;
1028            continue;
1029        }
1030        after_record = false;
1031        current.push_str(line);
1032        current.push('\n');
1033        if line.trim_end() == "$$$$" {
1034            records.push(std::mem::take(&mut current));
1035            after_record = true;
1036        }
1037    }
1038
1039    if !current.trim().is_empty() {
1040        records.push(current);
1041    }
1042
1043    records
1044}
1045
1046#[cfg(test)]
1047mod tests {
1048    use super::{BatchErrorMode, BatchRecord, MoleculeBatch};
1049    use crate::Molecule;
1050    use crate::io::molblock::SdfFormat;
1051    use std::fs;
1052    use std::path::PathBuf;
1053
1054    fn unique_temp_dir(name: &str) -> PathBuf {
1055        std::env::temp_dir().join(format!(
1056            "cosmolkit-batch-{name}-{}-{:?}",
1057            std::process::id(),
1058            std::thread::current().id()
1059        ))
1060    }
1061
1062    #[test]
1063    fn from_smiles_list_keeps_invalid_records_in_order() {
1064        let smiles = vec!["CCO".to_string(), "C1CC".to_string(), "CC".to_string()];
1065        let batch = MoleculeBatch::from_smiles_list(&smiles, BatchErrorMode::Keep)
1066            .expect("keep mode should not raise");
1067
1068        assert_eq!(batch.len(), 3);
1069        assert_eq!(batch.valid_mask(), vec![true, false, true]);
1070        assert_eq!(batch.errors()[0].index, 1);
1071        assert_eq!(
1072            batch
1073                .filter_valid()
1074                .to_smiles_list(true)
1075                .expect("valid records should serialize to SMILES"),
1076            vec![Some("CCO".to_string()), Some("CC".to_string())]
1077        );
1078    }
1079
1080    #[test]
1081    fn from_smiles_list_raise_aggregates_errors() {
1082        let smiles = vec!["C1CC".to_string(), "N1".to_string()];
1083        let error = MoleculeBatch::from_smiles_list(&smiles, BatchErrorMode::Raise)
1084            .expect_err("raise mode should aggregate invalid inputs");
1085
1086        assert_eq!(error.errors.len(), 2);
1087        assert_eq!(error.errors[0].stage, "parse_smiles");
1088        assert_eq!(error.errors[1].index, 1);
1089    }
1090
1091    #[test]
1092    fn transforms_skip_invalid_records_when_requested() {
1093        let smiles = vec!["CCO".to_string(), "C1CC".to_string()];
1094        let batch = MoleculeBatch::from_smiles_list(&smiles, BatchErrorMode::Keep)
1095            .expect("keep mode should not raise");
1096        let prepared = batch
1097            .add_hydrogens(BatchErrorMode::Skip)
1098            .expect("skip mode should drop invalid records")
1099            .compute_2d_coords(BatchErrorMode::Skip)
1100            .expect("2D coords should compute for valid record");
1101
1102        assert_eq!(prepared.len(), 1);
1103        assert_eq!(prepared.valid_count(), 1);
1104        assert_eq!(prepared.invalid_count(), 0);
1105    }
1106
1107    #[test]
1108    fn sanitize_transform_applies_pipeline_to_valid_records() {
1109        let raw = Molecule::from_smiles_with_sanitize("CN(=O)=O", false)
1110            .expect("unsanitized nitro SMILES should parse");
1111        let batch = MoleculeBatch::new(vec![BatchRecord::Valid(raw)]);
1112
1113        let sanitized = batch
1114            .sanitize(BatchErrorMode::Raise)
1115            .expect("sanitize should transform valid records");
1116
1117        let BatchRecord::Valid(mol) = &sanitized.records[0] else {
1118            panic!("record should remain valid");
1119        };
1120        assert_eq!(
1121            mol.atoms()
1122                .iter()
1123                .map(|atom| atom.formal_charge)
1124                .collect::<Vec<_>>(),
1125            vec![0, 1, -1, 0]
1126        );
1127    }
1128
1129    #[test]
1130    fn write_images_uses_custom_filenames_and_rejects_unsafe_names() {
1131        let smiles = vec!["CCO".to_string(), "C1CC".to_string(), "CC".to_string()];
1132        let batch = MoleculeBatch::from_smiles_list(&smiles, BatchErrorMode::Keep)
1133            .expect("keep mode should not raise");
1134        let out_dir = unique_temp_dir("images");
1135        let filenames = vec![
1136            Some("ethanol".to_string()),
1137            Some("invalid.svg".to_string()),
1138            None,
1139        ];
1140
1141        let report = batch
1142            .write_images(
1143                &out_dir,
1144                "svg",
1145                120,
1146                100,
1147                BatchErrorMode::Skip,
1148                Some(&filenames),
1149            )
1150            .expect("custom image filenames should write");
1151
1152        assert_eq!(report.success, 2);
1153        assert!(out_dir.join("ethanol.svg").exists());
1154        assert!(out_dir.join("000002.svg").exists());
1155        let _ = fs::remove_dir_all(&out_dir);
1156
1157        let bad = vec![Some("../escape".to_string()), None, None];
1158        let error = batch
1159            .write_images(
1160                &unique_temp_dir("bad-images"),
1161                "svg",
1162                120,
1163                100,
1164                BatchErrorMode::Skip,
1165                Some(&bad),
1166            )
1167            .expect_err("unsafe filename should be rejected");
1168        assert_eq!(error.errors[0].error_type, "FilenameError");
1169    }
1170
1171    #[test]
1172    fn write_sdf_files_uses_custom_filenames() {
1173        let smiles = vec!["CCO".to_string(), "CC".to_string()];
1174        let batch = MoleculeBatch::from_smiles_list(&smiles, BatchErrorMode::Raise)
1175            .expect("SMILES should parse")
1176            .compute_2d_coords(BatchErrorMode::Raise)
1177            .expect("2D coords should compute");
1178        let out_dir = unique_temp_dir("sdf-files");
1179        let filenames = vec![Some("ethanol".to_string()), Some("ethane.sdf".to_string())];
1180
1181        let report = batch
1182            .write_sdf_files(
1183                &out_dir,
1184                SdfFormat::V2000,
1185                BatchErrorMode::Raise,
1186                Some(&filenames),
1187            )
1188            .expect("custom SDF filenames should write");
1189
1190        assert_eq!(report.success, 2);
1191        assert!(out_dir.join("ethanol.sdf").exists());
1192        assert!(out_dir.join("ethane.sdf").exists());
1193        let _ = fs::remove_dir_all(&out_dir);
1194    }
1195}