use std::fs::File;
use std::io::{BufRead, BufReader};
use std::path::{Path, PathBuf};
use flate2::read::GzDecoder;
use super::aggregator::AggregatedNgram;
use super::parser::{parse_ngram_line, NgramRecord, ParseError};
#[cfg(feature = "google-books")]
use super::task_manager::RetryAfter;
pub trait NgramReader: Iterator<Item = Result<NgramRecord, ReaderError>> {
fn byte_offset(&self) -> u64;
fn total_bytes(&self) -> Option<u64>;
}
#[derive(Debug, thiserror::Error)]
pub enum ReaderError {
#[error("I/O error: {0}")]
Io(#[from] std::io::Error),
#[error("Parse error at line {line}: {error}")]
Parse {
line: u64,
#[source]
error: ParseError,
},
#[error("HTTP error: {0}")]
Http(String),
#[error("Decompression error: {0}")]
Decompression(String),
#[cfg(feature = "google-books")]
#[error("Rate limited (HTTP 429) for {url}")]
RateLimited {
url: String,
retry_after: Option<RetryAfter>,
},
}
fn save_failed_response(
url: &str,
status: reqwest::StatusCode,
headers: &reqwest::header::HeaderMap,
body: &[u8],
) -> Option<std::path::PathBuf> {
use std::io::Write;
let dir = std::env::temp_dir().join("grammstein-failed-responses");
if std::fs::create_dir_all(&dir).is_err() {
tracing::warn!(
"Failed to create directory for failed responses: {}",
dir.display()
);
return None;
}
let filename: String = url
.replace("://", "_")
.replace('/', "_")
.replace('?', "_")
.chars()
.take(100)
.collect();
let timestamp = std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.map(|d| d.as_secs())
.unwrap_or(0);
let path = dir.join(format!("{}_{}.response", filename, timestamp));
let file = match std::fs::File::create(&path) {
Ok(f) => f,
Err(e) => {
tracing::warn!("Failed to create response dump file: {}", e);
return None;
}
};
let mut writer = std::io::BufWriter::new(file);
writeln!(writer, "URL: {}", url).ok();
writeln!(writer, "HTTP/1.1 {}", status).ok();
for (name, value) in headers.iter() {
if let Ok(v) = value.to_str() {
writeln!(writer, "{}: {}", name, v).ok();
} else {
writeln!(writer, "{}: <binary>", name).ok();
}
}
writeln!(writer).ok();
if let Err(e) = writer.write_all(body) {
tracing::warn!("Failed to write response body: {}", e);
return None;
}
tracing::info!("Saved failed response to: {}", path.display());
Some(path)
}
pub struct FileNgramReader {
reader: BufReader<GzDecoder<File>>,
line_buffer: String,
current_line: u64,
total_compressed_size: Option<u64>,
path: PathBuf,
skip_pos_tags: bool,
min_count: u64,
}
impl FileNgramReader {
pub fn open<P: AsRef<Path>>(path: P) -> Result<Self, ReaderError> {
Self::open_with_options(path, false, 0)
}
pub fn open_with_options<P: AsRef<Path>>(
path: P,
skip_pos_tags: bool,
min_count: u64,
) -> Result<Self, ReaderError> {
let path = path.as_ref();
let file = File::open(path)?;
let total_compressed_size = file.metadata().map(|m| m.len()).ok();
let decoder = GzDecoder::new(file);
let reader = BufReader::with_capacity(64 * 1024, decoder);
Ok(Self {
reader,
line_buffer: String::with_capacity(256),
current_line: 0,
total_compressed_size,
path: path.to_path_buf(),
skip_pos_tags,
min_count,
})
}
pub fn open_all<P: AsRef<Path>>(
paths: &[P],
skip_pos_tags: bool,
min_count: u64,
) -> Result<MultiFileReader, ReaderError> {
MultiFileReader::new(paths, skip_pos_tags, min_count)
}
pub fn path(&self) -> &Path {
&self.path
}
fn next_valid_record(&mut self) -> Option<Result<NgramRecord, ReaderError>> {
loop {
self.line_buffer.clear();
match self.reader.read_line(&mut self.line_buffer) {
Ok(0) => return None, Ok(_) => {
self.current_line += 1;
let line = self.line_buffer.trim_end();
if line.is_empty() {
continue;
}
match parse_ngram_line(line) {
Ok(record) => {
if record.match_count < self.min_count {
continue;
}
if self.skip_pos_tags && super::parser::contains_pos_tag(&record.ngram)
{
continue;
}
return Some(Ok(record));
}
Err(e) => {
return Some(Err(ReaderError::Parse {
line: self.current_line,
error: e,
}));
}
}
}
Err(e) => return Some(Err(ReaderError::Io(e))),
}
}
}
}
impl Iterator for FileNgramReader {
type Item = Result<NgramRecord, ReaderError>;
fn next(&mut self) -> Option<Self::Item> {
self.next_valid_record()
}
}
impl NgramReader for FileNgramReader {
fn byte_offset(&self) -> u64 {
self.current_line
}
fn total_bytes(&self) -> Option<u64> {
self.total_compressed_size
}
}
pub struct MultiFileReader {
paths: Vec<PathBuf>,
current_index: usize,
current_reader: Option<FileNgramReader>,
skip_pos_tags: bool,
min_count: u64,
total_lines: u64,
}
impl MultiFileReader {
pub fn new<P: AsRef<Path>>(
paths: &[P],
skip_pos_tags: bool,
min_count: u64,
) -> Result<Self, ReaderError> {
let paths: Vec<PathBuf> = paths.iter().map(|p| p.as_ref().to_path_buf()).collect();
let mut reader = Self {
paths,
current_index: 0,
current_reader: None,
skip_pos_tags,
min_count,
total_lines: 0,
};
reader.open_next_file()?;
Ok(reader)
}
fn open_next_file(&mut self) -> Result<bool, ReaderError> {
if self.current_index >= self.paths.len() {
self.current_reader = None;
return Ok(false);
}
let path = &self.paths[self.current_index];
self.current_reader = Some(FileNgramReader::open_with_options(
path,
self.skip_pos_tags,
self.min_count,
)?);
self.current_index += 1;
Ok(true)
}
pub fn current_file(&self) -> Option<&Path> {
self.current_reader.as_ref().map(|r| r.path())
}
pub fn files_remaining(&self) -> usize {
self.paths.len().saturating_sub(self.current_index)
}
}
impl Iterator for MultiFileReader {
type Item = Result<NgramRecord, ReaderError>;
fn next(&mut self) -> Option<Self::Item> {
loop {
if let Some(ref mut reader) = self.current_reader {
match reader.next() {
Some(result) => {
self.total_lines += 1;
return Some(result);
}
None => {
match self.open_next_file() {
Ok(true) => continue,
Ok(false) => return None,
Err(e) => return Some(Err(e)),
}
}
}
} else {
return None;
}
}
}
}
impl NgramReader for MultiFileReader {
fn byte_offset(&self) -> u64 {
self.total_lines
}
fn total_bytes(&self) -> Option<u64> {
None }
}
pub struct HttpNgramReader {
url: String,
content_length: Option<u64>,
skip_pos_tags: bool,
min_count: u64,
}
impl HttpNgramReader {
pub fn new(url: &str) -> Self {
Self {
url: url.to_string(),
content_length: None,
skip_pos_tags: false,
min_count: 0,
}
}
pub fn with_options(url: &str, skip_pos_tags: bool, min_count: u64) -> Self {
Self {
url: url.to_string(),
content_length: None,
skip_pos_tags,
min_count,
}
}
pub fn url(&self) -> &str {
&self.url
}
#[cfg(feature = "google-books")]
pub async fn stream_records(
&mut self,
) -> Result<impl tokio_stream::Stream<Item = Result<NgramRecord, ReaderError>> + '_, ReaderError>
{
use async_compression::tokio::bufread::GzipDecoder;
use std::time::Duration;
use tokio::io::{AsyncBufReadExt, BufReader};
use tokio_stream::StreamExt;
use tokio_util::io::StreamReader;
let client = reqwest::Client::builder()
.timeout(Duration::from_secs(300)) .connect_timeout(Duration::from_secs(30)) .build()
.map_err(|e| ReaderError::Http(format!("Failed to build HTTP client: {}", e)))?;
let response = client
.get(&self.url)
.send()
.await
.map_err(|e| ReaderError::Http(e.to_string()))?;
if !response.status().is_success() {
return Err(ReaderError::Http(format!(
"HTTP {} for {}",
response.status(),
self.url
)));
}
self.content_length = response.content_length();
let byte_stream = response.bytes_stream();
let url_for_errors = self.url.clone();
let mapped_stream = byte_stream.map(move |result| {
result.map_err(|e| {
let kind = if e.is_timeout() {
std::io::ErrorKind::TimedOut
} else if e.is_connect() {
std::io::ErrorKind::ConnectionRefused
} else if e.is_body() || e.is_decode() {
std::io::ErrorKind::InvalidData
} else {
std::io::ErrorKind::Other
};
std::io::Error::new(
kind,
format!("HTTP stream error for {}: {}", url_for_errors, e),
)
})
});
let stream_reader = StreamReader::new(mapped_stream);
let decoder = GzipDecoder::new(BufReader::new(stream_reader));
let buf_reader = BufReader::new(decoder);
let lines = tokio_stream::wrappers::LinesStream::new(buf_reader.lines());
let skip_pos = self.skip_pos_tags;
let min_count = self.min_count;
let mut line_num = 0u64;
let record_stream = lines.filter_map(move |line_result| {
line_num += 1;
match line_result {
Ok(line) => {
if line.is_empty() {
return None;
}
match parse_ngram_line(&line) {
Ok(record) => {
if record.match_count < min_count {
return None;
}
if skip_pos && super::parser::contains_pos_tag(&record.ngram) {
return None;
}
Some(Ok(record))
}
Err(e) => Some(Err(ReaderError::Parse {
line: line_num,
error: e,
})),
}
}
Err(e) => Some(Err(ReaderError::Io(e))),
}
});
Ok(record_stream)
}
#[cfg(feature = "google-books")]
#[deprecated(
since = "0.2.0",
note = "This method can cause OOM for large files. Use stream_records() or stream_aggregated() instead."
)]
pub async fn read_all(&mut self) -> Result<Vec<NgramRecord>, ReaderError> {
use tokio_stream::StreamExt;
const MAX_RECORDS: usize = 10_000_000;
let url = self.url.clone();
let stream = self.stream_records().await?;
tokio::pin!(stream);
let mut records = Vec::new();
while let Some(result) = stream.next().await {
records.push(result?);
if records.len() >= MAX_RECORDS {
log::warn!(
"read_all() hit {} record limit for {}. Use stream_records() for larger files.",
MAX_RECORDS,
url
);
break;
}
}
Ok(records)
}
#[cfg(feature = "google-books")]
#[deprecated(
since = "0.2.0",
note = "This method can cause OOM for large files. Use stream_aggregated() instead."
)]
pub async fn read_aggregated(
&mut self,
year_range: Option<(u16, u16)>,
) -> Result<Vec<AggregatedNgram>, ReaderError> {
use super::aggregator::YearAggregator;
use tokio_stream::StreamExt;
const MAX_AGGREGATED: usize = 10_000_000;
let url = self.url.clone();
let stream = self.stream_records().await?;
tokio::pin!(stream);
let mut aggregator = YearAggregator::new(year_range);
let mut results = Vec::new();
while let Some(result) = stream.next().await {
let record = result?;
if let Some(aggregated) = aggregator.push(record) {
results.push(aggregated);
if results.len() >= MAX_AGGREGATED {
log::warn!(
"read_aggregated() hit {} record limit for {}. Use stream_aggregated() for larger files.",
MAX_AGGREGATED,
url
);
return Ok(results);
}
}
}
if let Some(aggregated) = aggregator.flush() {
results.push(aggregated);
}
Ok(results)
}
#[cfg(feature = "google-books")]
pub fn stream_aggregated(
&mut self,
year_range: Option<(u16, u16)>,
) -> impl tokio_stream::Stream<Item = Result<AggregatedNgram, ReaderError>> + '_ {
self.stream_aggregated_with_client(year_range, None)
}
#[cfg(feature = "google-books")]
pub fn stream_aggregated_with_client(
&mut self,
year_range: Option<(u16, u16)>,
client: Option<reqwest::Client>,
) -> impl tokio_stream::Stream<Item = Result<AggregatedNgram, ReaderError>> + '_ {
use super::aggregator::YearAggregator;
let skip_pos = self.skip_pos_tags;
let min_count = self.min_count;
let url = self.url.clone();
async_stream::try_stream! {
use async_compression::tokio::bufread::GzipDecoder;
use std::time::Duration;
use tokio::io::{AsyncBufReadExt, BufReader};
use tokio_stream::StreamExt;
use tokio_util::io::StreamReader;
let client = match client {
Some(c) => c,
None => {
reqwest::Client::builder()
.timeout(Duration::from_secs(300))
.connect_timeout(Duration::from_secs(30))
.read_timeout(Duration::from_secs(60))
.user_agent("Mozilla/5.0 (compatible; libgrammstein/0.1; +https://github.com/vinary-tree/libgrammstein)")
.build()
.map_err(|e| ReaderError::Http(format!("Failed to build HTTP client: {}", e)))?
}
};
let response = client
.get(&url)
.send()
.await
.map_err(|e| ReaderError::Http(e.to_string()))?;
let status = response.status();
if !status.is_success() {
if status == reqwest::StatusCode::TOO_MANY_REQUESTS {
let retry_after = response.headers()
.get("retry-after")
.and_then(|v| v.to_str().ok())
.and_then(RetryAfter::parse);
tracing::warn!(
"Rate limited (429) for {}, Retry-After: {:?}",
url,
retry_after
);
Err(ReaderError::RateLimited {
url: url.clone(),
retry_after,
})?;
}
if status == reqwest::StatusCode::SERVICE_UNAVAILABLE {
Err(ReaderError::Http(format!(
"Service unavailable (HTTP 503) for {} - Google may be throttling",
url
)))?;
}
Err(ReaderError::Http(format!("HTTP {} for {}", status, url)))?;
}
if let Some(content_length) = response.content_length() {
tracing::debug!("Downloading {} ({} bytes compressed)", url, content_length);
}
let status = response.status();
let headers_clone = response.headers().clone();
let content_length = response.content_length();
let content_type = headers_clone
.get("content-type")
.and_then(|v| v.to_str().ok())
.map(|s| s.to_lowercase());
let content_encoding = headers_clone
.get("content-encoding")
.and_then(|v| v.to_str().ok())
.map(|s| s.to_lowercase());
tracing::debug!(
"Response for {}: Content-Type={:?}, Content-Encoding={:?}, Length={:?}",
url, content_type, content_encoding, content_length
);
let is_html = content_type.as_ref().map(|t| t.contains("text/html")).unwrap_or(false);
let is_gzip = content_encoding.as_ref().map(|e| e.contains("gzip")).unwrap_or(false)
|| url.ends_with(".gz");
let is_plain_text = !is_gzip && content_type.as_ref().map(|t| t.starts_with("text/")).unwrap_or(false);
if is_html || is_plain_text {
let bytes = response.bytes().await.map_err(|e| ReaderError::Http(e.to_string()))?;
let saved_path = save_failed_response(&url, status, &headers_clone, &bytes);
let preview = String::from_utf8_lossy(&bytes[..bytes.len().min(500)]);
let preview_short = &preview[..preview.len().min(200)];
if is_html {
tracing::error!(
"HTML error page received for {}. Saved to: {:?}. Preview: {}",
url, saved_path, preview_short
);
Err(ReaderError::Http(format!(
"Server returned HTML error page for {} (saved to {:?})",
url, saved_path
)))?;
} else {
tracing::error!(
"Plain text received for {}. Saved to: {:?}. Preview: {}",
url, saved_path, preview_short
);
Err(ReaderError::Decompression(format!(
"Server returned plain text for {} (saved to {:?})",
url, saved_path
)))?;
}
} else {
let byte_stream = response.bytes_stream();
let url_for_errors = url.clone();
let mapped_stream = byte_stream.map(move |result| {
result.map_err(|e| {
let kind = if e.is_timeout() {
std::io::ErrorKind::TimedOut
} else if e.is_connect() {
std::io::ErrorKind::ConnectionRefused
} else if e.is_body() || e.is_decode() {
std::io::ErrorKind::InvalidData
} else {
std::io::ErrorKind::Other
};
std::io::Error::new(kind, format!("HTTP stream error for {}: {}", url_for_errors, e))
})
});
let stream_reader = StreamReader::new(mapped_stream);
let decoder = GzipDecoder::new(BufReader::new(stream_reader));
let buf_reader = BufReader::new(decoder);
let lines = tokio_stream::wrappers::LinesStream::new(buf_reader.lines());
tokio::pin!(lines);
let mut aggregator = YearAggregator::new(year_range);
let mut line_num = 0u64;
while let Some(line_result) = lines.next().await {
line_num += 1;
let line = line_result?;
if line.is_empty() {
continue;
}
match super::parser::parse_ngram_line_ref(&line) {
Ok(record) => {
if record.match_count < min_count {
continue;
}
if skip_pos && super::parser::contains_pos_tag(record.ngram) {
continue;
}
if let Some(aggregated) = aggregator.push_ref(&record) {
yield aggregated;
}
}
Err(e) => {
Err(ReaderError::Parse { line: line_num, error: e })?;
}
}
}
if let Some(aggregated) = aggregator.flush() {
yield aggregated;
}
}
}
}
}
pub struct ReaderBuilder {
skip_pos_tags: bool,
min_count: u64,
year_range: Option<(u16, u16)>,
}
impl ReaderBuilder {
pub fn new() -> Self {
Self {
skip_pos_tags: false,
min_count: 0,
year_range: None,
}
}
pub fn skip_pos_tags(mut self, skip: bool) -> Self {
self.skip_pos_tags = skip;
self
}
pub fn min_count(mut self, count: u64) -> Self {
self.min_count = count;
self
}
pub fn year_range(mut self, start: u16, end: u16) -> Self {
self.year_range = Some((start, end));
self
}
pub fn open_file<P: AsRef<Path>>(self, path: P) -> Result<FileNgramReader, ReaderError> {
FileNgramReader::open_with_options(path, self.skip_pos_tags, self.min_count)
}
pub fn open_files<P: AsRef<Path>>(self, paths: &[P]) -> Result<MultiFileReader, ReaderError> {
MultiFileReader::new(paths, self.skip_pos_tags, self.min_count)
}
pub fn http_reader(self, url: &str) -> HttpNgramReader {
HttpNgramReader::with_options(url, self.skip_pos_tags, self.min_count)
}
}
impl Default for ReaderBuilder {
fn default() -> Self {
Self::new()
}
}
pub trait AggregateReaderExt: Iterator<Item = Result<NgramRecord, ReaderError>> + Sized {
fn aggregate(self, year_range: Option<(u16, u16)>) -> AggregatingReaderIterator<Self> {
AggregatingReaderIterator::new(self, year_range)
}
}
impl<I: Iterator<Item = Result<NgramRecord, ReaderError>>> AggregateReaderExt for I {}
pub struct AggregatingReaderIterator<I> {
inner: I,
aggregator: super::aggregator::YearAggregator,
flushed: bool,
errors: Vec<ReaderError>,
}
impl<I> AggregatingReaderIterator<I>
where
I: Iterator<Item = Result<NgramRecord, ReaderError>>,
{
fn new(inner: I, year_range: Option<(u16, u16)>) -> Self {
Self {
inner,
aggregator: super::aggregator::YearAggregator::new(year_range),
flushed: false,
errors: Vec::new(),
}
}
pub fn errors(&self) -> &[ReaderError] {
&self.errors
}
pub fn take_errors(&mut self) -> Vec<ReaderError> {
std::mem::take(&mut self.errors)
}
}
impl<I> Iterator for AggregatingReaderIterator<I>
where
I: Iterator<Item = Result<NgramRecord, ReaderError>>,
{
type Item = AggregatedNgram;
fn next(&mut self) -> Option<Self::Item> {
loop {
match self.inner.next() {
Some(Ok(record)) => {
if let Some(aggregated) = self.aggregator.push(record) {
return Some(aggregated);
}
}
Some(Err(e)) => {
self.errors.push(e);
}
None => {
if !self.flushed {
self.flushed = true;
return self.aggregator.flush();
}
return None;
}
}
}
}
}
#[cfg(feature = "google-books")]
pub fn stream_aggregated_from_cached_file(
path: &std::path::Path,
year_range: Option<(u16, u16)>,
skip_pos_tags: bool,
min_count: u64,
) -> impl tokio_stream::Stream<Item = Result<AggregatedNgram, ReaderError>> + '_ {
use super::aggregator::YearAggregator;
async_stream::try_stream! {
use async_compression::tokio::bufread::GzipDecoder;
use tokio::io::{AsyncBufReadExt, BufReader};
use tokio_stream::StreamExt;
let file = tokio::fs::File::open(path).await.map_err(|e| {
ReaderError::Io(std::io::Error::new(
e.kind(),
format!("Failed to open cached file {}: {}", path.display(), e),
))
})?;
let decoder = GzipDecoder::new(BufReader::new(file));
let buf_reader = BufReader::new(decoder);
let lines = tokio_stream::wrappers::LinesStream::new(buf_reader.lines());
tokio::pin!(lines);
let mut aggregator = YearAggregator::new(year_range);
let mut line_num = 0u64;
while let Some(line_result) = lines.next().await {
line_num += 1;
let line = line_result?;
if line.is_empty() {
continue;
}
match super::parser::parse_ngram_line_ref(&line) {
Ok(record) => {
if record.match_count < min_count {
continue;
}
if skip_pos_tags && super::parser::contains_pos_tag(record.ngram) {
continue;
}
if let Some(aggregated) = aggregator.push_ref(&record) {
yield aggregated;
}
}
Err(e) => {
Err(ReaderError::Parse { line: line_num, error: e })?;
}
}
}
if let Some(aggregated) = aggregator.flush() {
yield aggregated;
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_reader_builder() {
let builder = ReaderBuilder::new()
.skip_pos_tags(true)
.min_count(100)
.year_range(2000, 2020);
assert!(builder.skip_pos_tags);
assert_eq!(builder.min_count, 100);
assert_eq!(builder.year_range, Some((2000, 2020)));
}
#[test]
fn test_http_reader_creation() {
let reader = HttpNgramReader::with_options(
"https://storage.googleapis.com/books/ngrams/books/test.gz",
true,
40,
);
assert!(reader.url().contains("storage.googleapis.com"));
assert!(reader.skip_pos_tags);
assert_eq!(reader.min_count, 40);
}
}