infiniloom_engine/embedding/
batch.rs1use std::path::{Path, PathBuf};
27use std::sync::atomic::{AtomicUsize, Ordering};
28use std::time::{Duration, Instant};
29
30use rayon::prelude::*;
31use serde::{Deserialize, Serialize};
32
33use super::chunker::EmbedChunker;
34use super::error::EmbedError;
35use super::progress::{ProgressReporter, QuietProgress};
36use super::types::{EmbedChunk, EmbedSettings, RepoIdentifier};
37use super::ResourceLimits;
38
39#[derive(Debug, Clone, Serialize, Deserialize)]
41pub struct BatchRepoConfig {
42 pub path: PathBuf,
44
45 #[serde(skip_serializing_if = "Option::is_none")]
47 pub namespace: Option<String>,
48
49 #[serde(skip_serializing_if = "Option::is_none")]
51 pub name: Option<String>,
52
53 #[serde(skip_serializing_if = "Option::is_none")]
55 pub version: Option<String>,
56
57 #[serde(skip_serializing_if = "Option::is_none")]
59 pub branch: Option<String>,
60
61 #[serde(default, skip_serializing_if = "Vec::is_empty")]
63 pub include_patterns: Vec<String>,
64
65 #[serde(default, skip_serializing_if = "Vec::is_empty")]
67 pub exclude_patterns: Vec<String>,
68}
69
70impl BatchRepoConfig {
71 pub fn new(path: impl Into<PathBuf>) -> Self {
73 Self {
74 path: path.into(),
75 namespace: None,
76 name: None,
77 version: None,
78 branch: None,
79 include_patterns: Vec::new(),
80 exclude_patterns: Vec::new(),
81 }
82 }
83
84 pub fn with_namespace(mut self, namespace: impl Into<String>) -> Self {
86 self.namespace = Some(namespace.into());
87 self
88 }
89
90 pub fn with_name(mut self, name: impl Into<String>) -> Self {
92 self.name = Some(name.into());
93 self
94 }
95
96 pub fn with_version(mut self, version: impl Into<String>) -> Self {
98 self.version = Some(version.into());
99 self
100 }
101
102 pub fn with_branch(mut self, branch: impl Into<String>) -> Self {
104 self.branch = Some(branch.into());
105 self
106 }
107
108 pub fn with_include_patterns(mut self, patterns: Vec<String>) -> Self {
110 self.include_patterns = patterns;
111 self
112 }
113
114 pub fn with_exclude_patterns(mut self, patterns: Vec<String>) -> Self {
116 self.exclude_patterns = patterns;
117 self
118 }
119
120 pub fn to_repo_id(&self) -> RepoIdentifier {
122 let name = self
123 .name
124 .clone()
125 .or_else(|| {
126 self.path
127 .file_name()
128 .and_then(|n| n.to_str())
129 .map(String::from)
130 })
131 .unwrap_or_else(|| "unknown".to_owned());
132
133 RepoIdentifier {
134 namespace: self.namespace.clone().unwrap_or_default(),
135 name,
136 version: self.version.clone(),
137 branch: self.branch.clone(),
138 commit: None, }
140 }
141}
142
143#[derive(Debug, Clone, Serialize, Deserialize)]
145pub struct BatchRepoResult {
146 pub repo_id: RepoIdentifier,
148
149 pub path: PathBuf,
151
152 pub chunks: Vec<EmbedChunk>,
154
155 pub elapsed: Duration,
157
158 #[serde(skip_serializing_if = "Option::is_none")]
160 pub error: Option<String>,
161}
162
163#[derive(Debug, Clone, Serialize, Deserialize)]
165pub struct BatchResult {
166 pub repos: Vec<BatchRepoResult>,
168
169 pub total_chunks: usize,
171
172 pub total_tokens: u64,
174
175 pub successful_repos: usize,
177
178 pub failed_repos: usize,
180
181 pub elapsed: Duration,
183}
184
185impl BatchResult {
186 pub fn all_chunks(&self) -> impl Iterator<Item = &EmbedChunk> {
188 self.repos.iter().flat_map(|r| r.chunks.iter())
189 }
190
191 pub fn into_chunks(self) -> Vec<EmbedChunk> {
193 self.repos.into_iter().flat_map(|r| r.chunks).collect()
194 }
195
196 pub fn has_failures(&self) -> bool {
198 self.failed_repos > 0
199 }
200
201 pub fn failures(&self) -> impl Iterator<Item = (&Path, &str)> {
203 self.repos
204 .iter()
205 .filter(|r| r.error.is_some())
206 .map(|r| (r.path.as_path(), r.error.as_deref().unwrap_or("unknown")))
207 }
208}
209
210pub struct BatchEmbedder {
212 settings: EmbedSettings,
213 limits: ResourceLimits,
214 max_parallel: usize,
216}
217
218impl BatchEmbedder {
219 pub fn new(settings: EmbedSettings) -> Self {
221 Self {
222 settings,
223 limits: ResourceLimits::default(),
224 max_parallel: std::thread::available_parallelism().map_or(4, |n| n.get()),
225 }
226 }
227
228 pub fn with_limits(settings: EmbedSettings, limits: ResourceLimits) -> Self {
230 Self {
231 settings,
232 limits,
233 max_parallel: std::thread::available_parallelism().map_or(4, |n| n.get()),
234 }
235 }
236
237 pub fn with_max_parallel(mut self, max: usize) -> Self {
239 self.max_parallel = max.max(1);
240 self
241 }
242
243 pub fn embed_batch(&self, repos: &[BatchRepoConfig]) -> Result<BatchResult, EmbedError> {
248 self.embed_batch_with_progress(repos, &QuietProgress)
249 }
250
251 pub fn embed_batch_with_progress(
253 &self,
254 repos: &[BatchRepoConfig],
255 progress: &dyn ProgressReporter,
256 ) -> Result<BatchResult, EmbedError> {
257 let start = Instant::now();
258
259 if repos.is_empty() {
260 return Ok(BatchResult {
261 repos: Vec::new(),
262 total_chunks: 0,
263 total_tokens: 0,
264 successful_repos: 0,
265 failed_repos: 0,
266 elapsed: start.elapsed(),
267 });
268 }
269
270 progress.set_phase(&format!("Processing {} repositories...", repos.len()));
271 progress.set_total(repos.len());
272
273 let processed = AtomicUsize::new(0);
274
275 let pool = rayon::ThreadPoolBuilder::new()
277 .num_threads(self.max_parallel)
278 .build()
279 .map_err(|e| EmbedError::SerializationError {
280 reason: format!("Failed to create thread pool: {}", e),
281 })?;
282
283 let results: Vec<BatchRepoResult> = pool.install(|| {
284 repos
285 .par_iter()
286 .map(|config| {
287 let result = self.process_single_repo(config);
288
289 let done = processed.fetch_add(1, Ordering::Relaxed) + 1;
290 progress.set_progress(done);
291
292 result
293 })
294 .collect()
295 });
296
297 let total_chunks: usize = results.iter().map(|r| r.chunks.len()).sum();
299 let total_tokens: u64 = results
300 .iter()
301 .flat_map(|r| r.chunks.iter())
302 .map(|c| c.tokens as u64)
303 .sum();
304 let successful_repos = results.iter().filter(|r| r.error.is_none()).count();
305 let failed_repos = results.iter().filter(|r| r.error.is_some()).count();
306
307 progress.set_phase("Batch complete");
308
309 Ok(BatchResult {
310 repos: results,
311 total_chunks,
312 total_tokens,
313 successful_repos,
314 failed_repos,
315 elapsed: start.elapsed(),
316 })
317 }
318
319 fn process_single_repo(&self, config: &BatchRepoConfig) -> BatchRepoResult {
321 let start = Instant::now();
322 let repo_id = config.to_repo_id();
323
324 let mut settings = self.settings.clone();
326 if !config.include_patterns.is_empty() {
327 settings.include_patterns = config.include_patterns.clone();
328 }
329 if !config.exclude_patterns.is_empty() {
330 settings.exclude_patterns = config.exclude_patterns.clone();
331 }
332
333 let chunker =
335 EmbedChunker::new(settings, self.limits.clone()).with_repo_id(repo_id.clone());
336
337 let quiet = QuietProgress;
339 match chunker.chunk_repository(&config.path, &quiet) {
340 Ok(chunks) => BatchRepoResult {
341 repo_id,
342 path: config.path.clone(),
343 chunks,
344 elapsed: start.elapsed(),
345 error: None,
346 },
347 Err(e) => BatchRepoResult {
348 repo_id,
349 path: config.path.clone(),
350 chunks: Vec::new(),
351 elapsed: start.elapsed(),
352 error: Some(e.to_string()),
353 },
354 }
355 }
356}
357
358#[cfg(test)]
359mod tests {
360 use super::*;
361 use tempfile::TempDir;
362
363 fn create_test_file(dir: &Path, name: &str, content: &str) {
364 let path = dir.join(name);
365 if let Some(parent) = path.parent() {
366 std::fs::create_dir_all(parent).unwrap();
367 }
368 std::fs::write(path, content).unwrap();
369 }
370
371 #[test]
372 fn test_batch_repo_config_builder() {
373 let config = BatchRepoConfig::new("/path/to/repo")
374 .with_namespace("github.com/org")
375 .with_name("my-repo")
376 .with_version("v1.0.0")
377 .with_branch("main");
378
379 assert_eq!(config.path, PathBuf::from("/path/to/repo"));
380 assert_eq!(config.namespace.as_deref(), Some("github.com/org"));
381 assert_eq!(config.name.as_deref(), Some("my-repo"));
382 assert_eq!(config.version.as_deref(), Some("v1.0.0"));
383 assert_eq!(config.branch.as_deref(), Some("main"));
384 }
385
386 #[test]
387 fn test_batch_repo_config_to_repo_id() {
388 let config = BatchRepoConfig::new("/path/to/my-repo")
389 .with_namespace("github.com/org")
390 .with_name("custom-name");
391
392 let repo_id = config.to_repo_id();
393 assert_eq!(repo_id.namespace, "github.com/org");
394 assert_eq!(repo_id.name, "custom-name");
395 }
396
397 #[test]
398 fn test_batch_repo_config_infer_name() {
399 let config = BatchRepoConfig::new("/path/to/my-repo");
400
401 let repo_id = config.to_repo_id();
402 assert_eq!(repo_id.name, "my-repo");
403 }
404
405 #[test]
406 fn test_batch_embedder_empty_batch() {
407 let embedder = BatchEmbedder::new(EmbedSettings::default());
408 let result = embedder.embed_batch(&[]).unwrap();
409
410 assert_eq!(result.total_chunks, 0);
411 assert_eq!(result.successful_repos, 0);
412 assert_eq!(result.failed_repos, 0);
413 assert!(!result.has_failures());
414 }
415
416 #[test]
417 fn test_batch_embedder_single_repo() {
418 let temp_dir = TempDir::new().unwrap();
419 create_test_file(
420 temp_dir.path(),
421 "lib.rs",
422 "/// A test function\npub fn hello() { println!(\"hello\"); }\n",
423 );
424
425 let repos = vec![BatchRepoConfig::new(temp_dir.path()).with_name("test-repo")];
426
427 let embedder = BatchEmbedder::new(EmbedSettings::default());
428 let result = embedder.embed_batch(&repos).unwrap();
429
430 assert_eq!(result.successful_repos, 1);
431 assert_eq!(result.failed_repos, 0);
432 assert!(!result.has_failures());
433 assert!(result.total_chunks > 0);
434 }
435
436 #[test]
437 fn test_batch_embedder_multiple_repos() {
438 let temp_dir1 = TempDir::new().unwrap();
439 let temp_dir2 = TempDir::new().unwrap();
440
441 create_test_file(temp_dir1.path(), "a.rs", "pub fn foo() { println!(\"foo\"); }\n");
442 create_test_file(temp_dir2.path(), "b.rs", "pub fn bar() { println!(\"bar\"); }\n");
443
444 let repos = vec![
445 BatchRepoConfig::new(temp_dir1.path())
446 .with_namespace("org")
447 .with_name("repo1"),
448 BatchRepoConfig::new(temp_dir2.path())
449 .with_namespace("org")
450 .with_name("repo2"),
451 ];
452
453 let embedder = BatchEmbedder::new(EmbedSettings::default()).with_max_parallel(2);
454 let result = embedder.embed_batch(&repos).unwrap();
455
456 assert_eq!(result.successful_repos, 2);
457 assert_eq!(result.failed_repos, 0);
458
459 let repo_ids: Vec<_> = result.repos.iter().map(|r| &r.repo_id).collect();
461 assert_ne!(repo_ids[0].name, repo_ids[1].name);
462 }
463
464 #[test]
465 fn test_batch_embedder_handles_failure() {
466 let temp_dir = TempDir::new().unwrap();
467 create_test_file(temp_dir.path(), "lib.rs", "pub fn ok() {}\n");
468
469 let repos = vec![
470 BatchRepoConfig::new(temp_dir.path()).with_name("good-repo"),
471 BatchRepoConfig::new("/nonexistent/path").with_name("bad-repo"),
472 ];
473
474 let embedder = BatchEmbedder::new(EmbedSettings::default());
475 let result = embedder.embed_batch(&repos).unwrap();
476
477 assert_eq!(result.successful_repos, 1);
478 assert_eq!(result.failed_repos, 1);
479 assert!(result.has_failures());
480
481 let failures: Vec<_> = result.failures().collect();
482 assert_eq!(failures.len(), 1);
483 assert!(failures[0].0.to_str().unwrap().contains("nonexistent"));
484 }
485
486 #[test]
487 fn test_batch_result_all_chunks() {
488 let temp_dir1 = TempDir::new().unwrap();
489 let temp_dir2 = TempDir::new().unwrap();
490
491 create_test_file(temp_dir1.path(), "a.rs", "pub fn func1() {}\npub fn func2() {}\n");
492 create_test_file(temp_dir2.path(), "b.rs", "pub fn func3() {}\n");
493
494 let repos = vec![
495 BatchRepoConfig::new(temp_dir1.path()).with_name("repo1"),
496 BatchRepoConfig::new(temp_dir2.path()).with_name("repo2"),
497 ];
498
499 let embedder = BatchEmbedder::new(EmbedSettings::default());
500 let result = embedder.embed_batch(&repos).unwrap();
501
502 let all_chunks: Vec<_> = result.all_chunks().collect();
504 assert_eq!(all_chunks.len(), result.total_chunks);
505 }
506
507 #[test]
508 fn test_batch_result_into_chunks() {
509 let temp_dir = TempDir::new().unwrap();
510 create_test_file(temp_dir.path(), "lib.rs", "pub fn test() {}\n");
511
512 let repos = vec![BatchRepoConfig::new(temp_dir.path())];
513
514 let embedder = BatchEmbedder::new(EmbedSettings::default());
515 let result = embedder.embed_batch(&repos).unwrap();
516
517 let total = result.total_chunks;
518 let chunks = result.into_chunks();
519 assert_eq!(chunks.len(), total);
520 }
521
522 #[test]
523 fn test_per_repo_pattern_override() {
524 let temp_dir = TempDir::new().unwrap();
525 create_test_file(temp_dir.path(), "src/lib.rs", "pub fn included() {}\n");
526 create_test_file(temp_dir.path(), "tests/test.rs", "pub fn excluded() {}\n");
527
528 let repos = vec![BatchRepoConfig::new(temp_dir.path())
530 .with_name("filtered-repo")
531 .with_include_patterns(vec!["src/**/*.rs".to_owned()])];
532
533 let embedder = BatchEmbedder::new(EmbedSettings::default());
534 let result = embedder.embed_batch(&repos).unwrap();
535
536 for chunk in result.all_chunks() {
538 assert!(
539 chunk.source.file.starts_with("src/"),
540 "Expected src/ prefix, got: {}",
541 chunk.source.file
542 );
543 }
544 }
545}