1use crate::{CacheManager, EmbeddingModel};
8use anyhow::{anyhow, Result};
9use chrono::{DateTime, Utc};
10use rayon::prelude::*;
11use serde::{Deserialize, Serialize};
12use std::collections::{HashMap, HashSet};
13use std::path::{Path, PathBuf};
14use std::sync::Arc;
15use std::time::Instant;
16use tokio::fs;
17use tokio::sync::{RwLock, Semaphore};
18use tokio::task::JoinHandle;
19use tracing::{debug, info, warn};
20use uuid::Uuid;
21
22pub struct MemoryOptimizedBatchIterator<T> {
24 data: Vec<T>,
26 position: usize,
28 batch_size: usize,
30 memory_usage: usize,
32 max_memory_bytes: usize,
34}
35
36impl<T> MemoryOptimizedBatchIterator<T> {
37 pub fn new(data: Vec<T>, batch_size: usize, max_memory_mb: usize) -> Self {
39 Self {
40 data,
41 position: 0,
42 batch_size,
43 memory_usage: 0,
44 max_memory_bytes: max_memory_mb * 1024 * 1024,
45 }
46 }
47
48 pub fn next_batch(&mut self) -> Option<Vec<T>>
50 where
51 T: Clone,
52 {
53 if self.position >= self.data.len() {
54 return None;
55 }
56
57 let mut batch = Vec::new();
58 let mut current_memory = 0;
59 let item_size = std::mem::size_of::<T>();
60
61 while self.position < self.data.len()
63 && batch.len() < self.batch_size
64 && current_memory + item_size <= self.max_memory_bytes
65 {
66 batch.push(self.data[self.position].clone());
67 self.position += 1;
68 current_memory += item_size;
69 }
70
71 self.memory_usage = current_memory;
72
73 if batch.is_empty() {
74 None
75 } else {
76 Some(batch)
77 }
78 }
79
80 pub fn get_memory_usage(&self) -> usize {
82 self.memory_usage
83 }
84
85 pub fn get_progress(&self) -> f64 {
87 if self.data.is_empty() {
88 1.0
89 } else {
90 self.position as f64 / self.data.len() as f64
91 }
92 }
93
94 pub fn is_finished(&self) -> bool {
96 self.position >= self.data.len()
97 }
98}
99
100pub struct BatchProcessingManager {
102 active_jobs: Arc<RwLock<HashMap<Uuid, BatchJob>>>,
104 config: BatchProcessingConfig,
106 cache_manager: Arc<CacheManager>,
108 semaphore: Arc<Semaphore>,
110 persistence_dir: PathBuf,
112}
113
114#[derive(Debug, Clone)]
116pub struct BatchProcessingConfig {
117 pub max_workers: usize,
119 pub chunk_size: usize,
121 pub enable_incremental: bool,
123 pub checkpoint_frequency: usize,
125 pub enable_resume: bool,
127 pub max_memory_per_worker_mb: usize,
129 pub enable_notifications: bool,
131 pub retry_config: RetryConfig,
133 pub output_config: OutputConfig,
135}
136
137impl Default for BatchProcessingConfig {
138 fn default() -> Self {
139 Self {
140 max_workers: num_cpus::get(),
141 chunk_size: 1000,
142 enable_incremental: true,
143 checkpoint_frequency: 10,
144 enable_resume: true,
145 max_memory_per_worker_mb: 512,
146 enable_notifications: true,
147 retry_config: RetryConfig::default(),
148 output_config: OutputConfig::default(),
149 }
150 }
151}
152
153#[derive(Debug, Clone)]
155pub struct RetryConfig {
156 pub max_retries: usize,
158 pub initial_backoff_ms: u64,
160 pub max_backoff_ms: u64,
162 pub backoff_multiplier: f64,
164}
165
166impl Default for RetryConfig {
167 fn default() -> Self {
168 Self {
169 max_retries: 3,
170 initial_backoff_ms: 1000,
171 max_backoff_ms: 30000,
172 backoff_multiplier: 2.0,
173 }
174 }
175}
176
177#[derive(Debug, Clone)]
179pub struct OutputConfig {
180 pub format: OutputFormat,
182 pub compression_level: u32,
184 pub include_metadata: bool,
186 pub batch_output: bool,
188 pub max_entities_per_file: usize,
190}
191
192impl Default for OutputConfig {
193 fn default() -> Self {
194 Self {
195 format: OutputFormat::Parquet,
196 compression_level: 6,
197 include_metadata: true,
198 batch_output: true,
199 max_entities_per_file: 100_000,
200 }
201 }
202}
203
204#[derive(Debug, Clone)]
206pub enum OutputFormat {
207 Parquet,
209 JsonLines,
211 Binary,
213 HDF5,
215}
216
217#[derive(Debug, Clone, Serialize, Deserialize)]
219pub struct BatchJob {
220 pub job_id: Uuid,
222 pub name: String,
224 pub status: JobStatus,
226 pub input: BatchInput,
228 pub output: BatchOutput,
230 pub config: BatchJobConfig,
232 pub model_id: Uuid,
234 pub created_at: DateTime<Utc>,
236 pub started_at: Option<DateTime<Utc>>,
238 pub completed_at: Option<DateTime<Utc>>,
240 pub progress: JobProgress,
242 pub error: Option<String>,
244 pub checkpoint: Option<JobCheckpoint>,
246}
247
248#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
250pub enum JobStatus {
251 Pending,
252 Running,
253 Completed,
254 Failed,
255 Cancelled,
256 Paused,
257}
258
259#[derive(Debug, Clone, Serialize, Deserialize)]
261pub struct BatchInput {
262 pub input_type: InputType,
264 pub source: String,
266 pub filters: Option<HashMap<String, String>>,
268 pub incremental: Option<IncrementalConfig>,
270}
271
272#[derive(Debug, Clone, Serialize, Deserialize)]
274pub enum InputType {
275 EntityList,
277 EntityFile,
279 SparqlQuery,
281 DatabaseQuery,
283 StreamSource,
285}
286
287#[derive(Debug, Clone, Serialize, Deserialize)]
289pub struct IncrementalConfig {
290 pub enabled: bool,
292 pub last_processed: Option<DateTime<Utc>>,
294 pub timestamp_field: String,
296 pub check_deletions: bool,
298 pub existing_embeddings_path: Option<String>,
300}
301
302#[derive(Debug, Clone, Serialize, Deserialize)]
304pub struct BatchOutput {
305 pub path: String,
307 pub format: String,
309 pub compression: Option<String>,
311 pub partitioning: Option<PartitioningStrategy>,
313}
314
315#[derive(Debug, Clone, Serialize, Deserialize)]
317pub enum PartitioningStrategy {
318 None,
320 ByEntityType,
322 ByDate,
324 ByHash { num_partitions: usize },
326 Custom { field: String },
328}
329
330#[derive(Debug, Clone, Serialize, Deserialize)]
332pub struct BatchJobConfig {
333 pub chunk_size: usize,
335 pub num_workers: usize,
337 pub max_retries: usize,
339 pub use_cache: bool,
341 pub custom_params: HashMap<String, String>,
343}
344
345#[derive(Debug, Clone, Serialize, Deserialize)]
347pub struct JobProgress {
348 pub total_entities: usize,
350 pub processed_entities: usize,
352 pub failed_entities: usize,
354 pub current_chunk: usize,
356 pub total_chunks: usize,
358 pub processing_rate: f64,
360 pub eta_seconds: Option<u64>,
362 pub memory_usage_mb: f64,
364}
365
366impl Default for JobProgress {
367 fn default() -> Self {
368 Self {
369 total_entities: 0,
370 processed_entities: 0,
371 failed_entities: 0,
372 current_chunk: 0,
373 total_chunks: 0,
374 processing_rate: 0.0,
375 eta_seconds: None,
376 memory_usage_mb: 0.0,
377 }
378 }
379}
380
381#[derive(Debug, Clone, Serialize, Deserialize)]
383pub struct JobCheckpoint {
384 pub timestamp: DateTime<Utc>,
386 pub last_processed_index: usize,
388 pub processed_entities: HashSet<String>,
390 pub failed_entities: HashMap<String, String>,
392 pub intermediate_results_path: String,
394 pub model_state_hash: String,
396}
397
398#[derive(Debug, Clone, Serialize, Deserialize)]
400pub struct BatchProcessingResult {
401 pub job_id: Uuid,
403 pub stats: BatchProcessingStats,
405 pub output_info: OutputInfo,
407 pub quality_metrics: Option<QualityMetrics>,
409}
410
411#[derive(Debug, Clone, Serialize, Deserialize)]
413pub struct BatchProcessingStats {
414 pub total_time_seconds: f64,
416 pub total_entities: usize,
418 pub successful_embeddings: usize,
420 pub failed_embeddings: usize,
422 pub cache_hits: usize,
424 pub cache_misses: usize,
426 pub avg_time_per_entity_ms: f64,
428 pub peak_memory_mb: f64,
430 pub cpu_utilization: f64,
432}
433
434#[derive(Debug, Clone, Serialize, Deserialize)]
436pub struct OutputInfo {
437 pub output_files: Vec<String>,
439 pub total_size_bytes: u64,
441 pub compression_ratio: f64,
443 pub num_partitions: usize,
445}
446
447#[derive(Debug, Clone, Serialize, Deserialize)]
449pub struct QualityMetrics {
450 pub avg_embedding_norm: f64,
452 pub embedding_norm_std: f64,
454 pub avg_cosine_similarity: f64,
456 pub embedding_dimension: usize,
458 pub zero_embeddings: usize,
460 pub nan_embeddings: usize,
462}
463
464impl BatchProcessingManager {
465 pub fn new(
467 config: BatchProcessingConfig,
468 cache_manager: Arc<CacheManager>,
469 persistence_dir: PathBuf,
470 ) -> Self {
471 Self {
472 active_jobs: Arc::new(RwLock::new(HashMap::new())),
473 semaphore: Arc::new(Semaphore::new(config.max_workers)),
474 config,
475 cache_manager,
476 persistence_dir,
477 }
478 }
479
480 pub async fn submit_job(&self, job: BatchJob) -> Result<Uuid> {
482 let job_id = job.job_id;
483
484 self.validate_job(&job).await?;
486
487 {
489 let mut jobs = self.active_jobs.write().await;
490 jobs.insert(job_id, job.clone());
491 }
492
493 self.persist_job(&job).await?;
495
496 info!("Submitted batch job: {} ({})", job.name, job_id);
497 Ok(job_id)
498 }
499
500 pub async fn start_job(
502 &self,
503 job_id: Uuid,
504 model: Arc<dyn EmbeddingModel + Send + Sync>,
505 ) -> Result<JoinHandle<Result<BatchProcessingResult>>> {
506 let job = {
507 let mut jobs = self.active_jobs.write().await;
508 let job = jobs
509 .get_mut(&job_id)
510 .ok_or_else(|| anyhow!("Job not found: {}", job_id))?;
511
512 if !matches!(job.status, JobStatus::Pending | JobStatus::Paused) {
513 return Err(anyhow!("Job {} is not in a startable state", job_id));
514 }
515
516 job.status = JobStatus::Running;
517 job.started_at = Some(Utc::now());
518 job.clone()
519 };
520
521 let manager = self.clone();
522 let handle = tokio::spawn(async move { manager.process_job(job, model).await });
523
524 Ok(handle)
525 }
526
527 async fn process_job(
529 &self,
530 job: BatchJob,
531 model: Arc<dyn EmbeddingModel + Send + Sync>,
532 ) -> Result<BatchProcessingResult> {
533 let start_time = Instant::now();
534 info!(
535 "Starting batch job processing: {} ({})",
536 job.name, job.job_id
537 );
538
539 let entities = self.load_entities(&job).await?;
541
542 let entities_to_process = if job
544 .input
545 .incremental
546 .as_ref()
547 .map(|inc| inc.enabled)
548 .unwrap_or(false)
549 {
550 self.filter_incremental_entities(&job, entities).await?
551 } else {
552 entities
553 };
554
555 {
557 let mut jobs = self.active_jobs.write().await;
558 if let Some(active_job) = jobs.get_mut(&job.job_id) {
559 active_job.progress.total_entities = entities_to_process.len();
560 active_job.progress.total_chunks =
561 (entities_to_process.len() + job.config.chunk_size - 1) / job.config.chunk_size;
562 }
563 }
564
565 let chunks: Vec<_> = entities_to_process
567 .chunks(job.config.chunk_size)
568 .map(|chunk| chunk.to_vec())
569 .collect();
570
571 let mut successful_embeddings = 0;
572 let mut failed_embeddings = 0;
573 let mut cache_hits = 0;
574 let mut cache_misses = 0;
575 let mut processed_entities = HashSet::new();
576 let mut failed_entities = HashMap::new();
577
578 for (chunk_idx, chunk) in chunks.iter().enumerate() {
579 {
581 let jobs = self.active_jobs.read().await;
582 if let Some(active_job) = jobs.get(&job.job_id) {
583 if matches!(active_job.status, JobStatus::Cancelled) {
584 info!("Job {} was cancelled", job.job_id);
585 return Err(anyhow!("Job was cancelled"));
586 }
587 }
588 }
589
590 let chunk_result = self
592 .process_chunk(&job, chunk, chunk_idx, model.clone())
593 .await?;
594
595 successful_embeddings += chunk_result.successful;
597 failed_embeddings += chunk_result.failed;
598 cache_hits += chunk_result.cache_hits;
599 cache_misses += chunk_result.cache_misses;
600
601 for entity in chunk {
603 processed_entities.insert(entity.clone());
604 }
605 for (entity, error) in chunk_result.failures {
606 failed_entities.insert(entity, error);
607 }
608
609 self.update_job_progress(
611 &job.job_id,
612 chunk_idx + 1,
613 successful_embeddings + failed_embeddings,
614 )
615 .await?;
616
617 if chunk_idx % self.config.checkpoint_frequency == 0 {
619 self.create_checkpoint(&job.job_id, &processed_entities, &failed_entities)
620 .await?;
621 }
622
623 info!(
624 "Processed chunk {}/{} for job {}",
625 chunk_idx + 1,
626 chunks.len(),
627 job.job_id
628 );
629 }
630
631 let processing_time = start_time.elapsed().as_secs_f64();
633 let result = self
634 .finalize_job_processing(
635 &job,
636 processing_time,
637 successful_embeddings,
638 failed_embeddings,
639 cache_hits,
640 cache_misses,
641 )
642 .await?;
643
644 {
646 let mut jobs = self.active_jobs.write().await;
647 if let Some(active_job) = jobs.get_mut(&job.job_id) {
648 active_job.status = JobStatus::Completed;
649 active_job.completed_at = Some(Utc::now());
650 }
651 }
652
653 info!(
654 "Completed batch job: {} in {:.2}s",
655 job.job_id, processing_time
656 );
657 Ok(result)
658 }
659
660 async fn process_chunk(
662 &self,
663 job: &BatchJob,
664 entities: &[String],
665 chunk_idx: usize,
666 model: Arc<dyn EmbeddingModel + Send + Sync>,
667 ) -> Result<ChunkResult> {
668 let _permit = self.semaphore.acquire().await?;
669
670 let mut successful = 0;
671 let mut failed = 0;
672 let mut cache_hits = 0;
673 let mut cache_misses = 0;
674 let mut failures = HashMap::new();
675
676 for entity in entities {
677 match self
678 .process_single_entity(entity, model.clone(), job.config.use_cache)
679 .await
680 {
681 Ok(from_cache) => {
682 successful += 1;
683 if from_cache {
684 cache_hits += 1;
685 } else {
686 cache_misses += 1;
687 }
688 }
689 Err(e) => {
690 failed += 1;
691 failures.insert(entity.clone(), e.to_string());
692 warn!("Failed to process entity {}: {}", entity, e);
693 }
694 }
695 }
696
697 Ok(ChunkResult {
698 chunk_idx,
699 successful,
700 failed,
701 cache_hits,
702 cache_misses,
703 failures,
704 })
705 }
706
707 async fn process_single_entity(
709 &self,
710 entity: &str,
711 model: Arc<dyn EmbeddingModel + Send + Sync>,
712 use_cache: bool,
713 ) -> Result<bool> {
714 if use_cache {
715 if let Some(_embedding) = self.cache_manager.get_embedding(entity) {
717 return Ok(true);
718 }
719 }
720
721 let embedding = model.get_entity_embedding(entity)?;
723
724 if use_cache {
726 self.cache_manager
727 .put_embedding(entity.to_string(), embedding);
728 }
729
730 Ok(false)
731 }
732
733 async fn load_entities(&self, job: &BatchJob) -> Result<Vec<String>> {
735 match &job.input.input_type {
736 InputType::EntityList => {
737 let entities: Vec<String> = serde_json::from_str(&job.input.source)?;
739 Ok(entities)
740 }
741 InputType::EntityFile => {
742 let content = fs::read_to_string(&job.input.source).await?;
744 let entities: Vec<String> = content
745 .lines()
746 .map(|line| line.trim().to_string())
747 .filter(|line| !line.is_empty())
748 .collect();
749 Ok(entities)
750 }
751 InputType::SparqlQuery => {
752 warn!("SPARQL query input type not yet implemented");
755 Ok(Vec::new())
756 }
757 InputType::DatabaseQuery => {
758 warn!("Database query input type not yet implemented");
760 Ok(Vec::new())
761 }
762 InputType::StreamSource => {
763 warn!("Stream source input type not yet implemented");
765 Ok(Vec::new())
766 }
767 }
768 }
769
770 async fn filter_incremental_entities(
772 &self,
773 job: &BatchJob,
774 entities: Vec<String>,
775 ) -> Result<Vec<String>> {
776 if let Some(incremental) = &job.input.incremental {
777 if !incremental.enabled {
778 return Ok(entities);
779 }
780
781 let existing_entities =
783 if let Some(existing_path) = &incremental.existing_embeddings_path {
784 self.load_existing_entities(existing_path).await?
785 } else {
786 HashSet::new()
787 };
788
789 let filtered: Vec<String> = entities
791 .into_iter()
792 .filter(|entity| !existing_entities.contains(entity))
793 .collect();
794
795 info!(
796 "Incremental filtering: {} entities remaining after filtering",
797 filtered.len()
798 );
799 Ok(filtered)
800 } else {
801 Ok(entities)
802 }
803 }
804
805 async fn load_existing_entities(&self, path: &str) -> Result<HashSet<String>> {
807 if Path::new(path).exists() {
810 let content = fs::read_to_string(path).await?;
811 let entities: HashSet<String> = content
812 .lines()
813 .map(|line| line.trim().to_string())
814 .filter(|line| !line.is_empty())
815 .collect();
816 Ok(entities)
817 } else {
818 Ok(HashSet::new())
819 }
820 }
821
822 async fn update_job_progress(
824 &self,
825 job_id: &Uuid,
826 current_chunk: usize,
827 processed_entities: usize,
828 ) -> Result<()> {
829 let mut jobs = self.active_jobs.write().await;
830 if let Some(job) = jobs.get_mut(job_id) {
831 job.progress.current_chunk = current_chunk;
832 job.progress.processed_entities = processed_entities;
833
834 if let Some(started_at) = job.started_at {
836 let elapsed = Utc::now().signed_duration_since(started_at);
837 let elapsed_seconds = elapsed.num_seconds() as f64;
838 if elapsed_seconds > 0.0 {
839 job.progress.processing_rate = processed_entities as f64 / elapsed_seconds;
840
841 let remaining_entities = job.progress.total_entities - processed_entities;
843 if job.progress.processing_rate > 0.0 {
844 let eta = remaining_entities as f64 / job.progress.processing_rate;
845 job.progress.eta_seconds = Some(eta as u64);
846 }
847 }
848 }
849 }
850 Ok(())
851 }
852
853 async fn create_checkpoint(
855 &self,
856 job_id: &Uuid,
857 processed_entities: &HashSet<String>,
858 failed_entities: &HashMap<String, String>,
859 ) -> Result<()> {
860 let checkpoint = JobCheckpoint {
861 timestamp: Utc::now(),
862 last_processed_index: processed_entities.len(),
863 processed_entities: processed_entities.clone(),
864 failed_entities: failed_entities.clone(),
865 intermediate_results_path: format!(
866 "{}/checkpoint_{}.json",
867 self.persistence_dir.display(),
868 job_id
869 ),
870 model_state_hash: "placeholder".to_string(), };
872
873 let checkpoint_path = self
875 .persistence_dir
876 .join(format!("checkpoint_{job_id}.json"));
877 let checkpoint_json = serde_json::to_string_pretty(&checkpoint)?;
878 fs::write(checkpoint_path, checkpoint_json).await?;
879
880 let mut jobs = self.active_jobs.write().await;
882 if let Some(job) = jobs.get_mut(job_id) {
883 job.checkpoint = Some(checkpoint);
884 }
885
886 debug!("Created checkpoint for job {}", job_id);
887 Ok(())
888 }
889
890 async fn finalize_job_processing(
892 &self,
893 job: &BatchJob,
894 processing_time: f64,
895 successful_embeddings: usize,
896 failed_embeddings: usize,
897 cache_hits: usize,
898 cache_misses: usize,
899 ) -> Result<BatchProcessingResult> {
900 let total_entities = successful_embeddings + failed_embeddings;
901 let avg_time_per_entity_ms = if total_entities > 0 {
902 (processing_time * 1000.0) / total_entities as f64
903 } else {
904 0.0
905 };
906
907 let stats = BatchProcessingStats {
908 total_time_seconds: processing_time,
909 total_entities,
910 successful_embeddings,
911 failed_embeddings,
912 cache_hits,
913 cache_misses,
914 avg_time_per_entity_ms,
915 peak_memory_mb: 0.0, cpu_utilization: 0.0, };
918
919 let output_info = OutputInfo {
920 output_files: vec![job.output.path.clone()],
921 total_size_bytes: 0, compression_ratio: 1.0,
923 num_partitions: 1,
924 };
925
926 Ok(BatchProcessingResult {
927 job_id: job.job_id,
928 stats,
929 output_info,
930 quality_metrics: None, })
932 }
933
934 async fn validate_job(&self, job: &BatchJob) -> Result<()> {
936 if let InputType::EntityFile = &job.input.input_type {
938 if !Path::new(&job.input.source).exists() {
939 return Err(anyhow!("Input file does not exist: {}", job.input.source));
940 }
941 } if let Some(parent) = Path::new(&job.output.path).parent() {
945 if !parent.exists() {
946 fs::create_dir_all(parent).await?;
947 }
948 }
949
950 Ok(())
951 }
952
953 async fn persist_job(&self, job: &BatchJob) -> Result<()> {
955 let job_path = self
956 .persistence_dir
957 .join(format!("job_{}.json", job.job_id));
958 let job_json = serde_json::to_string_pretty(job)?;
959 fs::write(job_path, job_json).await?;
960 Ok(())
961 }
962
963 pub async fn get_job_status(&self, job_id: &Uuid) -> Option<JobStatus> {
965 let jobs = self.active_jobs.read().await;
966 jobs.get(job_id).map(|job| job.status.clone())
967 }
968
969 pub async fn get_job_progress(&self, job_id: &Uuid) -> Option<JobProgress> {
971 let jobs = self.active_jobs.read().await;
972 jobs.get(job_id).map(|job| job.progress.clone())
973 }
974
975 pub async fn cancel_job(&self, job_id: &Uuid) -> Result<()> {
977 let mut jobs = self.active_jobs.write().await;
978 if let Some(job) = jobs.get_mut(job_id) {
979 job.status = JobStatus::Cancelled;
980 info!("Cancelled job: {}", job_id);
981 Ok(())
982 } else {
983 Err(anyhow!("Job not found: {}", job_id))
984 }
985 }
986
987 pub async fn list_jobs(&self) -> Vec<BatchJob> {
989 let jobs = self.active_jobs.read().await;
990 jobs.values().cloned().collect()
991 }
992}
993
994impl Clone for BatchProcessingManager {
995 fn clone(&self) -> Self {
996 Self {
997 active_jobs: Arc::clone(&self.active_jobs),
998 config: self.config.clone(),
999 cache_manager: Arc::clone(&self.cache_manager),
1000 semaphore: Arc::clone(&self.semaphore),
1001 persistence_dir: self.persistence_dir.clone(),
1002 }
1003 }
1004}
1005
1006#[derive(Debug)]
1008#[allow(dead_code)]
1009struct ChunkResult {
1010 chunk_idx: usize,
1011 successful: usize,
1012 failed: usize,
1013 cache_hits: usize,
1014 cache_misses: usize,
1015 failures: HashMap<String, String>,
1016}
1017
1018#[cfg(test)]
1019mod tests {
1020 use super::*;
1021 use tempfile::tempdir;
1022
1023 #[test]
1024 fn test_batch_job_creation() {
1025 let job = BatchJob {
1026 job_id: Uuid::new_v4(),
1027 name: "test_job".to_string(),
1028 status: JobStatus::Pending,
1029 input: BatchInput {
1030 input_type: InputType::EntityList,
1031 source: r#"["entity1", "entity2", "entity3"]"#.to_string(),
1032 filters: None,
1033 incremental: None,
1034 },
1035 output: BatchOutput {
1036 path: "/tmp/output".to_string(),
1037 format: "parquet".to_string(),
1038 compression: Some("gzip".to_string()),
1039 partitioning: Some(PartitioningStrategy::None),
1040 },
1041 config: BatchJobConfig {
1042 chunk_size: 100,
1043 num_workers: 4,
1044 max_retries: 3,
1045 use_cache: true,
1046 custom_params: HashMap::new(),
1047 },
1048 model_id: Uuid::new_v4(),
1049 created_at: Utc::now(),
1050 started_at: None,
1051 completed_at: None,
1052 progress: JobProgress::default(),
1053 error: None,
1054 checkpoint: None,
1055 };
1056
1057 assert_eq!(job.status, JobStatus::Pending);
1058 assert_eq!(job.name, "test_job");
1059 }
1060
1061 #[tokio::test]
1062 async fn test_batch_processing_manager_creation() {
1063 let config = BatchProcessingConfig::default();
1064 let cache_config = crate::CacheConfig::default();
1065 let cache_manager = Arc::new(CacheManager::new(cache_config));
1066 let temp_dir = tempdir().unwrap();
1067
1068 let manager =
1069 BatchProcessingManager::new(config, cache_manager, temp_dir.path().to_path_buf());
1070
1071 assert_eq!(manager.config.max_workers, num_cpus::get());
1072 assert_eq!(manager.config.chunk_size, 1000);
1073 }
1074
1075 #[test]
1076 fn test_incremental_config() {
1077 let incremental = IncrementalConfig {
1078 enabled: true,
1079 last_processed: Some(Utc::now()),
1080 timestamp_field: "updated_at".to_string(),
1081 check_deletions: true,
1082 existing_embeddings_path: Some("/path/to/existing".to_string()),
1083 };
1084
1085 assert!(incremental.enabled);
1086 assert!(incremental.last_processed.is_some());
1087 assert_eq!(incremental.timestamp_field, "updated_at");
1088 }
1089
1090 #[test]
1091 fn test_memory_optimized_batch_iterator() {
1092 let data = vec![1, 2, 3, 4, 5, 6, 7, 8, 9, 10];
1093 let mut iterator = MemoryOptimizedBatchIterator::new(data.clone(), 3, 1); let batch1 = iterator.next_batch().unwrap();
1097 assert_eq!(batch1.len(), 3);
1098 assert_eq!(batch1, vec![1, 2, 3]);
1099 assert_eq!(iterator.get_progress(), 0.3);
1100 assert!(!iterator.is_finished());
1101
1102 let batch2 = iterator.next_batch().unwrap();
1104 assert_eq!(batch2.len(), 3);
1105 assert_eq!(batch2, vec![4, 5, 6]);
1106 assert_eq!(iterator.get_progress(), 0.6);
1107
1108 let batch3 = iterator.next_batch().unwrap();
1110 assert_eq!(batch3.len(), 3);
1111 assert_eq!(batch3, vec![7, 8, 9]);
1112 assert_eq!(iterator.get_progress(), 0.9);
1113
1114 let batch4 = iterator.next_batch().unwrap();
1116 assert_eq!(batch4.len(), 1);
1117 assert_eq!(batch4, vec![10]);
1118 assert_eq!(iterator.get_progress(), 1.0);
1119 assert!(iterator.is_finished());
1120
1121 let batch5 = iterator.next_batch();
1123 assert!(batch5.is_none());
1124 }
1125
1126 #[test]
1127 fn test_memory_optimized_batch_iterator_empty() {
1128 let data: Vec<i32> = vec![];
1129 let mut iterator = MemoryOptimizedBatchIterator::new(data, 3, 1);
1130
1131 assert_eq!(iterator.get_progress(), 1.0);
1132 assert!(iterator.is_finished());
1133 assert!(iterator.next_batch().is_none());
1134 }
1135
1136 #[test]
1137 fn test_memory_optimized_batch_iterator_single_item() {
1138 let data = vec![42];
1139 let mut iterator = MemoryOptimizedBatchIterator::new(data, 5, 1);
1140
1141 let batch = iterator.next_batch().unwrap();
1142 assert_eq!(batch.len(), 1);
1143 assert_eq!(batch[0], 42);
1144 assert_eq!(iterator.get_progress(), 1.0);
1145 assert!(iterator.is_finished());
1146 }
1147
1148 #[test]
1149 fn test_memory_optimized_batch_iterator_memory_tracking() {
1150 let data = vec![1, 2, 3, 4, 5];
1151 let mut iterator = MemoryOptimizedBatchIterator::new(data, 3, 1);
1152
1153 let _batch = iterator.next_batch().unwrap();
1155 let memory_usage = iterator.get_memory_usage();
1156 assert!(memory_usage > 0);
1157
1158 let expected_memory = 3 * std::mem::size_of::<i32>();
1160 assert_eq!(memory_usage, expected_memory);
1161 }
1162
1163 #[test]
1164 fn test_parallel_batch_processor() {
1165 let processor = ParallelBatchProcessor::new(ParallelBatchConfig::default()).unwrap();
1167 assert!(processor.num_workers() > 0);
1169 assert!(processor.num_workers() <= num_cpus::get());
1170 }
1171}
1172
1173pub struct ParallelBatchProcessor {
1181 config: ParallelBatchConfig,
1182}
1183
1184#[derive(Debug, Clone, Serialize, Deserialize)]
1186pub struct ParallelBatchConfig {
1187 pub num_workers: usize,
1189 pub chunk_size: usize,
1191 pub adaptive_balancing: bool,
1193 pub memory_threshold_mb: usize,
1195 pub numa_aware: bool,
1197 pub work_stealing: bool,
1199}
1200
1201impl Default for ParallelBatchConfig {
1202 fn default() -> Self {
1203 Self {
1204 num_workers: num_cpus::get(),
1205 chunk_size: 1000,
1206 adaptive_balancing: true,
1207 memory_threshold_mb: 512,
1208 numa_aware: true,
1209 work_stealing: true,
1210 }
1211 }
1212}
1213
1214impl ParallelBatchProcessor {
1215 pub fn new(config: ParallelBatchConfig) -> Result<Self> {
1217 rayon::ThreadPoolBuilder::new()
1219 .num_threads(config.num_workers)
1220 .build_global()
1221 .ok(); Ok(Self { config })
1224 }
1225
1226 pub fn num_workers(&self) -> usize {
1228 self.config.num_workers
1229 }
1230
1231 pub fn process_parallel<T, F, R>(&self, items: Vec<T>, process_fn: F) -> Result<Vec<R>>
1233 where
1234 T: Send + Sync,
1235 F: Fn(&T) -> R + Send + Sync,
1236 R: Send,
1237 {
1238 let results: Vec<R> = items.par_iter().map(process_fn).collect();
1240
1241 Ok(results)
1242 }
1243
1244 pub fn process_with_load_balancing<T, F, R>(
1251 &self,
1252 items: Vec<T>,
1253 process_fn: F,
1254 ) -> Result<Vec<R>>
1255 where
1256 T: Send + Sync,
1257 F: Fn(&T) -> R + Send + Sync,
1258 R: Send,
1259 {
1260 let results: Vec<R> = items.par_iter().map(process_fn).collect();
1262
1263 Ok(results)
1264 }
1265
1266 pub fn process_memory_efficient<T, F, R>(&self, items: Vec<T>, process_fn: F) -> Result<Vec<R>>
1273 where
1274 T: Send + Sync,
1275 F: Fn(&T) -> R + Send + Sync,
1276 R: Send,
1277 {
1278 let chunk_size =
1280 (self.config.memory_threshold_mb * 1024 * 1024) / (std::mem::size_of::<T>().max(1));
1281
1282 let chunk_size = chunk_size.min(self.config.chunk_size).max(100);
1283
1284 let results: Vec<R> = items
1285 .par_chunks(chunk_size)
1286 .flat_map(|chunk| chunk.iter().map(&process_fn).collect::<Vec<_>>())
1287 .collect();
1288
1289 Ok(results)
1290 }
1291
1292 pub fn process_nested_parallel<T, F, R>(
1299 &self,
1300 items: Vec<Vec<T>>,
1301 process_fn: F,
1302 ) -> Result<Vec<Vec<R>>>
1303 where
1304 T: Send + Sync,
1305 F: Fn(&T) -> R + Send + Sync,
1306 R: Send,
1307 {
1308 let results: Vec<Vec<R>> = items
1309 .par_iter()
1310 .map(|batch| batch.iter().map(&process_fn).collect())
1311 .collect();
1312
1313 Ok(results)
1314 }
1315
1316 pub fn get_stats(&self) -> ParallelProcessingStats {
1318 ParallelProcessingStats {
1319 num_workers: self.config.num_workers,
1320 profiler_report: "Stats available".to_string(),
1321 memory_usage: 0,
1322 }
1323 }
1324}
1325
1326#[derive(Debug, Clone)]
1328pub struct ParallelProcessingStats {
1329 pub num_workers: usize,
1330 pub profiler_report: String,
1331 pub memory_usage: usize,
1332}