research_master/utils/
streaming.rs1use crate::models::{Paper, SearchQuery};
7use crate::sources::Source;
8use async_stream::stream;
9use futures_util::stream::{Stream, StreamExt};
10use std::pin::Pin;
11use std::task::{Context, Poll};
12use tokio::sync::mpsc;
13use tokio::time::{sleep, Duration};
14use tracing::warn;
15
16pub fn paper_stream<T: Source + Clone + 'static>(
22 source: T,
23 query: SearchQuery,
24 page_size: usize,
25) -> impl Stream<Item = Paper> + Send {
26 stream! {
27 let rate_limit_delay = Duration::from_millis(200);
28 loop {
29 let mut page_query = query.clone();
31 page_query.max_results = page_size;
32
33 match source.search(&page_query).await {
34 Ok(response) => {
35 let papers = response.papers;
36 let count = papers.len();
37
38 if count == 0 {
39 break;
41 }
42
43 for paper in papers {
45 yield paper;
46 }
47
48 if rate_limit_delay > Duration::ZERO {
50 sleep(rate_limit_delay).await;
51 }
52 }
53 Err(e) => {
54 warn!("Error fetching papers: {}", e);
55 break;
56 }
57 }
58 }
59 }
60}
61
62pub fn filter_by_year<S: Stream<Item = Paper> + Send + 'static>(
64 stream: S,
65 min_year: Option<i32>,
66 max_year: Option<i32>,
67) -> FilterByYearStream<S> {
68 FilterByYearStream::new(stream, min_year, max_year)
69}
70
71pub async fn collect_papers<S: Stream<Item = Paper> + Send + Unpin>(mut stream: S) -> Vec<Paper> {
73 let mut papers = Vec::new();
74 while let Some(paper) = stream.next().await {
75 papers.push(paper);
76 }
77 papers
78}
79
80#[allow(dead_code)]
85pub struct ConcurrentPaperStream {
86 receiver: mpsc::Receiver<Paper>,
87 pending: usize,
88}
89
90impl ConcurrentPaperStream {
91 pub async fn from_sources<S: Source + Clone + 'static>(
96 sources: Vec<S>,
97 query: &SearchQuery,
98 max_concurrent: usize,
99 ) -> Self {
100 let (sender, receiver) = mpsc::channel(100);
101 let semaphore = std::sync::Arc::new(tokio::sync::Semaphore::new(max_concurrent));
102 let sources_len = sources.len();
103
104 for source in sources {
105 let query = query.clone();
106 let sender = sender.clone();
107 let permit = semaphore.clone().acquire_owned().await.unwrap();
108 let source = source.clone();
109
110 tokio::spawn(async move {
111 match source.search(&query).await {
113 Ok(response) => {
114 for paper in response.papers {
115 if sender.send(paper).await.is_err() {
116 break; }
118 }
119 }
120 Err(e) => {
121 warn!("Source search failed: {}", e);
122 }
123 }
124 drop(permit);
125 });
126 }
127
128 drop(sender);
130
131 Self {
132 receiver,
133 pending: sources_len,
134 }
135 }
136
137 pub async fn next(&mut self) -> Option<Paper> {
139 self.receiver.recv().await
140 }
141
142 pub fn is_done(&self) -> bool {
144 self.receiver.is_closed()
145 }
146}
147
148#[derive(Debug)]
150pub struct TakeStream<S: Stream<Item = Paper>> {
151 stream: S,
152 remaining: usize,
153}
154
155impl<S: Stream<Item = Paper> + Unpin> TakeStream<S> {
156 pub fn new(stream: S, limit: usize) -> Self {
158 Self {
159 stream,
160 remaining: limit,
161 }
162 }
163}
164
165impl<S: Stream<Item = Paper> + Unpin> Stream for TakeStream<S> {
166 type Item = Paper;
167
168 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
169 if self.remaining == 0 {
170 return Poll::Ready(None);
171 }
172
173 match Pin::new(&mut self.stream).poll_next(cx) {
174 Poll::Ready(Some(item)) => {
175 self.remaining -= 1;
176 Poll::Ready(Some(item))
177 }
178 Poll::Ready(None) => Poll::Ready(None),
179 Poll::Pending => Poll::Pending,
180 }
181 }
182}
183
184#[derive(Debug)]
186pub struct SkipStream<S: Stream<Item = Paper>> {
187 stream: S,
188 to_skip: usize,
189}
190
191impl<S: Stream<Item = Paper>> SkipStream<S> {
192 pub fn new(stream: S, n: usize) -> Self {
194 Self { stream, to_skip: n }
195 }
196}
197
198impl<S: Stream<Item = Paper> + Unpin> Stream for SkipStream<S> {
199 type Item = Paper;
200
201 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
202 loop {
203 match Pin::new(&mut self.stream).poll_next(cx) {
204 Poll::Ready(Some(item)) => {
205 if self.to_skip > 0 {
206 self.to_skip -= 1;
207 continue;
208 }
209 return Poll::Ready(Some(item));
210 }
211 Poll::Ready(None) => return Poll::Ready(None),
212 Poll::Pending => return Poll::Pending,
213 }
214 }
215 }
216}
217
218#[derive(Debug)]
220pub struct FilterByYearStream<S: Stream<Item = Paper>> {
221 stream: S,
222 min_year: Option<i32>,
223 max_year: Option<i32>,
224}
225
226impl<S: Stream<Item = Paper>> FilterByYearStream<S> {
227 pub fn new(stream: S, min_year: Option<i32>, max_year: Option<i32>) -> Self {
229 Self {
230 stream,
231 min_year,
232 max_year,
233 }
234 }
235}
236
237impl<S: Stream<Item = Paper> + Unpin> Stream for FilterByYearStream<S> {
238 type Item = Paper;
239
240 fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
241 let this = self.get_mut();
242 loop {
243 match Pin::new(&mut this.stream).poll_next(cx) {
244 Poll::Ready(Some(paper)) => {
245 if let Some(year) = extract_year(&paper.published_date) {
247 if let Some(min) = this.min_year {
248 if year < min {
249 continue;
250 }
251 }
252 if let Some(max) = this.max_year {
253 if year > max {
254 continue;
255 }
256 }
257 }
258 return Poll::Ready(Some(paper));
259 }
260 Poll::Ready(None) => return Poll::Ready(None),
261 Poll::Pending => return Poll::Pending,
262 }
263 }
264 }
265}
266
267fn extract_year(published_date: &Option<String>) -> Option<i32> {
269 published_date.as_ref().and_then(|date| {
270 date.split(['-', '/'])
272 .next()
273 .filter(|s| !s.is_empty())
274 .and_then(|y| y.parse::<i32>().ok())
275 })
276}
277
278#[cfg(test)]
279mod tests {
280 use super::*;
281 use crate::models::{Paper, SearchResponse, SourceType};
282 use crate::sources::mock::MockSource;
283 use futures_util::StreamExt;
284
285 fn make_paper(paper_id: &str, title: &str, source_type: SourceType) -> Paper {
286 Paper::new(
287 paper_id.to_string(),
288 title.to_string(),
289 format!("http://example.com/{}", paper_id),
290 source_type,
291 )
292 }
293
294 #[tokio::test]
295 async fn test_paper_stream_basic() {
296 let mock = MockSource::new();
297 mock.set_search_response(SearchResponse::new(
298 vec![
299 make_paper("1", "Paper 1", SourceType::Arxiv),
300 make_paper("2", "Paper 2", SourceType::Arxiv),
301 make_paper("3", "Paper 3", SourceType::Arxiv),
302 ],
303 "Mock Source",
304 "test",
305 ));
306
307 let stream = paper_stream(mock, SearchQuery::new("test"), 10);
308 let mut stream = Box::pin(stream);
309 let mut papers = Vec::new();
310
311 while let Some(paper) = stream.next().await {
312 papers.push(paper);
313 }
314
315 assert_eq!(papers.len(), 3);
316 assert_eq!(papers[0].paper_id, "1");
317 assert_eq!(papers[1].paper_id, "2");
318 assert_eq!(papers[2].paper_id, "3");
319 }
320
321 #[tokio::test]
322 async fn test_paper_stream_empty() {
323 let mock = MockSource::new();
324 mock.set_search_response(SearchResponse::new(Vec::new(), "Mock Source", "test"));
325
326 let stream = paper_stream(mock, SearchQuery::new("test"), 10);
327 let mut stream = Box::pin(stream);
328 let mut papers = Vec::new();
329
330 while let Some(paper) = stream.next().await {
331 papers.push(paper);
332 }
333
334 assert!(papers.is_empty());
335 }
336
337 #[test]
338 fn test_extract_year() {
339 assert_eq!(extract_year(&Some("2023-05-15".to_string())), Some(2023));
340 assert_eq!(extract_year(&Some("2023".to_string())), Some(2023));
341 assert_eq!(extract_year(&Some("2023/05/15".to_string())), Some(2023));
342 assert_eq!(extract_year(&None), None);
343 assert_eq!(extract_year(&Some("invalid".to_string())), None);
344 }
345}