1use std::sync::Arc;
7use std::sync::atomic::{AtomicUsize, Ordering};
8
9use chrono::{DateTime, Utc};
10use futures::stream::{self, StreamExt};
11use serde::{Deserialize, Serialize};
12use tokio::sync::Semaphore;
13
14#[derive(Debug, Clone, Serialize, Deserialize)]
39#[serde(rename_all = "camelCase")]
40pub struct UrlWithLastmod {
41 pub url: String,
43 pub lastmod: Option<DateTime<Utc>>,
45}
46
47impl UrlWithLastmod {
48 #[must_use]
50 pub const fn new(url: String) -> Self {
51 Self { url, lastmod: None }
52 }
53
54 #[must_use]
56 pub const fn with_lastmod(mut self, lastmod: Option<DateTime<Utc>>) -> Self {
57 self.lastmod = lastmod;
58 self
59 }
60}
61
62#[derive(Debug, Default)]
77pub struct ScrapeResults {
78 pub successful: Vec<PageCacheEntry>,
80 pub failed: Vec<FailedPage>,
82}
83
84impl ScrapeResults {
85 #[must_use]
87 pub fn new() -> Self {
88 Self::default()
89 }
90
91 #[must_use]
93 pub fn total(&self) -> usize {
94 self.successful.len() + self.failed.len()
95 }
96
97 #[must_use]
99 #[allow(clippy::cast_precision_loss)] pub fn success_rate(&self) -> f64 {
101 let total = self.total();
102 if total == 0 {
103 0.0
104 } else {
105 (self.successful.len() as f64 / total as f64) * 100.0
106 }
107 }
108}
109
110pub type ProgressCallback = Arc<dyn Fn(usize, usize) + Send + Sync>;
114
115pub struct GenerateOrchestrator<S: Scraper> {
146 scraper: S,
147 concurrency: usize,
148 min_concurrency: usize,
149 progress_callback: Option<ProgressCallback>,
150}
151
152#[async_trait::async_trait]
157pub trait Scraper: Send + Sync {
158 async fn scrape(&self, url: &str) -> Result<ScrapeResult, ScrapeError>;
160}
161
162#[derive(Debug, Clone)]
164pub struct ScrapeError {
165 pub url: String,
167 pub message: String,
169 pub is_rate_limited: bool,
171}
172
173impl ScrapeError {
174 #[must_use]
176 pub const fn new(url: String, message: String) -> Self {
177 Self {
178 url,
179 message,
180 is_rate_limited: false,
181 }
182 }
183
184 #[must_use]
186 pub const fn with_rate_limit(mut self, is_rate_limited: bool) -> Self {
187 self.is_rate_limited = is_rate_limited;
188 self
189 }
190}
191
192impl std::fmt::Display for ScrapeError {
193 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
194 write!(f, "scrape failed for {}: {}", self.url, self.message)
195 }
196}
197
198impl std::error::Error for ScrapeError {}
199
200#[derive(Debug, Clone, Serialize, Deserialize)]
208#[serde(rename_all = "camelCase")]
209pub struct ScrapeResult {
210 pub markdown: String,
212 #[serde(default)]
214 pub title: Option<String>,
215 pub url: String,
217}
218
219#[derive(Debug, Clone, Serialize, Deserialize)]
223#[serde(rename_all = "camelCase")]
224pub struct PageCacheEntry {
225 pub url: String,
227 pub title: Option<String>,
229 pub fetched_at: DateTime<Utc>,
231 pub sitemap_lastmod: Option<DateTime<Utc>>,
233 pub markdown: String,
235 pub line_count: usize,
237}
238
239impl PageCacheEntry {
240 #[must_use]
242 pub fn from_scrape_result(result: ScrapeResult, lastmod: Option<DateTime<Utc>>) -> Self {
243 let line_count = result.markdown.lines().count();
244 Self {
245 url: result.url,
246 title: result.title,
247 fetched_at: Utc::now(),
248 sitemap_lastmod: lastmod,
249 markdown: result.markdown,
250 line_count,
251 }
252 }
253}
254
255#[derive(Debug, Clone, Serialize, Deserialize)]
259#[serde(rename_all = "camelCase")]
260pub struct FailedPage {
261 pub url: String,
263 pub error: String,
265 pub attempts: u32,
267 pub last_attempt: DateTime<Utc>,
269}
270
271impl FailedPage {
272 #[must_use]
274 pub fn new(url: String, error: String) -> Self {
275 Self {
276 url,
277 error,
278 attempts: 1,
279 last_attempt: Utc::now(),
280 }
281 }
282}
283
284impl<S: Scraper> GenerateOrchestrator<S> {
289 const DEFAULT_CONCURRENCY: usize = 5;
291
292 const MIN_CONCURRENCY: usize = 1;
294
295 #[must_use]
302 pub fn new(scraper: S, concurrency: usize) -> Self {
303 Self {
304 scraper,
305 concurrency: concurrency.clamp(1, 50),
306 min_concurrency: Self::MIN_CONCURRENCY,
307 progress_callback: None,
308 }
309 }
310
311 #[must_use]
313 pub fn with_default_concurrency(scraper: S) -> Self {
314 Self::new(scraper, Self::DEFAULT_CONCURRENCY)
315 }
316
317 #[must_use]
321 pub fn with_progress<F>(mut self, callback: F) -> Self
322 where
323 F: Fn(usize, usize) + Send + Sync + 'static,
324 {
325 self.progress_callback = Some(Arc::new(callback));
326 self
327 }
328
329 pub async fn scrape_all(&self, urls: &[UrlWithLastmod]) -> ScrapeResults {
338 if urls.is_empty() {
339 return ScrapeResults::default();
340 }
341
342 let total = urls.len();
343 let completed = Arc::new(AtomicUsize::new(0));
344 let semaphore = Arc::new(Semaphore::new(self.concurrency));
345
346 let results: Vec<Result<PageCacheEntry, FailedPage>> = stream::iter(urls)
348 .map(|url_info| {
349 let semaphore = Arc::clone(&semaphore);
350 let completed = Arc::clone(&completed);
351 let progress = self.progress_callback.clone();
352
353 async move {
354 let _permit = semaphore.acquire().await;
356
357 let result = self.scrape_one(&url_info.url, url_info.lastmod).await;
359
360 let done = completed.fetch_add(1, Ordering::SeqCst) + 1;
362 if let Some(cb) = progress {
363 cb(done, total);
364 }
365
366 result
367 }
368 })
369 .buffer_unordered(self.concurrency)
370 .collect()
371 .await;
372
373 let mut scrape_results = ScrapeResults::default();
375 for result in results {
376 match result {
377 Ok(entry) => scrape_results.successful.push(entry),
378 Err(failed) => scrape_results.failed.push(failed),
379 }
380 }
381
382 scrape_results
383 }
384
385 async fn scrape_one(
387 &self,
388 url: &str,
389 lastmod: Option<DateTime<Utc>>,
390 ) -> Result<PageCacheEntry, FailedPage> {
391 match self.scraper.scrape(url).await {
392 Ok(result) => Ok(PageCacheEntry::from_scrape_result(result, lastmod)),
393 Err(e) => Err(FailedPage::new(url.to_string(), e.message)),
394 }
395 }
396
397 #[must_use]
399 pub const fn concurrency(&self) -> usize {
400 self.concurrency
401 }
402
403 #[must_use]
405 pub const fn min_concurrency(&self) -> usize {
406 self.min_concurrency
407 }
408}
409
410#[cfg(test)]
415mod tests {
416 use super::*;
417 use std::sync::Mutex;
418 use std::time::Duration;
419
420 struct MockScraper {
425 responses: Mutex<Vec<Result<ScrapeResult, ScrapeError>>>,
426 }
427
428 impl MockScraper {
429 fn new() -> Self {
430 Self {
431 responses: Mutex::new(Vec::new()),
432 }
433 }
434
435 fn with_success(self, url: &str, markdown: &str) -> Self {
436 let mut responses = self.responses.lock().expect("lock poisoned");
437 responses.push(Ok(ScrapeResult {
438 markdown: markdown.to_string(),
439 title: Some("Test Page".to_string()),
440 url: url.to_string(),
441 }));
442 drop(responses);
443 self
444 }
445
446 fn with_failure(self, url: &str, error: &str) -> Self {
447 let mut responses = self.responses.lock().expect("lock poisoned");
448 responses.push(Err(ScrapeError::new(url.to_string(), error.to_string())));
449 drop(responses);
450 self
451 }
452 }
453
454 #[async_trait::async_trait]
455 impl Scraper for MockScraper {
456 async fn scrape(&self, url: &str) -> Result<ScrapeResult, ScrapeError> {
457 tokio::time::sleep(Duration::from_millis(10)).await;
459
460 let mut responses = self.responses.lock().expect("lock poisoned");
461 if responses.is_empty() {
462 Ok(ScrapeResult {
464 markdown: format!("# Content from {url}"),
465 title: Some("Default".to_string()),
466 url: url.to_string(),
467 })
468 } else {
469 responses.remove(0)
470 }
471 }
472 }
473
474 #[test]
479 fn test_url_with_lastmod_new() {
480 let url = UrlWithLastmod::new("https://example.com/page".to_string());
481 assert_eq!(url.url, "https://example.com/page");
482 assert!(url.lastmod.is_none());
483 }
484
485 #[test]
486 fn test_url_with_lastmod_builder() {
487 let now = Utc::now();
488 let url =
489 UrlWithLastmod::new("https://example.com/page".to_string()).with_lastmod(Some(now));
490 assert_eq!(url.lastmod, Some(now));
491 }
492
493 #[test]
494 fn test_url_with_lastmod_serialization() {
495 let url = UrlWithLastmod::new("https://example.com/page".to_string());
496 let json = serde_json::to_string(&url).expect("serialize");
497 assert!(json.contains("\"url\":\"https://example.com/page\""));
498
499 let roundtrip: UrlWithLastmod = serde_json::from_str(&json).expect("deserialize");
500 assert_eq!(roundtrip.url, url.url);
501 }
502
503 #[test]
508 fn test_scrape_results_default() {
509 let results = ScrapeResults::default();
510 assert!(results.successful.is_empty());
511 assert!(results.failed.is_empty());
512 }
513
514 #[test]
515 fn test_scrape_results_total() {
516 let mut results = ScrapeResults::new();
517 assert_eq!(results.total(), 0);
518
519 results.successful.push(PageCacheEntry::from_scrape_result(
520 ScrapeResult {
521 markdown: "test".to_string(),
522 title: None,
523 url: "https://a.com".to_string(),
524 },
525 None,
526 ));
527 assert_eq!(results.total(), 1);
528
529 results.failed.push(FailedPage::new(
530 "https://b.com".to_string(),
531 "error".to_string(),
532 ));
533 assert_eq!(results.total(), 2);
534 }
535
536 #[test]
537 fn test_scrape_results_success_rate() {
538 let results = ScrapeResults::default();
539 assert!((results.success_rate() - 0.0).abs() < f64::EPSILON);
540
541 let mut results = ScrapeResults::new();
542 results.successful.push(PageCacheEntry::from_scrape_result(
543 ScrapeResult {
544 markdown: "test".to_string(),
545 title: None,
546 url: "https://a.com".to_string(),
547 },
548 None,
549 ));
550 assert!((results.success_rate() - 100.0).abs() < f64::EPSILON);
551
552 results.failed.push(FailedPage::new(
553 "https://b.com".to_string(),
554 "error".to_string(),
555 ));
556 assert!((results.success_rate() - 50.0).abs() < f64::EPSILON);
557 }
558
559 #[test]
564 fn test_page_cache_entry_from_scrape_result() {
565 let result = ScrapeResult {
566 markdown: "# Hello\n\nWorld".to_string(),
567 title: Some("Hello".to_string()),
568 url: "https://example.com/page".to_string(),
569 };
570
571 let entry = PageCacheEntry::from_scrape_result(result, None);
572
573 assert_eq!(entry.url, "https://example.com/page");
574 assert_eq!(entry.title, Some("Hello".to_string()));
575 assert_eq!(entry.line_count, 3);
576 assert!(entry.sitemap_lastmod.is_none());
577 }
578
579 #[test]
580 fn test_page_cache_entry_with_lastmod() {
581 let lastmod = Utc::now();
582 let result = ScrapeResult {
583 markdown: "content".to_string(),
584 title: None,
585 url: "https://example.com".to_string(),
586 };
587
588 let entry = PageCacheEntry::from_scrape_result(result, Some(lastmod));
589
590 assert_eq!(entry.sitemap_lastmod, Some(lastmod));
591 }
592
593 #[test]
598 fn test_failed_page_new() {
599 let failed = FailedPage::new("https://example.com".to_string(), "timeout".to_string());
600 assert_eq!(failed.url, "https://example.com");
601 assert_eq!(failed.error, "timeout");
602 assert_eq!(failed.attempts, 1);
603 }
604
605 #[test]
610 fn test_scrape_error_new() {
611 let err = ScrapeError::new("https://example.com".to_string(), "timeout".to_string());
612 assert_eq!(err.url, "https://example.com");
613 assert_eq!(err.message, "timeout");
614 assert!(!err.is_rate_limited);
615 }
616
617 #[test]
618 fn test_scrape_error_rate_limit() {
619 let err = ScrapeError::new("https://example.com".to_string(), "429".to_string())
620 .with_rate_limit(true);
621 assert!(err.is_rate_limited);
622 }
623
624 #[test]
625 fn test_scrape_error_display() {
626 let err = ScrapeError::new("https://example.com".to_string(), "timeout".to_string());
627 assert_eq!(
628 format!("{err}"),
629 "scrape failed for https://example.com: timeout"
630 );
631 }
632
633 #[test]
638 fn test_orchestrator_creation() {
639 let scraper = MockScraper::new();
640 let orchestrator = GenerateOrchestrator::new(scraper, 5);
641 assert_eq!(orchestrator.concurrency(), 5);
642 }
643
644 #[test]
645 fn test_orchestrator_default_concurrency() {
646 let scraper = MockScraper::new();
647 let orchestrator = GenerateOrchestrator::with_default_concurrency(scraper);
648 assert_eq!(orchestrator.concurrency(), 5);
649 }
650
651 #[test]
652 fn test_orchestrator_concurrency_clamped() {
653 let scraper1 = MockScraper::new();
654 let orchestrator1 = GenerateOrchestrator::new(scraper1, 0);
655 assert_eq!(orchestrator1.concurrency(), 1);
656
657 let scraper2 = MockScraper::new();
658 let orchestrator2 = GenerateOrchestrator::new(scraper2, 100);
659 assert_eq!(orchestrator2.concurrency(), 50);
660 }
661
662 #[test]
663 fn test_orchestrator_min_concurrency() {
664 let scraper = MockScraper::new();
665 let orchestrator = GenerateOrchestrator::new(scraper, 5);
666 assert_eq!(orchestrator.min_concurrency(), 1);
667 }
668
669 #[tokio::test]
670 async fn test_orchestrator_empty_urls() {
671 let scraper = MockScraper::new();
672 let orchestrator = GenerateOrchestrator::new(scraper, 5);
673
674 let results = orchestrator.scrape_all(&[]).await;
675
676 assert!(results.successful.is_empty());
677 assert!(results.failed.is_empty());
678 }
679
680 #[tokio::test]
681 async fn test_orchestrator_single_success() {
682 let scraper = MockScraper::new().with_success("https://example.com/page", "# Content");
683 let orchestrator = GenerateOrchestrator::new(scraper, 5);
684
685 let urls = vec![UrlWithLastmod::new("https://example.com/page".to_string())];
686 let results = orchestrator.scrape_all(&urls).await;
687
688 assert_eq!(results.successful.len(), 1);
689 assert!(results.failed.is_empty());
690 assert_eq!(results.successful[0].url, "https://example.com/page");
691 }
692
693 #[tokio::test]
694 async fn test_orchestrator_single_failure() {
695 let scraper =
696 MockScraper::new().with_failure("https://example.com/page", "connection refused");
697 let orchestrator = GenerateOrchestrator::new(scraper, 5);
698
699 let urls = vec![UrlWithLastmod::new("https://example.com/page".to_string())];
700 let results = orchestrator.scrape_all(&urls).await;
701
702 assert!(results.successful.is_empty());
703 assert_eq!(results.failed.len(), 1);
704 assert_eq!(results.failed[0].url, "https://example.com/page");
705 assert_eq!(results.failed[0].error, "connection refused");
706 }
707
708 #[tokio::test]
709 async fn test_orchestrator_mixed_results() {
710 let scraper = MockScraper::new()
711 .with_success("https://example.com/a", "# A")
712 .with_failure("https://example.com/b", "timeout")
713 .with_success("https://example.com/c", "# C");
714 let orchestrator = GenerateOrchestrator::new(scraper, 5);
715
716 let urls = vec![
717 UrlWithLastmod::new("https://example.com/a".to_string()),
718 UrlWithLastmod::new("https://example.com/b".to_string()),
719 UrlWithLastmod::new("https://example.com/c".to_string()),
720 ];
721 let results = orchestrator.scrape_all(&urls).await;
722
723 assert_eq!(results.successful.len(), 2);
724 assert_eq!(results.failed.len(), 1);
725 }
726
727 #[tokio::test]
728 async fn test_orchestrator_progress_callback() {
729 let progress = Arc::new(Mutex::new(Vec::new()));
730 let progress_clone = Arc::clone(&progress);
731
732 let scraper = MockScraper::new();
733 let orchestrator =
734 GenerateOrchestrator::new(scraper, 5).with_progress(move |completed, total| {
735 progress_clone
736 .lock()
737 .expect("lock")
738 .push((completed, total));
739 });
740
741 let urls = vec![
742 UrlWithLastmod::new("https://example.com/a".to_string()),
743 UrlWithLastmod::new("https://example.com/b".to_string()),
744 UrlWithLastmod::new("https://example.com/c".to_string()),
745 ];
746 orchestrator.scrape_all(&urls).await;
747
748 let calls = progress.lock().expect("lock");
749 assert_eq!(calls.len(), 3);
750 for (_, total) in calls.iter() {
752 assert_eq!(*total, 3);
753 }
754 let mut completed: Vec<_> = calls.iter().map(|(c, _)| *c).collect();
756 drop(calls); completed.sort_unstable();
758 assert_eq!(completed, vec![1, 2, 3]);
759 }
760
761 #[tokio::test]
762 async fn test_orchestrator_preserves_lastmod() {
763 let lastmod = Utc::now();
764 let scraper = MockScraper::new().with_success("https://example.com/page", "# Content");
765 let orchestrator = GenerateOrchestrator::new(scraper, 5);
766
767 let urls = vec![
768 UrlWithLastmod::new("https://example.com/page".to_string()).with_lastmod(Some(lastmod)),
769 ];
770 let results = orchestrator.scrape_all(&urls).await;
771
772 assert_eq!(results.successful.len(), 1);
773 assert_eq!(results.successful[0].sitemap_lastmod, Some(lastmod));
774 }
775
776 #[tokio::test]
777 async fn test_orchestrator_respects_concurrency() {
778 use std::sync::atomic::AtomicUsize;
779 use std::sync::atomic::Ordering;
780
781 struct ConcurrencyTracker {
783 current: AtomicUsize,
784 max_seen: AtomicUsize,
785 }
786
787 impl ConcurrencyTracker {
788 fn new() -> Self {
789 Self {
790 current: AtomicUsize::new(0),
791 max_seen: AtomicUsize::new(0),
792 }
793 }
794 }
795
796 #[async_trait::async_trait]
797 impl Scraper for Arc<ConcurrencyTracker> {
798 async fn scrape(&self, url: &str) -> Result<ScrapeResult, ScrapeError> {
799 let current = self.current.fetch_add(1, Ordering::SeqCst) + 1;
801
802 let mut max = self.max_seen.load(Ordering::SeqCst);
804 while current > max {
805 match self.max_seen.compare_exchange_weak(
806 max,
807 current,
808 Ordering::SeqCst,
809 Ordering::SeqCst,
810 ) {
811 Ok(_) => break,
812 Err(actual) => max = actual,
813 }
814 }
815
816 tokio::time::sleep(Duration::from_millis(50)).await;
818
819 self.current.fetch_sub(1, Ordering::SeqCst);
821
822 Ok(ScrapeResult {
823 markdown: "content".to_string(),
824 title: None,
825 url: url.to_string(),
826 })
827 }
828 }
829
830 let tracker = Arc::new(ConcurrencyTracker::new());
831 let orchestrator = GenerateOrchestrator::new(Arc::clone(&tracker), 3);
832
833 let urls: Vec<_> = (0..10)
835 .map(|i| UrlWithLastmod::new(format!("https://example.com/page{i}")))
836 .collect();
837
838 orchestrator.scrape_all(&urls).await;
839
840 let max_seen = tracker.max_seen.load(Ordering::SeqCst);
841 assert!(
842 max_seen <= 3,
843 "Max concurrent was {max_seen}, should be <= 3"
844 );
845 }
846}