use std::pin::Pin;
use std::task::{Context, Poll};
use syncable_ag_ui_core::{Event, JsonValue};
use futures::Stream;
use reqwest::Client;
use reqwest_eventsource::{Event as SseEvent, EventSource};
use crate::error::{ClientError, Result};
#[derive(Debug, Clone)]
pub struct SseConfig {
pub connect_timeout: std::time::Duration,
pub headers: Vec<(String, String)>,
}
impl Default for SseConfig {
fn default() -> Self {
Self {
connect_timeout: std::time::Duration::from_secs(30),
headers: Vec::new(),
}
}
}
impl SseConfig {
pub fn new() -> Self {
Self::default()
}
pub fn connect_timeout(mut self, timeout: std::time::Duration) -> Self {
self.connect_timeout = timeout;
self
}
pub fn header(mut self, name: impl Into<String>, value: impl Into<String>) -> Self {
self.headers.push((name.into(), value.into()));
self
}
pub fn bearer_token(self, token: impl Into<String>) -> Self {
self.header("Authorization", format!("Bearer {}", token.into()))
}
}
pub struct SseClient {
event_source: EventSource,
}
impl SseClient {
pub async fn connect(url: &str) -> Result<Self> {
Self::connect_with_config(url, SseConfig::default()).await
}
pub async fn connect_with_config(url: &str, config: SseConfig) -> Result<Self> {
let client = Client::builder()
.timeout(config.connect_timeout)
.build()
.map_err(|e| ClientError::connection(e.to_string()))?;
let mut request = client.get(url);
for (name, value) in config.headers {
request = request.header(&name, &value);
}
let event_source = EventSource::new(request)
.map_err(|e| ClientError::connection(e.to_string()))?;
Ok(Self { event_source })
}
pub fn into_stream(self) -> SseEventStream {
SseEventStream {
event_source: self.event_source,
}
}
}
pub struct SseEventStream {
event_source: EventSource,
}
impl Stream for SseEventStream {
type Item = Result<Event<JsonValue>>;
fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
loop {
match Pin::new(&mut self.event_source).poll_next(cx) {
Poll::Ready(Some(Ok(sse_event))) => {
match sse_event {
SseEvent::Open => {
continue;
}
SseEvent::Message(msg) => {
match serde_json::from_str::<Event<JsonValue>>(&msg.data) {
Ok(event) => return Poll::Ready(Some(Ok(event))),
Err(e) => {
return Poll::Ready(Some(Err(ClientError::parse(format!(
"failed to parse event: {}",
e
)))))
}
}
}
}
}
Poll::Ready(Some(Err(e))) => {
return Poll::Ready(Some(Err(ClientError::sse(e.to_string()))))
}
Poll::Ready(None) => return Poll::Ready(None),
Poll::Pending => return Poll::Pending,
}
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_sse_config_default() {
let config = SseConfig::default();
assert_eq!(config.connect_timeout, std::time::Duration::from_secs(30));
assert!(config.headers.is_empty());
}
#[test]
fn test_sse_config_builder() {
let config = SseConfig::new()
.connect_timeout(std::time::Duration::from_secs(60))
.header("X-Custom", "value")
.bearer_token("token123");
assert_eq!(config.connect_timeout, std::time::Duration::from_secs(60));
assert_eq!(config.headers.len(), 2);
assert_eq!(config.headers[0], ("X-Custom".to_string(), "value".to_string()));
assert_eq!(
config.headers[1],
("Authorization".to_string(), "Bearer token123".to_string())
);
}
}