infiniloom_engine/embedding/
batch.rs

1//! Batch embedding API for processing multiple repositories
2//!
3//! This module provides APIs for embedding multiple repositories in a single
4//! operation, with parallel processing and unified output.
5//!
6//! # Example
7//!
8//! ```rust,ignore
9//! use infiniloom_engine::embedding::{BatchEmbedder, BatchRepoConfig, EmbedSettings};
10//!
11//! let embedder = BatchEmbedder::new(EmbedSettings::default());
12//!
13//! let repos = vec![
14//!     BatchRepoConfig::new("/path/to/repo1")
15//!         .with_namespace("github.com/org")
16//!         .with_name("auth-service"),
17//!     BatchRepoConfig::new("/path/to/repo2")
18//!         .with_namespace("github.com/org")
19//!         .with_name("user-service"),
20//! ];
21//!
22//! let result = embedder.embed_batch(&repos)?;
23//! println!("Total chunks: {}", result.total_chunks);
24//! ```
25
26use std::path::{Path, PathBuf};
27use std::sync::atomic::{AtomicUsize, Ordering};
28use std::time::{Duration, Instant};
29
30use rayon::prelude::*;
31use serde::{Deserialize, Serialize};
32
33use super::chunker::EmbedChunker;
34use super::error::EmbedError;
35use super::progress::{ProgressReporter, QuietProgress};
36use super::types::{EmbedChunk, EmbedSettings, RepoIdentifier};
37use super::ResourceLimits;
38
39/// Configuration for a single repository in a batch operation
40#[derive(Debug, Clone, Serialize, Deserialize)]
41pub struct BatchRepoConfig {
42    /// Path to the repository (local path or URL for remote)
43    pub path: PathBuf,
44
45    /// Optional namespace override (e.g., "github.com/org")
46    #[serde(skip_serializing_if = "Option::is_none")]
47    pub namespace: Option<String>,
48
49    /// Optional repository name override (defaults to directory name)
50    #[serde(skip_serializing_if = "Option::is_none")]
51    pub name: Option<String>,
52
53    /// Optional version tag
54    #[serde(skip_serializing_if = "Option::is_none")]
55    pub version: Option<String>,
56
57    /// Optional branch name
58    #[serde(skip_serializing_if = "Option::is_none")]
59    pub branch: Option<String>,
60
61    /// Override include patterns for this repo only
62    #[serde(default, skip_serializing_if = "Vec::is_empty")]
63    pub include_patterns: Vec<String>,
64
65    /// Override exclude patterns for this repo only
66    #[serde(default, skip_serializing_if = "Vec::is_empty")]
67    pub exclude_patterns: Vec<String>,
68}
69
70impl BatchRepoConfig {
71    /// Create a new batch repo config from a path
72    pub fn new(path: impl Into<PathBuf>) -> Self {
73        Self {
74            path: path.into(),
75            namespace: None,
76            name: None,
77            version: None,
78            branch: None,
79            include_patterns: Vec::new(),
80            exclude_patterns: Vec::new(),
81        }
82    }
83
84    /// Set the namespace
85    pub fn with_namespace(mut self, namespace: impl Into<String>) -> Self {
86        self.namespace = Some(namespace.into());
87        self
88    }
89
90    /// Set the repository name
91    pub fn with_name(mut self, name: impl Into<String>) -> Self {
92        self.name = Some(name.into());
93        self
94    }
95
96    /// Set the version tag
97    pub fn with_version(mut self, version: impl Into<String>) -> Self {
98        self.version = Some(version.into());
99        self
100    }
101
102    /// Set the branch name
103    pub fn with_branch(mut self, branch: impl Into<String>) -> Self {
104        self.branch = Some(branch.into());
105        self
106    }
107
108    /// Set include patterns (overrides global settings for this repo)
109    pub fn with_include_patterns(mut self, patterns: Vec<String>) -> Self {
110        self.include_patterns = patterns;
111        self
112    }
113
114    /// Set exclude patterns (overrides global settings for this repo)
115    pub fn with_exclude_patterns(mut self, patterns: Vec<String>) -> Self {
116        self.exclude_patterns = patterns;
117        self
118    }
119
120    /// Build the RepoIdentifier from this config
121    pub fn to_repo_id(&self) -> RepoIdentifier {
122        let name = self
123            .name
124            .clone()
125            .or_else(|| {
126                self.path
127                    .file_name()
128                    .and_then(|n| n.to_str())
129                    .map(String::from)
130            })
131            .unwrap_or_else(|| "unknown".to_owned());
132
133        RepoIdentifier {
134            namespace: self.namespace.clone().unwrap_or_default(),
135            name,
136            version: self.version.clone(),
137            branch: self.branch.clone(),
138            commit: None, // Would need git integration to get this
139        }
140    }
141}
142
143/// Result for a single repository in a batch operation
144#[derive(Debug, Clone, Serialize, Deserialize)]
145pub struct BatchRepoResult {
146    /// Repository identifier
147    pub repo_id: RepoIdentifier,
148
149    /// Path that was processed
150    pub path: PathBuf,
151
152    /// Generated chunks
153    pub chunks: Vec<EmbedChunk>,
154
155    /// Processing time for this repository
156    pub elapsed: Duration,
157
158    /// Error if processing failed (chunks will be empty)
159    #[serde(skip_serializing_if = "Option::is_none")]
160    pub error: Option<String>,
161}
162
163/// Summary of a batch embedding operation
164#[derive(Debug, Clone, Serialize, Deserialize)]
165pub struct BatchResult {
166    /// Results per repository
167    pub repos: Vec<BatchRepoResult>,
168
169    /// Total chunks across all repositories
170    pub total_chunks: usize,
171
172    /// Total tokens across all chunks
173    pub total_tokens: u64,
174
175    /// Number of successfully processed repositories
176    pub successful_repos: usize,
177
178    /// Number of failed repositories
179    pub failed_repos: usize,
180
181    /// Total elapsed time
182    pub elapsed: Duration,
183}
184
185impl BatchResult {
186    /// Get all chunks from all successful repositories
187    pub fn all_chunks(&self) -> impl Iterator<Item = &EmbedChunk> {
188        self.repos.iter().flat_map(|r| r.chunks.iter())
189    }
190
191    /// Get all chunks as owned vector
192    pub fn into_chunks(self) -> Vec<EmbedChunk> {
193        self.repos.into_iter().flat_map(|r| r.chunks).collect()
194    }
195
196    /// Check if there were any failures
197    pub fn has_failures(&self) -> bool {
198        self.failed_repos > 0
199    }
200
201    /// Get failed repository paths and errors
202    pub fn failures(&self) -> impl Iterator<Item = (&Path, &str)> {
203        self.repos
204            .iter()
205            .filter(|r| r.error.is_some())
206            .map(|r| (r.path.as_path(), r.error.as_deref().unwrap_or("unknown")))
207    }
208}
209
210/// Batch embedder for processing multiple repositories
211pub struct BatchEmbedder {
212    settings: EmbedSettings,
213    limits: ResourceLimits,
214    /// Maximum number of parallel repository processing
215    max_parallel: usize,
216}
217
218impl BatchEmbedder {
219    /// Create a new batch embedder with default limits
220    pub fn new(settings: EmbedSettings) -> Self {
221        Self {
222            settings,
223            limits: ResourceLimits::default(),
224            max_parallel: std::thread::available_parallelism().map_or(4, |n| n.get()),
225        }
226    }
227
228    /// Create with custom resource limits
229    pub fn with_limits(settings: EmbedSettings, limits: ResourceLimits) -> Self {
230        Self {
231            settings,
232            limits,
233            max_parallel: std::thread::available_parallelism().map_or(4, |n| n.get()),
234        }
235    }
236
237    /// Set maximum parallel repository processing
238    pub fn with_max_parallel(mut self, max: usize) -> Self {
239        self.max_parallel = max.max(1);
240        self
241    }
242
243    /// Process a batch of repositories
244    ///
245    /// Repositories are processed in parallel up to `max_parallel` at a time.
246    /// Each repository gets its own chunker instance for thread safety.
247    pub fn embed_batch(&self, repos: &[BatchRepoConfig]) -> Result<BatchResult, EmbedError> {
248        self.embed_batch_with_progress(repos, &QuietProgress)
249    }
250
251    /// Process a batch of repositories with progress reporting
252    pub fn embed_batch_with_progress(
253        &self,
254        repos: &[BatchRepoConfig],
255        progress: &dyn ProgressReporter,
256    ) -> Result<BatchResult, EmbedError> {
257        let start = Instant::now();
258
259        if repos.is_empty() {
260            return Ok(BatchResult {
261                repos: Vec::new(),
262                total_chunks: 0,
263                total_tokens: 0,
264                successful_repos: 0,
265                failed_repos: 0,
266                elapsed: start.elapsed(),
267            });
268        }
269
270        progress.set_phase(&format!("Processing {} repositories...", repos.len()));
271        progress.set_total(repos.len());
272
273        let processed = AtomicUsize::new(0);
274
275        // Configure rayon for controlled parallelism
276        let pool = rayon::ThreadPoolBuilder::new()
277            .num_threads(self.max_parallel)
278            .build()
279            .map_err(|e| EmbedError::SerializationError {
280                reason: format!("Failed to create thread pool: {}", e),
281            })?;
282
283        let results: Vec<BatchRepoResult> = pool.install(|| {
284            repos
285                .par_iter()
286                .map(|config| {
287                    let result = self.process_single_repo(config);
288
289                    let done = processed.fetch_add(1, Ordering::Relaxed) + 1;
290                    progress.set_progress(done);
291
292                    result
293                })
294                .collect()
295        });
296
297        // Calculate totals
298        let total_chunks: usize = results.iter().map(|r| r.chunks.len()).sum();
299        let total_tokens: u64 = results
300            .iter()
301            .flat_map(|r| r.chunks.iter())
302            .map(|c| c.tokens as u64)
303            .sum();
304        let successful_repos = results.iter().filter(|r| r.error.is_none()).count();
305        let failed_repos = results.iter().filter(|r| r.error.is_some()).count();
306
307        progress.set_phase("Batch complete");
308
309        Ok(BatchResult {
310            repos: results,
311            total_chunks,
312            total_tokens,
313            successful_repos,
314            failed_repos,
315            elapsed: start.elapsed(),
316        })
317    }
318
319    /// Process a single repository
320    fn process_single_repo(&self, config: &BatchRepoConfig) -> BatchRepoResult {
321        let start = Instant::now();
322        let repo_id = config.to_repo_id();
323
324        // Build settings with per-repo overrides
325        let mut settings = self.settings.clone();
326        if !config.include_patterns.is_empty() {
327            settings.include_patterns = config.include_patterns.clone();
328        }
329        if !config.exclude_patterns.is_empty() {
330            settings.exclude_patterns = config.exclude_patterns.clone();
331        }
332
333        // Create chunker for this repo
334        let chunker =
335            EmbedChunker::new(settings, self.limits.clone()).with_repo_id(repo_id.clone());
336
337        // Process the repository
338        let quiet = QuietProgress;
339        match chunker.chunk_repository(&config.path, &quiet) {
340            Ok(chunks) => BatchRepoResult {
341                repo_id,
342                path: config.path.clone(),
343                chunks,
344                elapsed: start.elapsed(),
345                error: None,
346            },
347            Err(e) => BatchRepoResult {
348                repo_id,
349                path: config.path.clone(),
350                chunks: Vec::new(),
351                elapsed: start.elapsed(),
352                error: Some(e.to_string()),
353            },
354        }
355    }
356}
357
358#[cfg(test)]
359mod tests {
360    use super::*;
361    use tempfile::TempDir;
362
363    fn create_test_file(dir: &Path, name: &str, content: &str) {
364        let path = dir.join(name);
365        if let Some(parent) = path.parent() {
366            std::fs::create_dir_all(parent).unwrap();
367        }
368        std::fs::write(path, content).unwrap();
369    }
370
371    #[test]
372    fn test_batch_repo_config_builder() {
373        let config = BatchRepoConfig::new("/path/to/repo")
374            .with_namespace("github.com/org")
375            .with_name("my-repo")
376            .with_version("v1.0.0")
377            .with_branch("main");
378
379        assert_eq!(config.path, PathBuf::from("/path/to/repo"));
380        assert_eq!(config.namespace.as_deref(), Some("github.com/org"));
381        assert_eq!(config.name.as_deref(), Some("my-repo"));
382        assert_eq!(config.version.as_deref(), Some("v1.0.0"));
383        assert_eq!(config.branch.as_deref(), Some("main"));
384    }
385
386    #[test]
387    fn test_batch_repo_config_to_repo_id() {
388        let config = BatchRepoConfig::new("/path/to/my-repo")
389            .with_namespace("github.com/org")
390            .with_name("custom-name");
391
392        let repo_id = config.to_repo_id();
393        assert_eq!(repo_id.namespace, "github.com/org");
394        assert_eq!(repo_id.name, "custom-name");
395    }
396
397    #[test]
398    fn test_batch_repo_config_infer_name() {
399        let config = BatchRepoConfig::new("/path/to/my-repo");
400
401        let repo_id = config.to_repo_id();
402        assert_eq!(repo_id.name, "my-repo");
403    }
404
405    #[test]
406    fn test_batch_embedder_empty_batch() {
407        let embedder = BatchEmbedder::new(EmbedSettings::default());
408        let result = embedder.embed_batch(&[]).unwrap();
409
410        assert_eq!(result.total_chunks, 0);
411        assert_eq!(result.successful_repos, 0);
412        assert_eq!(result.failed_repos, 0);
413        assert!(!result.has_failures());
414    }
415
416    #[test]
417    fn test_batch_embedder_single_repo() {
418        let temp_dir = TempDir::new().unwrap();
419        create_test_file(
420            temp_dir.path(),
421            "lib.rs",
422            "/// A test function\npub fn hello() { println!(\"hello\"); }\n",
423        );
424
425        let repos = vec![BatchRepoConfig::new(temp_dir.path()).with_name("test-repo")];
426
427        let embedder = BatchEmbedder::new(EmbedSettings::default());
428        let result = embedder.embed_batch(&repos).unwrap();
429
430        assert_eq!(result.successful_repos, 1);
431        assert_eq!(result.failed_repos, 0);
432        assert!(!result.has_failures());
433        assert!(result.total_chunks > 0);
434    }
435
436    #[test]
437    fn test_batch_embedder_multiple_repos() {
438        let temp_dir1 = TempDir::new().unwrap();
439        let temp_dir2 = TempDir::new().unwrap();
440
441        create_test_file(temp_dir1.path(), "a.rs", "pub fn foo() { println!(\"foo\"); }\n");
442        create_test_file(temp_dir2.path(), "b.rs", "pub fn bar() { println!(\"bar\"); }\n");
443
444        let repos = vec![
445            BatchRepoConfig::new(temp_dir1.path())
446                .with_namespace("org")
447                .with_name("repo1"),
448            BatchRepoConfig::new(temp_dir2.path())
449                .with_namespace("org")
450                .with_name("repo2"),
451        ];
452
453        let embedder = BatchEmbedder::new(EmbedSettings::default()).with_max_parallel(2);
454        let result = embedder.embed_batch(&repos).unwrap();
455
456        assert_eq!(result.successful_repos, 2);
457        assert_eq!(result.failed_repos, 0);
458
459        // Verify repo IDs are different
460        let repo_ids: Vec<_> = result.repos.iter().map(|r| &r.repo_id).collect();
461        assert_ne!(repo_ids[0].name, repo_ids[1].name);
462    }
463
464    #[test]
465    fn test_batch_embedder_handles_failure() {
466        let temp_dir = TempDir::new().unwrap();
467        create_test_file(temp_dir.path(), "lib.rs", "pub fn ok() {}\n");
468
469        let repos = vec![
470            BatchRepoConfig::new(temp_dir.path()).with_name("good-repo"),
471            BatchRepoConfig::new("/nonexistent/path").with_name("bad-repo"),
472        ];
473
474        let embedder = BatchEmbedder::new(EmbedSettings::default());
475        let result = embedder.embed_batch(&repos).unwrap();
476
477        assert_eq!(result.successful_repos, 1);
478        assert_eq!(result.failed_repos, 1);
479        assert!(result.has_failures());
480
481        let failures: Vec<_> = result.failures().collect();
482        assert_eq!(failures.len(), 1);
483        assert!(failures[0].0.to_str().unwrap().contains("nonexistent"));
484    }
485
486    #[test]
487    fn test_batch_result_all_chunks() {
488        let temp_dir1 = TempDir::new().unwrap();
489        let temp_dir2 = TempDir::new().unwrap();
490
491        create_test_file(temp_dir1.path(), "a.rs", "pub fn func1() {}\npub fn func2() {}\n");
492        create_test_file(temp_dir2.path(), "b.rs", "pub fn func3() {}\n");
493
494        let repos = vec![
495            BatchRepoConfig::new(temp_dir1.path()).with_name("repo1"),
496            BatchRepoConfig::new(temp_dir2.path()).with_name("repo2"),
497        ];
498
499        let embedder = BatchEmbedder::new(EmbedSettings::default());
500        let result = embedder.embed_batch(&repos).unwrap();
501
502        // Verify all_chunks returns chunks from all repos
503        let all_chunks: Vec<_> = result.all_chunks().collect();
504        assert_eq!(all_chunks.len(), result.total_chunks);
505    }
506
507    #[test]
508    fn test_batch_result_into_chunks() {
509        let temp_dir = TempDir::new().unwrap();
510        create_test_file(temp_dir.path(), "lib.rs", "pub fn test() {}\n");
511
512        let repos = vec![BatchRepoConfig::new(temp_dir.path())];
513
514        let embedder = BatchEmbedder::new(EmbedSettings::default());
515        let result = embedder.embed_batch(&repos).unwrap();
516
517        let total = result.total_chunks;
518        let chunks = result.into_chunks();
519        assert_eq!(chunks.len(), total);
520    }
521
522    #[test]
523    fn test_per_repo_pattern_override() {
524        let temp_dir = TempDir::new().unwrap();
525        create_test_file(temp_dir.path(), "src/lib.rs", "pub fn included() {}\n");
526        create_test_file(temp_dir.path(), "tests/test.rs", "pub fn excluded() {}\n");
527
528        // Test with include pattern override
529        let repos = vec![BatchRepoConfig::new(temp_dir.path())
530            .with_name("filtered-repo")
531            .with_include_patterns(vec!["src/**/*.rs".to_owned()])];
532
533        let embedder = BatchEmbedder::new(EmbedSettings::default());
534        let result = embedder.embed_batch(&repos).unwrap();
535
536        // Should only have chunks from src/
537        for chunk in result.all_chunks() {
538            assert!(
539                chunk.source.file.starts_with("src/"),
540                "Expected src/ prefix, got: {}",
541                chunk.source.file
542            );
543        }
544    }
545}