1use std::collections::HashMap;
2use std::pin::Pin;
3use std::task::{Context, Poll};
4
5use chrono::{DateTime, Utc};
6use futures::{Stream, StreamExt};
7use tokio::sync::mpsc;
8use tracing::{debug, error, trace, warn};
9
10use wme_models::metadata::ArticleUpdate;
11use wme_models::RequestParams;
12
13use crate::{ClientError, Result, WmeClient};
14
15#[derive(Debug, Clone, Default)]
17pub struct RealtimeConnectOptions {
18 pub since: Option<DateTime<Utc>>,
20 pub since_per_partition: Option<HashMap<String, DateTime<Utc>>>,
22 pub partitions: Option<Vec<u32>>,
24 pub offsets: Option<HashMap<String, u64>>,
26}
27
28impl RealtimeConnectOptions {
29 pub fn since(timestamp: DateTime<Utc>) -> Self {
31 Self {
32 since: Some(timestamp),
33 since_per_partition: None,
34 partitions: None,
35 offsets: None,
36 }
37 }
38
39 pub fn since_per_partition(partitions: HashMap<String, DateTime<Utc>>) -> Self {
41 Self {
42 since: None,
43 since_per_partition: Some(partitions),
44 partitions: None,
45 offsets: None,
46 }
47 }
48
49 pub fn with_partitions(mut self, partitions: Vec<u32>) -> Self {
51 self.partitions = Some(partitions);
52 self
53 }
54
55 pub fn with_offsets(mut self, offsets: HashMap<String, u64>) -> Self {
57 self.offsets = Some(offsets);
58 self
59 }
60
61 fn to_request_body(&self, filters: Option<&wme_models::RequestParams>) -> serde_json::Value {
63 let mut body = serde_json::Map::new();
64
65 if let Some(since) = self.since {
66 body.insert("since".to_string(), serde_json::json!(since.to_rfc3339()));
67 }
68
69 if let Some(ref per_partition) = self.since_per_partition {
70 let map: serde_json::Map<String, serde_json::Value> = per_partition
71 .iter()
72 .map(|(k, v)| (k.clone(), serde_json::json!(v.to_rfc3339())))
73 .collect();
74 body.insert(
75 "since_per_partition".to_string(),
76 serde_json::Value::Object(map),
77 );
78 }
79
80 if let Some(ref partitions) = self.partitions {
81 body.insert("parts".to_string(), serde_json::json!(partitions));
82 }
83
84 if let Some(ref offsets) = self.offsets {
85 let map: serde_json::Map<String, serde_json::Value> = offsets
86 .iter()
87 .map(|(k, v)| (k.clone(), serde_json::json!(v)))
88 .collect();
89 body.insert("offsets".to_string(), serde_json::Value::Object(map));
90 }
91
92 if let Some(req) = filters {
93 if let Some(ref fields) = req.fields {
94 body.insert("fields".to_string(), serde_json::json!(fields));
95 }
96 if let Some(ref filters) = req.filters {
97 body.insert("filters".to_string(), serde_json::json!(filters));
98 }
99 }
100
101 serde_json::Value::Object(body)
102 }
103}
104
105pub struct RealtimeClient<'a> {
107 client: &'a WmeClient,
108}
109
110impl<'a> RealtimeClient<'a> {
111 pub(crate) fn new(client: &'a WmeClient) -> Self {
113 Self { client }
114 }
115
116 pub async fn connect(
149 &self,
150 options: &RealtimeConnectOptions,
151 filters: Option<&RequestParams>,
152 ) -> Result<Box<dyn Stream<Item = Result<ArticleUpdate>> + Send + Unpin>> {
153 let url = format!("{}/v2/articles", self.client.base_urls().realtime);
154 let body = options.to_request_body(filters).to_string();
155
156 let client = reqwest::Client::new();
158 let mut request = client
159 .post(&url)
160 .header("Accept", "text/event-stream")
161 .header("Content-Type", "application/json");
162
163 if let Some(headers) = self.client.auth_headers().await? {
165 if let Some(auth) = headers.get("Authorization") {
166 request = request.header("Authorization", auth);
167 }
168 }
169
170 let response = request.body(body).send().await.map_err(ClientError::from)?;
171
172 if !response.status().is_success() {
173 return Err(ClientError::Http(format!(
174 "Failed to connect to realtime stream: {}",
175 response.status()
176 )));
177 }
178
179 let (tx, rx) = mpsc::channel(100);
180
181 tokio::spawn(async move {
183 let mut stream = response.bytes_stream();
184 let mut buffer = String::new();
185 let mut current_event = String::new();
186
187 while let Some(chunk) = stream.next().await {
188 match chunk {
189 Ok(bytes) => {
190 buffer.push_str(&String::from_utf8_lossy(&bytes));
192
193 while let Some(pos) = buffer.find('\n') {
195 let line = buffer[..pos].to_string();
196 buffer = buffer[pos + 1..].to_string();
197
198 let line = line.trim_end_matches('\r');
199
200 if line.is_empty() {
201 if !current_event.is_empty() {
203 trace!(event_data = %current_event, "Processing SSE event");
204
205 match serde_json::from_str::<ArticleUpdate>(¤t_event) {
206 Ok(update) => {
207 if tx.send(Ok(update)).await.is_err() {
208 debug!(
209 "SSE stream receiver dropped, closing stream"
210 );
211 return;
212 }
213 }
214 Err(e) => {
215 warn!(error = %e, data = %current_event, "Failed to parse SSE event data");
216 if tx
217 .send(Err(ClientError::JsonParse(e.to_string())))
218 .await
219 .is_err()
220 {
221 return;
222 }
223 }
224 }
225 current_event.clear();
226 }
227 } else if let Some(data) = line.strip_prefix("data: ") {
228 current_event.push_str(data);
230 } else if line.starts_with("id: ") {
231 trace!(event_id = %line, "Received SSE id");
233 } else if line.starts_with("event: ") {
234 trace!(event_type = %line, "Received SSE event type");
236 } else if line.starts_with(":") {
237 trace!(comment = %line, "Received SSE comment");
239 }
240 }
241 }
242 Err(e) => {
243 error!(error = %e, "SSE stream error");
244 let _ = tx.send(Err(ClientError::Stream(e.to_string()))).await;
245 return;
246 }
247 }
248 }
249
250 if !buffer.is_empty() {
252 let line = buffer.trim_end_matches('\r');
253 if let Some(data) = line.strip_prefix("data: ") {
254 current_event.push_str(data);
255 }
256 }
257
258 if !current_event.is_empty() {
260 match serde_json::from_str::<ArticleUpdate>(¤t_event) {
261 Ok(update) => {
262 let _ = tx.send(Ok(update)).await;
263 }
264 Err(e) => {
265 let _ = tx.send(Err(ClientError::JsonParse(e.to_string()))).await;
266 }
267 }
268 }
269
270 debug!("SSE stream ended");
271 });
272
273 let stream = SseStream { receiver: rx };
275 Ok(Box::new(stream))
276 }
277
278 pub async fn list_batches(
280 &self,
281 date: &str,
282 hour: &str,
283 ) -> Result<Vec<wme_models::metadata::RealtimeBatchInfo>> {
284 self.list_batches_with_params(date, hour, None).await
285 }
286
287 pub async fn list_batches_with_params(
289 &self,
290 date: &str,
291 hour: &str,
292 params: Option<&RequestParams>,
293 ) -> Result<Vec<wme_models::metadata::RealtimeBatchInfo>> {
294 let url = format!(
295 "{}/v2/batches/{}/{}",
296 self.client.base_urls().api,
297 date,
298 hour
299 );
300 let headers = self.client.auth_headers().await?;
301
302 let response = if let Some(p) = params {
303 let body = serde_json::to_string(p)?;
304 self.client
305 .transport()
306 .request(reqwest::Method::POST, &url, headers, Some(body))
307 .await?
308 } else {
309 self.client
310 .transport()
311 .request(reqwest::Method::GET, &url, headers, None)
312 .await?
313 };
314
315 if !response.status().is_success() {
316 return Err(ClientError::Http(format!(
317 "Failed to list batches: {}",
318 response.status()
319 )));
320 }
321
322 let batches = response.json().await.map_err(ClientError::from)?;
323 Ok(batches)
324 }
325
326 pub async fn get_batch_info(
328 &self,
329 date: &str,
330 hour: &str,
331 identifier: &str,
332 ) -> Result<wme_models::metadata::RealtimeBatchInfo> {
333 self.get_batch_info_with_params(date, hour, identifier, None)
334 .await
335 }
336
337 pub async fn get_batch_info_with_params(
339 &self,
340 date: &str,
341 hour: &str,
342 identifier: &str,
343 params: Option<&RequestParams>,
344 ) -> Result<wme_models::metadata::RealtimeBatchInfo> {
345 let url = format!(
346 "{}/v2/batches/{}/{}/{}",
347 self.client.base_urls().api,
348 date,
349 hour,
350 identifier
351 );
352 let headers = self.client.auth_headers().await?;
353
354 let response = if let Some(p) = params {
355 let body = serde_json::to_string(p)?;
356 self.client
357 .transport()
358 .request(reqwest::Method::POST, &url, headers, Some(body))
359 .await?
360 } else {
361 self.client
362 .transport()
363 .request(reqwest::Method::GET, &url, headers, None)
364 .await?
365 };
366
367 if !response.status().is_success() {
368 return Err(ClientError::Http(format!(
369 "Failed to get batch info: {}",
370 response.status()
371 )));
372 }
373
374 let batch = response.json().await.map_err(ClientError::from)?;
375 Ok(batch)
376 }
377
378 pub async fn download_batch(
380 &self,
381 date: &str,
382 hour: &str,
383 identifier: &str,
384 range: Option<&str>,
385 ) -> Result<Box<dyn Stream<Item = Result<bytes::Bytes>> + Send + Unpin>> {
386 let url = format!(
387 "{}/v2/batches/{}/{}/{}/download",
388 self.client.base_urls().api,
389 date,
390 hour,
391 identifier
392 );
393
394 let mut headers = self.client.auth_headers().await?;
395
396 if let Some(range) = range {
397 headers = headers.or_else(|| Some(std::collections::HashMap::new()));
398 if let Some(ref mut h) = headers {
399 h.insert(reqwest::header::RANGE.to_string(), range.to_string());
400 }
401 }
402
403 let stream = self
404 .client
405 .transport()
406 .stream(reqwest::Method::GET, &url, headers)
407 .await?;
408
409 Ok(stream)
410 }
411
412 pub async fn stream_batch(
436 &self,
437 date: &str,
438 hour: &str,
439 identifier: &str,
440 ) -> Result<Box<dyn Stream<Item = Result<wme_models::Article>> + Send + Unpin>> {
441 use async_compression::tokio::bufread::GzipDecoder;
442 use tokio::io::{AsyncBufReadExt, BufReader as TokioBufReader};
443 use tokio_tar::Archive;
444
445 let byte_stream = self.download_batch(date, hour, identifier, None).await?;
447
448 let reader = tokio_util::io::StreamReader::new(
450 byte_stream.map(|result| result.map_err(std::io::Error::other)),
451 );
452
453 let gz_decoder = GzipDecoder::new(TokioBufReader::new(reader));
455
456 let mut archive = Archive::new(gz_decoder);
458
459 let (tx, rx) = mpsc::channel(100);
460
461 tokio::spawn(async move {
463 let mut entries = archive.entries().map_err(|e| {
464 error!(error = %e, "Failed to read tar archive entries");
465 ClientError::Io(format!("Failed to read tar entries: {}", e))
466 })?;
467
468 while let Some(entry) = entries.next().await {
469 match entry {
470 Ok(mut entry) => {
471 let mut lines = TokioBufReader::new(&mut entry).lines();
473
474 while let Ok(Some(line)) = lines.next_line().await {
475 if line.trim().is_empty() {
476 continue;
477 }
478
479 match serde_json::from_str::<wme_models::Article>(&line) {
480 Ok(article) => {
481 if tx.send(Ok(article)).await.is_err() {
482 debug!("Batch stream receiver dropped");
483 return Ok::<(), ClientError>(());
484 }
485 }
486 Err(e) => {
487 warn!(error = %e, line = %line, "Failed to parse NDJSON line");
488 if tx
489 .send(Err(ClientError::JsonParse(e.to_string())))
490 .await
491 .is_err()
492 {
493 return Ok::<(), ClientError>(());
494 }
495 }
496 }
497 }
498 }
499 Err(e) => {
500 error!(error = %e, "Error reading tar entry");
501 let _ = tx
502 .send(Err(ClientError::Io(format!("Tar entry error: {}", e))))
503 .await;
504 }
505 }
506 }
507
508 debug!("Batch stream completed");
509 Ok(())
510 });
511
512 let stream = BatchStream { receiver: rx };
514 Ok(Box::new(stream))
515 }
516
517 pub async fn head_batch_download(
519 &self,
520 date: &str,
521 hour: &str,
522 identifier: &str,
523 ) -> Result<reqwest::header::HeaderMap> {
524 let url = format!(
525 "{}/v2/batches/{}/{}/{}/download",
526 self.client.base_urls().api,
527 date,
528 hour,
529 identifier
530 );
531 let headers = self.client.auth_headers().await?;
532
533 let response = self
534 .client
535 .transport()
536 .request(reqwest::Method::HEAD, &url, headers, None)
537 .await?;
538
539 match response.status().as_u16() {
540 200 => Ok(response.headers().clone()),
541 status => Err(ClientError::Http(format!(
542 "Failed to get batch headers: {}",
543 status
544 ))),
545 }
546 }
547}
548
549struct SseStream {
551 receiver: mpsc::Receiver<Result<ArticleUpdate>>,
552}
553
554impl Stream for SseStream {
555 type Item = Result<ArticleUpdate>;
556
557 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
558 self.receiver.poll_recv(cx)
559 }
560}
561
562struct BatchStream {
564 receiver: mpsc::Receiver<Result<wme_models::Article>>,
565}
566
567impl Stream for BatchStream {
568 type Item = Result<wme_models::Article>;
569
570 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
571 self.receiver.poll_recv(cx)
572 }
573}