infiniloom_engine/embedding/
batch.rs

1//! Batch embedding API for processing multiple repositories
2//!
3//! This module provides APIs for embedding multiple repositories in a single
4//! operation, with parallel processing and unified output.
5//!
6//! # Example
7//!
8//! ```rust,ignore
9//! use infiniloom_engine::embedding::{BatchEmbedder, BatchRepoConfig, EmbedSettings};
10//!
11//! let embedder = BatchEmbedder::new(EmbedSettings::default());
12//!
13//! let repos = vec![
14//!     BatchRepoConfig::new("/path/to/repo1")
15//!         .with_namespace("github.com/org")
16//!         .with_name("auth-service"),
17//!     BatchRepoConfig::new("/path/to/repo2")
18//!         .with_namespace("github.com/org")
19//!         .with_name("user-service"),
20//! ];
21//!
22//! let result = embedder.embed_batch(&repos)?;
23//! println!("Total chunks: {}", result.total_chunks);
24//! ```
25
26use std::path::{Path, PathBuf};
27use std::sync::atomic::{AtomicUsize, Ordering};
28use std::time::{Duration, Instant};
29
30use rayon::prelude::*;
31use serde::{Deserialize, Serialize};
32
33use super::chunker::EmbedChunker;
34use super::error::EmbedError;
35use super::progress::{ProgressReporter, QuietProgress};
36use super::types::{EmbedChunk, EmbedSettings, RepoIdentifier};
37use super::ResourceLimits;
38
39/// Configuration for a single repository in a batch operation
40#[derive(Debug, Clone, Serialize, Deserialize)]
41pub struct BatchRepoConfig {
42    /// Path to the repository (local path or URL for remote)
43    pub path: PathBuf,
44
45    /// Optional namespace override (e.g., "github.com/org")
46    #[serde(skip_serializing_if = "Option::is_none")]
47    pub namespace: Option<String>,
48
49    /// Optional repository name override (defaults to directory name)
50    #[serde(skip_serializing_if = "Option::is_none")]
51    pub name: Option<String>,
52
53    /// Optional version tag
54    #[serde(skip_serializing_if = "Option::is_none")]
55    pub version: Option<String>,
56
57    /// Optional branch name
58    #[serde(skip_serializing_if = "Option::is_none")]
59    pub branch: Option<String>,
60
61    /// Override include patterns for this repo only
62    #[serde(default, skip_serializing_if = "Vec::is_empty")]
63    pub include_patterns: Vec<String>,
64
65    /// Override exclude patterns for this repo only
66    #[serde(default, skip_serializing_if = "Vec::is_empty")]
67    pub exclude_patterns: Vec<String>,
68}
69
70impl BatchRepoConfig {
71    /// Create a new batch repo config from a path
72    pub fn new(path: impl Into<PathBuf>) -> Self {
73        Self {
74            path: path.into(),
75            namespace: None,
76            name: None,
77            version: None,
78            branch: None,
79            include_patterns: Vec::new(),
80            exclude_patterns: Vec::new(),
81        }
82    }
83
84    /// Set the namespace
85    pub fn with_namespace(mut self, namespace: impl Into<String>) -> Self {
86        self.namespace = Some(namespace.into());
87        self
88    }
89
90    /// Set the repository name
91    pub fn with_name(mut self, name: impl Into<String>) -> Self {
92        self.name = Some(name.into());
93        self
94    }
95
96    /// Set the version tag
97    pub fn with_version(mut self, version: impl Into<String>) -> Self {
98        self.version = Some(version.into());
99        self
100    }
101
102    /// Set the branch name
103    pub fn with_branch(mut self, branch: impl Into<String>) -> Self {
104        self.branch = Some(branch.into());
105        self
106    }
107
108    /// Set include patterns (overrides global settings for this repo)
109    pub fn with_include_patterns(mut self, patterns: Vec<String>) -> Self {
110        self.include_patterns = patterns;
111        self
112    }
113
114    /// Set exclude patterns (overrides global settings for this repo)
115    pub fn with_exclude_patterns(mut self, patterns: Vec<String>) -> Self {
116        self.exclude_patterns = patterns;
117        self
118    }
119
120    /// Build the RepoIdentifier from this config
121    pub fn to_repo_id(&self) -> RepoIdentifier {
122        let name = self
123            .name
124            .clone()
125            .or_else(|| {
126                self.path
127                    .file_name()
128                    .and_then(|n| n.to_str())
129                    .map(String::from)
130            })
131            .unwrap_or_else(|| "unknown".to_string());
132
133        RepoIdentifier {
134            namespace: self.namespace.clone().unwrap_or_default(),
135            name,
136            version: self.version.clone(),
137            branch: self.branch.clone(),
138            commit: None, // Would need git integration to get this
139        }
140    }
141}
142
143/// Result for a single repository in a batch operation
144#[derive(Debug, Clone, Serialize, Deserialize)]
145pub struct BatchRepoResult {
146    /// Repository identifier
147    pub repo_id: RepoIdentifier,
148
149    /// Path that was processed
150    pub path: PathBuf,
151
152    /// Generated chunks
153    pub chunks: Vec<EmbedChunk>,
154
155    /// Processing time for this repository
156    pub elapsed: Duration,
157
158    /// Error if processing failed (chunks will be empty)
159    #[serde(skip_serializing_if = "Option::is_none")]
160    pub error: Option<String>,
161}
162
163/// Summary of a batch embedding operation
164#[derive(Debug, Clone, Serialize, Deserialize)]
165pub struct BatchResult {
166    /// Results per repository
167    pub repos: Vec<BatchRepoResult>,
168
169    /// Total chunks across all repositories
170    pub total_chunks: usize,
171
172    /// Total tokens across all chunks
173    pub total_tokens: u64,
174
175    /// Number of successfully processed repositories
176    pub successful_repos: usize,
177
178    /// Number of failed repositories
179    pub failed_repos: usize,
180
181    /// Total elapsed time
182    pub elapsed: Duration,
183}
184
185impl BatchResult {
186    /// Get all chunks from all successful repositories
187    pub fn all_chunks(&self) -> impl Iterator<Item = &EmbedChunk> {
188        self.repos.iter().flat_map(|r| r.chunks.iter())
189    }
190
191    /// Get all chunks as owned vector
192    pub fn into_chunks(self) -> Vec<EmbedChunk> {
193        self.repos.into_iter().flat_map(|r| r.chunks).collect()
194    }
195
196    /// Check if there were any failures
197    pub fn has_failures(&self) -> bool {
198        self.failed_repos > 0
199    }
200
201    /// Get failed repository paths and errors
202    pub fn failures(&self) -> impl Iterator<Item = (&Path, &str)> {
203        self.repos
204            .iter()
205            .filter(|r| r.error.is_some())
206            .map(|r| (r.path.as_path(), r.error.as_deref().unwrap_or("unknown")))
207    }
208}
209
210/// Batch embedder for processing multiple repositories
211pub struct BatchEmbedder {
212    settings: EmbedSettings,
213    limits: ResourceLimits,
214    /// Maximum number of parallel repository processing
215    max_parallel: usize,
216}
217
218impl BatchEmbedder {
219    /// Create a new batch embedder with default limits
220    pub fn new(settings: EmbedSettings) -> Self {
221        Self {
222            settings,
223            limits: ResourceLimits::default(),
224            max_parallel: std::thread::available_parallelism().map_or(4, |n| n.get()),
225        }
226    }
227
228    /// Create with custom resource limits
229    pub fn with_limits(settings: EmbedSettings, limits: ResourceLimits) -> Self {
230        Self {
231            settings,
232            limits,
233            max_parallel: std::thread::available_parallelism().map_or(4, |n| n.get()),
234        }
235    }
236
237    /// Set maximum parallel repository processing
238    pub fn with_max_parallel(mut self, max: usize) -> Self {
239        self.max_parallel = max.max(1);
240        self
241    }
242
243    /// Process a batch of repositories
244    ///
245    /// Repositories are processed in parallel up to `max_parallel` at a time.
246    /// Each repository gets its own chunker instance for thread safety.
247    pub fn embed_batch(&self, repos: &[BatchRepoConfig]) -> Result<BatchResult, EmbedError> {
248        self.embed_batch_with_progress(repos, &QuietProgress)
249    }
250
251    /// Process a batch of repositories with progress reporting
252    pub fn embed_batch_with_progress(
253        &self,
254        repos: &[BatchRepoConfig],
255        progress: &dyn ProgressReporter,
256    ) -> Result<BatchResult, EmbedError> {
257        let start = Instant::now();
258
259        if repos.is_empty() {
260            return Ok(BatchResult {
261                repos: Vec::new(),
262                total_chunks: 0,
263                total_tokens: 0,
264                successful_repos: 0,
265                failed_repos: 0,
266                elapsed: start.elapsed(),
267            });
268        }
269
270        progress.set_phase(&format!("Processing {} repositories...", repos.len()));
271        progress.set_total(repos.len());
272
273        let processed = AtomicUsize::new(0);
274
275        // Configure rayon for controlled parallelism
276        let pool = rayon::ThreadPoolBuilder::new()
277            .num_threads(self.max_parallel)
278            .build()
279            .map_err(|e| EmbedError::SerializationError {
280                reason: format!("Failed to create thread pool: {}", e),
281            })?;
282
283        let results: Vec<BatchRepoResult> = pool.install(|| {
284            repos
285                .par_iter()
286                .map(|config| {
287                    let result = self.process_single_repo(config);
288
289                    let done = processed.fetch_add(1, Ordering::Relaxed) + 1;
290                    progress.set_progress(done);
291
292                    result
293                })
294                .collect()
295        });
296
297        // Calculate totals
298        let total_chunks: usize = results.iter().map(|r| r.chunks.len()).sum();
299        let total_tokens: u64 = results
300            .iter()
301            .flat_map(|r| r.chunks.iter())
302            .map(|c| c.tokens as u64)
303            .sum();
304        let successful_repos = results.iter().filter(|r| r.error.is_none()).count();
305        let failed_repos = results.iter().filter(|r| r.error.is_some()).count();
306
307        progress.set_phase("Batch complete");
308
309        Ok(BatchResult {
310            repos: results,
311            total_chunks,
312            total_tokens,
313            successful_repos,
314            failed_repos,
315            elapsed: start.elapsed(),
316        })
317    }
318
319    /// Process a single repository
320    fn process_single_repo(&self, config: &BatchRepoConfig) -> BatchRepoResult {
321        let start = Instant::now();
322        let repo_id = config.to_repo_id();
323
324        // Build settings with per-repo overrides
325        let mut settings = self.settings.clone();
326        if !config.include_patterns.is_empty() {
327            settings.include_patterns = config.include_patterns.clone();
328        }
329        if !config.exclude_patterns.is_empty() {
330            settings.exclude_patterns = config.exclude_patterns.clone();
331        }
332
333        // Create chunker for this repo
334        let chunker = EmbedChunker::new(settings, self.limits.clone()).with_repo_id(repo_id.clone());
335
336        // Process the repository
337        let quiet = QuietProgress;
338        match chunker.chunk_repository(&config.path, &quiet) {
339            Ok(chunks) => BatchRepoResult {
340                repo_id,
341                path: config.path.clone(),
342                chunks,
343                elapsed: start.elapsed(),
344                error: None,
345            },
346            Err(e) => BatchRepoResult {
347                repo_id,
348                path: config.path.clone(),
349                chunks: Vec::new(),
350                elapsed: start.elapsed(),
351                error: Some(e.to_string()),
352            },
353        }
354    }
355}
356
357#[cfg(test)]
358mod tests {
359    use super::*;
360    use tempfile::TempDir;
361
362    fn create_test_file(dir: &Path, name: &str, content: &str) {
363        let path = dir.join(name);
364        if let Some(parent) = path.parent() {
365            std::fs::create_dir_all(parent).unwrap();
366        }
367        std::fs::write(path, content).unwrap();
368    }
369
370    #[test]
371    fn test_batch_repo_config_builder() {
372        let config = BatchRepoConfig::new("/path/to/repo")
373            .with_namespace("github.com/org")
374            .with_name("my-repo")
375            .with_version("v1.0.0")
376            .with_branch("main");
377
378        assert_eq!(config.path, PathBuf::from("/path/to/repo"));
379        assert_eq!(config.namespace.as_deref(), Some("github.com/org"));
380        assert_eq!(config.name.as_deref(), Some("my-repo"));
381        assert_eq!(config.version.as_deref(), Some("v1.0.0"));
382        assert_eq!(config.branch.as_deref(), Some("main"));
383    }
384
385    #[test]
386    fn test_batch_repo_config_to_repo_id() {
387        let config = BatchRepoConfig::new("/path/to/my-repo")
388            .with_namespace("github.com/org")
389            .with_name("custom-name");
390
391        let repo_id = config.to_repo_id();
392        assert_eq!(repo_id.namespace, "github.com/org");
393        assert_eq!(repo_id.name, "custom-name");
394    }
395
396    #[test]
397    fn test_batch_repo_config_infer_name() {
398        let config = BatchRepoConfig::new("/path/to/my-repo");
399
400        let repo_id = config.to_repo_id();
401        assert_eq!(repo_id.name, "my-repo");
402    }
403
404    #[test]
405    fn test_batch_embedder_empty_batch() {
406        let embedder = BatchEmbedder::new(EmbedSettings::default());
407        let result = embedder.embed_batch(&[]).unwrap();
408
409        assert_eq!(result.total_chunks, 0);
410        assert_eq!(result.successful_repos, 0);
411        assert_eq!(result.failed_repos, 0);
412        assert!(!result.has_failures());
413    }
414
415    #[test]
416    fn test_batch_embedder_single_repo() {
417        let temp_dir = TempDir::new().unwrap();
418        create_test_file(
419            temp_dir.path(),
420            "lib.rs",
421            "/// A test function\npub fn hello() { println!(\"hello\"); }\n",
422        );
423
424        let repos = vec![
425            BatchRepoConfig::new(temp_dir.path()).with_name("test-repo"),
426        ];
427
428        let embedder = BatchEmbedder::new(EmbedSettings::default());
429        let result = embedder.embed_batch(&repos).unwrap();
430
431        assert_eq!(result.successful_repos, 1);
432        assert_eq!(result.failed_repos, 0);
433        assert!(!result.has_failures());
434        assert!(result.total_chunks > 0);
435    }
436
437    #[test]
438    fn test_batch_embedder_multiple_repos() {
439        let temp_dir1 = TempDir::new().unwrap();
440        let temp_dir2 = TempDir::new().unwrap();
441
442        create_test_file(
443            temp_dir1.path(),
444            "a.rs",
445            "pub fn foo() { println!(\"foo\"); }\n",
446        );
447        create_test_file(
448            temp_dir2.path(),
449            "b.rs",
450            "pub fn bar() { println!(\"bar\"); }\n",
451        );
452
453        let repos = vec![
454            BatchRepoConfig::new(temp_dir1.path())
455                .with_namespace("org")
456                .with_name("repo1"),
457            BatchRepoConfig::new(temp_dir2.path())
458                .with_namespace("org")
459                .with_name("repo2"),
460        ];
461
462        let embedder = BatchEmbedder::new(EmbedSettings::default()).with_max_parallel(2);
463        let result = embedder.embed_batch(&repos).unwrap();
464
465        assert_eq!(result.successful_repos, 2);
466        assert_eq!(result.failed_repos, 0);
467
468        // Verify repo IDs are different
469        let repo_ids: Vec<_> = result.repos.iter().map(|r| &r.repo_id).collect();
470        assert_ne!(repo_ids[0].name, repo_ids[1].name);
471    }
472
473    #[test]
474    fn test_batch_embedder_handles_failure() {
475        let temp_dir = TempDir::new().unwrap();
476        create_test_file(temp_dir.path(), "lib.rs", "pub fn ok() {}\n");
477
478        let repos = vec![
479            BatchRepoConfig::new(temp_dir.path()).with_name("good-repo"),
480            BatchRepoConfig::new("/nonexistent/path").with_name("bad-repo"),
481        ];
482
483        let embedder = BatchEmbedder::new(EmbedSettings::default());
484        let result = embedder.embed_batch(&repos).unwrap();
485
486        assert_eq!(result.successful_repos, 1);
487        assert_eq!(result.failed_repos, 1);
488        assert!(result.has_failures());
489
490        let failures: Vec<_> = result.failures().collect();
491        assert_eq!(failures.len(), 1);
492        assert!(failures[0].0.to_str().unwrap().contains("nonexistent"));
493    }
494
495    #[test]
496    fn test_batch_result_all_chunks() {
497        let temp_dir1 = TempDir::new().unwrap();
498        let temp_dir2 = TempDir::new().unwrap();
499
500        create_test_file(
501            temp_dir1.path(),
502            "a.rs",
503            "pub fn func1() {}\npub fn func2() {}\n",
504        );
505        create_test_file(temp_dir2.path(), "b.rs", "pub fn func3() {}\n");
506
507        let repos = vec![
508            BatchRepoConfig::new(temp_dir1.path()).with_name("repo1"),
509            BatchRepoConfig::new(temp_dir2.path()).with_name("repo2"),
510        ];
511
512        let embedder = BatchEmbedder::new(EmbedSettings::default());
513        let result = embedder.embed_batch(&repos).unwrap();
514
515        // Verify all_chunks returns chunks from all repos
516        let all_chunks: Vec<_> = result.all_chunks().collect();
517        assert_eq!(all_chunks.len(), result.total_chunks);
518    }
519
520    #[test]
521    fn test_batch_result_into_chunks() {
522        let temp_dir = TempDir::new().unwrap();
523        create_test_file(temp_dir.path(), "lib.rs", "pub fn test() {}\n");
524
525        let repos = vec![BatchRepoConfig::new(temp_dir.path())];
526
527        let embedder = BatchEmbedder::new(EmbedSettings::default());
528        let result = embedder.embed_batch(&repos).unwrap();
529
530        let total = result.total_chunks;
531        let chunks = result.into_chunks();
532        assert_eq!(chunks.len(), total);
533    }
534
535    #[test]
536    fn test_per_repo_pattern_override() {
537        let temp_dir = TempDir::new().unwrap();
538        create_test_file(temp_dir.path(), "src/lib.rs", "pub fn included() {}\n");
539        create_test_file(temp_dir.path(), "tests/test.rs", "pub fn excluded() {}\n");
540
541        // Test with include pattern override
542        let repos = vec![
543            BatchRepoConfig::new(temp_dir.path())
544                .with_name("filtered-repo")
545                .with_include_patterns(vec!["src/**/*.rs".to_string()]),
546        ];
547
548        let embedder = BatchEmbedder::new(EmbedSettings::default());
549        let result = embedder.embed_batch(&repos).unwrap();
550
551        // Should only have chunks from src/
552        for chunk in result.all_chunks() {
553            assert!(
554                chunk.source.file.starts_with("src/"),
555                "Expected src/ prefix, got: {}",
556                chunk.source.file
557            );
558        }
559    }
560}