kode_bridge/
stream_client.rs

1use crate::errors::{KodeBridgeError, Result};
2use bytes::Bytes;
3use futures::stream::StreamExt;
4use http::{header, HeaderMap, StatusCode};
5use pin_project_lite::pin_project;
6use serde::de::DeserializeOwned;
7use serde_json::Value;
8use std::collections::HashMap;
9use std::pin::Pin;
10use std::str::FromStr;
11use std::task::{Context, Poll};
12use std::time::Duration;
13use tokio::io::{AsyncBufReadExt, AsyncRead, AsyncWrite, AsyncWriteExt, BufReader};
14use tokio_stream::Stream;
15use tokio_util::codec::{FramedRead, LinesCodec};
16use tracing::{debug, trace, warn};
17
18pin_project! {
19    /// Streaming HTTP response that yields data as it arrives
20    pub struct StreamingResponse {
21        pub status: StatusCode,
22        pub headers: HeaderMap,
23        #[pin]
24        pub stream: Pin<Box<dyn Stream<Item = std::result::Result<String, std::io::Error>> + Send>>,
25    }
26}
27
28impl StreamingResponse {
29    pub fn new(
30        status: StatusCode,
31        headers: HeaderMap,
32        stream: Pin<Box<dyn Stream<Item = std::result::Result<String, std::io::Error>> + Send>>,
33    ) -> Self {
34        Self {
35            status,
36            headers,
37            stream,
38        }
39    }
40
41    /// Get HTTP status code
42    pub fn status(&self) -> StatusCode {
43        self.status
44    }
45
46    /// Get status code as u16
47    pub fn status_code(&self) -> u16 {
48        self.status.as_u16()
49    }
50
51    /// Get response headers
52    pub fn headers(&self) -> &HeaderMap {
53        &self.headers
54    }
55
56    /// Check if response indicates success (2xx status)
57    pub fn is_success(&self) -> bool {
58        self.status.is_success()
59    }
60
61    /// Check if response indicates client error (4xx status)
62    pub fn is_client_error(&self) -> bool {
63        self.status.is_client_error()
64    }
65
66    /// Check if response indicates server error (5xx status)
67    pub fn is_server_error(&self) -> bool {
68        self.status.is_server_error()
69    }
70
71    /// Get content length from headers
72    pub fn content_length(&self) -> Option<u64> {
73        self.headers
74            .get(header::CONTENT_LENGTH)?
75            .to_str()
76            .ok()?
77            .parse()
78            .ok()
79    }
80
81    /// Get content type from headers
82    pub fn content_type(&self) -> Option<&str> {
83        self.headers.get(header::CONTENT_TYPE)?.to_str().ok()
84    }
85
86    /// Elegant JSON stream processing - automatically parse and filter valid data
87    pub async fn json<T>(mut self, timeout: Duration) -> Result<Vec<T>>
88    where
89        T: DeserializeOwned + Send,
90    {
91        let mut results = Vec::new();
92        let timeout_future = tokio::time::sleep(timeout);
93        tokio::pin!(timeout_future);
94
95        loop {
96            tokio::select! {
97                line_result = self.stream.next() => {
98                    match line_result {
99                        Some(Ok(line)) => {
100                            if line.trim().is_empty() {
101                                continue;
102                            }
103                            // Auto-parse JSON, ignore failures for robustness
104                            if let Ok(parsed) = serde_json::from_str::<T>(&line) {
105                                results.push(parsed);
106                            } else {
107                                trace!("Failed to parse JSON line: {}", line);
108                            }
109                        }
110                        Some(Err(e)) => {
111                            warn!("Stream error: {}", e);
112                            break;
113                        }
114                        None => break,
115                    }
116                }
117                _ = &mut timeout_future => {
118                    debug!("Stream timeout reached after {}ms", timeout.as_millis());
119                    break;
120                }
121            }
122        }
123
124        Ok(results)
125    }
126
127    /// Process stream with custom JSON handler
128    pub async fn process_json<F, T>(mut self, timeout: Duration, mut handler: F) -> Result<Vec<T>>
129    where
130        F: FnMut(&str) -> Option<T>,
131        T: Send + 'static,
132    {
133        let mut results = Vec::new();
134        let timeout_future = tokio::time::sleep(timeout);
135        tokio::pin!(timeout_future);
136
137        loop {
138            tokio::select! {
139                line_result = self.stream.next() => {
140                    match line_result {
141                        Some(Ok(line)) => {
142                            if let Some(parsed) = handler(&line) {
143                                results.push(parsed);
144                            }
145                        }
146                        Some(Err(e)) => {
147                            warn!("Stream error: {}", e);
148                            break;
149                        }
150                        None => break,
151                    }
152                }
153                _ = &mut timeout_future => break,
154            }
155        }
156
157        Ok(results)
158    }
159
160    /// Process stream data in real-time with error handling
161    pub async fn process_lines<F>(mut self, mut handler: F) -> Result<()>
162    where
163        F: FnMut(&str) -> std::result::Result<(), Box<dyn std::error::Error + Send + Sync>>,
164    {
165        while let Some(line_result) = self.stream.next().await {
166            match line_result {
167                Ok(line) => {
168                    if let Err(e) = handler(&line) {
169                        warn!("Handler error: {}", e);
170                        return Err(KodeBridgeError::custom(format!("Handler error: {}", e)));
171                    }
172                }
173                Err(e) => {
174                    warn!("Stream error: {}", e);
175                    return Err(KodeBridgeError::from(e));
176                }
177            }
178        }
179        Ok(())
180    }
181
182    /// Process stream with timeout and error handling - optimized for better performance
183    pub async fn process_lines_with_timeout<F>(
184        mut self,
185        timeout: Duration,
186        mut handler: F,
187    ) -> Result<()>
188    where
189        F: FnMut(&str) -> std::result::Result<bool, Box<dyn std::error::Error + Send + Sync>>, // Return false to stop
190    {
191        // 使用更短的超时避免长时间的waker等待
192        let optimized_timeout = std::cmp::min(timeout, Duration::from_secs(5));
193        let timeout_future = tokio::time::sleep(optimized_timeout);
194        tokio::pin!(timeout_future);
195
196        loop {
197            tokio::select! {
198                line_result = self.stream.next() => {
199                    match line_result {
200                        Some(Ok(line)) => {
201                            match handler(&line) {
202                                Ok(continue_processing) => {
203                                    if !continue_processing {
204                                        break;
205                                    }
206                                    // 重置超时计时器以避免不必要的超时
207                                    timeout_future.as_mut().reset(tokio::time::Instant::now() + optimized_timeout);
208                                }
209                                Err(e) => {
210                                    warn!("Handler error: {}", e);
211                                    return Err(KodeBridgeError::custom(format!("Handler error: {}", e)));
212                                }
213                            }
214                        }
215                        Some(Err(e)) => {
216                            warn!("Stream error: {}", e);
217                            return Err(KodeBridgeError::from(e));
218                        }
219                        None => break,
220                    }
221                }
222                _ = &mut timeout_future => {
223                    debug!("Processing timeout reached ({:?})", optimized_timeout);
224                    break;
225                }
226            }
227        }
228
229        Ok(())
230    }
231
232    /// Collect all stream data into a string
233    pub async fn collect_text(mut self) -> Result<String> {
234        let mut body_lines = Vec::new();
235
236        while let Some(line_result) = self.stream.next().await {
237            match line_result {
238                Ok(line) => body_lines.push(line),
239                Err(e) => return Err(KodeBridgeError::from(e)),
240            }
241        }
242
243        Ok(body_lines.join("\n"))
244    }
245
246    /// Collect stream data with a timeout - optimized for better performance
247    pub async fn collect_text_with_timeout(mut self, timeout: Duration) -> Result<String> {
248        let mut body_lines = Vec::new();
249
250        // 限制最大超时时间避免长时间waker等待
251        let optimized_timeout = std::cmp::min(timeout, Duration::from_secs(30));
252        let timeout_future = tokio::time::sleep(optimized_timeout);
253        tokio::pin!(timeout_future);
254
255        loop {
256            tokio::select! {
257                line_result = self.stream.next() => {
258                    match line_result {
259                        Some(Ok(line)) => {
260                            body_lines.push(line);
261                            // 收到数据后重置超时,避免不必要的超时
262                            timeout_future.as_mut().reset(tokio::time::Instant::now() + optimized_timeout);
263                        }
264                        Some(Err(e)) => return Err(KodeBridgeError::from(e)),
265                        None => break, // Stream ended
266                    }
267                }
268                _ = &mut timeout_future => {
269                    debug!("Collection timeout reached");
270                    break; // Timeout reached
271                }
272            }
273        }
274
275        Ok(body_lines.join("\n"))
276    }
277
278    /// Convert to legacy format for compatibility
279    pub fn status_u16(&self) -> u16 {
280        self.status.as_u16()
281    }
282
283    /// Get headers as JSON value for compatibility
284    pub fn headers_json(&self) -> Value {
285        let headers_map: HashMap<String, String> = self
286            .headers
287            .iter()
288            .map(|(k, v)| (k.to_string(), v.to_str().unwrap_or("").to_string()))
289            .collect();
290        serde_json::to_value(headers_map).unwrap_or(Value::Null)
291    }
292}
293
294impl Stream for StreamingResponse {
295    type Item = std::result::Result<String, std::io::Error>;
296
297    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
298        let this = self.project();
299        this.stream.poll_next(cx)
300    }
301}
302
303/// Parse HTTP response headers and create streaming response
304pub async fn parse_streaming_response<S>(stream: S) -> Result<StreamingResponse>
305where
306    S: AsyncRead + Unpin + Send + 'static,
307{
308    let mut reader = BufReader::new(stream);
309    let mut buffer = Vec::new();
310
311    // Read until we have the complete headers
312    let mut headers_end = None;
313    loop {
314        let mut line = Vec::new();
315        let n = reader.read_until(b'\n', &mut line).await?;
316        if n == 0 {
317            return Err(KodeBridgeError::protocol("Unexpected end of stream"));
318        }
319
320        buffer.extend_from_slice(&line);
321
322        // Check for end of headers (\r\n\r\n)
323        if buffer.len() >= 4 {
324            for i in 0..buffer.len() - 3 {
325                if &buffer[i..i + 4] == b"\r\n\r\n" {
326                    headers_end = Some(i + 4);
327                    break;
328                }
329            }
330        }
331
332        if headers_end.is_some() {
333            break;
334        }
335    }
336
337    let headers_end = headers_end
338        .ok_or_else(|| KodeBridgeError::protocol("Could not find end of HTTP headers"))?;
339
340    // Parse the headers using httparse
341    let mut headers = [httparse::EMPTY_HEADER; 64];
342    let mut response = httparse::Response::new(&mut headers);
343
344    let status = match response.parse(&buffer[..headers_end])? {
345        httparse::Status::Complete(_) => response
346            .code
347            .ok_or_else(|| KodeBridgeError::protocol("HTTP response missing status code"))?,
348        httparse::Status::Partial => {
349            return Err(KodeBridgeError::protocol("Incomplete HTTP response"));
350        }
351    };
352
353    // Build HeaderMap
354    let mut header_map = HeaderMap::new();
355    for header in response.headers {
356        let name =
357            http::HeaderName::from_str(header.name).map_err(|e| KodeBridgeError::Http(e.into()))?;
358        let value = http::HeaderValue::from_bytes(header.value)
359            .map_err(|e| KodeBridgeError::Http(e.into()))?;
360        header_map.insert(name, value);
361    }
362
363    // Create line stream from the remaining reader
364    let framed = FramedRead::new(reader, LinesCodec::new());
365    let line_stream = framed.map(|result| result.map_err(std::io::Error::other));
366
367    Ok(StreamingResponse::new(
368        StatusCode::from_u16(status)?,
369        header_map,
370        Box::pin(line_stream),
371    ))
372}
373
374/// Send HTTP request and get streaming response
375pub async fn send_streaming_request<S>(mut stream: S, request: Bytes) -> Result<StreamingResponse>
376where
377    S: AsyncRead + AsyncWrite + Unpin + Send + 'static,
378{
379    // Send request
380    stream.write_all(&request).await?;
381    stream.flush().await?;
382
383    trace!("Sent HTTP streaming request ({} bytes)", request.len());
384
385    // Parse response
386    let response = parse_streaming_response(stream).await?;
387
388    debug!(
389        "Received HTTP streaming response: {} {}",
390        response.status(),
391        response.content_length().unwrap_or(0)
392    );
393
394    Ok(response)
395}