1use std::path::{Path, PathBuf};
27use std::sync::atomic::{AtomicUsize, Ordering};
28use std::time::{Duration, Instant};
29
30use rayon::prelude::*;
31use serde::{Deserialize, Serialize};
32
33use super::chunker::EmbedChunker;
34use super::error::EmbedError;
35use super::progress::{ProgressReporter, QuietProgress};
36use super::types::{EmbedChunk, EmbedSettings, RepoIdentifier};
37use super::ResourceLimits;
38
39#[derive(Debug, Clone, Serialize, Deserialize)]
41pub struct BatchRepoConfig {
42 pub path: PathBuf,
44
45 #[serde(skip_serializing_if = "Option::is_none")]
47 pub namespace: Option<String>,
48
49 #[serde(skip_serializing_if = "Option::is_none")]
51 pub name: Option<String>,
52
53 #[serde(skip_serializing_if = "Option::is_none")]
55 pub version: Option<String>,
56
57 #[serde(skip_serializing_if = "Option::is_none")]
59 pub branch: Option<String>,
60
61 #[serde(default, skip_serializing_if = "Vec::is_empty")]
63 pub include_patterns: Vec<String>,
64
65 #[serde(default, skip_serializing_if = "Vec::is_empty")]
67 pub exclude_patterns: Vec<String>,
68}
69
70impl BatchRepoConfig {
71 pub fn new(path: impl Into<PathBuf>) -> Self {
73 Self {
74 path: path.into(),
75 namespace: None,
76 name: None,
77 version: None,
78 branch: None,
79 include_patterns: Vec::new(),
80 exclude_patterns: Vec::new(),
81 }
82 }
83
84 pub fn with_namespace(mut self, namespace: impl Into<String>) -> Self {
86 self.namespace = Some(namespace.into());
87 self
88 }
89
90 pub fn with_name(mut self, name: impl Into<String>) -> Self {
92 self.name = Some(name.into());
93 self
94 }
95
96 pub fn with_version(mut self, version: impl Into<String>) -> Self {
98 self.version = Some(version.into());
99 self
100 }
101
102 pub fn with_branch(mut self, branch: impl Into<String>) -> Self {
104 self.branch = Some(branch.into());
105 self
106 }
107
108 pub fn with_include_patterns(mut self, patterns: Vec<String>) -> Self {
110 self.include_patterns = patterns;
111 self
112 }
113
114 pub fn with_exclude_patterns(mut self, patterns: Vec<String>) -> Self {
116 self.exclude_patterns = patterns;
117 self
118 }
119
120 pub fn to_repo_id(&self) -> RepoIdentifier {
122 let name = self
123 .name
124 .clone()
125 .or_else(|| {
126 self.path
127 .file_name()
128 .and_then(|n| n.to_str())
129 .map(String::from)
130 })
131 .unwrap_or_else(|| "unknown".to_string());
132
133 RepoIdentifier {
134 namespace: self.namespace.clone().unwrap_or_default(),
135 name,
136 version: self.version.clone(),
137 branch: self.branch.clone(),
138 commit: None, }
140 }
141}
142
143#[derive(Debug, Clone, Serialize, Deserialize)]
145pub struct BatchRepoResult {
146 pub repo_id: RepoIdentifier,
148
149 pub path: PathBuf,
151
152 pub chunks: Vec<EmbedChunk>,
154
155 pub elapsed: Duration,
157
158 #[serde(skip_serializing_if = "Option::is_none")]
160 pub error: Option<String>,
161}
162
163#[derive(Debug, Clone, Serialize, Deserialize)]
165pub struct BatchResult {
166 pub repos: Vec<BatchRepoResult>,
168
169 pub total_chunks: usize,
171
172 pub total_tokens: u64,
174
175 pub successful_repos: usize,
177
178 pub failed_repos: usize,
180
181 pub elapsed: Duration,
183}
184
185impl BatchResult {
186 pub fn all_chunks(&self) -> impl Iterator<Item = &EmbedChunk> {
188 self.repos.iter().flat_map(|r| r.chunks.iter())
189 }
190
191 pub fn into_chunks(self) -> Vec<EmbedChunk> {
193 self.repos.into_iter().flat_map(|r| r.chunks).collect()
194 }
195
196 pub fn has_failures(&self) -> bool {
198 self.failed_repos > 0
199 }
200
201 pub fn failures(&self) -> impl Iterator<Item = (&Path, &str)> {
203 self.repos
204 .iter()
205 .filter(|r| r.error.is_some())
206 .map(|r| (r.path.as_path(), r.error.as_deref().unwrap_or("unknown")))
207 }
208}
209
210pub struct BatchEmbedder {
212 settings: EmbedSettings,
213 limits: ResourceLimits,
214 max_parallel: usize,
216}
217
218impl BatchEmbedder {
219 pub fn new(settings: EmbedSettings) -> Self {
221 Self {
222 settings,
223 limits: ResourceLimits::default(),
224 max_parallel: std::thread::available_parallelism().map_or(4, |n| n.get()),
225 }
226 }
227
228 pub fn with_limits(settings: EmbedSettings, limits: ResourceLimits) -> Self {
230 Self {
231 settings,
232 limits,
233 max_parallel: std::thread::available_parallelism().map_or(4, |n| n.get()),
234 }
235 }
236
237 pub fn with_max_parallel(mut self, max: usize) -> Self {
239 self.max_parallel = max.max(1);
240 self
241 }
242
243 pub fn embed_batch(&self, repos: &[BatchRepoConfig]) -> Result<BatchResult, EmbedError> {
248 self.embed_batch_with_progress(repos, &QuietProgress)
249 }
250
251 pub fn embed_batch_with_progress(
253 &self,
254 repos: &[BatchRepoConfig],
255 progress: &dyn ProgressReporter,
256 ) -> Result<BatchResult, EmbedError> {
257 let start = Instant::now();
258
259 if repos.is_empty() {
260 return Ok(BatchResult {
261 repos: Vec::new(),
262 total_chunks: 0,
263 total_tokens: 0,
264 successful_repos: 0,
265 failed_repos: 0,
266 elapsed: start.elapsed(),
267 });
268 }
269
270 progress.set_phase(&format!("Processing {} repositories...", repos.len()));
271 progress.set_total(repos.len());
272
273 let processed = AtomicUsize::new(0);
274
275 let pool = rayon::ThreadPoolBuilder::new()
277 .num_threads(self.max_parallel)
278 .build()
279 .map_err(|e| EmbedError::SerializationError {
280 reason: format!("Failed to create thread pool: {}", e),
281 })?;
282
283 let results: Vec<BatchRepoResult> = pool.install(|| {
284 repos
285 .par_iter()
286 .map(|config| {
287 let result = self.process_single_repo(config);
288
289 let done = processed.fetch_add(1, Ordering::Relaxed) + 1;
290 progress.set_progress(done);
291
292 result
293 })
294 .collect()
295 });
296
297 let total_chunks: usize = results.iter().map(|r| r.chunks.len()).sum();
299 let total_tokens: u64 = results
300 .iter()
301 .flat_map(|r| r.chunks.iter())
302 .map(|c| c.tokens as u64)
303 .sum();
304 let successful_repos = results.iter().filter(|r| r.error.is_none()).count();
305 let failed_repos = results.iter().filter(|r| r.error.is_some()).count();
306
307 progress.set_phase("Batch complete");
308
309 Ok(BatchResult {
310 repos: results,
311 total_chunks,
312 total_tokens,
313 successful_repos,
314 failed_repos,
315 elapsed: start.elapsed(),
316 })
317 }
318
319 fn process_single_repo(&self, config: &BatchRepoConfig) -> BatchRepoResult {
321 let start = Instant::now();
322 let repo_id = config.to_repo_id();
323
324 let mut settings = self.settings.clone();
326 if !config.include_patterns.is_empty() {
327 settings.include_patterns = config.include_patterns.clone();
328 }
329 if !config.exclude_patterns.is_empty() {
330 settings.exclude_patterns = config.exclude_patterns.clone();
331 }
332
333 let chunker = EmbedChunker::new(settings, self.limits.clone()).with_repo_id(repo_id.clone());
335
336 let quiet = QuietProgress;
338 match chunker.chunk_repository(&config.path, &quiet) {
339 Ok(chunks) => BatchRepoResult {
340 repo_id,
341 path: config.path.clone(),
342 chunks,
343 elapsed: start.elapsed(),
344 error: None,
345 },
346 Err(e) => BatchRepoResult {
347 repo_id,
348 path: config.path.clone(),
349 chunks: Vec::new(),
350 elapsed: start.elapsed(),
351 error: Some(e.to_string()),
352 },
353 }
354 }
355}
356
357#[cfg(test)]
358mod tests {
359 use super::*;
360 use tempfile::TempDir;
361
362 fn create_test_file(dir: &Path, name: &str, content: &str) {
363 let path = dir.join(name);
364 if let Some(parent) = path.parent() {
365 std::fs::create_dir_all(parent).unwrap();
366 }
367 std::fs::write(path, content).unwrap();
368 }
369
370 #[test]
371 fn test_batch_repo_config_builder() {
372 let config = BatchRepoConfig::new("/path/to/repo")
373 .with_namespace("github.com/org")
374 .with_name("my-repo")
375 .with_version("v1.0.0")
376 .with_branch("main");
377
378 assert_eq!(config.path, PathBuf::from("/path/to/repo"));
379 assert_eq!(config.namespace.as_deref(), Some("github.com/org"));
380 assert_eq!(config.name.as_deref(), Some("my-repo"));
381 assert_eq!(config.version.as_deref(), Some("v1.0.0"));
382 assert_eq!(config.branch.as_deref(), Some("main"));
383 }
384
385 #[test]
386 fn test_batch_repo_config_to_repo_id() {
387 let config = BatchRepoConfig::new("/path/to/my-repo")
388 .with_namespace("github.com/org")
389 .with_name("custom-name");
390
391 let repo_id = config.to_repo_id();
392 assert_eq!(repo_id.namespace, "github.com/org");
393 assert_eq!(repo_id.name, "custom-name");
394 }
395
396 #[test]
397 fn test_batch_repo_config_infer_name() {
398 let config = BatchRepoConfig::new("/path/to/my-repo");
399
400 let repo_id = config.to_repo_id();
401 assert_eq!(repo_id.name, "my-repo");
402 }
403
404 #[test]
405 fn test_batch_embedder_empty_batch() {
406 let embedder = BatchEmbedder::new(EmbedSettings::default());
407 let result = embedder.embed_batch(&[]).unwrap();
408
409 assert_eq!(result.total_chunks, 0);
410 assert_eq!(result.successful_repos, 0);
411 assert_eq!(result.failed_repos, 0);
412 assert!(!result.has_failures());
413 }
414
415 #[test]
416 fn test_batch_embedder_single_repo() {
417 let temp_dir = TempDir::new().unwrap();
418 create_test_file(
419 temp_dir.path(),
420 "lib.rs",
421 "/// A test function\npub fn hello() { println!(\"hello\"); }\n",
422 );
423
424 let repos = vec![
425 BatchRepoConfig::new(temp_dir.path()).with_name("test-repo"),
426 ];
427
428 let embedder = BatchEmbedder::new(EmbedSettings::default());
429 let result = embedder.embed_batch(&repos).unwrap();
430
431 assert_eq!(result.successful_repos, 1);
432 assert_eq!(result.failed_repos, 0);
433 assert!(!result.has_failures());
434 assert!(result.total_chunks > 0);
435 }
436
437 #[test]
438 fn test_batch_embedder_multiple_repos() {
439 let temp_dir1 = TempDir::new().unwrap();
440 let temp_dir2 = TempDir::new().unwrap();
441
442 create_test_file(
443 temp_dir1.path(),
444 "a.rs",
445 "pub fn foo() { println!(\"foo\"); }\n",
446 );
447 create_test_file(
448 temp_dir2.path(),
449 "b.rs",
450 "pub fn bar() { println!(\"bar\"); }\n",
451 );
452
453 let repos = vec![
454 BatchRepoConfig::new(temp_dir1.path())
455 .with_namespace("org")
456 .with_name("repo1"),
457 BatchRepoConfig::new(temp_dir2.path())
458 .with_namespace("org")
459 .with_name("repo2"),
460 ];
461
462 let embedder = BatchEmbedder::new(EmbedSettings::default()).with_max_parallel(2);
463 let result = embedder.embed_batch(&repos).unwrap();
464
465 assert_eq!(result.successful_repos, 2);
466 assert_eq!(result.failed_repos, 0);
467
468 let repo_ids: Vec<_> = result.repos.iter().map(|r| &r.repo_id).collect();
470 assert_ne!(repo_ids[0].name, repo_ids[1].name);
471 }
472
473 #[test]
474 fn test_batch_embedder_handles_failure() {
475 let temp_dir = TempDir::new().unwrap();
476 create_test_file(temp_dir.path(), "lib.rs", "pub fn ok() {}\n");
477
478 let repos = vec![
479 BatchRepoConfig::new(temp_dir.path()).with_name("good-repo"),
480 BatchRepoConfig::new("/nonexistent/path").with_name("bad-repo"),
481 ];
482
483 let embedder = BatchEmbedder::new(EmbedSettings::default());
484 let result = embedder.embed_batch(&repos).unwrap();
485
486 assert_eq!(result.successful_repos, 1);
487 assert_eq!(result.failed_repos, 1);
488 assert!(result.has_failures());
489
490 let failures: Vec<_> = result.failures().collect();
491 assert_eq!(failures.len(), 1);
492 assert!(failures[0].0.to_str().unwrap().contains("nonexistent"));
493 }
494
495 #[test]
496 fn test_batch_result_all_chunks() {
497 let temp_dir1 = TempDir::new().unwrap();
498 let temp_dir2 = TempDir::new().unwrap();
499
500 create_test_file(
501 temp_dir1.path(),
502 "a.rs",
503 "pub fn func1() {}\npub fn func2() {}\n",
504 );
505 create_test_file(temp_dir2.path(), "b.rs", "pub fn func3() {}\n");
506
507 let repos = vec![
508 BatchRepoConfig::new(temp_dir1.path()).with_name("repo1"),
509 BatchRepoConfig::new(temp_dir2.path()).with_name("repo2"),
510 ];
511
512 let embedder = BatchEmbedder::new(EmbedSettings::default());
513 let result = embedder.embed_batch(&repos).unwrap();
514
515 let all_chunks: Vec<_> = result.all_chunks().collect();
517 assert_eq!(all_chunks.len(), result.total_chunks);
518 }
519
520 #[test]
521 fn test_batch_result_into_chunks() {
522 let temp_dir = TempDir::new().unwrap();
523 create_test_file(temp_dir.path(), "lib.rs", "pub fn test() {}\n");
524
525 let repos = vec![BatchRepoConfig::new(temp_dir.path())];
526
527 let embedder = BatchEmbedder::new(EmbedSettings::default());
528 let result = embedder.embed_batch(&repos).unwrap();
529
530 let total = result.total_chunks;
531 let chunks = result.into_chunks();
532 assert_eq!(chunks.len(), total);
533 }
534
535 #[test]
536 fn test_per_repo_pattern_override() {
537 let temp_dir = TempDir::new().unwrap();
538 create_test_file(temp_dir.path(), "src/lib.rs", "pub fn included() {}\n");
539 create_test_file(temp_dir.path(), "tests/test.rs", "pub fn excluded() {}\n");
540
541 let repos = vec![
543 BatchRepoConfig::new(temp_dir.path())
544 .with_name("filtered-repo")
545 .with_include_patterns(vec!["src/**/*.rs".to_string()]),
546 ];
547
548 let embedder = BatchEmbedder::new(EmbedSettings::default());
549 let result = embedder.embed_batch(&repos).unwrap();
550
551 for chunk in result.all_chunks() {
553 assert!(
554 chunk.source.file.starts_with("src/"),
555 "Expected src/ prefix, got: {}",
556 chunk.source.file
557 );
558 }
559 }
560}