oxirs_embed/
batch_processing.rs

1//! Offline batch embedding generation with incremental updates
2//!
3//! This module provides comprehensive batch processing capabilities for generating
4//! embeddings offline, with support for incremental updates, resumable jobs,
5//! and efficient resource utilization.
6
7use crate::{CacheManager, EmbeddingModel};
8use anyhow::{anyhow, Result};
9use chrono::{DateTime, Utc};
10use serde::{Deserialize, Serialize};
11use std::collections::{HashMap, HashSet};
12use std::path::{Path, PathBuf};
13use std::sync::Arc;
14use std::time::Instant;
15use tokio::fs;
16use tokio::sync::{RwLock, Semaphore};
17use tokio::task::JoinHandle;
18use tracing::{debug, info, warn};
19use uuid::Uuid;
20
21/// Memory-optimized batch iterator for large datasets
22pub struct MemoryOptimizedBatchIterator<T> {
23    /// Data source
24    data: Vec<T>,
25    /// Current position
26    position: usize,
27    /// Batch size
28    batch_size: usize,
29    /// Memory usage tracker
30    memory_usage: usize,
31    /// Maximum memory threshold (bytes)
32    max_memory_bytes: usize,
33}
34
35impl<T> MemoryOptimizedBatchIterator<T> {
36    /// Create a new memory-optimized batch iterator
37    pub fn new(data: Vec<T>, batch_size: usize, max_memory_mb: usize) -> Self {
38        Self {
39            data,
40            position: 0,
41            batch_size,
42            memory_usage: 0,
43            max_memory_bytes: max_memory_mb * 1024 * 1024,
44        }
45    }
46
47    /// Get the next batch with memory optimization
48    pub fn next_batch(&mut self) -> Option<Vec<T>>
49    where
50        T: Clone,
51    {
52        if self.position >= self.data.len() {
53            return None;
54        }
55
56        let mut batch = Vec::new();
57        let mut current_memory = 0;
58        let item_size = std::mem::size_of::<T>();
59
60        // Collect items for batch while respecting memory limits
61        while self.position < self.data.len()
62            && batch.len() < self.batch_size
63            && current_memory + item_size <= self.max_memory_bytes
64        {
65            batch.push(self.data[self.position].clone());
66            self.position += 1;
67            current_memory += item_size;
68        }
69
70        self.memory_usage = current_memory;
71
72        if batch.is_empty() {
73            None
74        } else {
75            Some(batch)
76        }
77    }
78
79    /// Get current memory usage
80    pub fn get_memory_usage(&self) -> usize {
81        self.memory_usage
82    }
83
84    /// Get progress percentage
85    pub fn get_progress(&self) -> f64 {
86        if self.data.is_empty() {
87            1.0
88        } else {
89            self.position as f64 / self.data.len() as f64
90        }
91    }
92
93    /// Check if iterator is finished
94    pub fn is_finished(&self) -> bool {
95        self.position >= self.data.len()
96    }
97}
98
99/// Batch processing manager for offline embedding generation
100pub struct BatchProcessingManager {
101    /// Active batch jobs
102    active_jobs: Arc<RwLock<HashMap<Uuid, BatchJob>>>,
103    /// Configuration
104    config: BatchProcessingConfig,
105    /// Cache manager for optimization
106    cache_manager: Arc<CacheManager>,
107    /// Concurrency semaphore
108    semaphore: Arc<Semaphore>,
109    /// Job persistence directory
110    persistence_dir: PathBuf,
111}
112
113/// Configuration for batch processing
114#[derive(Debug, Clone)]
115pub struct BatchProcessingConfig {
116    /// Maximum concurrent workers
117    pub max_workers: usize,
118    /// Chunk size for processing
119    pub chunk_size: usize,
120    /// Enable incremental updates
121    pub enable_incremental: bool,
122    /// Checkpoint frequency (number of chunks)
123    pub checkpoint_frequency: usize,
124    /// Enable resume from checkpoint
125    pub enable_resume: bool,
126    /// Maximum memory usage per worker (MB)
127    pub max_memory_per_worker_mb: usize,
128    /// Enable progress notifications
129    pub enable_notifications: bool,
130    /// Retry configuration
131    pub retry_config: RetryConfig,
132    /// Output format configuration
133    pub output_config: OutputConfig,
134}
135
136impl Default for BatchProcessingConfig {
137    fn default() -> Self {
138        Self {
139            max_workers: num_cpus::get(),
140            chunk_size: 1000,
141            enable_incremental: true,
142            checkpoint_frequency: 10,
143            enable_resume: true,
144            max_memory_per_worker_mb: 512,
145            enable_notifications: true,
146            retry_config: RetryConfig::default(),
147            output_config: OutputConfig::default(),
148        }
149    }
150}
151
152/// Retry configuration
153#[derive(Debug, Clone)]
154pub struct RetryConfig {
155    /// Maximum retry attempts
156    pub max_retries: usize,
157    /// Initial backoff delay in milliseconds
158    pub initial_backoff_ms: u64,
159    /// Maximum backoff delay in milliseconds
160    pub max_backoff_ms: u64,
161    /// Backoff multiplier
162    pub backoff_multiplier: f64,
163}
164
165impl Default for RetryConfig {
166    fn default() -> Self {
167        Self {
168            max_retries: 3,
169            initial_backoff_ms: 1000,
170            max_backoff_ms: 30000,
171            backoff_multiplier: 2.0,
172        }
173    }
174}
175
176/// Output configuration
177#[derive(Debug, Clone)]
178pub struct OutputConfig {
179    /// Output format
180    pub format: OutputFormat,
181    /// Compression level (0-9)
182    pub compression_level: u32,
183    /// Include metadata
184    pub include_metadata: bool,
185    /// Batch output into files
186    pub batch_output: bool,
187    /// Maximum entities per output file
188    pub max_entities_per_file: usize,
189}
190
191impl Default for OutputConfig {
192    fn default() -> Self {
193        Self {
194            format: OutputFormat::Parquet,
195            compression_level: 6,
196            include_metadata: true,
197            batch_output: true,
198            max_entities_per_file: 100_000,
199        }
200    }
201}
202
203/// Output formats
204#[derive(Debug, Clone)]
205pub enum OutputFormat {
206    /// Apache Parquet format
207    Parquet,
208    /// Compressed JSON Lines
209    JsonLines,
210    /// Binary format (custom)
211    Binary,
212    /// HDF5 format
213    HDF5,
214}
215
216/// Batch job definition
217#[derive(Debug, Clone, Serialize, Deserialize)]
218pub struct BatchJob {
219    /// Unique job ID
220    pub job_id: Uuid,
221    /// Job name
222    pub name: String,
223    /// Job status
224    pub status: JobStatus,
225    /// Input specification
226    pub input: BatchInput,
227    /// Output specification
228    pub output: BatchOutput,
229    /// Processing configuration
230    pub config: BatchJobConfig,
231    /// Model information
232    pub model_id: Uuid,
233    /// Created timestamp
234    pub created_at: DateTime<Utc>,
235    /// Started timestamp
236    pub started_at: Option<DateTime<Utc>>,
237    /// Completed timestamp
238    pub completed_at: Option<DateTime<Utc>>,
239    /// Progress information
240    pub progress: JobProgress,
241    /// Error information
242    pub error: Option<String>,
243    /// Checkpoint data
244    pub checkpoint: Option<JobCheckpoint>,
245}
246
247/// Job status
248#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
249pub enum JobStatus {
250    Pending,
251    Running,
252    Completed,
253    Failed,
254    Cancelled,
255    Paused,
256}
257
258/// Batch input specification
259#[derive(Debug, Clone, Serialize, Deserialize)]
260pub struct BatchInput {
261    /// Input type
262    pub input_type: InputType,
263    /// Input source
264    pub source: String,
265    /// Filter criteria
266    pub filters: Option<HashMap<String, String>>,
267    /// Incremental mode settings
268    pub incremental: Option<IncrementalConfig>,
269}
270
271/// Input types
272#[derive(Debug, Clone, Serialize, Deserialize)]
273pub enum InputType {
274    /// List of entity IDs
275    EntityList,
276    /// File containing entity IDs
277    EntityFile,
278    /// SPARQL query result
279    SparqlQuery,
280    /// Database query
281    DatabaseQuery,
282    /// Stream source
283    StreamSource,
284}
285
286/// Incremental processing configuration
287#[derive(Debug, Clone, Serialize, Deserialize)]
288pub struct IncrementalConfig {
289    /// Enable incremental processing
290    pub enabled: bool,
291    /// Last processed timestamp
292    pub last_processed: Option<DateTime<Utc>>,
293    /// Timestamp field name
294    pub timestamp_field: String,
295    /// Check for deletions
296    pub check_deletions: bool,
297    /// Existing embeddings source
298    pub existing_embeddings_path: Option<String>,
299}
300
301/// Batch output specification
302#[derive(Debug, Clone, Serialize, Deserialize)]
303pub struct BatchOutput {
304    /// Output path
305    pub path: String,
306    /// Output format
307    pub format: String,
308    /// Compression settings
309    pub compression: Option<String>,
310    /// Partitioning strategy
311    pub partitioning: Option<PartitioningStrategy>,
312}
313
314/// Partitioning strategy
315#[derive(Debug, Clone, Serialize, Deserialize)]
316pub enum PartitioningStrategy {
317    /// No partitioning
318    None,
319    /// Partition by entity type
320    ByEntityType,
321    /// Partition by date
322    ByDate,
323    /// Partition by hash
324    ByHash { num_partitions: usize },
325    /// Custom partitioning
326    Custom { field: String },
327}
328
329/// Job-specific configuration
330#[derive(Debug, Clone, Serialize, Deserialize)]
331pub struct BatchJobConfig {
332    /// Chunk size for this job
333    pub chunk_size: usize,
334    /// Number of workers
335    pub num_workers: usize,
336    /// Retry configuration
337    pub max_retries: usize,
338    /// Enable caching
339    pub use_cache: bool,
340    /// Custom parameters
341    pub custom_params: HashMap<String, String>,
342}
343
344/// Job progress information
345#[derive(Debug, Clone, Serialize, Deserialize)]
346pub struct JobProgress {
347    /// Total entities to process
348    pub total_entities: usize,
349    /// Entities processed
350    pub processed_entities: usize,
351    /// Entities failed
352    pub failed_entities: usize,
353    /// Current chunk being processed
354    pub current_chunk: usize,
355    /// Total chunks
356    pub total_chunks: usize,
357    /// Processing rate (entities/second)
358    pub processing_rate: f64,
359    /// Estimated time remaining
360    pub eta_seconds: Option<u64>,
361    /// Memory usage (MB)
362    pub memory_usage_mb: f64,
363}
364
365impl Default for JobProgress {
366    fn default() -> Self {
367        Self {
368            total_entities: 0,
369            processed_entities: 0,
370            failed_entities: 0,
371            current_chunk: 0,
372            total_chunks: 0,
373            processing_rate: 0.0,
374            eta_seconds: None,
375            memory_usage_mb: 0.0,
376        }
377    }
378}
379
380/// Job checkpoint for resumability
381#[derive(Debug, Clone, Serialize, Deserialize)]
382pub struct JobCheckpoint {
383    /// Checkpoint timestamp
384    pub timestamp: DateTime<Utc>,
385    /// Last processed entity index
386    pub last_processed_index: usize,
387    /// Processed entity IDs
388    pub processed_entities: HashSet<String>,
389    /// Failed entity IDs with error messages
390    pub failed_entities: HashMap<String, String>,
391    /// Intermediate results path
392    pub intermediate_results_path: String,
393    /// Model state hash
394    pub model_state_hash: String,
395}
396
397/// Batch processing result
398#[derive(Debug, Clone, Serialize, Deserialize)]
399pub struct BatchProcessingResult {
400    /// Job ID
401    pub job_id: Uuid,
402    /// Processing statistics
403    pub stats: BatchProcessingStats,
404    /// Output information
405    pub output_info: OutputInfo,
406    /// Quality metrics
407    pub quality_metrics: Option<QualityMetrics>,
408}
409
410/// Batch processing statistics
411#[derive(Debug, Clone, Serialize, Deserialize)]
412pub struct BatchProcessingStats {
413    /// Total processing time
414    pub total_time_seconds: f64,
415    /// Total entities processed
416    pub total_entities: usize,
417    /// Successful embeddings
418    pub successful_embeddings: usize,
419    /// Failed embeddings
420    pub failed_embeddings: usize,
421    /// Cache hits
422    pub cache_hits: usize,
423    /// Cache misses
424    pub cache_misses: usize,
425    /// Average processing time per entity (ms)
426    pub avg_time_per_entity_ms: f64,
427    /// Peak memory usage (MB)
428    pub peak_memory_mb: f64,
429    /// CPU utilization
430    pub cpu_utilization: f64,
431}
432
433/// Output information
434#[derive(Debug, Clone, Serialize, Deserialize)]
435pub struct OutputInfo {
436    /// Output files created
437    pub output_files: Vec<String>,
438    /// Total output size (bytes)
439    pub total_size_bytes: u64,
440    /// Compression ratio
441    pub compression_ratio: f64,
442    /// Number of partitions
443    pub num_partitions: usize,
444}
445
446/// Quality metrics for batch processing
447#[derive(Debug, Clone, Serialize, Deserialize)]
448pub struct QualityMetrics {
449    /// Average embedding norm
450    pub avg_embedding_norm: f64,
451    /// Embedding norm standard deviation
452    pub embedding_norm_std: f64,
453    /// Average cosine similarity to centroid
454    pub avg_cosine_similarity: f64,
455    /// Embedding dimension
456    pub embedding_dimension: usize,
457    /// Number of zero embeddings
458    pub zero_embeddings: usize,
459    /// Number of NaN embeddings
460    pub nan_embeddings: usize,
461}
462
463impl BatchProcessingManager {
464    /// Create a new batch processing manager
465    pub fn new(
466        config: BatchProcessingConfig,
467        cache_manager: Arc<CacheManager>,
468        persistence_dir: PathBuf,
469    ) -> Self {
470        Self {
471            active_jobs: Arc::new(RwLock::new(HashMap::new())),
472            semaphore: Arc::new(Semaphore::new(config.max_workers)),
473            config,
474            cache_manager,
475            persistence_dir,
476        }
477    }
478
479    /// Submit a new batch job
480    pub async fn submit_job(&self, job: BatchJob) -> Result<Uuid> {
481        let job_id = job.job_id;
482
483        // Validate job
484        self.validate_job(&job).await?;
485
486        // Store job
487        {
488            let mut jobs = self.active_jobs.write().await;
489            jobs.insert(job_id, job.clone());
490        }
491
492        // Persist job configuration
493        self.persist_job(&job).await?;
494
495        info!("Submitted batch job: {} ({})", job.name, job_id);
496        Ok(job_id)
497    }
498
499    /// Start processing a batch job
500    pub async fn start_job(
501        &self,
502        job_id: Uuid,
503        model: Arc<dyn EmbeddingModel + Send + Sync>,
504    ) -> Result<JoinHandle<Result<BatchProcessingResult>>> {
505        let job = {
506            let mut jobs = self.active_jobs.write().await;
507            let job = jobs
508                .get_mut(&job_id)
509                .ok_or_else(|| anyhow!("Job not found: {}", job_id))?;
510
511            if !matches!(job.status, JobStatus::Pending | JobStatus::Paused) {
512                return Err(anyhow!("Job {} is not in a startable state", job_id));
513            }
514
515            job.status = JobStatus::Running;
516            job.started_at = Some(Utc::now());
517            job.clone()
518        };
519
520        let manager = self.clone();
521        let handle = tokio::spawn(async move { manager.process_job(job, model).await });
522
523        Ok(handle)
524    }
525
526    /// Process a batch job
527    async fn process_job(
528        &self,
529        job: BatchJob,
530        model: Arc<dyn EmbeddingModel + Send + Sync>,
531    ) -> Result<BatchProcessingResult> {
532        let start_time = Instant::now();
533        info!(
534            "Starting batch job processing: {} ({})",
535            job.name, job.job_id
536        );
537
538        // Load entities to process
539        let entities = self.load_entities(&job).await?;
540
541        // Filter entities for incremental processing
542        let entities_to_process = if job
543            .input
544            .incremental
545            .as_ref()
546            .map(|inc| inc.enabled)
547            .unwrap_or(false)
548        {
549            self.filter_incremental_entities(&job, entities).await?
550        } else {
551            entities
552        };
553
554        // Update job progress
555        {
556            let mut jobs = self.active_jobs.write().await;
557            if let Some(active_job) = jobs.get_mut(&job.job_id) {
558                active_job.progress.total_entities = entities_to_process.len();
559                active_job.progress.total_chunks =
560                    (entities_to_process.len() + job.config.chunk_size - 1) / job.config.chunk_size;
561            }
562        }
563
564        // Process entities in chunks
565        let chunks: Vec<_> = entities_to_process
566            .chunks(job.config.chunk_size)
567            .map(|chunk| chunk.to_vec())
568            .collect();
569
570        let mut successful_embeddings = 0;
571        let mut failed_embeddings = 0;
572        let mut cache_hits = 0;
573        let mut cache_misses = 0;
574        let mut processed_entities = HashSet::new();
575        let mut failed_entities = HashMap::new();
576
577        for (chunk_idx, chunk) in chunks.iter().enumerate() {
578            // Check if job was cancelled
579            {
580                let jobs = self.active_jobs.read().await;
581                if let Some(active_job) = jobs.get(&job.job_id) {
582                    if matches!(active_job.status, JobStatus::Cancelled) {
583                        info!("Job {} was cancelled", job.job_id);
584                        return Err(anyhow!("Job was cancelled"));
585                    }
586                }
587            }
588
589            // Process chunk
590            let chunk_result = self
591                .process_chunk(&job, chunk, chunk_idx, model.clone())
592                .await?;
593
594            // Update statistics
595            successful_embeddings += chunk_result.successful;
596            failed_embeddings += chunk_result.failed;
597            cache_hits += chunk_result.cache_hits;
598            cache_misses += chunk_result.cache_misses;
599
600            // Track processed entities
601            for entity in chunk {
602                processed_entities.insert(entity.clone());
603            }
604            for (entity, error) in chunk_result.failures {
605                failed_entities.insert(entity, error);
606            }
607
608            // Update progress
609            self.update_job_progress(
610                &job.job_id,
611                chunk_idx + 1,
612                successful_embeddings + failed_embeddings,
613            )
614            .await?;
615
616            // Create checkpoint
617            if chunk_idx % self.config.checkpoint_frequency == 0 {
618                self.create_checkpoint(&job.job_id, &processed_entities, &failed_entities)
619                    .await?;
620            }
621
622            info!(
623                "Processed chunk {}/{} for job {}",
624                chunk_idx + 1,
625                chunks.len(),
626                job.job_id
627            );
628        }
629
630        // Finalize processing
631        let processing_time = start_time.elapsed().as_secs_f64();
632        let result = self
633            .finalize_job_processing(
634                &job,
635                processing_time,
636                successful_embeddings,
637                failed_embeddings,
638                cache_hits,
639                cache_misses,
640            )
641            .await?;
642
643        // Update job status
644        {
645            let mut jobs = self.active_jobs.write().await;
646            if let Some(active_job) = jobs.get_mut(&job.job_id) {
647                active_job.status = JobStatus::Completed;
648                active_job.completed_at = Some(Utc::now());
649            }
650        }
651
652        info!(
653            "Completed batch job: {} in {:.2}s",
654            job.job_id, processing_time
655        );
656        Ok(result)
657    }
658
659    /// Process a single chunk of entities
660    async fn process_chunk(
661        &self,
662        job: &BatchJob,
663        entities: &[String],
664        chunk_idx: usize,
665        model: Arc<dyn EmbeddingModel + Send + Sync>,
666    ) -> Result<ChunkResult> {
667        let _permit = self.semaphore.acquire().await?;
668
669        let mut successful = 0;
670        let mut failed = 0;
671        let mut cache_hits = 0;
672        let mut cache_misses = 0;
673        let mut failures = HashMap::new();
674
675        for entity in entities {
676            match self
677                .process_single_entity(entity, model.clone(), job.config.use_cache)
678                .await
679            {
680                Ok(from_cache) => {
681                    successful += 1;
682                    if from_cache {
683                        cache_hits += 1;
684                    } else {
685                        cache_misses += 1;
686                    }
687                }
688                Err(e) => {
689                    failed += 1;
690                    failures.insert(entity.clone(), e.to_string());
691                    warn!("Failed to process entity {}: {}", entity, e);
692                }
693            }
694        }
695
696        Ok(ChunkResult {
697            chunk_idx,
698            successful,
699            failed,
700            cache_hits,
701            cache_misses,
702            failures,
703        })
704    }
705
706    /// Process a single entity
707    async fn process_single_entity(
708        &self,
709        entity: &str,
710        model: Arc<dyn EmbeddingModel + Send + Sync>,
711        use_cache: bool,
712    ) -> Result<bool> {
713        if use_cache {
714            // Check cache first
715            if let Some(_embedding) = self.cache_manager.get_embedding(entity) {
716                return Ok(true);
717            }
718        }
719
720        // Generate embedding
721        let embedding = model.get_entity_embedding(entity)?;
722
723        // Cache the result
724        if use_cache {
725            self.cache_manager
726                .put_embedding(entity.to_string(), embedding);
727        }
728
729        Ok(false)
730    }
731
732    /// Load entities to process based on input specification
733    async fn load_entities(&self, job: &BatchJob) -> Result<Vec<String>> {
734        match &job.input.input_type {
735            InputType::EntityList => {
736                // Parse entity list from source
737                let entities: Vec<String> = serde_json::from_str(&job.input.source)?;
738                Ok(entities)
739            }
740            InputType::EntityFile => {
741                // Read entities from file
742                let content = fs::read_to_string(&job.input.source).await?;
743                let entities: Vec<String> = content
744                    .lines()
745                    .map(|line| line.trim().to_string())
746                    .filter(|line| !line.is_empty())
747                    .collect();
748                Ok(entities)
749            }
750            InputType::SparqlQuery => {
751                // Execute SPARQL query and extract entities
752                // This would need to be implemented based on SPARQL engine
753                warn!("SPARQL query input type not yet implemented");
754                Ok(Vec::new())
755            }
756            InputType::DatabaseQuery => {
757                // Execute database query and extract entities
758                warn!("Database query input type not yet implemented");
759                Ok(Vec::new())
760            }
761            InputType::StreamSource => {
762                // Read from stream source
763                warn!("Stream source input type not yet implemented");
764                Ok(Vec::new())
765            }
766        }
767    }
768
769    /// Filter entities for incremental processing
770    async fn filter_incremental_entities(
771        &self,
772        job: &BatchJob,
773        entities: Vec<String>,
774    ) -> Result<Vec<String>> {
775        if let Some(incremental) = &job.input.incremental {
776            if !incremental.enabled {
777                return Ok(entities);
778            }
779
780            // Load existing embeddings if specified
781            let existing_entities =
782                if let Some(existing_path) = &incremental.existing_embeddings_path {
783                    self.load_existing_entities(existing_path).await?
784                } else {
785                    HashSet::new()
786                };
787
788            // Filter out entities that already have embeddings
789            let filtered: Vec<String> = entities
790                .into_iter()
791                .filter(|entity| !existing_entities.contains(entity))
792                .collect();
793
794            info!(
795                "Incremental filtering: {} entities remaining after filtering",
796                filtered.len()
797            );
798            Ok(filtered)
799        } else {
800            Ok(entities)
801        }
802    }
803
804    /// Load existing entities from embeddings file
805    async fn load_existing_entities(&self, path: &str) -> Result<HashSet<String>> {
806        // This would depend on the output format
807        // For now, assume a simple text file with entity IDs
808        if Path::new(path).exists() {
809            let content = fs::read_to_string(path).await?;
810            let entities: HashSet<String> = content
811                .lines()
812                .map(|line| line.trim().to_string())
813                .filter(|line| !line.is_empty())
814                .collect();
815            Ok(entities)
816        } else {
817            Ok(HashSet::new())
818        }
819    }
820
821    /// Update job progress
822    async fn update_job_progress(
823        &self,
824        job_id: &Uuid,
825        current_chunk: usize,
826        processed_entities: usize,
827    ) -> Result<()> {
828        let mut jobs = self.active_jobs.write().await;
829        if let Some(job) = jobs.get_mut(job_id) {
830            job.progress.current_chunk = current_chunk;
831            job.progress.processed_entities = processed_entities;
832
833            // Calculate processing rate
834            if let Some(started_at) = job.started_at {
835                let elapsed = Utc::now().signed_duration_since(started_at);
836                let elapsed_seconds = elapsed.num_seconds() as f64;
837                if elapsed_seconds > 0.0 {
838                    job.progress.processing_rate = processed_entities as f64 / elapsed_seconds;
839
840                    // Estimate time remaining
841                    let remaining_entities = job.progress.total_entities - processed_entities;
842                    if job.progress.processing_rate > 0.0 {
843                        let eta = remaining_entities as f64 / job.progress.processing_rate;
844                        job.progress.eta_seconds = Some(eta as u64);
845                    }
846                }
847            }
848        }
849        Ok(())
850    }
851
852    /// Create a checkpoint for job resumability
853    async fn create_checkpoint(
854        &self,
855        job_id: &Uuid,
856        processed_entities: &HashSet<String>,
857        failed_entities: &HashMap<String, String>,
858    ) -> Result<()> {
859        let checkpoint = JobCheckpoint {
860            timestamp: Utc::now(),
861            last_processed_index: processed_entities.len(),
862            processed_entities: processed_entities.clone(),
863            failed_entities: failed_entities.clone(),
864            intermediate_results_path: format!(
865                "{}/checkpoint_{}.json",
866                self.persistence_dir.display(),
867                job_id
868            ),
869            model_state_hash: "placeholder".to_string(), // Would calculate actual hash
870        };
871
872        // Save checkpoint to disk
873        let checkpoint_path = self
874            .persistence_dir
875            .join(format!("checkpoint_{job_id}.json"));
876        let checkpoint_json = serde_json::to_string_pretty(&checkpoint)?;
877        fs::write(checkpoint_path, checkpoint_json).await?;
878
879        // Update job with checkpoint
880        let mut jobs = self.active_jobs.write().await;
881        if let Some(job) = jobs.get_mut(job_id) {
882            job.checkpoint = Some(checkpoint);
883        }
884
885        debug!("Created checkpoint for job {}", job_id);
886        Ok(())
887    }
888
889    /// Finalize job processing and create result
890    async fn finalize_job_processing(
891        &self,
892        job: &BatchJob,
893        processing_time: f64,
894        successful_embeddings: usize,
895        failed_embeddings: usize,
896        cache_hits: usize,
897        cache_misses: usize,
898    ) -> Result<BatchProcessingResult> {
899        let total_entities = successful_embeddings + failed_embeddings;
900        let avg_time_per_entity_ms = if total_entities > 0 {
901            (processing_time * 1000.0) / total_entities as f64
902        } else {
903            0.0
904        };
905
906        let stats = BatchProcessingStats {
907            total_time_seconds: processing_time,
908            total_entities,
909            successful_embeddings,
910            failed_embeddings,
911            cache_hits,
912            cache_misses,
913            avg_time_per_entity_ms,
914            peak_memory_mb: 0.0,  // Would measure actual memory usage
915            cpu_utilization: 0.0, // Would measure actual CPU usage
916        };
917
918        let output_info = OutputInfo {
919            output_files: vec![job.output.path.clone()],
920            total_size_bytes: 0, // Would calculate actual size
921            compression_ratio: 1.0,
922            num_partitions: 1,
923        };
924
925        Ok(BatchProcessingResult {
926            job_id: job.job_id,
927            stats,
928            output_info,
929            quality_metrics: None, // Would calculate if requested
930        })
931    }
932
933    /// Validate a batch job before submission
934    async fn validate_job(&self, job: &BatchJob) -> Result<()> {
935        // Validate input source exists
936        if let InputType::EntityFile = &job.input.input_type {
937            if !Path::new(&job.input.source).exists() {
938                return Err(anyhow!("Input file does not exist: {}", job.input.source));
939            }
940        } // Other validations would be implemented
941
942        // Validate output path is writable
943        if let Some(parent) = Path::new(&job.output.path).parent() {
944            if !parent.exists() {
945                fs::create_dir_all(parent).await?;
946            }
947        }
948
949        Ok(())
950    }
951
952    /// Persist job configuration to disk
953    async fn persist_job(&self, job: &BatchJob) -> Result<()> {
954        let job_path = self
955            .persistence_dir
956            .join(format!("job_{}.json", job.job_id));
957        let job_json = serde_json::to_string_pretty(job)?;
958        fs::write(job_path, job_json).await?;
959        Ok(())
960    }
961
962    /// Get job status
963    pub async fn get_job_status(&self, job_id: &Uuid) -> Option<JobStatus> {
964        let jobs = self.active_jobs.read().await;
965        jobs.get(job_id).map(|job| job.status.clone())
966    }
967
968    /// Get job progress
969    pub async fn get_job_progress(&self, job_id: &Uuid) -> Option<JobProgress> {
970        let jobs = self.active_jobs.read().await;
971        jobs.get(job_id).map(|job| job.progress.clone())
972    }
973
974    /// Cancel a job
975    pub async fn cancel_job(&self, job_id: &Uuid) -> Result<()> {
976        let mut jobs = self.active_jobs.write().await;
977        if let Some(job) = jobs.get_mut(job_id) {
978            job.status = JobStatus::Cancelled;
979            info!("Cancelled job: {}", job_id);
980            Ok(())
981        } else {
982            Err(anyhow!("Job not found: {}", job_id))
983        }
984    }
985
986    /// List all jobs
987    pub async fn list_jobs(&self) -> Vec<BatchJob> {
988        let jobs = self.active_jobs.read().await;
989        jobs.values().cloned().collect()
990    }
991}
992
993impl Clone for BatchProcessingManager {
994    fn clone(&self) -> Self {
995        Self {
996            active_jobs: Arc::clone(&self.active_jobs),
997            config: self.config.clone(),
998            cache_manager: Arc::clone(&self.cache_manager),
999            semaphore: Arc::clone(&self.semaphore),
1000            persistence_dir: self.persistence_dir.clone(),
1001        }
1002    }
1003}
1004
1005/// Result of processing a single chunk
1006#[derive(Debug)]
1007#[allow(dead_code)]
1008struct ChunkResult {
1009    chunk_idx: usize,
1010    successful: usize,
1011    failed: usize,
1012    cache_hits: usize,
1013    cache_misses: usize,
1014    failures: HashMap<String, String>,
1015}
1016
1017#[cfg(test)]
1018mod tests {
1019    use super::*;
1020    use tempfile::tempdir;
1021
1022    #[test]
1023    fn test_batch_job_creation() {
1024        let job = BatchJob {
1025            job_id: Uuid::new_v4(),
1026            name: "test_job".to_string(),
1027            status: JobStatus::Pending,
1028            input: BatchInput {
1029                input_type: InputType::EntityList,
1030                source: r#"["entity1", "entity2", "entity3"]"#.to_string(),
1031                filters: None,
1032                incremental: None,
1033            },
1034            output: BatchOutput {
1035                path: "/tmp/output".to_string(),
1036                format: "parquet".to_string(),
1037                compression: Some("gzip".to_string()),
1038                partitioning: Some(PartitioningStrategy::None),
1039            },
1040            config: BatchJobConfig {
1041                chunk_size: 100,
1042                num_workers: 4,
1043                max_retries: 3,
1044                use_cache: true,
1045                custom_params: HashMap::new(),
1046            },
1047            model_id: Uuid::new_v4(),
1048            created_at: Utc::now(),
1049            started_at: None,
1050            completed_at: None,
1051            progress: JobProgress::default(),
1052            error: None,
1053            checkpoint: None,
1054        };
1055
1056        assert_eq!(job.status, JobStatus::Pending);
1057        assert_eq!(job.name, "test_job");
1058    }
1059
1060    #[tokio::test]
1061    async fn test_batch_processing_manager_creation() {
1062        let config = BatchProcessingConfig::default();
1063        let cache_config = crate::CacheConfig::default();
1064        let cache_manager = Arc::new(CacheManager::new(cache_config));
1065        let temp_dir = tempdir().unwrap();
1066
1067        let manager =
1068            BatchProcessingManager::new(config, cache_manager, temp_dir.path().to_path_buf());
1069
1070        assert_eq!(manager.config.max_workers, num_cpus::get());
1071        assert_eq!(manager.config.chunk_size, 1000);
1072    }
1073
1074    #[test]
1075    fn test_incremental_config() {
1076        let incremental = IncrementalConfig {
1077            enabled: true,
1078            last_processed: Some(Utc::now()),
1079            timestamp_field: "updated_at".to_string(),
1080            check_deletions: true,
1081            existing_embeddings_path: Some("/path/to/existing".to_string()),
1082        };
1083
1084        assert!(incremental.enabled);
1085        assert!(incremental.last_processed.is_some());
1086        assert_eq!(incremental.timestamp_field, "updated_at");
1087    }
1088
1089    #[test]
1090    fn test_memory_optimized_batch_iterator() {
1091        let data = vec![1, 2, 3, 4, 5, 6, 7, 8, 9, 10];
1092        let mut iterator = MemoryOptimizedBatchIterator::new(data.clone(), 3, 1); // 1MB limit
1093
1094        // Test first batch
1095        let batch1 = iterator.next_batch().unwrap();
1096        assert_eq!(batch1.len(), 3);
1097        assert_eq!(batch1, vec![1, 2, 3]);
1098        assert_eq!(iterator.get_progress(), 0.3);
1099        assert!(!iterator.is_finished());
1100
1101        // Test second batch
1102        let batch2 = iterator.next_batch().unwrap();
1103        assert_eq!(batch2.len(), 3);
1104        assert_eq!(batch2, vec![4, 5, 6]);
1105        assert_eq!(iterator.get_progress(), 0.6);
1106
1107        // Test third batch
1108        let batch3 = iterator.next_batch().unwrap();
1109        assert_eq!(batch3.len(), 3);
1110        assert_eq!(batch3, vec![7, 8, 9]);
1111        assert_eq!(iterator.get_progress(), 0.9);
1112
1113        // Test final batch
1114        let batch4 = iterator.next_batch().unwrap();
1115        assert_eq!(batch4.len(), 1);
1116        assert_eq!(batch4, vec![10]);
1117        assert_eq!(iterator.get_progress(), 1.0);
1118        assert!(iterator.is_finished());
1119
1120        // Test empty batch
1121        let batch5 = iterator.next_batch();
1122        assert!(batch5.is_none());
1123    }
1124
1125    #[test]
1126    fn test_memory_optimized_batch_iterator_empty() {
1127        let data: Vec<i32> = vec![];
1128        let mut iterator = MemoryOptimizedBatchIterator::new(data, 3, 1);
1129
1130        assert_eq!(iterator.get_progress(), 1.0);
1131        assert!(iterator.is_finished());
1132        assert!(iterator.next_batch().is_none());
1133    }
1134
1135    #[test]
1136    fn test_memory_optimized_batch_iterator_single_item() {
1137        let data = vec![42];
1138        let mut iterator = MemoryOptimizedBatchIterator::new(data, 5, 1);
1139
1140        let batch = iterator.next_batch().unwrap();
1141        assert_eq!(batch.len(), 1);
1142        assert_eq!(batch[0], 42);
1143        assert_eq!(iterator.get_progress(), 1.0);
1144        assert!(iterator.is_finished());
1145    }
1146
1147    #[test]
1148    fn test_memory_optimized_batch_iterator_memory_tracking() {
1149        let data = vec![1, 2, 3, 4, 5];
1150        let mut iterator = MemoryOptimizedBatchIterator::new(data, 3, 1);
1151
1152        // Process one batch and check memory usage
1153        let _batch = iterator.next_batch().unwrap();
1154        let memory_usage = iterator.get_memory_usage();
1155        assert!(memory_usage > 0);
1156
1157        // Memory usage should be roughly 3 * size_of::<i32>()
1158        let expected_memory = 3 * std::mem::size_of::<i32>();
1159        assert_eq!(memory_usage, expected_memory);
1160    }
1161}