Skip to main content

research_master/utils/
streaming.rs

1//! Async streaming utilities for large result sets.
2//!
3//! This module provides streaming iterators for processing large
4//! search results incrementally without loading everything into memory.
5
6use crate::models::{Paper, SearchQuery};
7use crate::sources::Source;
8use async_stream::stream;
9use futures_util::stream::{Stream, StreamExt};
10use std::pin::Pin;
11use std::task::{Context, Poll};
12use tokio::sync::mpsc;
13use tokio::time::{sleep, Duration};
14use tracing::warn;
15
16/// Create a stream that yields papers one at a time from paginated search results.
17///
18/// This allows processing large result sets without loading everything
19/// into memory at once. The stream automatically handles pagination
20/// and rate limiting.
21pub fn paper_stream<T: Source + Clone + 'static>(
22    source: T,
23    query: SearchQuery,
24    page_size: usize,
25) -> impl Stream<Item = Paper> + Send {
26    stream! {
27        let rate_limit_delay = Duration::from_millis(200);
28        loop {
29            // Clone query for this page
30            let mut page_query = query.clone();
31            page_query.max_results = page_size;
32
33            match source.search(&page_query).await {
34                Ok(response) => {
35                    let papers = response.papers;
36                    let count = papers.len();
37
38                    if count == 0 {
39                        // No more papers
40                        break;
41                    }
42
43                    // Yield each paper
44                    for paper in papers {
45                        yield paper;
46                    }
47
48                    // Apply rate limiting
49                    if rate_limit_delay > Duration::ZERO {
50                        sleep(rate_limit_delay).await;
51                    }
52                }
53                Err(e) => {
54                    warn!("Error fetching papers: {}", e);
55                    break;
56                }
57            }
58        }
59    }
60}
61
62/// Create a stream that filters papers by year range.
63pub fn filter_by_year<S: Stream<Item = Paper> + Send + 'static>(
64    stream: S,
65    min_year: Option<i32>,
66    max_year: Option<i32>,
67) -> FilterByYearStream<S> {
68    FilterByYearStream::new(stream, min_year, max_year)
69}
70
71/// Collect all papers from a stream into a Vec.
72pub async fn collect_papers<S: Stream<Item = Paper> + Send + Unpin>(mut stream: S) -> Vec<Paper> {
73    let mut papers = Vec::new();
74    while let Some(paper) = stream.next().await {
75        papers.push(paper);
76    }
77    papers
78}
79
80/// A channel-based concurrent stream for parallel source searches.
81///
82/// This allows searching multiple sources concurrently and
83/// yielding results as they arrive.
84#[allow(dead_code)]
85pub struct ConcurrentPaperStream {
86    receiver: mpsc::Receiver<Paper>,
87    pending: usize,
88}
89
90impl ConcurrentPaperStream {
91    /// Create a new concurrent stream from a list of sources
92    ///
93    /// Searches all sources concurrently and yields papers in the
94    /// order they complete.
95    pub async fn from_sources<S: Source + Clone + 'static>(
96        sources: Vec<S>,
97        query: &SearchQuery,
98        max_concurrent: usize,
99    ) -> Self {
100        let (sender, receiver) = mpsc::channel(100);
101        let semaphore = std::sync::Arc::new(tokio::sync::Semaphore::new(max_concurrent));
102        let sources_len = sources.len();
103
104        for source in sources {
105            let query = query.clone();
106            let sender = sender.clone();
107            let permit = semaphore.clone().acquire_owned().await.unwrap();
108            let source = source.clone();
109
110            tokio::spawn(async move {
111                // permit is automatically dropped when this async block ends
112                match source.search(&query).await {
113                    Ok(response) => {
114                        for paper in response.papers {
115                            if sender.send(paper).await.is_err() {
116                                break; // Receiver dropped
117                            }
118                        }
119                    }
120                    Err(e) => {
121                        warn!("Source search failed: {}", e);
122                    }
123                }
124                drop(permit);
125            });
126        }
127
128        // Drop sender to signal completion when all tasks finish
129        drop(sender);
130
131        Self {
132            receiver,
133            pending: sources_len,
134        }
135    }
136
137    /// Get the next paper from any source
138    pub async fn next(&mut self) -> Option<Paper> {
139        self.receiver.recv().await
140    }
141
142    /// Check if more results are coming
143    pub fn is_done(&self) -> bool {
144        self.receiver.is_closed()
145    }
146}
147
148/// Stream that limits the number of items
149#[derive(Debug)]
150pub struct TakeStream<S: Stream<Item = Paper>> {
151    stream: S,
152    remaining: usize,
153}
154
155impl<S: Stream<Item = Paper> + Unpin> TakeStream<S> {
156    /// Create a new take stream
157    pub fn new(stream: S, limit: usize) -> Self {
158        Self {
159            stream,
160            remaining: limit,
161        }
162    }
163}
164
165impl<S: Stream<Item = Paper> + Unpin> Stream for TakeStream<S> {
166    type Item = Paper;
167
168    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
169        if self.remaining == 0 {
170            return Poll::Ready(None);
171        }
172
173        match Pin::new(&mut self.stream).poll_next(cx) {
174            Poll::Ready(Some(item)) => {
175                self.remaining -= 1;
176                Poll::Ready(Some(item))
177            }
178            Poll::Ready(None) => Poll::Ready(None),
179            Poll::Pending => Poll::Pending,
180        }
181    }
182}
183
184/// Stream that skips items
185#[derive(Debug)]
186pub struct SkipStream<S: Stream<Item = Paper>> {
187    stream: S,
188    to_skip: usize,
189}
190
191impl<S: Stream<Item = Paper>> SkipStream<S> {
192    /// Create a new skip stream
193    pub fn new(stream: S, n: usize) -> Self {
194        Self { stream, to_skip: n }
195    }
196}
197
198impl<S: Stream<Item = Paper> + Unpin> Stream for SkipStream<S> {
199    type Item = Paper;
200
201    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
202        loop {
203            match Pin::new(&mut self.stream).poll_next(cx) {
204                Poll::Ready(Some(item)) => {
205                    if self.to_skip > 0 {
206                        self.to_skip -= 1;
207                        continue;
208                    }
209                    return Poll::Ready(Some(item));
210                }
211                Poll::Ready(None) => return Poll::Ready(None),
212                Poll::Pending => return Poll::Pending,
213            }
214        }
215    }
216}
217
218/// Stream filter for year range
219#[derive(Debug)]
220pub struct FilterByYearStream<S: Stream<Item = Paper>> {
221    stream: S,
222    min_year: Option<i32>,
223    max_year: Option<i32>,
224}
225
226impl<S: Stream<Item = Paper>> FilterByYearStream<S> {
227    /// Create a new year filter stream
228    pub fn new(stream: S, min_year: Option<i32>, max_year: Option<i32>) -> Self {
229        Self {
230            stream,
231            min_year,
232            max_year,
233        }
234    }
235}
236
237impl<S: Stream<Item = Paper> + Unpin> Stream for FilterByYearStream<S> {
238    type Item = Paper;
239
240    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
241        let this = self.get_mut();
242        loop {
243            match Pin::new(&mut this.stream).poll_next(cx) {
244                Poll::Ready(Some(paper)) => {
245                    // Try to extract year from published_date (format: "YYYY-MM-DD" or "YYYY")
246                    if let Some(year) = extract_year(&paper.published_date) {
247                        if let Some(min) = this.min_year {
248                            if year < min {
249                                continue;
250                            }
251                        }
252                        if let Some(max) = this.max_year {
253                            if year > max {
254                                continue;
255                            }
256                        }
257                    }
258                    return Poll::Ready(Some(paper));
259                }
260                Poll::Ready(None) => return Poll::Ready(None),
261                Poll::Pending => return Poll::Pending,
262            }
263        }
264    }
265}
266
267/// Extract year from published_date string
268fn extract_year(published_date: &Option<String>) -> Option<i32> {
269    published_date.as_ref().and_then(|date| {
270        // Split by both '-' and '/' and take first non-empty part
271        date.split(['-', '/'])
272            .next()
273            .filter(|s| !s.is_empty())
274            .and_then(|y| y.parse::<i32>().ok())
275    })
276}
277
278#[cfg(test)]
279mod tests {
280    use super::*;
281    use crate::models::{Paper, SearchResponse, SourceType};
282    use crate::sources::mock::MockSource;
283    use futures_util::StreamExt;
284
285    fn make_paper(paper_id: &str, title: &str, source_type: SourceType) -> Paper {
286        Paper::new(
287            paper_id.to_string(),
288            title.to_string(),
289            format!("http://example.com/{}", paper_id),
290            source_type,
291        )
292    }
293
294    #[tokio::test]
295    async fn test_paper_stream_basic() {
296        let mock = MockSource::new();
297        mock.set_search_response(SearchResponse::new(
298            vec![
299                make_paper("1", "Paper 1", SourceType::Arxiv),
300                make_paper("2", "Paper 2", SourceType::Arxiv),
301                make_paper("3", "Paper 3", SourceType::Arxiv),
302            ],
303            "Mock Source",
304            "test",
305        ));
306
307        let stream = paper_stream(mock, SearchQuery::new("test"), 10);
308        let mut stream = Box::pin(stream);
309        let mut papers = Vec::new();
310
311        while let Some(paper) = stream.next().await {
312            papers.push(paper);
313        }
314
315        assert_eq!(papers.len(), 3);
316        assert_eq!(papers[0].paper_id, "1");
317        assert_eq!(papers[1].paper_id, "2");
318        assert_eq!(papers[2].paper_id, "3");
319    }
320
321    #[tokio::test]
322    async fn test_paper_stream_empty() {
323        let mock = MockSource::new();
324        mock.set_search_response(SearchResponse::new(Vec::new(), "Mock Source", "test"));
325
326        let stream = paper_stream(mock, SearchQuery::new("test"), 10);
327        let mut stream = Box::pin(stream);
328        let mut papers = Vec::new();
329
330        while let Some(paper) = stream.next().await {
331            papers.push(paper);
332        }
333
334        assert!(papers.is_empty());
335    }
336
337    #[test]
338    fn test_extract_year() {
339        assert_eq!(extract_year(&Some("2023-05-15".to_string())), Some(2023));
340        assert_eq!(extract_year(&Some("2023".to_string())), Some(2023));
341        assert_eq!(extract_year(&Some("2023/05/15".to_string())), Some(2023));
342        assert_eq!(extract_year(&None), None);
343        assert_eq!(extract_year(&Some("invalid".to_string())), None);
344    }
345}