pub mod error;
use std::{pin::Pin, time::Duration};
use async_stream::try_stream;
use reqwest::{
Response, StatusCode,
header::{CONTENT_TYPE, HeaderValue},
};
use tokio::io::AsyncBufReadExt;
use tokio_stream::{Stream, StreamExt};
use tokio_util::io::StreamReader;
use crate::error::{EventError, EventSourceError};
pub type ServerSentEvents = Pin<Box<dyn Stream<Item = Result<Event, EventError>>>>;
pub static MIME_EVENT_STREAM: &[u8] = b"text/event-stream";
fn is_event_stream(value: &HeaderValue) -> bool {
value
.as_bytes()
.split(|&b| b == b';')
.next()
.unwrap_or(b"")
.trim_ascii()
.eq_ignore_ascii_case(MIME_EVENT_STREAM)
}
struct EventBuffer {
event_type: String,
data: String,
last_event_id: Option<String>,
retry: Option<Duration>,
}
impl EventBuffer {
#[allow(clippy::new_without_default)]
fn new() -> Self {
Self {
event_type: String::new(),
data: String::new(),
last_event_id: None,
retry: None,
}
}
fn produce_event(&mut self) -> Option<Event> {
let event = if self.data.is_empty() {
None
} else {
Some(Event {
event_type: if self.event_type.is_empty() {
"message".to_string()
} else {
self.event_type.clone()
},
data: self.data.clone(),
last_event_id: self.last_event_id.clone(),
retry: self.retry,
})
};
self.event_type.clear();
self.data.clear();
event
}
fn set_event_type(&mut self, event_type: &str) {
self.event_type.clear();
self.event_type.push_str(event_type);
}
fn push_data(&mut self, data: &str) {
if !self.data.is_empty() {
self.data.push('\n');
}
self.data.push_str(data);
}
fn set_id(&mut self, id: &str) {
self.last_event_id = Some(id.to_string());
}
fn set_retry(&mut self, retry: Duration) {
self.retry = Some(retry);
}
}
fn parse_line(line: &str) -> (&str, &str) {
let (field, value) = line.split_once(':').unwrap_or((line, ""));
let value = value.strip_prefix(' ').unwrap_or(value);
(field, value)
}
#[derive(Debug, Clone, Eq, PartialEq)]
pub struct Event {
pub event_type: String,
pub data: String,
pub last_event_id: Option<String>,
pub retry: Option<Duration>,
}
pub trait EventSource {
fn events(self) -> impl Future<Output = Result<ServerSentEvents, EventSourceError>> + Send;
}
impl EventSource for Response {
async fn events(self) -> Result<ServerSentEvents, EventSourceError> {
let status = self.status();
if status != StatusCode::OK {
return Err(EventSourceError::BadStatus(status));
}
match self.headers().get(CONTENT_TYPE) {
Some(content_type) => {
if !is_event_stream(content_type) {
return Err(EventSourceError::BadContentType(Some(
content_type.to_owned(),
)));
}
}
None => return Err(EventSourceError::BadContentType(None)),
}
let mut stream = StreamReader::new(
self.bytes_stream()
.map(|result| result.map_err(std::io::Error::other)),
);
let mut line_buffer = String::new();
let mut event_buffer = EventBuffer::new();
let stream = Box::pin(try_stream! {
loop {
line_buffer.clear();
let count = stream.read_line(&mut line_buffer).await.map_err(EventError::IoError)?;
if count == 0 {
break;
}
let line = if let Some(line) = line_buffer.strip_suffix('\n') {
line
} else {
&line_buffer
};
if line.is_empty() {
if let Some(event) = event_buffer.produce_event() {
yield event;
}
continue;
}
let (field, value) = parse_line(line);
match field {
"event" => {
event_buffer.set_event_type(value);
}
"data" => {
event_buffer.push_data(value);
}
"id" => {
event_buffer.set_id(value);
}
"retry" => {
if let Ok(millis) = value.parse() {
event_buffer.set_retry(Duration::from_millis(millis));
}
}
_ => {}
}
}
});
Ok(stream)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn parse_line_properly() {
let (field, value) = parse_line("event: message");
assert_eq!(field, "event");
assert_eq!(value, "message");
let (field, value) = parse_line("non-standard field");
assert_eq!(field, "non-standard field");
assert_eq!(value, "");
let (field, value) = parse_line("data:data with : inside");
assert_eq!(field, "data");
assert_eq!(value, "data with : inside");
}
#[test]
fn is_event_stream_accept_valid_values() {
assert!(is_event_stream(&HeaderValue::from_static(
"text/event-stream"
)));
assert!(is_event_stream(&HeaderValue::from_static(
"text/event-stream; charset=utf-8"
)));
assert!(is_event_stream(&HeaderValue::from_static(
" TEXT/event-stream ; charset=utf-8"
)));
}
#[test]
fn is_event_stream_reject_invalid_values() {
assert!(!is_event_stream(&HeaderValue::from_static("plain/text")));
assert!(!is_event_stream(&HeaderValue::from_static(
"text/event-but-not-realy"
)));
}
}