oxirs_embed/
batch_processing.rs

1//! Offline batch embedding generation with incremental updates
2//!
3//! This module provides comprehensive batch processing capabilities for generating
4//! embeddings offline, with support for incremental updates, resumable jobs,
5//! and efficient resource utilization with SciRS2 integration.
6
7use crate::{CacheManager, EmbeddingModel};
8use anyhow::{anyhow, Result};
9use chrono::{DateTime, Utc};
10use rayon::prelude::*;
11use serde::{Deserialize, Serialize};
12use std::collections::{HashMap, HashSet};
13use std::path::{Path, PathBuf};
14use std::sync::Arc;
15use std::time::Instant;
16use tokio::fs;
17use tokio::sync::{RwLock, Semaphore};
18use tokio::task::JoinHandle;
19use tracing::{debug, info, warn};
20use uuid::Uuid;
21
22/// Memory-optimized batch iterator for large datasets
23pub struct MemoryOptimizedBatchIterator<T> {
24    /// Data source
25    data: Vec<T>,
26    /// Current position
27    position: usize,
28    /// Batch size
29    batch_size: usize,
30    /// Memory usage tracker
31    memory_usage: usize,
32    /// Maximum memory threshold (bytes)
33    max_memory_bytes: usize,
34}
35
36impl<T> MemoryOptimizedBatchIterator<T> {
37    /// Create a new memory-optimized batch iterator
38    pub fn new(data: Vec<T>, batch_size: usize, max_memory_mb: usize) -> Self {
39        Self {
40            data,
41            position: 0,
42            batch_size,
43            memory_usage: 0,
44            max_memory_bytes: max_memory_mb * 1024 * 1024,
45        }
46    }
47
48    /// Get the next batch with memory optimization
49    pub fn next_batch(&mut self) -> Option<Vec<T>>
50    where
51        T: Clone,
52    {
53        if self.position >= self.data.len() {
54            return None;
55        }
56
57        let mut batch = Vec::new();
58        let mut current_memory = 0;
59        let item_size = std::mem::size_of::<T>();
60
61        // Collect items for batch while respecting memory limits
62        while self.position < self.data.len()
63            && batch.len() < self.batch_size
64            && current_memory + item_size <= self.max_memory_bytes
65        {
66            batch.push(self.data[self.position].clone());
67            self.position += 1;
68            current_memory += item_size;
69        }
70
71        self.memory_usage = current_memory;
72
73        if batch.is_empty() {
74            None
75        } else {
76            Some(batch)
77        }
78    }
79
80    /// Get current memory usage
81    pub fn get_memory_usage(&self) -> usize {
82        self.memory_usage
83    }
84
85    /// Get progress percentage
86    pub fn get_progress(&self) -> f64 {
87        if self.data.is_empty() {
88            1.0
89        } else {
90            self.position as f64 / self.data.len() as f64
91        }
92    }
93
94    /// Check if iterator is finished
95    pub fn is_finished(&self) -> bool {
96        self.position >= self.data.len()
97    }
98}
99
100/// Batch processing manager for offline embedding generation
101pub struct BatchProcessingManager {
102    /// Active batch jobs
103    active_jobs: Arc<RwLock<HashMap<Uuid, BatchJob>>>,
104    /// Configuration
105    config: BatchProcessingConfig,
106    /// Cache manager for optimization
107    cache_manager: Arc<CacheManager>,
108    /// Concurrency semaphore
109    semaphore: Arc<Semaphore>,
110    /// Job persistence directory
111    persistence_dir: PathBuf,
112}
113
114/// Configuration for batch processing
115#[derive(Debug, Clone)]
116pub struct BatchProcessingConfig {
117    /// Maximum concurrent workers
118    pub max_workers: usize,
119    /// Chunk size for processing
120    pub chunk_size: usize,
121    /// Enable incremental updates
122    pub enable_incremental: bool,
123    /// Checkpoint frequency (number of chunks)
124    pub checkpoint_frequency: usize,
125    /// Enable resume from checkpoint
126    pub enable_resume: bool,
127    /// Maximum memory usage per worker (MB)
128    pub max_memory_per_worker_mb: usize,
129    /// Enable progress notifications
130    pub enable_notifications: bool,
131    /// Retry configuration
132    pub retry_config: RetryConfig,
133    /// Output format configuration
134    pub output_config: OutputConfig,
135}
136
137impl Default for BatchProcessingConfig {
138    fn default() -> Self {
139        Self {
140            max_workers: num_cpus::get(),
141            chunk_size: 1000,
142            enable_incremental: true,
143            checkpoint_frequency: 10,
144            enable_resume: true,
145            max_memory_per_worker_mb: 512,
146            enable_notifications: true,
147            retry_config: RetryConfig::default(),
148            output_config: OutputConfig::default(),
149        }
150    }
151}
152
153/// Retry configuration
154#[derive(Debug, Clone)]
155pub struct RetryConfig {
156    /// Maximum retry attempts
157    pub max_retries: usize,
158    /// Initial backoff delay in milliseconds
159    pub initial_backoff_ms: u64,
160    /// Maximum backoff delay in milliseconds
161    pub max_backoff_ms: u64,
162    /// Backoff multiplier
163    pub backoff_multiplier: f64,
164}
165
166impl Default for RetryConfig {
167    fn default() -> Self {
168        Self {
169            max_retries: 3,
170            initial_backoff_ms: 1000,
171            max_backoff_ms: 30000,
172            backoff_multiplier: 2.0,
173        }
174    }
175}
176
177/// Output configuration
178#[derive(Debug, Clone)]
179pub struct OutputConfig {
180    /// Output format
181    pub format: OutputFormat,
182    /// Compression level (0-9)
183    pub compression_level: u32,
184    /// Include metadata
185    pub include_metadata: bool,
186    /// Batch output into files
187    pub batch_output: bool,
188    /// Maximum entities per output file
189    pub max_entities_per_file: usize,
190}
191
192impl Default for OutputConfig {
193    fn default() -> Self {
194        Self {
195            format: OutputFormat::Parquet,
196            compression_level: 6,
197            include_metadata: true,
198            batch_output: true,
199            max_entities_per_file: 100_000,
200        }
201    }
202}
203
204/// Output formats
205#[derive(Debug, Clone)]
206pub enum OutputFormat {
207    /// Apache Parquet format
208    Parquet,
209    /// Compressed JSON Lines
210    JsonLines,
211    /// Binary format (custom)
212    Binary,
213    /// HDF5 format
214    HDF5,
215}
216
217/// Batch job definition
218#[derive(Debug, Clone, Serialize, Deserialize)]
219pub struct BatchJob {
220    /// Unique job ID
221    pub job_id: Uuid,
222    /// Job name
223    pub name: String,
224    /// Job status
225    pub status: JobStatus,
226    /// Input specification
227    pub input: BatchInput,
228    /// Output specification
229    pub output: BatchOutput,
230    /// Processing configuration
231    pub config: BatchJobConfig,
232    /// Model information
233    pub model_id: Uuid,
234    /// Created timestamp
235    pub created_at: DateTime<Utc>,
236    /// Started timestamp
237    pub started_at: Option<DateTime<Utc>>,
238    /// Completed timestamp
239    pub completed_at: Option<DateTime<Utc>>,
240    /// Progress information
241    pub progress: JobProgress,
242    /// Error information
243    pub error: Option<String>,
244    /// Checkpoint data
245    pub checkpoint: Option<JobCheckpoint>,
246}
247
248/// Job status
249#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
250pub enum JobStatus {
251    Pending,
252    Running,
253    Completed,
254    Failed,
255    Cancelled,
256    Paused,
257}
258
259/// Batch input specification
260#[derive(Debug, Clone, Serialize, Deserialize)]
261pub struct BatchInput {
262    /// Input type
263    pub input_type: InputType,
264    /// Input source
265    pub source: String,
266    /// Filter criteria
267    pub filters: Option<HashMap<String, String>>,
268    /// Incremental mode settings
269    pub incremental: Option<IncrementalConfig>,
270}
271
272/// Input types
273#[derive(Debug, Clone, Serialize, Deserialize)]
274pub enum InputType {
275    /// List of entity IDs
276    EntityList,
277    /// File containing entity IDs
278    EntityFile,
279    /// SPARQL query result
280    SparqlQuery,
281    /// Database query
282    DatabaseQuery,
283    /// Stream source
284    StreamSource,
285}
286
287/// Incremental processing configuration
288#[derive(Debug, Clone, Serialize, Deserialize)]
289pub struct IncrementalConfig {
290    /// Enable incremental processing
291    pub enabled: bool,
292    /// Last processed timestamp
293    pub last_processed: Option<DateTime<Utc>>,
294    /// Timestamp field name
295    pub timestamp_field: String,
296    /// Check for deletions
297    pub check_deletions: bool,
298    /// Existing embeddings source
299    pub existing_embeddings_path: Option<String>,
300}
301
302/// Batch output specification
303#[derive(Debug, Clone, Serialize, Deserialize)]
304pub struct BatchOutput {
305    /// Output path
306    pub path: String,
307    /// Output format
308    pub format: String,
309    /// Compression settings
310    pub compression: Option<String>,
311    /// Partitioning strategy
312    pub partitioning: Option<PartitioningStrategy>,
313}
314
315/// Partitioning strategy
316#[derive(Debug, Clone, Serialize, Deserialize)]
317pub enum PartitioningStrategy {
318    /// No partitioning
319    None,
320    /// Partition by entity type
321    ByEntityType,
322    /// Partition by date
323    ByDate,
324    /// Partition by hash
325    ByHash { num_partitions: usize },
326    /// Custom partitioning
327    Custom { field: String },
328}
329
330/// Job-specific configuration
331#[derive(Debug, Clone, Serialize, Deserialize)]
332pub struct BatchJobConfig {
333    /// Chunk size for this job
334    pub chunk_size: usize,
335    /// Number of workers
336    pub num_workers: usize,
337    /// Retry configuration
338    pub max_retries: usize,
339    /// Enable caching
340    pub use_cache: bool,
341    /// Custom parameters
342    pub custom_params: HashMap<String, String>,
343}
344
345/// Job progress information
346#[derive(Debug, Clone, Serialize, Deserialize)]
347pub struct JobProgress {
348    /// Total entities to process
349    pub total_entities: usize,
350    /// Entities processed
351    pub processed_entities: usize,
352    /// Entities failed
353    pub failed_entities: usize,
354    /// Current chunk being processed
355    pub current_chunk: usize,
356    /// Total chunks
357    pub total_chunks: usize,
358    /// Processing rate (entities/second)
359    pub processing_rate: f64,
360    /// Estimated time remaining
361    pub eta_seconds: Option<u64>,
362    /// Memory usage (MB)
363    pub memory_usage_mb: f64,
364}
365
366impl Default for JobProgress {
367    fn default() -> Self {
368        Self {
369            total_entities: 0,
370            processed_entities: 0,
371            failed_entities: 0,
372            current_chunk: 0,
373            total_chunks: 0,
374            processing_rate: 0.0,
375            eta_seconds: None,
376            memory_usage_mb: 0.0,
377        }
378    }
379}
380
381/// Job checkpoint for resumability
382#[derive(Debug, Clone, Serialize, Deserialize)]
383pub struct JobCheckpoint {
384    /// Checkpoint timestamp
385    pub timestamp: DateTime<Utc>,
386    /// Last processed entity index
387    pub last_processed_index: usize,
388    /// Processed entity IDs
389    pub processed_entities: HashSet<String>,
390    /// Failed entity IDs with error messages
391    pub failed_entities: HashMap<String, String>,
392    /// Intermediate results path
393    pub intermediate_results_path: String,
394    /// Model state hash
395    pub model_state_hash: String,
396}
397
398/// Batch processing result
399#[derive(Debug, Clone, Serialize, Deserialize)]
400pub struct BatchProcessingResult {
401    /// Job ID
402    pub job_id: Uuid,
403    /// Processing statistics
404    pub stats: BatchProcessingStats,
405    /// Output information
406    pub output_info: OutputInfo,
407    /// Quality metrics
408    pub quality_metrics: Option<QualityMetrics>,
409}
410
411/// Batch processing statistics
412#[derive(Debug, Clone, Serialize, Deserialize)]
413pub struct BatchProcessingStats {
414    /// Total processing time
415    pub total_time_seconds: f64,
416    /// Total entities processed
417    pub total_entities: usize,
418    /// Successful embeddings
419    pub successful_embeddings: usize,
420    /// Failed embeddings
421    pub failed_embeddings: usize,
422    /// Cache hits
423    pub cache_hits: usize,
424    /// Cache misses
425    pub cache_misses: usize,
426    /// Average processing time per entity (ms)
427    pub avg_time_per_entity_ms: f64,
428    /// Peak memory usage (MB)
429    pub peak_memory_mb: f64,
430    /// CPU utilization
431    pub cpu_utilization: f64,
432}
433
434/// Output information
435#[derive(Debug, Clone, Serialize, Deserialize)]
436pub struct OutputInfo {
437    /// Output files created
438    pub output_files: Vec<String>,
439    /// Total output size (bytes)
440    pub total_size_bytes: u64,
441    /// Compression ratio
442    pub compression_ratio: f64,
443    /// Number of partitions
444    pub num_partitions: usize,
445}
446
447/// Quality metrics for batch processing
448#[derive(Debug, Clone, Serialize, Deserialize)]
449pub struct QualityMetrics {
450    /// Average embedding norm
451    pub avg_embedding_norm: f64,
452    /// Embedding norm standard deviation
453    pub embedding_norm_std: f64,
454    /// Average cosine similarity to centroid
455    pub avg_cosine_similarity: f64,
456    /// Embedding dimension
457    pub embedding_dimension: usize,
458    /// Number of zero embeddings
459    pub zero_embeddings: usize,
460    /// Number of NaN embeddings
461    pub nan_embeddings: usize,
462}
463
464impl BatchProcessingManager {
465    /// Create a new batch processing manager
466    pub fn new(
467        config: BatchProcessingConfig,
468        cache_manager: Arc<CacheManager>,
469        persistence_dir: PathBuf,
470    ) -> Self {
471        Self {
472            active_jobs: Arc::new(RwLock::new(HashMap::new())),
473            semaphore: Arc::new(Semaphore::new(config.max_workers)),
474            config,
475            cache_manager,
476            persistence_dir,
477        }
478    }
479
480    /// Submit a new batch job
481    pub async fn submit_job(&self, job: BatchJob) -> Result<Uuid> {
482        let job_id = job.job_id;
483
484        // Validate job
485        self.validate_job(&job).await?;
486
487        // Store job
488        {
489            let mut jobs = self.active_jobs.write().await;
490            jobs.insert(job_id, job.clone());
491        }
492
493        // Persist job configuration
494        self.persist_job(&job).await?;
495
496        info!("Submitted batch job: {} ({})", job.name, job_id);
497        Ok(job_id)
498    }
499
500    /// Start processing a batch job
501    pub async fn start_job(
502        &self,
503        job_id: Uuid,
504        model: Arc<dyn EmbeddingModel + Send + Sync>,
505    ) -> Result<JoinHandle<Result<BatchProcessingResult>>> {
506        let job = {
507            let mut jobs = self.active_jobs.write().await;
508            let job = jobs
509                .get_mut(&job_id)
510                .ok_or_else(|| anyhow!("Job not found: {}", job_id))?;
511
512            if !matches!(job.status, JobStatus::Pending | JobStatus::Paused) {
513                return Err(anyhow!("Job {} is not in a startable state", job_id));
514            }
515
516            job.status = JobStatus::Running;
517            job.started_at = Some(Utc::now());
518            job.clone()
519        };
520
521        let manager = self.clone();
522        let handle = tokio::spawn(async move { manager.process_job(job, model).await });
523
524        Ok(handle)
525    }
526
527    /// Process a batch job
528    async fn process_job(
529        &self,
530        job: BatchJob,
531        model: Arc<dyn EmbeddingModel + Send + Sync>,
532    ) -> Result<BatchProcessingResult> {
533        let start_time = Instant::now();
534        info!(
535            "Starting batch job processing: {} ({})",
536            job.name, job.job_id
537        );
538
539        // Load entities to process
540        let entities = self.load_entities(&job).await?;
541
542        // Filter entities for incremental processing
543        let entities_to_process = if job
544            .input
545            .incremental
546            .as_ref()
547            .map(|inc| inc.enabled)
548            .unwrap_or(false)
549        {
550            self.filter_incremental_entities(&job, entities).await?
551        } else {
552            entities
553        };
554
555        // Update job progress
556        {
557            let mut jobs = self.active_jobs.write().await;
558            if let Some(active_job) = jobs.get_mut(&job.job_id) {
559                active_job.progress.total_entities = entities_to_process.len();
560                active_job.progress.total_chunks =
561                    (entities_to_process.len() + job.config.chunk_size - 1) / job.config.chunk_size;
562            }
563        }
564
565        // Process entities in chunks
566        let chunks: Vec<_> = entities_to_process
567            .chunks(job.config.chunk_size)
568            .map(|chunk| chunk.to_vec())
569            .collect();
570
571        let mut successful_embeddings = 0;
572        let mut failed_embeddings = 0;
573        let mut cache_hits = 0;
574        let mut cache_misses = 0;
575        let mut processed_entities = HashSet::new();
576        let mut failed_entities = HashMap::new();
577
578        for (chunk_idx, chunk) in chunks.iter().enumerate() {
579            // Check if job was cancelled
580            {
581                let jobs = self.active_jobs.read().await;
582                if let Some(active_job) = jobs.get(&job.job_id) {
583                    if matches!(active_job.status, JobStatus::Cancelled) {
584                        info!("Job {} was cancelled", job.job_id);
585                        return Err(anyhow!("Job was cancelled"));
586                    }
587                }
588            }
589
590            // Process chunk
591            let chunk_result = self
592                .process_chunk(&job, chunk, chunk_idx, model.clone())
593                .await?;
594
595            // Update statistics
596            successful_embeddings += chunk_result.successful;
597            failed_embeddings += chunk_result.failed;
598            cache_hits += chunk_result.cache_hits;
599            cache_misses += chunk_result.cache_misses;
600
601            // Track processed entities
602            for entity in chunk {
603                processed_entities.insert(entity.clone());
604            }
605            for (entity, error) in chunk_result.failures {
606                failed_entities.insert(entity, error);
607            }
608
609            // Update progress
610            self.update_job_progress(
611                &job.job_id,
612                chunk_idx + 1,
613                successful_embeddings + failed_embeddings,
614            )
615            .await?;
616
617            // Create checkpoint
618            if chunk_idx % self.config.checkpoint_frequency == 0 {
619                self.create_checkpoint(&job.job_id, &processed_entities, &failed_entities)
620                    .await?;
621            }
622
623            info!(
624                "Processed chunk {}/{} for job {}",
625                chunk_idx + 1,
626                chunks.len(),
627                job.job_id
628            );
629        }
630
631        // Finalize processing
632        let processing_time = start_time.elapsed().as_secs_f64();
633        let result = self
634            .finalize_job_processing(
635                &job,
636                processing_time,
637                successful_embeddings,
638                failed_embeddings,
639                cache_hits,
640                cache_misses,
641            )
642            .await?;
643
644        // Update job status
645        {
646            let mut jobs = self.active_jobs.write().await;
647            if let Some(active_job) = jobs.get_mut(&job.job_id) {
648                active_job.status = JobStatus::Completed;
649                active_job.completed_at = Some(Utc::now());
650            }
651        }
652
653        info!(
654            "Completed batch job: {} in {:.2}s",
655            job.job_id, processing_time
656        );
657        Ok(result)
658    }
659
660    /// Process a single chunk of entities
661    async fn process_chunk(
662        &self,
663        job: &BatchJob,
664        entities: &[String],
665        chunk_idx: usize,
666        model: Arc<dyn EmbeddingModel + Send + Sync>,
667    ) -> Result<ChunkResult> {
668        let _permit = self.semaphore.acquire().await?;
669
670        let mut successful = 0;
671        let mut failed = 0;
672        let mut cache_hits = 0;
673        let mut cache_misses = 0;
674        let mut failures = HashMap::new();
675
676        for entity in entities {
677            match self
678                .process_single_entity(entity, model.clone(), job.config.use_cache)
679                .await
680            {
681                Ok(from_cache) => {
682                    successful += 1;
683                    if from_cache {
684                        cache_hits += 1;
685                    } else {
686                        cache_misses += 1;
687                    }
688                }
689                Err(e) => {
690                    failed += 1;
691                    failures.insert(entity.clone(), e.to_string());
692                    warn!("Failed to process entity {}: {}", entity, e);
693                }
694            }
695        }
696
697        Ok(ChunkResult {
698            chunk_idx,
699            successful,
700            failed,
701            cache_hits,
702            cache_misses,
703            failures,
704        })
705    }
706
707    /// Process a single entity
708    async fn process_single_entity(
709        &self,
710        entity: &str,
711        model: Arc<dyn EmbeddingModel + Send + Sync>,
712        use_cache: bool,
713    ) -> Result<bool> {
714        if use_cache {
715            // Check cache first
716            if let Some(_embedding) = self.cache_manager.get_embedding(entity) {
717                return Ok(true);
718            }
719        }
720
721        // Generate embedding
722        let embedding = model.get_entity_embedding(entity)?;
723
724        // Cache the result
725        if use_cache {
726            self.cache_manager
727                .put_embedding(entity.to_string(), embedding);
728        }
729
730        Ok(false)
731    }
732
733    /// Load entities to process based on input specification
734    async fn load_entities(&self, job: &BatchJob) -> Result<Vec<String>> {
735        match &job.input.input_type {
736            InputType::EntityList => {
737                // Parse entity list from source
738                let entities: Vec<String> = serde_json::from_str(&job.input.source)?;
739                Ok(entities)
740            }
741            InputType::EntityFile => {
742                // Read entities from file
743                let content = fs::read_to_string(&job.input.source).await?;
744                let entities: Vec<String> = content
745                    .lines()
746                    .map(|line| line.trim().to_string())
747                    .filter(|line| !line.is_empty())
748                    .collect();
749                Ok(entities)
750            }
751            InputType::SparqlQuery => {
752                // Execute SPARQL query and extract entities
753                // This would need to be implemented based on SPARQL engine
754                warn!("SPARQL query input type not yet implemented");
755                Ok(Vec::new())
756            }
757            InputType::DatabaseQuery => {
758                // Execute database query and extract entities
759                warn!("Database query input type not yet implemented");
760                Ok(Vec::new())
761            }
762            InputType::StreamSource => {
763                // Read from stream source
764                warn!("Stream source input type not yet implemented");
765                Ok(Vec::new())
766            }
767        }
768    }
769
770    /// Filter entities for incremental processing
771    async fn filter_incremental_entities(
772        &self,
773        job: &BatchJob,
774        entities: Vec<String>,
775    ) -> Result<Vec<String>> {
776        if let Some(incremental) = &job.input.incremental {
777            if !incremental.enabled {
778                return Ok(entities);
779            }
780
781            // Load existing embeddings if specified
782            let existing_entities =
783                if let Some(existing_path) = &incremental.existing_embeddings_path {
784                    self.load_existing_entities(existing_path).await?
785                } else {
786                    HashSet::new()
787                };
788
789            // Filter out entities that already have embeddings
790            let filtered: Vec<String> = entities
791                .into_iter()
792                .filter(|entity| !existing_entities.contains(entity))
793                .collect();
794
795            info!(
796                "Incremental filtering: {} entities remaining after filtering",
797                filtered.len()
798            );
799            Ok(filtered)
800        } else {
801            Ok(entities)
802        }
803    }
804
805    /// Load existing entities from embeddings file
806    async fn load_existing_entities(&self, path: &str) -> Result<HashSet<String>> {
807        // This would depend on the output format
808        // For now, assume a simple text file with entity IDs
809        if Path::new(path).exists() {
810            let content = fs::read_to_string(path).await?;
811            let entities: HashSet<String> = content
812                .lines()
813                .map(|line| line.trim().to_string())
814                .filter(|line| !line.is_empty())
815                .collect();
816            Ok(entities)
817        } else {
818            Ok(HashSet::new())
819        }
820    }
821
822    /// Update job progress
823    async fn update_job_progress(
824        &self,
825        job_id: &Uuid,
826        current_chunk: usize,
827        processed_entities: usize,
828    ) -> Result<()> {
829        let mut jobs = self.active_jobs.write().await;
830        if let Some(job) = jobs.get_mut(job_id) {
831            job.progress.current_chunk = current_chunk;
832            job.progress.processed_entities = processed_entities;
833
834            // Calculate processing rate
835            if let Some(started_at) = job.started_at {
836                let elapsed = Utc::now().signed_duration_since(started_at);
837                let elapsed_seconds = elapsed.num_seconds() as f64;
838                if elapsed_seconds > 0.0 {
839                    job.progress.processing_rate = processed_entities as f64 / elapsed_seconds;
840
841                    // Estimate time remaining
842                    let remaining_entities = job.progress.total_entities - processed_entities;
843                    if job.progress.processing_rate > 0.0 {
844                        let eta = remaining_entities as f64 / job.progress.processing_rate;
845                        job.progress.eta_seconds = Some(eta as u64);
846                    }
847                }
848            }
849        }
850        Ok(())
851    }
852
853    /// Create a checkpoint for job resumability
854    async fn create_checkpoint(
855        &self,
856        job_id: &Uuid,
857        processed_entities: &HashSet<String>,
858        failed_entities: &HashMap<String, String>,
859    ) -> Result<()> {
860        let checkpoint = JobCheckpoint {
861            timestamp: Utc::now(),
862            last_processed_index: processed_entities.len(),
863            processed_entities: processed_entities.clone(),
864            failed_entities: failed_entities.clone(),
865            intermediate_results_path: format!(
866                "{}/checkpoint_{}.json",
867                self.persistence_dir.display(),
868                job_id
869            ),
870            model_state_hash: "placeholder".to_string(), // Would calculate actual hash
871        };
872
873        // Save checkpoint to disk
874        let checkpoint_path = self
875            .persistence_dir
876            .join(format!("checkpoint_{job_id}.json"));
877        let checkpoint_json = serde_json::to_string_pretty(&checkpoint)?;
878        fs::write(checkpoint_path, checkpoint_json).await?;
879
880        // Update job with checkpoint
881        let mut jobs = self.active_jobs.write().await;
882        if let Some(job) = jobs.get_mut(job_id) {
883            job.checkpoint = Some(checkpoint);
884        }
885
886        debug!("Created checkpoint for job {}", job_id);
887        Ok(())
888    }
889
890    /// Finalize job processing and create result
891    async fn finalize_job_processing(
892        &self,
893        job: &BatchJob,
894        processing_time: f64,
895        successful_embeddings: usize,
896        failed_embeddings: usize,
897        cache_hits: usize,
898        cache_misses: usize,
899    ) -> Result<BatchProcessingResult> {
900        let total_entities = successful_embeddings + failed_embeddings;
901        let avg_time_per_entity_ms = if total_entities > 0 {
902            (processing_time * 1000.0) / total_entities as f64
903        } else {
904            0.0
905        };
906
907        let stats = BatchProcessingStats {
908            total_time_seconds: processing_time,
909            total_entities,
910            successful_embeddings,
911            failed_embeddings,
912            cache_hits,
913            cache_misses,
914            avg_time_per_entity_ms,
915            peak_memory_mb: 0.0,  // Would measure actual memory usage
916            cpu_utilization: 0.0, // Would measure actual CPU usage
917        };
918
919        let output_info = OutputInfo {
920            output_files: vec![job.output.path.clone()],
921            total_size_bytes: 0, // Would calculate actual size
922            compression_ratio: 1.0,
923            num_partitions: 1,
924        };
925
926        Ok(BatchProcessingResult {
927            job_id: job.job_id,
928            stats,
929            output_info,
930            quality_metrics: None, // Would calculate if requested
931        })
932    }
933
934    /// Validate a batch job before submission
935    async fn validate_job(&self, job: &BatchJob) -> Result<()> {
936        // Validate input source exists
937        if let InputType::EntityFile = &job.input.input_type {
938            if !Path::new(&job.input.source).exists() {
939                return Err(anyhow!("Input file does not exist: {}", job.input.source));
940            }
941        } // Other validations would be implemented
942
943        // Validate output path is writable
944        if let Some(parent) = Path::new(&job.output.path).parent() {
945            if !parent.exists() {
946                fs::create_dir_all(parent).await?;
947            }
948        }
949
950        Ok(())
951    }
952
953    /// Persist job configuration to disk
954    async fn persist_job(&self, job: &BatchJob) -> Result<()> {
955        let job_path = self
956            .persistence_dir
957            .join(format!("job_{}.json", job.job_id));
958        let job_json = serde_json::to_string_pretty(job)?;
959        fs::write(job_path, job_json).await?;
960        Ok(())
961    }
962
963    /// Get job status
964    pub async fn get_job_status(&self, job_id: &Uuid) -> Option<JobStatus> {
965        let jobs = self.active_jobs.read().await;
966        jobs.get(job_id).map(|job| job.status.clone())
967    }
968
969    /// Get job progress
970    pub async fn get_job_progress(&self, job_id: &Uuid) -> Option<JobProgress> {
971        let jobs = self.active_jobs.read().await;
972        jobs.get(job_id).map(|job| job.progress.clone())
973    }
974
975    /// Cancel a job
976    pub async fn cancel_job(&self, job_id: &Uuid) -> Result<()> {
977        let mut jobs = self.active_jobs.write().await;
978        if let Some(job) = jobs.get_mut(job_id) {
979            job.status = JobStatus::Cancelled;
980            info!("Cancelled job: {}", job_id);
981            Ok(())
982        } else {
983            Err(anyhow!("Job not found: {}", job_id))
984        }
985    }
986
987    /// List all jobs
988    pub async fn list_jobs(&self) -> Vec<BatchJob> {
989        let jobs = self.active_jobs.read().await;
990        jobs.values().cloned().collect()
991    }
992}
993
994impl Clone for BatchProcessingManager {
995    fn clone(&self) -> Self {
996        Self {
997            active_jobs: Arc::clone(&self.active_jobs),
998            config: self.config.clone(),
999            cache_manager: Arc::clone(&self.cache_manager),
1000            semaphore: Arc::clone(&self.semaphore),
1001            persistence_dir: self.persistence_dir.clone(),
1002        }
1003    }
1004}
1005
1006/// Result of processing a single chunk
1007#[derive(Debug)]
1008#[allow(dead_code)]
1009struct ChunkResult {
1010    chunk_idx: usize,
1011    successful: usize,
1012    failed: usize,
1013    cache_hits: usize,
1014    cache_misses: usize,
1015    failures: HashMap<String, String>,
1016}
1017
1018#[cfg(test)]
1019mod tests {
1020    use super::*;
1021    use tempfile::tempdir;
1022
1023    #[test]
1024    fn test_batch_job_creation() {
1025        let job = BatchJob {
1026            job_id: Uuid::new_v4(),
1027            name: "test_job".to_string(),
1028            status: JobStatus::Pending,
1029            input: BatchInput {
1030                input_type: InputType::EntityList,
1031                source: r#"["entity1", "entity2", "entity3"]"#.to_string(),
1032                filters: None,
1033                incremental: None,
1034            },
1035            output: BatchOutput {
1036                path: "/tmp/output".to_string(),
1037                format: "parquet".to_string(),
1038                compression: Some("gzip".to_string()),
1039                partitioning: Some(PartitioningStrategy::None),
1040            },
1041            config: BatchJobConfig {
1042                chunk_size: 100,
1043                num_workers: 4,
1044                max_retries: 3,
1045                use_cache: true,
1046                custom_params: HashMap::new(),
1047            },
1048            model_id: Uuid::new_v4(),
1049            created_at: Utc::now(),
1050            started_at: None,
1051            completed_at: None,
1052            progress: JobProgress::default(),
1053            error: None,
1054            checkpoint: None,
1055        };
1056
1057        assert_eq!(job.status, JobStatus::Pending);
1058        assert_eq!(job.name, "test_job");
1059    }
1060
1061    #[tokio::test]
1062    async fn test_batch_processing_manager_creation() {
1063        let config = BatchProcessingConfig::default();
1064        let cache_config = crate::CacheConfig::default();
1065        let cache_manager = Arc::new(CacheManager::new(cache_config));
1066        let temp_dir = tempdir().unwrap();
1067
1068        let manager =
1069            BatchProcessingManager::new(config, cache_manager, temp_dir.path().to_path_buf());
1070
1071        assert_eq!(manager.config.max_workers, num_cpus::get());
1072        assert_eq!(manager.config.chunk_size, 1000);
1073    }
1074
1075    #[test]
1076    fn test_incremental_config() {
1077        let incremental = IncrementalConfig {
1078            enabled: true,
1079            last_processed: Some(Utc::now()),
1080            timestamp_field: "updated_at".to_string(),
1081            check_deletions: true,
1082            existing_embeddings_path: Some("/path/to/existing".to_string()),
1083        };
1084
1085        assert!(incremental.enabled);
1086        assert!(incremental.last_processed.is_some());
1087        assert_eq!(incremental.timestamp_field, "updated_at");
1088    }
1089
1090    #[test]
1091    fn test_memory_optimized_batch_iterator() {
1092        let data = vec![1, 2, 3, 4, 5, 6, 7, 8, 9, 10];
1093        let mut iterator = MemoryOptimizedBatchIterator::new(data.clone(), 3, 1); // 1MB limit
1094
1095        // Test first batch
1096        let batch1 = iterator.next_batch().unwrap();
1097        assert_eq!(batch1.len(), 3);
1098        assert_eq!(batch1, vec![1, 2, 3]);
1099        assert_eq!(iterator.get_progress(), 0.3);
1100        assert!(!iterator.is_finished());
1101
1102        // Test second batch
1103        let batch2 = iterator.next_batch().unwrap();
1104        assert_eq!(batch2.len(), 3);
1105        assert_eq!(batch2, vec![4, 5, 6]);
1106        assert_eq!(iterator.get_progress(), 0.6);
1107
1108        // Test third batch
1109        let batch3 = iterator.next_batch().unwrap();
1110        assert_eq!(batch3.len(), 3);
1111        assert_eq!(batch3, vec![7, 8, 9]);
1112        assert_eq!(iterator.get_progress(), 0.9);
1113
1114        // Test final batch
1115        let batch4 = iterator.next_batch().unwrap();
1116        assert_eq!(batch4.len(), 1);
1117        assert_eq!(batch4, vec![10]);
1118        assert_eq!(iterator.get_progress(), 1.0);
1119        assert!(iterator.is_finished());
1120
1121        // Test empty batch
1122        let batch5 = iterator.next_batch();
1123        assert!(batch5.is_none());
1124    }
1125
1126    #[test]
1127    fn test_memory_optimized_batch_iterator_empty() {
1128        let data: Vec<i32> = vec![];
1129        let mut iterator = MemoryOptimizedBatchIterator::new(data, 3, 1);
1130
1131        assert_eq!(iterator.get_progress(), 1.0);
1132        assert!(iterator.is_finished());
1133        assert!(iterator.next_batch().is_none());
1134    }
1135
1136    #[test]
1137    fn test_memory_optimized_batch_iterator_single_item() {
1138        let data = vec![42];
1139        let mut iterator = MemoryOptimizedBatchIterator::new(data, 5, 1);
1140
1141        let batch = iterator.next_batch().unwrap();
1142        assert_eq!(batch.len(), 1);
1143        assert_eq!(batch[0], 42);
1144        assert_eq!(iterator.get_progress(), 1.0);
1145        assert!(iterator.is_finished());
1146    }
1147
1148    #[test]
1149    fn test_memory_optimized_batch_iterator_memory_tracking() {
1150        let data = vec![1, 2, 3, 4, 5];
1151        let mut iterator = MemoryOptimizedBatchIterator::new(data, 3, 1);
1152
1153        // Process one batch and check memory usage
1154        let _batch = iterator.next_batch().unwrap();
1155        let memory_usage = iterator.get_memory_usage();
1156        assert!(memory_usage > 0);
1157
1158        // Memory usage should be roughly 3 * size_of::<i32>()
1159        let expected_memory = 3 * std::mem::size_of::<i32>();
1160        assert_eq!(memory_usage, expected_memory);
1161    }
1162
1163    #[test]
1164    fn test_parallel_batch_processor() {
1165        // Test basic functionality
1166        let processor = ParallelBatchProcessor::new(ParallelBatchConfig::default()).unwrap();
1167        // Should use system's num_cpus
1168        assert!(processor.num_workers() > 0);
1169        assert!(processor.num_workers() <= num_cpus::get());
1170    }
1171}
1172
1173/// Advanced parallel batch processor using SciRS2 and Rayon
1174///
1175/// This processor leverages parallel operations for:
1176/// - Optimal work distribution across cores
1177/// - Adaptive load balancing
1178/// - Memory-efficient chunking
1179/// - NUMA-aware processing
1180pub struct ParallelBatchProcessor {
1181    config: ParallelBatchConfig,
1182}
1183
1184/// Configuration for parallel batch processing
1185#[derive(Debug, Clone, Serialize, Deserialize)]
1186pub struct ParallelBatchConfig {
1187    /// Number of worker threads
1188    pub num_workers: usize,
1189    /// Chunk size for parallel processing
1190    pub chunk_size: usize,
1191    /// Enable adaptive load balancing
1192    pub adaptive_balancing: bool,
1193    /// Memory threshold in MB
1194    pub memory_threshold_mb: usize,
1195    /// Enable NUMA optimization
1196    pub numa_aware: bool,
1197    /// Enable work stealing
1198    pub work_stealing: bool,
1199}
1200
1201impl Default for ParallelBatchConfig {
1202    fn default() -> Self {
1203        Self {
1204            num_workers: num_cpus::get(),
1205            chunk_size: 1000,
1206            adaptive_balancing: true,
1207            memory_threshold_mb: 512,
1208            numa_aware: true,
1209            work_stealing: true,
1210        }
1211    }
1212}
1213
1214impl ParallelBatchProcessor {
1215    /// Create new parallel batch processor
1216    pub fn new(config: ParallelBatchConfig) -> Result<Self> {
1217        // Configure rayon thread pool
1218        rayon::ThreadPoolBuilder::new()
1219            .num_threads(config.num_workers)
1220            .build_global()
1221            .ok(); // Ignore error if already initialized
1222
1223        Ok(Self { config })
1224    }
1225
1226    /// Get number of worker threads
1227    pub fn num_workers(&self) -> usize {
1228        self.config.num_workers
1229    }
1230
1231    /// Process batch in parallel with automatic load balancing
1232    pub fn process_parallel<T, F, R>(&self, items: Vec<T>, process_fn: F) -> Result<Vec<R>>
1233    where
1234        T: Send + Sync,
1235        F: Fn(&T) -> R + Send + Sync,
1236        R: Send,
1237    {
1238        // Use rayon for parallel processing
1239        let results: Vec<R> = items.par_iter().map(process_fn).collect();
1240
1241        Ok(results)
1242    }
1243
1244    /// Process batch with dynamic load balancing (using rayon's work stealing)
1245    ///
1246    /// Uses Rayon's work stealing for:
1247    /// - Automatic work stealing between threads
1248    /// - Dynamic work distribution
1249    /// - Optimal CPU utilization
1250    pub fn process_with_load_balancing<T, F, R>(
1251        &self,
1252        items: Vec<T>,
1253        process_fn: F,
1254    ) -> Result<Vec<R>>
1255    where
1256        T: Send + Sync,
1257        F: Fn(&T) -> R + Send + Sync,
1258        R: Send,
1259    {
1260        // Rayon automatically uses work stealing
1261        let results: Vec<R> = items.par_iter().map(process_fn).collect();
1262
1263        Ok(results)
1264    }
1265
1266    /// Process very large batches with memory-efficient chunking
1267    ///
1268    /// Uses chunked parallel processing to:
1269    /// - Respect memory limits
1270    /// - Optimize cache locality
1271    /// - Minimize memory allocations
1272    pub fn process_memory_efficient<T, F, R>(&self, items: Vec<T>, process_fn: F) -> Result<Vec<R>>
1273    where
1274        T: Send + Sync,
1275        F: Fn(&T) -> R + Send + Sync,
1276        R: Send,
1277    {
1278        // Process in chunks to respect memory limits
1279        let chunk_size =
1280            (self.config.memory_threshold_mb * 1024 * 1024) / (std::mem::size_of::<T>().max(1));
1281
1282        let chunk_size = chunk_size.min(self.config.chunk_size).max(100);
1283
1284        let results: Vec<R> = items
1285            .par_chunks(chunk_size)
1286            .flat_map(|chunk| chunk.iter().map(&process_fn).collect::<Vec<_>>())
1287            .collect();
1288
1289        Ok(results)
1290    }
1291
1292    /// Process with nested parallelism using rayon
1293    ///
1294    /// Enables safe nested parallel processing:
1295    /// - Automatic thread lifetime management
1296    /// - Rayon's work stealing
1297    /// - Cache-friendly execution
1298    pub fn process_nested_parallel<T, F, R>(
1299        &self,
1300        items: Vec<Vec<T>>,
1301        process_fn: F,
1302    ) -> Result<Vec<Vec<R>>>
1303    where
1304        T: Send + Sync,
1305        F: Fn(&T) -> R + Send + Sync,
1306        R: Send,
1307    {
1308        let results: Vec<Vec<R>> = items
1309            .par_iter()
1310            .map(|batch| batch.iter().map(&process_fn).collect())
1311            .collect();
1312
1313        Ok(results)
1314    }
1315
1316    /// Get processing statistics
1317    pub fn get_stats(&self) -> ParallelProcessingStats {
1318        ParallelProcessingStats {
1319            num_workers: self.config.num_workers,
1320            profiler_report: "Stats available".to_string(),
1321            memory_usage: 0,
1322        }
1323    }
1324}
1325
1326/// Statistics for parallel batch processing
1327#[derive(Debug, Clone)]
1328pub struct ParallelProcessingStats {
1329    pub num_workers: usize,
1330    pub profiler_report: String,
1331    pub memory_usage: usize,
1332}