1use std::future::Future;
4use std::path::PathBuf;
5use std::pin::Pin;
6use std::sync::Arc;
7
8use futures::stream::{self, StreamExt};
9use rand::SeedableRng;
10use rand::seq::SliceRandom;
11use rand_chacha::ChaCha8Rng;
12use serde::{Deserialize, Serialize};
13use url::Url;
14
15use crate::extractor::options::{SampleStrategy, TablesMode};
16use crate::extractor::output::OutputPaths;
17use crate::extractor::pipeline::ExtractorError;
18
19#[derive(Debug, Clone, Serialize, Deserialize)]
20pub struct TableTransform {
21 pub ordinal: usize,
22 pub mode: String,
23 #[serde(skip_serializing_if = "Option::is_none")]
24 pub path: Option<PathBuf>,
25 #[serde(skip_serializing_if = "Option::is_none")]
26 pub kept_rows: Option<usize>,
27 #[serde(skip_serializing_if = "Option::is_none")]
28 pub truncated_rows: Option<usize>,
29 #[serde(skip_serializing_if = "Option::is_none")]
33 pub summary_md: Option<String>,
34 #[serde(skip_serializing_if = "Option::is_none")]
41 pub fallback_reason: Option<String>,
42 #[serde(skip_serializing_if = "Option::is_none")]
50 pub fallback_from: Option<String>,
51}
52
53#[derive(Debug, Clone)]
58pub struct FallbackInfo {
59 pub from: String,
60 pub reason: String,
61}
62
63pub type TableSummarizeHook = Arc<
69 dyn for<'a> Fn(
70 &'a str,
71 ) -> Pin<
72 Box<dyn Future<Output = Result<(String, Option<FallbackInfo>), String>> + Send + 'a>,
73 > + Send
74 + Sync,
75>;
76
77enum TableEvent<'a> {
82 Line(&'a str),
83 Table(Vec<String>, usize),
84}
85
86fn iterate_tables<F>(markdown: &str, mut sink: F) -> Result<(), ExtractorError>
96where
97 F: FnMut(TableEvent<'_>) -> Result<(), ExtractorError>,
98{
99 let mut ordinal: usize = 0;
100 let mut iter = markdown.lines().peekable();
101 while let Some(line) = iter.next() {
102 if is_pipe_table_start(line, iter.peek().copied()) {
103 let mut rows: Vec<String> = vec![line.to_string()];
104 while let Some(next) = iter.peek().copied() {
105 if next.trim_start().starts_with('|') {
106 rows.push(next.to_string());
107 iter.next();
108 } else {
109 break;
110 }
111 }
112 sink(TableEvent::Table(rows, ordinal))?;
113 ordinal += 1;
114 } else {
115 sink(TableEvent::Line(line))?;
116 }
117 }
118 Ok(())
119}
120
121pub fn apply(
130 markdown: &str,
131 mode: &TablesMode,
132 output_paths: &OutputPaths,
133 base_url: &Url,
134) -> Result<(String, Vec<TableTransform>), ExtractorError> {
135 let mut out = String::with_capacity(markdown.len());
136 let mut records = Vec::new();
137
138 iterate_tables(markdown, |ev| {
139 match ev {
140 TableEvent::Line(line) => {
141 out.push_str(line);
142 out.push('\n');
143 }
144 TableEvent::Table(rows, ordinal) => {
145 let (replacement, record) =
146 transform_table(rows, ordinal, mode, output_paths, base_url)?;
147 out.push_str(&replacement);
148 out.push('\n');
149 if let Some(r) = record {
150 records.push(r);
151 }
152 }
153 }
154 Ok(())
155 })?;
156
157 if !markdown.ends_with('\n') && out.ends_with('\n') {
158 out.pop();
159 }
160 Ok((out, records))
161}
162
163pub async fn apply_with_summarizer(
176 markdown: &str,
177 mode: &TablesMode,
178 output_paths: &OutputPaths,
179 base_url: &Url,
180 hook: Option<&TableSummarizeHook>,
181) -> Result<(String, Vec<TableTransform>), ExtractorError> {
182 if !matches!(mode, TablesMode::Summarize) {
183 return apply(markdown, mode, output_paths, base_url);
184 }
185 let Some(hook) = hook else {
186 return Err(ExtractorError::Metadata(
187 "internal: apply_with_summarizer requires a hook for TablesMode::Summarize".into(),
188 ));
189 };
190
191 enum OwnedEvent {
196 Line(String),
197 Table(Vec<String>, usize, usize),
201 }
202 let mut events: Vec<OwnedEvent> = Vec::new();
203 let mut tables: Vec<(Vec<String>, usize)> = Vec::new();
204 iterate_tables(markdown, |ev| {
205 match ev {
206 TableEvent::Line(line) => events.push(OwnedEvent::Line(line.to_string())),
207 TableEvent::Table(rows, ordinal) => {
208 let idx = tables.len();
209 tables.push((rows.clone(), ordinal));
210 events.push(OwnedEvent::Table(rows, ordinal, idx));
211 }
212 }
213 Ok(())
214 })?;
215
216 let hook_results: Vec<Result<(String, Option<FallbackInfo>), String>> = stream::iter(tables)
224 .map(|(rows, _ordinal)| async move {
225 let table_text = rows.join("\n");
226 hook(&table_text).await
227 })
228 .buffered(4)
229 .collect()
230 .await;
231
232 let mut out = String::with_capacity(markdown.len());
235 let mut records = Vec::new();
236 for ev in events {
237 match ev {
238 OwnedEvent::Line(line) => {
239 out.push_str(&line);
240 out.push('\n');
241 }
242 OwnedEvent::Table(rows, ordinal, idx) => {
243 let table_text = rows.join("\n");
244 match &hook_results[idx] {
245 Ok((summary, fallback)) => {
246 out.push_str(summary);
247 out.push('\n');
248 records.push(TableTransform {
249 ordinal,
250 mode: "summarize".into(),
251 path: None,
252 kept_rows: None,
253 truncated_rows: None,
254 summary_md: Some(summary.clone()),
255 fallback_reason: fallback.as_ref().map(|f| f.reason.clone()),
256 fallback_from: fallback.as_ref().map(|f| f.from.clone()),
257 });
258 }
259 Err(reason) => {
260 out.push_str(&table_text);
261 out.push('\n');
262 records.push(TableTransform {
263 ordinal,
264 mode: "summarize".into(),
265 path: None,
266 kept_rows: None,
267 truncated_rows: None,
268 summary_md: None,
269 fallback_reason: Some(reason.clone()),
270 fallback_from: None,
271 });
272 }
273 }
274 }
275 }
276 }
277 if !markdown.ends_with('\n') && out.ends_with('\n') {
278 out.pop();
279 }
280 Ok((out, records))
281}
282
283fn is_pipe_table_start(line: &str, next: Option<&str>) -> bool {
284 if !line.trim_start().starts_with('|') {
285 return false;
286 }
287 let Some(n) = next else {
288 return false;
289 };
290 let nt = n.trim_start();
292 nt.starts_with('|') && nt.chars().all(|c| matches!(c, '|' | '-' | ':' | ' '))
293}
294
295fn transform_table(
296 rows: Vec<String>,
297 ordinal: usize,
298 mode: &TablesMode,
299 output_paths: &OutputPaths,
300 base_url: &Url,
301) -> Result<(String, Option<TableTransform>), ExtractorError> {
302 match mode {
303 TablesMode::Embed => Ok((
304 rows.join("\n"),
305 Some(TableTransform {
306 ordinal,
307 mode: "embed".into(),
308 path: None,
309 kept_rows: None,
310 truncated_rows: None,
311 summary_md: None,
312 fallback_reason: None,
313 fallback_from: None,
314 }),
315 )),
316 TablesMode::Drop => Ok((
317 format!("_Table {ordinal} omitted_"),
318 Some(TableTransform {
319 ordinal,
320 mode: "drop".into(),
321 path: None,
322 kept_rows: None,
323 truncated_rows: None,
324 summary_md: None,
325 fallback_reason: None,
326 fallback_from: None,
327 }),
328 )),
329 TablesMode::Sample(strategy) => {
330 if rows.len() < 3 {
332 return Ok((rows.join("\n"), None));
333 }
334 let header = &rows[0];
335 let sep = &rows[1];
336 let data: Vec<&String> = rows[2..].iter().collect();
337 let (kept, truncated) = sample_rows(&data, strategy);
338 let mut out = vec![header.clone(), sep.clone()];
339 for r in &kept {
340 out.push((*r).clone());
341 }
342 if truncated > 0 {
343 out.push(format!("_… {truncated} rows truncated …_"));
344 }
345 Ok((
346 out.join("\n"),
347 Some(TableTransform {
348 ordinal,
349 mode: "sample".into(),
350 path: None,
351 kept_rows: Some(kept.len()),
352 truncated_rows: Some(truncated),
353 summary_md: None,
354 fallback_reason: None,
355 fallback_from: None,
356 }),
357 ))
358 }
359 TablesMode::CsvFile => {
360 let path = output_paths.table_path(base_url, ordinal);
361 if let Some(parent) = path.parent() {
362 std::fs::create_dir_all(parent).map_err(|source| ExtractorError::TableWrite {
363 ordinal,
364 path: parent.display().to_string(),
365 source,
366 })?;
367 }
368 write_csv(&path, &rows, ordinal)?;
369 let abs = path.canonicalize().unwrap_or_else(|_| path.clone());
370 Ok((
371 format!("_Table {ordinal} saved to {}_", abs.display()),
372 Some(TableTransform {
373 ordinal,
374 mode: "csv_file".into(),
375 path: Some(abs),
376 kept_rows: None,
377 truncated_rows: None,
378 summary_md: None,
379 fallback_reason: None,
380 fallback_from: None,
381 }),
382 ))
383 }
384 TablesMode::Summarize => Err(ExtractorError::Metadata(
385 "internal: TablesMode::Summarize must go through apply_with_summarizer".into(),
386 )),
387 }
388}
389
390fn sample_rows<'a>(data: &[&'a String], strategy: &SampleStrategy) -> (Vec<&'a String>, usize) {
391 let total = data.len();
392 match strategy {
393 SampleStrategy::HeadTail { head, tail } => {
394 if total <= head + tail {
395 return (data.to_vec(), 0);
396 }
397 let mut kept: Vec<&String> = data.iter().take(*head).copied().collect();
398 kept.extend(data.iter().rev().take(*tail).rev().copied());
399 let truncated = total - kept.len();
400 (kept, truncated)
401 }
402 SampleStrategy::RandomSeed { rows, seed } => {
403 if total <= *rows {
404 return (data.to_vec(), 0);
405 }
406 let mut rng = ChaCha8Rng::seed_from_u64(*seed);
407 let mut indices: Vec<usize> = (0..total).collect();
408 indices.shuffle(&mut rng);
409 indices.truncate(*rows);
410 indices.sort();
411 let kept: Vec<&String> = indices.iter().map(|i| data[*i]).collect();
412 let truncated = total - kept.len();
413 (kept, truncated)
414 }
415 }
416}
417
418fn parse_pipe_row(line: &str) -> Vec<String> {
419 let line = line.trim();
420 let line = line.trim_start_matches('|').trim_end_matches('|');
421 line.split('|').map(|c| c.trim().to_string()).collect()
422}
423
424fn write_csv(
425 path: &std::path::Path,
426 rows: &[String],
427 ordinal: usize,
428) -> Result<(), ExtractorError> {
429 let file = std::fs::File::create(path).map_err(|source| ExtractorError::TableWrite {
430 ordinal,
431 path: path.display().to_string(),
432 source,
433 })?;
434 let mut wtr = csv::Writer::from_writer(file);
435 for (i, row) in rows.iter().enumerate() {
436 if i == 1 {
437 continue; }
439 let cells = parse_pipe_row(row);
440 wtr.write_record(&cells)
441 .map_err(|e| ExtractorError::TableWrite {
442 ordinal,
443 path: path.display().to_string(),
444 source: std::io::Error::other(e.to_string()),
445 })?;
446 }
447 wtr.flush().map_err(|source| ExtractorError::TableWrite {
448 ordinal,
449 path: path.display().to_string(),
450 source,
451 })?;
452 Ok(())
453}
454
455#[cfg(test)]
456mod tests {
457 use super::*;
458 use crate::extractor::OUTPUT_DIR_TEST_MUTEX as TEST_MUTEX;
459
460 fn paths() -> OutputPaths {
461 let tmp = tempfile::tempdir().unwrap();
462 let dir = tmp.path().to_path_buf();
463 std::mem::forget(tmp);
464 unsafe { std::env::set_var("ROVER_OUTPUT_DIR", &dir) };
466 OutputPaths::resolve(None).unwrap()
467 }
468
469 fn url() -> Url {
470 Url::parse("https://example.com/").unwrap()
471 }
472
473 const TABLE_3ROWS: &str = "| A | B |\n|---|---|\n| 1 | 2 |\n| 3 | 4 |\n| 5 | 6 |";
474
475 #[test]
476 fn embed_mode_passes_through() {
477 let _g = TEST_MUTEX.lock().unwrap_or_else(|e| e.into_inner());
478 let (out, recs) = apply(TABLE_3ROWS, &TablesMode::Embed, &paths(), &url()).unwrap();
479 assert!(out.contains("| 1 | 2 |"));
480 assert_eq!(recs.len(), 1);
481 assert_eq!(recs[0].mode, "embed");
482 }
483
484 #[test]
485 fn drop_mode_replaces_with_marker() {
486 let _g = TEST_MUTEX.lock().unwrap_or_else(|e| e.into_inner());
487 let (out, recs) = apply(TABLE_3ROWS, &TablesMode::Drop, &paths(), &url()).unwrap();
488 assert!(out.contains("_Table 0 omitted_"));
489 assert!(!out.contains("| 1 | 2 |"));
490 assert_eq!(recs[0].mode, "drop");
491 }
492
493 #[test]
494 fn sample_head_tail_keeps_head_plus_tail() {
495 let _g = TEST_MUTEX.lock().unwrap_or_else(|e| e.into_inner());
496 let strategy = SampleStrategy::HeadTail { head: 1, tail: 1 };
497 let (out, recs) =
498 apply(TABLE_3ROWS, &TablesMode::Sample(strategy), &paths(), &url()).unwrap();
499 assert!(out.contains("| 1 | 2 |"));
500 assert!(out.contains("| 5 | 6 |"));
501 assert!(out.contains("_… 1 rows truncated …_"));
502 assert_eq!(recs[0].kept_rows, Some(2));
503 assert_eq!(recs[0].truncated_rows, Some(1));
504 }
505
506 #[test]
507 fn sample_random_seed_is_deterministic() {
508 let _g = TEST_MUTEX.lock().unwrap_or_else(|e| e.into_inner());
509 let strat = SampleStrategy::RandomSeed { rows: 2, seed: 42 };
510 let (out_a, _) = apply(
511 TABLE_3ROWS,
512 &TablesMode::Sample(strat.clone()),
513 &paths(),
514 &url(),
515 )
516 .unwrap();
517 let (out_b, _) = apply(TABLE_3ROWS, &TablesMode::Sample(strat), &paths(), &url()).unwrap();
518 assert_eq!(out_a, out_b);
519 }
520
521 #[test]
522 fn csv_file_writes_table_to_disk_and_replaces_markdown() {
523 let _g = TEST_MUTEX.lock().unwrap_or_else(|e| e.into_inner());
524 let (out, recs) = apply(TABLE_3ROWS, &TablesMode::CsvFile, &paths(), &url()).unwrap();
525 assert!(out.contains("_Table 0 saved to "));
526 let p = recs[0].path.as_ref().unwrap();
527 let csv = std::fs::read_to_string(p).unwrap();
528 assert!(csv.contains("A,B"));
529 assert!(csv.contains("1,2"));
530 assert!(csv.contains("5,6"));
531 }
532
533 #[tokio::test]
534 async fn summarize_mode_invokes_hook_per_table() {
535 let paths = {
538 let _g = TEST_MUTEX.lock().unwrap_or_else(|e| e.into_inner());
539 paths()
540 };
541 let md = "Intro.\n\n| A | B |\n|---|---|\n| 1 | 2 |\n\nMiddle.\n\n| X | Y |\n|---|---|\n| 9 | 8 |\n";
542 let counter = std::sync::Arc::new(std::sync::atomic::AtomicUsize::new(0));
543 let counter_clone = counter.clone();
544 let hook: TableSummarizeHook = std::sync::Arc::new(move |_text: &str| {
545 let counter_clone = counter_clone.clone();
546 Box::pin(async move {
547 counter_clone.fetch_add(1, std::sync::atomic::Ordering::SeqCst);
548 Ok::<(String, Option<FallbackInfo>), String>(("(summary)".to_string(), None))
549 })
550 });
551 let (out, recs) =
552 apply_with_summarizer(md, &TablesMode::Summarize, &paths, &url(), Some(&hook))
553 .await
554 .unwrap();
555 assert_eq!(counter.load(std::sync::atomic::Ordering::SeqCst), 2);
556 assert_eq!(recs.len(), 2);
557 assert!(recs.iter().all(|r| r.mode == "summarize"));
558 assert!(
559 recs.iter()
560 .all(|r| r.summary_md.as_deref() == Some("(summary)"))
561 );
562 assert!(recs.iter().all(|r| r.fallback_reason.is_none()));
563 assert!(recs.iter().all(|r| r.fallback_from.is_none()));
564 assert!(out.contains("(summary)"));
565 assert!(!out.contains("| 1 | 2 |"));
566 assert!(!out.contains("| 9 | 8 |"));
567 }
568
569 #[tokio::test]
570 async fn summarize_mode_records_fallback_when_hook_fails() {
571 let paths = {
572 let _g = TEST_MUTEX.lock().unwrap_or_else(|e| e.into_inner());
573 paths()
574 };
575 let hook: TableSummarizeHook = std::sync::Arc::new(|_text: &str| {
576 Box::pin(async move {
577 Err::<(String, Option<FallbackInfo>), String>("auth_failed".to_string())
578 })
579 });
580 let (out, recs) = apply_with_summarizer(
581 TABLE_3ROWS,
582 &TablesMode::Summarize,
583 &paths,
584 &url(),
585 Some(&hook),
586 )
587 .await
588 .unwrap();
589 assert_eq!(recs.len(), 1);
590 assert_eq!(recs[0].mode, "summarize");
591 assert!(recs[0].summary_md.is_none());
592 assert_eq!(recs[0].fallback_reason.as_deref(), Some("auth_failed"));
593 assert!(recs[0].fallback_from.is_none());
594 assert!(out.contains("| 1 | 2 |"));
595 assert!(out.contains("| 5 | 6 |"));
596 }
597
598 #[tokio::test]
599 async fn summarize_mode_records_internal_fallback_when_hook_returns_fallback_info() {
600 let paths = {
601 let _g = TEST_MUTEX.lock().unwrap_or_else(|e| e.into_inner());
602 paths()
603 };
604 let hook: TableSummarizeHook = std::sync::Arc::new(|_text: &str| {
605 Box::pin(async move {
606 Ok::<(String, Option<FallbackInfo>), String>((
607 "(extractive summary)".to_string(),
608 Some(FallbackInfo {
609 from: "fast".to_string(),
610 reason: "backend_unavailable".to_string(),
611 }),
612 ))
613 })
614 });
615 let (out, recs) = apply_with_summarizer(
616 TABLE_3ROWS,
617 &TablesMode::Summarize,
618 &paths,
619 &url(),
620 Some(&hook),
621 )
622 .await
623 .unwrap();
624 assert_eq!(recs.len(), 1);
625 assert_eq!(recs[0].summary_md.as_deref(), Some("(extractive summary)"));
626 assert_eq!(recs[0].fallback_from.as_deref(), Some("fast"));
627 assert_eq!(
628 recs[0].fallback_reason.as_deref(),
629 Some("backend_unavailable")
630 );
631 assert!(out.contains("(extractive summary)"));
632 assert!(!out.contains("| 1 | 2 |"));
633 }
634
635 #[test]
636 fn non_table_content_passes_through_unchanged() {
637 let _g = TEST_MUTEX.lock().unwrap_or_else(|e| e.into_inner());
638 let md = "Just some text\n\nNo tables here.\n";
639 let (out, recs) = apply(md, &TablesMode::Drop, &paths(), &url()).unwrap();
640 assert_eq!(out, md);
641 assert!(recs.is_empty());
642 }
643}