clientix_core/client/asynchronous/stream/
sse.rs

1use crate::client::asynchronous::stream::{ClientixStream, ClientixStreamInterface};
2use crate::client::response::{ClientixError, ClientixResult};
3use futures_core::Stream;
4use futures_util::{StreamExt, TryStreamExt};
5use http::{HeaderMap, StatusCode, Version};
6use reqwest::Url;
7use std::net::SocketAddr;
8use std::pin::Pin;
9use std::str::FromStr;
10use std::task::{Context, Poll};
11use encoding_rs::UTF_8;
12use serde::de::DeserializeOwned;
13
14const ID_PROPERTY: &str = "id:";
15const EVENT_PROPERTY: &str = "event:";
16const COMMENT_PROPERTY: &str = ":";
17const RETRY_PROPERTY: &str = "retry:";
18const DATA_PROPERTY: &str = "data:";
19
20#[derive(Debug, Clone)]
21pub struct SSE<T> {
22    id: Option<String>,
23    event: Option<String>,
24    comment: Option<String>,
25    retry: Option<u64>,
26    data: Option<T>
27}
28
29impl<T> SSE<T> {
30
31    fn new() -> Self {
32        Self {
33            id: None,
34            event: None,
35            comment: None,
36            retry: None,
37            data: None,
38        }
39    }
40
41    pub fn id(&self) -> &Option<String> {
42        &self.id
43    }
44
45    pub fn event(&self) -> &Option<String> {
46        &self.event
47    }
48
49    pub fn comment(&self) -> &Option<String> {
50        &self.comment
51    }
52
53    pub fn retry(&self) -> &Option<u64> {
54        &self.retry
55    }
56
57    pub fn data(&self) -> &Option<T> {
58        &self.data
59    }
60
61}
62
63pub struct ClientixSSEStream<T> {
64    version: Version,
65    content_length: Option<u64>,
66    status: StatusCode,
67    url: Url,
68    remote_addr: Option<SocketAddr>,
69    headers: HeaderMap,
70    stream: Pin<Box<dyn Stream<Item = ClientixResult<SSE<T>>>>>,
71}
72
73impl<T> ClientixSSEStream<T> {
74
75    pub fn new(
76        version: Version,
77        content_length: Option<u64>,
78        status: StatusCode,
79        url: Url,
80        remote_addr: Option<SocketAddr>,
81        headers: HeaderMap,
82        stream: impl Stream<Item = ClientixResult<SSE<T>>> + 'static
83    ) -> Self {
84        Self {
85            version,
86            content_length,
87            status,
88            url,
89            remote_addr,
90            headers,
91            stream: Box::pin(stream)
92        }
93    }
94
95}
96
97impl ClientixSSEStream<String> {
98
99    pub fn object_stream<T, F>(self, mut convert: F) -> ClientixSSEStream<T> 
100    where T: DeserializeOwned + Clone, F: FnMut(&str) -> ClientixResult<T> + 'static {
101        let version = self.version();
102        let content_length = self.content_length();
103        let status = self.status();
104        let url = self.url().clone();
105        let remote_addr = self.remote_addr();
106        let headers = self.headers().clone();
107        let stream = self
108            .filter(|line| match line {
109                Ok(line) if !line.data.clone().unwrap_or(String::new()).contains("[DONE]") => futures_util::future::ready(true),
110                _ => futures_util::future::ready(false)
111            })
112            .map(move |line| match line {
113                Ok(line) => {
114                    let mut sse = SSE::new();
115                    sse.id = line.id.clone();
116                    sse.event = line.event.clone();
117                    sse.comment = line.comment.clone();
118                    sse.retry = line.retry;
119                    sse.data =  Some(convert(line.data.clone().unwrap_or(String::new()).as_str())?);
120
121                    Ok(sse)
122                },
123                Err(err) => Err(err),
124            });
125        
126        ClientixSSEStream::new(version, content_length, status, url, remote_addr, headers, stream)
127    }
128
129    pub fn json_stream<T>(self) -> ClientixSSEStream<T> where T: DeserializeOwned + Clone {
130        self.object_stream(|string| {
131            serde_json::from_str::<T>(string).map_err(ClientixError::from)
132        })
133    }
134    
135    pub fn xml_stream<T>(self) -> ClientixSSEStream<T> where T: DeserializeOwned + Clone {
136        self.object_stream(|string| serde_xml_rs::from_str::<T>(string).map_err(ClientixError::from))
137    }
138    
139}
140
141impl<T> Stream for ClientixSSEStream<T> {
142    type Item = ClientixResult<SSE<T>>;
143
144    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
145        self.stream.poll_next_unpin(cx)
146    }
147}
148
149impl<T> ClientixStreamInterface<SSE<T>> for ClientixSSEStream<T> {
150
151    fn version(&self) -> Version {
152        self.version
153    }
154
155    fn content_length(&self) -> Option<u64> {
156        self.content_length
157    }
158
159    fn status(&self) -> StatusCode {
160        self.status
161    }
162
163    fn url(&self) -> &Url {
164        &self.url
165    }
166
167    fn remote_addr(&self) -> Option<SocketAddr> {
168        self.remote_addr
169    }
170
171    fn headers(&self) -> &HeaderMap {
172        &self.headers
173    }
174
175    async fn execute<F>(mut self, mut handle: F) where F: FnMut(ClientixResult<SSE<T>>) {
176        while let Some(result) = self.stream.next().await {
177            handle(result);
178        }
179    }
180
181    async fn collect(self) -> ClientixResult<Vec<SSE<T>>> {
182        self.stream.try_collect().await
183    }
184
185}
186
187impl From<ClientixStream> for ClientixSSEStream<String> {
188    fn from(stream: ClientixStream) -> Self {
189        let version = stream.version();
190        let content_length = stream.content_length();
191        let status = stream.status();
192        let url = stream.url().clone();
193        let remote_addr = stream.remote_addr();
194        let headers = stream.headers().clone();
195
196        let mut buffer = String::new();
197        let stream = stream
198            .map(|chunk| match chunk {
199                Ok(chunk) => {
200                    let (text, _, _) = UTF_8.decode(&chunk);
201                    Ok(text.to_string())
202                },
203                Err(error) => Err(error),
204            })
205            .flat_map(move |text| match text {
206                Ok(text) => {
207                    let mut events = Vec::new();
208                    for line in text.lines() {
209                        let mut sse = SSE::new();
210                        match line {
211                            line if line.starts_with(ID_PROPERTY) => {
212                                sse.id = line.strip_prefix(ID_PROPERTY)
213                                    .map(str::trim)
214                                    .map(str::to_string);
215                            }
216                            line if line.starts_with(EVENT_PROPERTY) => {
217                                sse.event = line.strip_prefix(EVENT_PROPERTY)
218                                    .map(str::trim)
219                                    .map(str::to_string);
220                            }
221                            line if line.starts_with(COMMENT_PROPERTY) => {
222                                sse.comment = line.strip_prefix(COMMENT_PROPERTY)
223                                    .map(str::trim)
224                                    .map(str::to_string);
225                            }
226                            line if line.starts_with(RETRY_PROPERTY) => {
227                                sse.retry = line.strip_prefix(RETRY_PROPERTY)
228                                    .map(str::trim)
229                                    .map(u64::from_str)
230                                    .map(|result| match result {
231                                        Ok(value) => Some(value),
232                                        Err(_) => None
233                                    })
234                                    .unwrap_or(None);
235                            }
236                            line if line.starts_with(DATA_PROPERTY) => {
237                                buffer.push_str(line.trim_start_matches(DATA_PROPERTY).trim());
238                            }
239                            _ => {
240                                if !buffer.is_empty() {
241                                    sse.data = Some(buffer.to_string());
242                                    buffer.clear();
243
244                                    events.push(Ok(sse))
245                                }
246                            }
247                        }
248                    }
249
250                    futures_util::stream::iter(events)
251                }
252                Err(error) => futures_util::stream::iter(vec![Err(error)])
253            });
254
255        ClientixSSEStream::new(
256            version,
257            content_length,
258            status,
259            url,
260            remote_addr,
261            headers,
262            stream
263        )
264    }
265}