1use crate::{CacheManager, EmbeddingModel};
8use anyhow::{anyhow, Result};
9use chrono::{DateTime, Utc};
10use serde::{Deserialize, Serialize};
11use std::collections::{HashMap, HashSet};
12use std::path::{Path, PathBuf};
13use std::sync::Arc;
14use std::time::Instant;
15use tokio::fs;
16use tokio::sync::{RwLock, Semaphore};
17use tokio::task::JoinHandle;
18use tracing::{debug, info, warn};
19use uuid::Uuid;
20
21pub struct MemoryOptimizedBatchIterator<T> {
23 data: Vec<T>,
25 position: usize,
27 batch_size: usize,
29 memory_usage: usize,
31 max_memory_bytes: usize,
33}
34
35impl<T> MemoryOptimizedBatchIterator<T> {
36 pub fn new(data: Vec<T>, batch_size: usize, max_memory_mb: usize) -> Self {
38 Self {
39 data,
40 position: 0,
41 batch_size,
42 memory_usage: 0,
43 max_memory_bytes: max_memory_mb * 1024 * 1024,
44 }
45 }
46
47 pub fn next_batch(&mut self) -> Option<Vec<T>>
49 where
50 T: Clone,
51 {
52 if self.position >= self.data.len() {
53 return None;
54 }
55
56 let mut batch = Vec::new();
57 let mut current_memory = 0;
58 let item_size = std::mem::size_of::<T>();
59
60 while self.position < self.data.len()
62 && batch.len() < self.batch_size
63 && current_memory + item_size <= self.max_memory_bytes
64 {
65 batch.push(self.data[self.position].clone());
66 self.position += 1;
67 current_memory += item_size;
68 }
69
70 self.memory_usage = current_memory;
71
72 if batch.is_empty() {
73 None
74 } else {
75 Some(batch)
76 }
77 }
78
79 pub fn get_memory_usage(&self) -> usize {
81 self.memory_usage
82 }
83
84 pub fn get_progress(&self) -> f64 {
86 if self.data.is_empty() {
87 1.0
88 } else {
89 self.position as f64 / self.data.len() as f64
90 }
91 }
92
93 pub fn is_finished(&self) -> bool {
95 self.position >= self.data.len()
96 }
97}
98
99pub struct BatchProcessingManager {
101 active_jobs: Arc<RwLock<HashMap<Uuid, BatchJob>>>,
103 config: BatchProcessingConfig,
105 cache_manager: Arc<CacheManager>,
107 semaphore: Arc<Semaphore>,
109 persistence_dir: PathBuf,
111}
112
113#[derive(Debug, Clone)]
115pub struct BatchProcessingConfig {
116 pub max_workers: usize,
118 pub chunk_size: usize,
120 pub enable_incremental: bool,
122 pub checkpoint_frequency: usize,
124 pub enable_resume: bool,
126 pub max_memory_per_worker_mb: usize,
128 pub enable_notifications: bool,
130 pub retry_config: RetryConfig,
132 pub output_config: OutputConfig,
134}
135
136impl Default for BatchProcessingConfig {
137 fn default() -> Self {
138 Self {
139 max_workers: num_cpus::get(),
140 chunk_size: 1000,
141 enable_incremental: true,
142 checkpoint_frequency: 10,
143 enable_resume: true,
144 max_memory_per_worker_mb: 512,
145 enable_notifications: true,
146 retry_config: RetryConfig::default(),
147 output_config: OutputConfig::default(),
148 }
149 }
150}
151
152#[derive(Debug, Clone)]
154pub struct RetryConfig {
155 pub max_retries: usize,
157 pub initial_backoff_ms: u64,
159 pub max_backoff_ms: u64,
161 pub backoff_multiplier: f64,
163}
164
165impl Default for RetryConfig {
166 fn default() -> Self {
167 Self {
168 max_retries: 3,
169 initial_backoff_ms: 1000,
170 max_backoff_ms: 30000,
171 backoff_multiplier: 2.0,
172 }
173 }
174}
175
176#[derive(Debug, Clone)]
178pub struct OutputConfig {
179 pub format: OutputFormat,
181 pub compression_level: u32,
183 pub include_metadata: bool,
185 pub batch_output: bool,
187 pub max_entities_per_file: usize,
189}
190
191impl Default for OutputConfig {
192 fn default() -> Self {
193 Self {
194 format: OutputFormat::Parquet,
195 compression_level: 6,
196 include_metadata: true,
197 batch_output: true,
198 max_entities_per_file: 100_000,
199 }
200 }
201}
202
203#[derive(Debug, Clone)]
205pub enum OutputFormat {
206 Parquet,
208 JsonLines,
210 Binary,
212 HDF5,
214}
215
216#[derive(Debug, Clone, Serialize, Deserialize)]
218pub struct BatchJob {
219 pub job_id: Uuid,
221 pub name: String,
223 pub status: JobStatus,
225 pub input: BatchInput,
227 pub output: BatchOutput,
229 pub config: BatchJobConfig,
231 pub model_id: Uuid,
233 pub created_at: DateTime<Utc>,
235 pub started_at: Option<DateTime<Utc>>,
237 pub completed_at: Option<DateTime<Utc>>,
239 pub progress: JobProgress,
241 pub error: Option<String>,
243 pub checkpoint: Option<JobCheckpoint>,
245}
246
247#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
249pub enum JobStatus {
250 Pending,
251 Running,
252 Completed,
253 Failed,
254 Cancelled,
255 Paused,
256}
257
258#[derive(Debug, Clone, Serialize, Deserialize)]
260pub struct BatchInput {
261 pub input_type: InputType,
263 pub source: String,
265 pub filters: Option<HashMap<String, String>>,
267 pub incremental: Option<IncrementalConfig>,
269}
270
271#[derive(Debug, Clone, Serialize, Deserialize)]
273pub enum InputType {
274 EntityList,
276 EntityFile,
278 SparqlQuery,
280 DatabaseQuery,
282 StreamSource,
284}
285
286#[derive(Debug, Clone, Serialize, Deserialize)]
288pub struct IncrementalConfig {
289 pub enabled: bool,
291 pub last_processed: Option<DateTime<Utc>>,
293 pub timestamp_field: String,
295 pub check_deletions: bool,
297 pub existing_embeddings_path: Option<String>,
299}
300
301#[derive(Debug, Clone, Serialize, Deserialize)]
303pub struct BatchOutput {
304 pub path: String,
306 pub format: String,
308 pub compression: Option<String>,
310 pub partitioning: Option<PartitioningStrategy>,
312}
313
314#[derive(Debug, Clone, Serialize, Deserialize)]
316pub enum PartitioningStrategy {
317 None,
319 ByEntityType,
321 ByDate,
323 ByHash { num_partitions: usize },
325 Custom { field: String },
327}
328
329#[derive(Debug, Clone, Serialize, Deserialize)]
331pub struct BatchJobConfig {
332 pub chunk_size: usize,
334 pub num_workers: usize,
336 pub max_retries: usize,
338 pub use_cache: bool,
340 pub custom_params: HashMap<String, String>,
342}
343
344#[derive(Debug, Clone, Serialize, Deserialize)]
346pub struct JobProgress {
347 pub total_entities: usize,
349 pub processed_entities: usize,
351 pub failed_entities: usize,
353 pub current_chunk: usize,
355 pub total_chunks: usize,
357 pub processing_rate: f64,
359 pub eta_seconds: Option<u64>,
361 pub memory_usage_mb: f64,
363}
364
365impl Default for JobProgress {
366 fn default() -> Self {
367 Self {
368 total_entities: 0,
369 processed_entities: 0,
370 failed_entities: 0,
371 current_chunk: 0,
372 total_chunks: 0,
373 processing_rate: 0.0,
374 eta_seconds: None,
375 memory_usage_mb: 0.0,
376 }
377 }
378}
379
380#[derive(Debug, Clone, Serialize, Deserialize)]
382pub struct JobCheckpoint {
383 pub timestamp: DateTime<Utc>,
385 pub last_processed_index: usize,
387 pub processed_entities: HashSet<String>,
389 pub failed_entities: HashMap<String, String>,
391 pub intermediate_results_path: String,
393 pub model_state_hash: String,
395}
396
397#[derive(Debug, Clone, Serialize, Deserialize)]
399pub struct BatchProcessingResult {
400 pub job_id: Uuid,
402 pub stats: BatchProcessingStats,
404 pub output_info: OutputInfo,
406 pub quality_metrics: Option<QualityMetrics>,
408}
409
410#[derive(Debug, Clone, Serialize, Deserialize)]
412pub struct BatchProcessingStats {
413 pub total_time_seconds: f64,
415 pub total_entities: usize,
417 pub successful_embeddings: usize,
419 pub failed_embeddings: usize,
421 pub cache_hits: usize,
423 pub cache_misses: usize,
425 pub avg_time_per_entity_ms: f64,
427 pub peak_memory_mb: f64,
429 pub cpu_utilization: f64,
431}
432
433#[derive(Debug, Clone, Serialize, Deserialize)]
435pub struct OutputInfo {
436 pub output_files: Vec<String>,
438 pub total_size_bytes: u64,
440 pub compression_ratio: f64,
442 pub num_partitions: usize,
444}
445
446#[derive(Debug, Clone, Serialize, Deserialize)]
448pub struct QualityMetrics {
449 pub avg_embedding_norm: f64,
451 pub embedding_norm_std: f64,
453 pub avg_cosine_similarity: f64,
455 pub embedding_dimension: usize,
457 pub zero_embeddings: usize,
459 pub nan_embeddings: usize,
461}
462
463impl BatchProcessingManager {
464 pub fn new(
466 config: BatchProcessingConfig,
467 cache_manager: Arc<CacheManager>,
468 persistence_dir: PathBuf,
469 ) -> Self {
470 Self {
471 active_jobs: Arc::new(RwLock::new(HashMap::new())),
472 semaphore: Arc::new(Semaphore::new(config.max_workers)),
473 config,
474 cache_manager,
475 persistence_dir,
476 }
477 }
478
479 pub async fn submit_job(&self, job: BatchJob) -> Result<Uuid> {
481 let job_id = job.job_id;
482
483 self.validate_job(&job).await?;
485
486 {
488 let mut jobs = self.active_jobs.write().await;
489 jobs.insert(job_id, job.clone());
490 }
491
492 self.persist_job(&job).await?;
494
495 info!("Submitted batch job: {} ({})", job.name, job_id);
496 Ok(job_id)
497 }
498
499 pub async fn start_job(
501 &self,
502 job_id: Uuid,
503 model: Arc<dyn EmbeddingModel + Send + Sync>,
504 ) -> Result<JoinHandle<Result<BatchProcessingResult>>> {
505 let job = {
506 let mut jobs = self.active_jobs.write().await;
507 let job = jobs
508 .get_mut(&job_id)
509 .ok_or_else(|| anyhow!("Job not found: {}", job_id))?;
510
511 if !matches!(job.status, JobStatus::Pending | JobStatus::Paused) {
512 return Err(anyhow!("Job {} is not in a startable state", job_id));
513 }
514
515 job.status = JobStatus::Running;
516 job.started_at = Some(Utc::now());
517 job.clone()
518 };
519
520 let manager = self.clone();
521 let handle = tokio::spawn(async move { manager.process_job(job, model).await });
522
523 Ok(handle)
524 }
525
526 async fn process_job(
528 &self,
529 job: BatchJob,
530 model: Arc<dyn EmbeddingModel + Send + Sync>,
531 ) -> Result<BatchProcessingResult> {
532 let start_time = Instant::now();
533 info!(
534 "Starting batch job processing: {} ({})",
535 job.name, job.job_id
536 );
537
538 let entities = self.load_entities(&job).await?;
540
541 let entities_to_process = if job
543 .input
544 .incremental
545 .as_ref()
546 .map(|inc| inc.enabled)
547 .unwrap_or(false)
548 {
549 self.filter_incremental_entities(&job, entities).await?
550 } else {
551 entities
552 };
553
554 {
556 let mut jobs = self.active_jobs.write().await;
557 if let Some(active_job) = jobs.get_mut(&job.job_id) {
558 active_job.progress.total_entities = entities_to_process.len();
559 active_job.progress.total_chunks =
560 (entities_to_process.len() + job.config.chunk_size - 1) / job.config.chunk_size;
561 }
562 }
563
564 let chunks: Vec<_> = entities_to_process
566 .chunks(job.config.chunk_size)
567 .map(|chunk| chunk.to_vec())
568 .collect();
569
570 let mut successful_embeddings = 0;
571 let mut failed_embeddings = 0;
572 let mut cache_hits = 0;
573 let mut cache_misses = 0;
574 let mut processed_entities = HashSet::new();
575 let mut failed_entities = HashMap::new();
576
577 for (chunk_idx, chunk) in chunks.iter().enumerate() {
578 {
580 let jobs = self.active_jobs.read().await;
581 if let Some(active_job) = jobs.get(&job.job_id) {
582 if matches!(active_job.status, JobStatus::Cancelled) {
583 info!("Job {} was cancelled", job.job_id);
584 return Err(anyhow!("Job was cancelled"));
585 }
586 }
587 }
588
589 let chunk_result = self
591 .process_chunk(&job, chunk, chunk_idx, model.clone())
592 .await?;
593
594 successful_embeddings += chunk_result.successful;
596 failed_embeddings += chunk_result.failed;
597 cache_hits += chunk_result.cache_hits;
598 cache_misses += chunk_result.cache_misses;
599
600 for entity in chunk {
602 processed_entities.insert(entity.clone());
603 }
604 for (entity, error) in chunk_result.failures {
605 failed_entities.insert(entity, error);
606 }
607
608 self.update_job_progress(
610 &job.job_id,
611 chunk_idx + 1,
612 successful_embeddings + failed_embeddings,
613 )
614 .await?;
615
616 if chunk_idx % self.config.checkpoint_frequency == 0 {
618 self.create_checkpoint(&job.job_id, &processed_entities, &failed_entities)
619 .await?;
620 }
621
622 info!(
623 "Processed chunk {}/{} for job {}",
624 chunk_idx + 1,
625 chunks.len(),
626 job.job_id
627 );
628 }
629
630 let processing_time = start_time.elapsed().as_secs_f64();
632 let result = self
633 .finalize_job_processing(
634 &job,
635 processing_time,
636 successful_embeddings,
637 failed_embeddings,
638 cache_hits,
639 cache_misses,
640 )
641 .await?;
642
643 {
645 let mut jobs = self.active_jobs.write().await;
646 if let Some(active_job) = jobs.get_mut(&job.job_id) {
647 active_job.status = JobStatus::Completed;
648 active_job.completed_at = Some(Utc::now());
649 }
650 }
651
652 info!(
653 "Completed batch job: {} in {:.2}s",
654 job.job_id, processing_time
655 );
656 Ok(result)
657 }
658
659 async fn process_chunk(
661 &self,
662 job: &BatchJob,
663 entities: &[String],
664 chunk_idx: usize,
665 model: Arc<dyn EmbeddingModel + Send + Sync>,
666 ) -> Result<ChunkResult> {
667 let _permit = self.semaphore.acquire().await?;
668
669 let mut successful = 0;
670 let mut failed = 0;
671 let mut cache_hits = 0;
672 let mut cache_misses = 0;
673 let mut failures = HashMap::new();
674
675 for entity in entities {
676 match self
677 .process_single_entity(entity, model.clone(), job.config.use_cache)
678 .await
679 {
680 Ok(from_cache) => {
681 successful += 1;
682 if from_cache {
683 cache_hits += 1;
684 } else {
685 cache_misses += 1;
686 }
687 }
688 Err(e) => {
689 failed += 1;
690 failures.insert(entity.clone(), e.to_string());
691 warn!("Failed to process entity {}: {}", entity, e);
692 }
693 }
694 }
695
696 Ok(ChunkResult {
697 chunk_idx,
698 successful,
699 failed,
700 cache_hits,
701 cache_misses,
702 failures,
703 })
704 }
705
706 async fn process_single_entity(
708 &self,
709 entity: &str,
710 model: Arc<dyn EmbeddingModel + Send + Sync>,
711 use_cache: bool,
712 ) -> Result<bool> {
713 if use_cache {
714 if let Some(_embedding) = self.cache_manager.get_embedding(entity) {
716 return Ok(true);
717 }
718 }
719
720 let embedding = model.get_entity_embedding(entity)?;
722
723 if use_cache {
725 self.cache_manager
726 .put_embedding(entity.to_string(), embedding);
727 }
728
729 Ok(false)
730 }
731
732 async fn load_entities(&self, job: &BatchJob) -> Result<Vec<String>> {
734 match &job.input.input_type {
735 InputType::EntityList => {
736 let entities: Vec<String> = serde_json::from_str(&job.input.source)?;
738 Ok(entities)
739 }
740 InputType::EntityFile => {
741 let content = fs::read_to_string(&job.input.source).await?;
743 let entities: Vec<String> = content
744 .lines()
745 .map(|line| line.trim().to_string())
746 .filter(|line| !line.is_empty())
747 .collect();
748 Ok(entities)
749 }
750 InputType::SparqlQuery => {
751 warn!("SPARQL query input type not yet implemented");
754 Ok(Vec::new())
755 }
756 InputType::DatabaseQuery => {
757 warn!("Database query input type not yet implemented");
759 Ok(Vec::new())
760 }
761 InputType::StreamSource => {
762 warn!("Stream source input type not yet implemented");
764 Ok(Vec::new())
765 }
766 }
767 }
768
769 async fn filter_incremental_entities(
771 &self,
772 job: &BatchJob,
773 entities: Vec<String>,
774 ) -> Result<Vec<String>> {
775 if let Some(incremental) = &job.input.incremental {
776 if !incremental.enabled {
777 return Ok(entities);
778 }
779
780 let existing_entities =
782 if let Some(existing_path) = &incremental.existing_embeddings_path {
783 self.load_existing_entities(existing_path).await?
784 } else {
785 HashSet::new()
786 };
787
788 let filtered: Vec<String> = entities
790 .into_iter()
791 .filter(|entity| !existing_entities.contains(entity))
792 .collect();
793
794 info!(
795 "Incremental filtering: {} entities remaining after filtering",
796 filtered.len()
797 );
798 Ok(filtered)
799 } else {
800 Ok(entities)
801 }
802 }
803
804 async fn load_existing_entities(&self, path: &str) -> Result<HashSet<String>> {
806 if Path::new(path).exists() {
809 let content = fs::read_to_string(path).await?;
810 let entities: HashSet<String> = content
811 .lines()
812 .map(|line| line.trim().to_string())
813 .filter(|line| !line.is_empty())
814 .collect();
815 Ok(entities)
816 } else {
817 Ok(HashSet::new())
818 }
819 }
820
821 async fn update_job_progress(
823 &self,
824 job_id: &Uuid,
825 current_chunk: usize,
826 processed_entities: usize,
827 ) -> Result<()> {
828 let mut jobs = self.active_jobs.write().await;
829 if let Some(job) = jobs.get_mut(job_id) {
830 job.progress.current_chunk = current_chunk;
831 job.progress.processed_entities = processed_entities;
832
833 if let Some(started_at) = job.started_at {
835 let elapsed = Utc::now().signed_duration_since(started_at);
836 let elapsed_seconds = elapsed.num_seconds() as f64;
837 if elapsed_seconds > 0.0 {
838 job.progress.processing_rate = processed_entities as f64 / elapsed_seconds;
839
840 let remaining_entities = job.progress.total_entities - processed_entities;
842 if job.progress.processing_rate > 0.0 {
843 let eta = remaining_entities as f64 / job.progress.processing_rate;
844 job.progress.eta_seconds = Some(eta as u64);
845 }
846 }
847 }
848 }
849 Ok(())
850 }
851
852 async fn create_checkpoint(
854 &self,
855 job_id: &Uuid,
856 processed_entities: &HashSet<String>,
857 failed_entities: &HashMap<String, String>,
858 ) -> Result<()> {
859 let checkpoint = JobCheckpoint {
860 timestamp: Utc::now(),
861 last_processed_index: processed_entities.len(),
862 processed_entities: processed_entities.clone(),
863 failed_entities: failed_entities.clone(),
864 intermediate_results_path: format!(
865 "{}/checkpoint_{}.json",
866 self.persistence_dir.display(),
867 job_id
868 ),
869 model_state_hash: "placeholder".to_string(), };
871
872 let checkpoint_path = self
874 .persistence_dir
875 .join(format!("checkpoint_{job_id}.json"));
876 let checkpoint_json = serde_json::to_string_pretty(&checkpoint)?;
877 fs::write(checkpoint_path, checkpoint_json).await?;
878
879 let mut jobs = self.active_jobs.write().await;
881 if let Some(job) = jobs.get_mut(job_id) {
882 job.checkpoint = Some(checkpoint);
883 }
884
885 debug!("Created checkpoint for job {}", job_id);
886 Ok(())
887 }
888
889 async fn finalize_job_processing(
891 &self,
892 job: &BatchJob,
893 processing_time: f64,
894 successful_embeddings: usize,
895 failed_embeddings: usize,
896 cache_hits: usize,
897 cache_misses: usize,
898 ) -> Result<BatchProcessingResult> {
899 let total_entities = successful_embeddings + failed_embeddings;
900 let avg_time_per_entity_ms = if total_entities > 0 {
901 (processing_time * 1000.0) / total_entities as f64
902 } else {
903 0.0
904 };
905
906 let stats = BatchProcessingStats {
907 total_time_seconds: processing_time,
908 total_entities,
909 successful_embeddings,
910 failed_embeddings,
911 cache_hits,
912 cache_misses,
913 avg_time_per_entity_ms,
914 peak_memory_mb: 0.0, cpu_utilization: 0.0, };
917
918 let output_info = OutputInfo {
919 output_files: vec![job.output.path.clone()],
920 total_size_bytes: 0, compression_ratio: 1.0,
922 num_partitions: 1,
923 };
924
925 Ok(BatchProcessingResult {
926 job_id: job.job_id,
927 stats,
928 output_info,
929 quality_metrics: None, })
931 }
932
933 async fn validate_job(&self, job: &BatchJob) -> Result<()> {
935 if let InputType::EntityFile = &job.input.input_type {
937 if !Path::new(&job.input.source).exists() {
938 return Err(anyhow!("Input file does not exist: {}", job.input.source));
939 }
940 } if let Some(parent) = Path::new(&job.output.path).parent() {
944 if !parent.exists() {
945 fs::create_dir_all(parent).await?;
946 }
947 }
948
949 Ok(())
950 }
951
952 async fn persist_job(&self, job: &BatchJob) -> Result<()> {
954 let job_path = self
955 .persistence_dir
956 .join(format!("job_{}.json", job.job_id));
957 let job_json = serde_json::to_string_pretty(job)?;
958 fs::write(job_path, job_json).await?;
959 Ok(())
960 }
961
962 pub async fn get_job_status(&self, job_id: &Uuid) -> Option<JobStatus> {
964 let jobs = self.active_jobs.read().await;
965 jobs.get(job_id).map(|job| job.status.clone())
966 }
967
968 pub async fn get_job_progress(&self, job_id: &Uuid) -> Option<JobProgress> {
970 let jobs = self.active_jobs.read().await;
971 jobs.get(job_id).map(|job| job.progress.clone())
972 }
973
974 pub async fn cancel_job(&self, job_id: &Uuid) -> Result<()> {
976 let mut jobs = self.active_jobs.write().await;
977 if let Some(job) = jobs.get_mut(job_id) {
978 job.status = JobStatus::Cancelled;
979 info!("Cancelled job: {}", job_id);
980 Ok(())
981 } else {
982 Err(anyhow!("Job not found: {}", job_id))
983 }
984 }
985
986 pub async fn list_jobs(&self) -> Vec<BatchJob> {
988 let jobs = self.active_jobs.read().await;
989 jobs.values().cloned().collect()
990 }
991}
992
993impl Clone for BatchProcessingManager {
994 fn clone(&self) -> Self {
995 Self {
996 active_jobs: Arc::clone(&self.active_jobs),
997 config: self.config.clone(),
998 cache_manager: Arc::clone(&self.cache_manager),
999 semaphore: Arc::clone(&self.semaphore),
1000 persistence_dir: self.persistence_dir.clone(),
1001 }
1002 }
1003}
1004
1005#[derive(Debug)]
1007#[allow(dead_code)]
1008struct ChunkResult {
1009 chunk_idx: usize,
1010 successful: usize,
1011 failed: usize,
1012 cache_hits: usize,
1013 cache_misses: usize,
1014 failures: HashMap<String, String>,
1015}
1016
1017#[cfg(test)]
1018mod tests {
1019 use super::*;
1020 use tempfile::tempdir;
1021
1022 #[test]
1023 fn test_batch_job_creation() {
1024 let job = BatchJob {
1025 job_id: Uuid::new_v4(),
1026 name: "test_job".to_string(),
1027 status: JobStatus::Pending,
1028 input: BatchInput {
1029 input_type: InputType::EntityList,
1030 source: r#"["entity1", "entity2", "entity3"]"#.to_string(),
1031 filters: None,
1032 incremental: None,
1033 },
1034 output: BatchOutput {
1035 path: "/tmp/output".to_string(),
1036 format: "parquet".to_string(),
1037 compression: Some("gzip".to_string()),
1038 partitioning: Some(PartitioningStrategy::None),
1039 },
1040 config: BatchJobConfig {
1041 chunk_size: 100,
1042 num_workers: 4,
1043 max_retries: 3,
1044 use_cache: true,
1045 custom_params: HashMap::new(),
1046 },
1047 model_id: Uuid::new_v4(),
1048 created_at: Utc::now(),
1049 started_at: None,
1050 completed_at: None,
1051 progress: JobProgress::default(),
1052 error: None,
1053 checkpoint: None,
1054 };
1055
1056 assert_eq!(job.status, JobStatus::Pending);
1057 assert_eq!(job.name, "test_job");
1058 }
1059
1060 #[tokio::test]
1061 async fn test_batch_processing_manager_creation() {
1062 let config = BatchProcessingConfig::default();
1063 let cache_config = crate::CacheConfig::default();
1064 let cache_manager = Arc::new(CacheManager::new(cache_config));
1065 let temp_dir = tempdir().unwrap();
1066
1067 let manager =
1068 BatchProcessingManager::new(config, cache_manager, temp_dir.path().to_path_buf());
1069
1070 assert_eq!(manager.config.max_workers, num_cpus::get());
1071 assert_eq!(manager.config.chunk_size, 1000);
1072 }
1073
1074 #[test]
1075 fn test_incremental_config() {
1076 let incremental = IncrementalConfig {
1077 enabled: true,
1078 last_processed: Some(Utc::now()),
1079 timestamp_field: "updated_at".to_string(),
1080 check_deletions: true,
1081 existing_embeddings_path: Some("/path/to/existing".to_string()),
1082 };
1083
1084 assert!(incremental.enabled);
1085 assert!(incremental.last_processed.is_some());
1086 assert_eq!(incremental.timestamp_field, "updated_at");
1087 }
1088
1089 #[test]
1090 fn test_memory_optimized_batch_iterator() {
1091 let data = vec![1, 2, 3, 4, 5, 6, 7, 8, 9, 10];
1092 let mut iterator = MemoryOptimizedBatchIterator::new(data.clone(), 3, 1); let batch1 = iterator.next_batch().unwrap();
1096 assert_eq!(batch1.len(), 3);
1097 assert_eq!(batch1, vec![1, 2, 3]);
1098 assert_eq!(iterator.get_progress(), 0.3);
1099 assert!(!iterator.is_finished());
1100
1101 let batch2 = iterator.next_batch().unwrap();
1103 assert_eq!(batch2.len(), 3);
1104 assert_eq!(batch2, vec![4, 5, 6]);
1105 assert_eq!(iterator.get_progress(), 0.6);
1106
1107 let batch3 = iterator.next_batch().unwrap();
1109 assert_eq!(batch3.len(), 3);
1110 assert_eq!(batch3, vec![7, 8, 9]);
1111 assert_eq!(iterator.get_progress(), 0.9);
1112
1113 let batch4 = iterator.next_batch().unwrap();
1115 assert_eq!(batch4.len(), 1);
1116 assert_eq!(batch4, vec![10]);
1117 assert_eq!(iterator.get_progress(), 1.0);
1118 assert!(iterator.is_finished());
1119
1120 let batch5 = iterator.next_batch();
1122 assert!(batch5.is_none());
1123 }
1124
1125 #[test]
1126 fn test_memory_optimized_batch_iterator_empty() {
1127 let data: Vec<i32> = vec![];
1128 let mut iterator = MemoryOptimizedBatchIterator::new(data, 3, 1);
1129
1130 assert_eq!(iterator.get_progress(), 1.0);
1131 assert!(iterator.is_finished());
1132 assert!(iterator.next_batch().is_none());
1133 }
1134
1135 #[test]
1136 fn test_memory_optimized_batch_iterator_single_item() {
1137 let data = vec![42];
1138 let mut iterator = MemoryOptimizedBatchIterator::new(data, 5, 1);
1139
1140 let batch = iterator.next_batch().unwrap();
1141 assert_eq!(batch.len(), 1);
1142 assert_eq!(batch[0], 42);
1143 assert_eq!(iterator.get_progress(), 1.0);
1144 assert!(iterator.is_finished());
1145 }
1146
1147 #[test]
1148 fn test_memory_optimized_batch_iterator_memory_tracking() {
1149 let data = vec![1, 2, 3, 4, 5];
1150 let mut iterator = MemoryOptimizedBatchIterator::new(data, 3, 1);
1151
1152 let _batch = iterator.next_batch().unwrap();
1154 let memory_usage = iterator.get_memory_usage();
1155 assert!(memory_usage > 0);
1156
1157 let expected_memory = 3 * std::mem::size_of::<i32>();
1159 assert_eq!(memory_usage, expected_memory);
1160 }
1161}