use std::pin::Pin;
use std::task::{Context, Poll};
use jiff::Timestamp;
use quick_xml::NsReader;
use quick_xml::events::Event;
use rama_core::futures::Stream;
use rama_core::futures::StreamExt as _;
use rama_core::futures::async_stream::stream_fn;
use rama_core::futures::stream::BoxStream;
use rama_core::telemetry::tracing;
use rama_net::uri::Uri;
use tokio::io::AsyncBufRead;
use super::names::elem;
use super::{Rss2Category, Rss2Feed, Rss2Guid, Rss2Image, Rss2Item, Rss2Source};
use crate::protocols::rss::atom::AtomLink;
use crate::protocols::rss::atom::names::elem as atom_elem;
use crate::protocols::rss::error::{CollectError, FeedParseError, Rss2CollectError};
use crate::protocols::rss::feed_ext::FeedExtensions;
use crate::protocols::rss::feed_ext::names::attr;
use crate::protocols::rss::feed_ext::parse::{FeedExtAcc, ItemExtAcc, Ns, classify_ns};
use crate::protocols::rss::parse_util::{
attr_uri, attr_value, enclosure_from_attrs, end_event_parts, parse_rss2_date, parse_uri,
push_general_ref, push_text,
};
#[derive(Debug, Clone, PartialEq)]
pub struct Rss2Channel {
pub title: String,
pub link: Uri,
pub description: String,
pub language: Option<String>,
pub copyright: Option<String>,
pub managing_editor: Option<String>,
pub web_master: Option<String>,
pub pub_date: Option<Timestamp>,
pub last_build_date: Option<Timestamp>,
pub categories: Vec<Rss2Category>,
pub generator: Option<String>,
pub docs: Option<String>,
pub ttl: Option<u32>,
pub image: Option<Rss2Image>,
pub atom_links: Vec<AtomLink>,
pub extensions: FeedExtensions,
}
impl Default for Rss2Channel {
fn default() -> Self {
Self {
title: String::new(),
link: Uri::from_static("/"),
description: String::new(),
language: None,
copyright: None,
managing_editor: None,
web_master: None,
pub_date: None,
last_build_date: None,
categories: Vec::new(),
generator: None,
docs: None,
ttl: None,
image: None,
atom_links: Vec::new(),
extensions: FeedExtensions::default(),
}
}
}
impl Rss2Channel {
#[must_use]
pub fn into_feed_with_items<I>(self, items: I) -> Rss2Feed
where
I: IntoIterator<Item = Rss2Item>,
{
Rss2Feed {
title: self.title,
link: self.link,
description: self.description,
language: self.language,
copyright: self.copyright,
managing_editor: self.managing_editor,
web_master: self.web_master,
pub_date: self.pub_date,
last_build_date: self.last_build_date,
categories: self.categories,
generator: self.generator,
docs: self.docs,
ttl: self.ttl,
image: self.image,
atom_links: self.atom_links,
items: items.into_iter().collect(),
extensions: self.extensions,
}
}
}
pub struct Rss2FeedStream {
channel: Rss2Channel,
items: BoxStream<'static, Result<Rss2Item, FeedParseError>>,
}
impl Rss2FeedStream {
pub async fn new<R>(reader: R) -> Result<Self, FeedParseError>
where
R: AsyncBufRead + Unpin + Send + 'static,
{
Self::new_with_mode(reader, false).await
}
pub async fn new_strict<R>(reader: R) -> Result<Self, FeedParseError>
where
R: AsyncBufRead + Unpin + Send + 'static,
{
Self::new_with_mode(reader, true).await
}
pub(in crate::protocols::rss) async fn new_with_mode<R>(
reader: R,
strict: bool,
) -> Result<Self, FeedParseError>
where
R: AsyncBufRead + Unpin + Send + 'static,
{
let mut state = Rss2Reader::new(reader, strict);
let channel = state.read_channel().await?;
let items: BoxStream<'static, Result<Rss2Item, FeedParseError>> =
Box::pin(stream_fn(move |mut yielder| async move {
let mut state = state;
loop {
match state.read_next_item().await {
Ok(Some(item)) => yielder.yield_item(Ok(item)).await,
Ok(None) => return,
Err(e) => {
yielder.yield_item(Err(e)).await;
return;
}
}
}
}));
Ok(Self { channel, items })
}
#[must_use]
pub fn channel(&self) -> &Rss2Channel {
&self.channel
}
#[must_use]
pub fn drain(
self,
) -> (
Rss2Channel,
BoxStream<'static, Result<Rss2Item, FeedParseError>>,
) {
(self.channel, self.items)
}
pub async fn collect(mut self) -> Result<Rss2Feed, Rss2CollectError> {
let mut items = Vec::new();
while let Some(item) = self.items.next().await {
match item {
Ok(it) => items.push(it),
Err(error) => {
return Err(CollectError {
error,
partial: self.channel.into_feed_with_items(items),
});
}
}
}
Ok(self.channel.into_feed_with_items(items))
}
pub async fn collect_lossy(mut self) -> Rss2Feed {
let mut items = Vec::new();
while let Some(item) = self.items.next().await {
match item {
Ok(it) => items.push(it),
Err(err) => tracing::debug!(error = %err, "rss item dropped by collect_lossy"),
}
}
self.channel.into_feed_with_items(items)
}
pub async fn collect_filtered<F>(
mut self,
mut predicate: F,
) -> Result<Rss2Feed, Rss2CollectError>
where
F: FnMut(&Rss2Item) -> bool + Send,
{
let mut items = Vec::new();
while let Some(item) = self.items.next().await {
match item {
Ok(it) => {
if predicate(&it) {
items.push(it);
}
}
Err(error) => {
return Err(CollectError {
error,
partial: self.channel.into_feed_with_items(items),
});
}
}
}
Ok(self.channel.into_feed_with_items(items))
}
}
impl Stream for Rss2FeedStream {
type Item = Result<Rss2Item, FeedParseError>;
fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
let this = Pin::into_inner(self);
this.items.poll_next_unpin(cx)
}
}
enum Action {
Continue,
FirstItemStarted,
ItemFinished(Box<Rss2Item>),
Eof,
}
struct Rss2Reader<R: AsyncBufRead + Unpin + Send> {
nsr: NsReader<R>,
buf: Vec<u8>,
strict: bool,
text_buf: String,
depth: i32,
saw_root: bool,
channel_link_set: bool,
channel: Rss2Channel,
feed_acc: FeedExtAcc,
in_image_block: bool,
image_url: Option<Uri>,
image_title: String,
image_link: Option<Uri>,
image_width: Option<u32>,
image_height: Option<u32>,
image_description: Option<String>,
in_item: bool,
current_item: Rss2Item,
item_acc: ItemExtAcc,
pending_category_domain: Option<String>,
pending_source_url: Option<Uri>,
}
impl<R: AsyncBufRead + Unpin + Send> Rss2Reader<R> {
fn new(reader: R, strict: bool) -> Self {
let mut nsr = NsReader::from_reader(reader);
nsr.config_mut().trim_text(false);
Self {
nsr,
buf: Vec::with_capacity(4096),
strict,
text_buf: String::new(),
depth: 0,
saw_root: false,
channel_link_set: false,
channel: Rss2Channel::default(),
feed_acc: FeedExtAcc::default(),
in_image_block: false,
image_url: None,
image_title: String::new(),
image_link: None,
image_width: None,
image_height: None,
image_description: None,
in_item: false,
current_item: Rss2Item::default(),
item_acc: ItemExtAcc::default(),
pending_category_domain: None,
pending_source_url: None,
}
}
async fn read_channel(&mut self) -> Result<Rss2Channel, FeedParseError> {
loop {
match self.step().await? {
Action::Continue => {}
Action::FirstItemStarted | Action::Eof => {
return self.take_channel();
}
Action::ItemFinished(_) => {
return Err(FeedParseError::new(
"internal: item finished during channel phase",
));
}
}
}
}
async fn read_next_item(&mut self) -> Result<Option<Rss2Item>, FeedParseError> {
loop {
match self.step().await? {
Action::Continue | Action::FirstItemStarted => {}
Action::ItemFinished(item) => return Ok(Some(*item)),
Action::Eof => return Ok(None),
}
}
}
fn take_channel(&mut self) -> Result<Rss2Channel, FeedParseError> {
if !self.saw_root {
return Err(FeedParseError::new("no <rss>/<channel> root encountered"));
}
let mut channel = std::mem::take(&mut self.channel);
channel.extensions = std::mem::take(&mut self.feed_acc).finish();
if self.strict {
if channel.title.is_empty() {
return Err(FeedParseError::new(
"RSS 2.0 channel missing required <title>",
));
}
if !self.channel_link_set {
return Err(FeedParseError::new(
"RSS 2.0 channel missing required <link>",
));
}
if channel.description.is_empty() {
return Err(FeedParseError::new(
"RSS 2.0 channel missing required <description>",
));
}
}
Ok(channel)
}
async fn step(&mut self) -> Result<Action, FeedParseError> {
self.buf.clear();
let (rr, ev) = match self.nsr.read_resolved_event_into_async(&mut self.buf).await {
Ok(p) => p,
Err(e) => {
if self.strict {
return Err(FeedParseError::new(format!("xml error: {e}")));
}
tracing::debug!("rss2 stream xml error (lenient): {e}");
return Ok(Action::Eof);
}
};
match ev {
Event::Start(e) => {
self.depth += 1;
let ns = classify_ns(&rr);
let local_name = e.local_name();
let local = std::str::from_utf8(local_name.as_ref()).unwrap_or("");
self.text_buf.clear();
let consumed = if self.in_item {
self.item_acc.on_start(ns, local, &e)
} else {
self.feed_acc.on_start(ns, local, &e)
};
if !consumed
&& !self.in_item
&& ns == Ns::Atom
&& local == atom_elem::LINK
&& let Some(link) = crate::protocols::rss::parse_util::atom_link_from_attrs(&e)
{
self.channel.atom_links.push(link);
return Ok(Action::Continue);
}
if consumed || ns != Ns::None {
return Ok(Action::Continue);
}
match local {
elem::RSS | elem::CHANNEL => {
self.saw_root = true;
Ok(Action::Continue)
}
elem::ITEM => {
let first_item = !self.in_item;
if !first_item {
if self.strict {
return Err(FeedParseError::new(format!(
"RSS 2.0: nested or re-opened <item> at depth {}",
self.depth,
)));
}
tracing::debug!(
"rss2: nested or re-opened <item> at depth {} — \
partial outer item discarded",
self.depth,
);
}
self.in_item = true;
self.current_item = Rss2Item::default();
self.item_acc = ItemExtAcc::default();
if first_item {
Ok(Action::FirstItemStarted)
} else {
Ok(Action::Continue)
}
}
elem::IMAGE if !self.in_item => {
self.in_image_block = true;
Ok(Action::Continue)
}
elem::ENCLOSURE if self.in_item => {
if let Some(enclosure) = enclosure_from_attrs(&e) {
self.current_item.enclosures.push(enclosure);
}
Ok(Action::Continue)
}
elem::GUID if self.in_item => {
let permalink = attr_value(&e, attr::IS_PERMALINK)
.map(|v| v != "false")
.unwrap_or(true);
self.current_item.guid = Some(Rss2Guid {
value: String::new(),
permalink,
});
Ok(Action::Continue)
}
elem::SOURCE if self.in_item => {
self.pending_source_url = attr_uri(&e, attr::URL);
Ok(Action::Continue)
}
elem::CATEGORY => {
self.pending_category_domain = attr_value(&e, attr::DOMAIN);
Ok(Action::Continue)
}
_ => Ok(Action::Continue),
}
}
Event::Empty(e) => {
let ns = classify_ns(&rr);
let local_name = e.local_name();
let local = std::str::from_utf8(local_name.as_ref()).unwrap_or("");
let consumed = if self.in_item {
self.item_acc.on_empty(ns, local, &e)
} else {
self.feed_acc.on_empty(ns, local, &e)
};
if consumed {
return Ok(Action::Continue);
}
if !self.in_item
&& ns == Ns::Atom
&& local == atom_elem::LINK
&& let Some(link) = crate::protocols::rss::parse_util::atom_link_from_attrs(&e)
{
self.channel.atom_links.push(link);
return Ok(Action::Continue);
}
if ns == Ns::None
&& self.in_item
&& local == elem::ENCLOSURE
&& let Some(enclosure) = enclosure_from_attrs(&e)
{
self.current_item.enclosures.push(enclosure);
}
Ok(Action::Continue)
}
Event::Text(e) => {
push_text(&mut self.text_buf, &e, self.strict)?;
Ok(Action::Continue)
}
Event::GeneralRef(e) => {
push_general_ref(&mut self.text_buf, &e, self.strict)?;
Ok(Action::Continue)
}
Event::CData(e) => {
match std::str::from_utf8(e.as_ref()) {
Ok(t) => self.text_buf.push_str(t),
Err(err) => {
if self.strict {
return Err(FeedParseError::new(format!("invalid CDATA: {err}")));
}
tracing::debug!("rss2 stream CDATA utf8 error (lenient): {err}");
self.text_buf.push_str(&String::from_utf8_lossy(e.as_ref()));
}
}
Ok(Action::Continue)
}
Event::End(e) => {
self.depth -= 1;
let ns = classify_ns(&rr);
let mut name = [0u8; 64];
let (local, text) = end_event_parts(e, &mut name, &mut self.text_buf);
self.handle_end(ns, local, text)
}
Event::Eof => {
if self.strict && self.depth > 0 {
return Err(FeedParseError::new(format!(
"truncated RSS 2.0 document ({} unclosed elements at EOF)",
self.depth
)));
}
Ok(Action::Eof)
}
_ => Ok(Action::Continue),
}
}
fn handle_end(&mut self, ns: Ns, local: &str, text: String) -> Result<Action, FeedParseError> {
if self.in_item {
let Some(text) = self.item_acc.on_end(ns, local, text) else {
return Ok(Action::Continue);
};
if ns != Ns::None {
return Ok(Action::Continue);
}
match local {
elem::TITLE => self.current_item.title = Some(text),
elem::LINK => self.current_item.link = parse_uri(&text),
elem::DESCRIPTION => self.current_item.description = Some(text),
elem::AUTHOR => self.current_item.author = Some(text),
elem::COMMENTS => self.current_item.comments = Some(text),
elem::PUB_DATE => self.current_item.pub_date = parse_rss2_date(&text),
elem::GUID => {
if let Some(guid) = &mut self.current_item.guid {
guid.value = text;
}
}
elem::CATEGORY => self.current_item.categories.push(Rss2Category {
name: text,
domain: self.pending_category_domain.take(),
}),
elem::SOURCE => {
if let Some(url) = self.pending_source_url.take() {
self.current_item.source = Some(Rss2Source { title: text, url });
}
}
elem::ITEM => {
self.current_item.extensions = std::mem::take(&mut self.item_acc).finish();
let item = std::mem::take(&mut self.current_item);
self.in_item = false;
if self.strict
&& item.title.as_deref().is_none_or(str::is_empty)
&& item.description.as_deref().is_none_or(str::is_empty)
{
return Err(FeedParseError::new(
"RSS 2.0 item must carry at least one of <title> or <description>",
));
}
return Ok(Action::ItemFinished(Box::new(item)));
}
_ => {}
}
} else if self.in_image_block {
match local {
elem::URL => self.image_url = parse_uri(&text),
elem::TITLE => self.image_title = text,
elem::LINK => self.image_link = parse_uri(&text),
elem::WIDTH => self.image_width = text.parse().ok(),
elem::HEIGHT => self.image_height = text.parse().ok(),
elem::DESCRIPTION => self.image_description = Some(text),
elem::IMAGE => {
self.in_image_block = false;
if let (Some(url), Some(link)) = (self.image_url.take(), self.image_link.take())
{
self.channel.image = Some(Rss2Image {
url,
title: std::mem::take(&mut self.image_title),
link,
width: self.image_width.take(),
height: self.image_height.take(),
description: self.image_description.take(),
});
}
}
_ => {}
}
} else {
let Some(text) = self.feed_acc.on_end(ns, local, text) else {
return Ok(Action::Continue);
};
if ns != Ns::None {
return Ok(Action::Continue);
}
match local {
elem::TITLE => self.channel.title = text,
elem::LINK => {
if let Some(link) = parse_uri(&text) {
self.channel.link = link;
self.channel_link_set = true;
} else if self.strict {
return Err(FeedParseError::new(format!(
"RSS 2.0 channel <link> could not be parsed as URI: {text:?}"
)));
}
}
elem::DESCRIPTION => self.channel.description = text,
elem::LANGUAGE => self.channel.language = Some(text),
elem::COPYRIGHT => self.channel.copyright = Some(text),
elem::MANAGING_EDITOR => self.channel.managing_editor = Some(text),
elem::WEB_MASTER => self.channel.web_master = Some(text),
elem::PUB_DATE => self.channel.pub_date = parse_rss2_date(&text),
elem::LAST_BUILD_DATE => self.channel.last_build_date = parse_rss2_date(&text),
elem::GENERATOR => self.channel.generator = Some(text),
elem::TTL => self.channel.ttl = text.parse().ok(),
elem::DOCS => self.channel.docs = Some(text),
elem::CATEGORY => self.channel.categories.push(Rss2Category {
name: text,
domain: self.pending_category_domain.take(),
}),
_ => {}
}
}
Ok(Action::Continue)
}
}