1use anyhow::{anyhow, Result};
12use async_trait::async_trait;
13use futures::Stream;
14use reqwest::{Client, Response, StatusCode};
15use std::collections::VecDeque;
16use std::sync::atomic::{AtomicBool, Ordering};
17use std::sync::Arc;
18use std::time::Duration;
19
20use crate::streaming::providers::StreamEvent;
21use crate::streaming::sse::SseEvent;
22use crate::streaming::sse::SseParser;
23
24#[derive(Debug, Clone)]
26pub struct HttpConfig {
27 pub timeout_ms: u64,
29 pub connect_timeout_ms: u64,
31 pub max_retries: u32,
33 pub retry_interval_ms: u64,
35 pub stream_timeout_ms: u64,
37}
38
39impl Default for HttpConfig {
40 fn default() -> Self {
41 Self {
42 timeout_ms: 30000,
43 connect_timeout_ms: 10000,
44 max_retries: 3,
45 retry_interval_ms: 1000,
46 stream_timeout_ms: 60000,
47 }
48 }
49}
50
51#[derive(Debug, Clone)]
53pub enum HttpMethod {
54 Get,
55 Post,
56 Put,
57 Delete,
58 Patch,
59}
60
61#[derive(Debug, Clone)]
63pub struct HttpRequest {
64 pub url: String,
66 pub method: HttpMethod,
68 pub headers: Vec<(String, String)>,
70 pub body: Option<serde_json::Value>,
72}
73
74impl HttpRequest {
75 pub fn get(url: impl Into<String>) -> Self {
77 Self {
78 url: url.into(),
79 method: HttpMethod::Get,
80 headers: Vec::new(),
81 body: None,
82 }
83 }
84
85 pub fn post(url: impl Into<String>, body: serde_json::Value) -> Self {
87 Self {
88 url: url.into(),
89 method: HttpMethod::Post,
90 headers: Vec::new(),
91 body: Some(body),
92 }
93 }
94
95 pub fn header(mut self, key: impl Into<String>, value: impl Into<String>) -> Self {
97 self.headers.push((key.into(), value.into()));
98 self
99 }
100
101 pub fn bearer_auth(mut self, token: impl Into<String>) -> Self {
103 self.headers.push((
104 "Authorization".to_string(),
105 format!("Bearer {}", token.into()),
106 ));
107 self
108 }
109
110 pub fn api_key(mut self, key: impl Into<String>) -> Self {
112 self.headers.push(("x-api-key".to_string(), key.into()));
113 self
114 }
115}
116
117pub struct HttpResponseStream {
121 response: Response,
122 parser: SseParser,
123 pending: VecDeque<StreamEvent>,
124 done: bool,
125 abort_flag: Arc<AtomicBool>,
126}
127
128impl HttpResponseStream {
129 pub fn new(response: Response, abort_flag: Arc<AtomicBool>) -> Self {
131 Self {
132 response,
133 parser: SseParser::new(),
134 pending: VecDeque::new(),
135 done: false,
136 abort_flag,
137 }
138 }
139
140 pub fn status(&self) -> StatusCode {
142 self.response.status()
143 }
144
145 pub fn headers(&self) -> &reqwest::header::HeaderMap {
147 self.response.headers()
148 }
149
150 pub async fn next_event(&mut self) -> Result<Option<StreamEvent>> {
152 loop {
154 if self.abort_flag.load(Ordering::Relaxed) {
155 return Ok(None);
156 }
157
158 if let Some(event) = self.pending.pop_front() {
159 return Ok(Some(event));
160 }
161
162 if self.done {
163 let _remaining = self.parser.finish()?;
164 if let Some(event) = self.pending.pop_front() {
165 return Ok(Some(event));
166 }
167 return Ok(None);
168 }
169
170 match self.response.chunk().await? {
171 Some(chunk) => {
172 let sse_events = self.parser.push(&chunk)?;
173 for _sse_event in sse_events {
174 self.pending.push_back(StreamEvent::MessageStart {
176 id: String::new(),
177 model: String::new(),
178 });
179 }
180 }
181 None => {
182 self.done = true;
183 }
184 }
185 }
186 }
187
188 pub async fn collect_text(&mut self) -> Result<String> {
190 let mut result = String::new();
191 while let Some(event) = self.next_event().await? {
192 if let StreamEvent::ContentBlockDelta {
193 delta: crate::streaming::providers::ContentDelta::Text(t),
194 ..
195 } = event
196 {
197 result.push_str(&t);
198 }
199 }
200 Ok(result)
201 }
202
203 pub fn into_sse_stream(mut self) -> impl Stream<Item = Result<SseEvent>> {
205 async_stream::stream! {
206 loop {
207 if self.abort_flag.load(Ordering::Relaxed) {
208 break;
209 }
210
211 match self.response.chunk().await {
212 Ok(Some(chunk)) => {
213 let events = self.parser.push(&chunk)?;
214 for event in events {
215 yield Ok(event);
216 }
217 }
218 Ok(None) => {
219 let remaining = self.parser.finish()?;
220 for event in remaining {
221 yield Ok(event);
222 }
223 break;
224 }
225 Err(e) => {
226 yield Err(anyhow!("Stream error: {}", e));
227 break;
228 }
229 }
230 }
231 }
232 }
233}
234
235pub struct HttpAdapter {
239 client: Client,
241 config: HttpConfig,
243 abort_flag: Arc<AtomicBool>,
245}
246
247impl HttpAdapter {
248 pub fn new() -> Self {
250 Self::with_config(HttpConfig::default())
251 }
252
253 pub fn with_config(config: HttpConfig) -> Self {
255 let client = Client::builder()
256 .timeout(Duration::from_millis(config.timeout_ms))
257 .connect_timeout(Duration::from_millis(config.connect_timeout_ms))
258 .build()
259 .expect("Failed to create HTTP client");
260
261 Self {
262 client,
263 config,
264 abort_flag: Arc::new(AtomicBool::new(false)),
265 }
266 }
267
268 pub fn abort_flag(&self) -> Arc<AtomicBool> {
270 Arc::clone(&self.abort_flag)
271 }
272
273 pub fn abort(&self) {
275 self.abort_flag.store(true, Ordering::Relaxed);
276 }
277
278 pub fn reset(&self) {
280 self.abort_flag.store(false, Ordering::Relaxed);
281 }
282
283 pub fn is_aborted(&self) -> bool {
285 self.abort_flag.load(Ordering::Relaxed)
286 }
287
288 pub async fn request(&self, request: HttpRequest) -> Result<Response> {
290 self.request_with_retry(request, self.config.max_retries)
291 .await
292 }
293
294 async fn request_with_retry(&self, request: HttpRequest, max_retries: u32) -> Result<Response> {
296 let mut attempts = 0;
297
298 loop {
299 if self.is_aborted() {
300 return Err(anyhow!("Request aborted"));
301 }
302
303 attempts += 1;
304
305 let result = self.execute_request(&request).await;
306
307 match result {
308 Ok(response) => {
309 let status = response.status();
310 if status.is_success() {
311 return Ok(response);
312 }
313
314 if Self::is_retryable_status(status) && attempts <= max_retries {
316 tracing::warn!(
317 "HTTP request failed with status {}, attempt {}/{}",
318 status,
319 attempts,
320 max_retries
321 );
322 let delay = Duration::from_millis(
323 self.config.retry_interval_ms * (1 << (attempts - 1)),
324 );
325 tokio::time::sleep(delay).await;
326 continue;
327 }
328
329 let body = response.text().await.unwrap_or_default();
331 return Err(anyhow!("HTTP {}: {}", status, body));
332 }
333 Err(e) => {
334 if Self::is_retryable_error(&e) && attempts <= max_retries {
336 tracing::warn!(
337 "HTTP request error: {}, attempt {}/{}",
338 e,
339 attempts,
340 max_retries
341 );
342 let delay = Duration::from_millis(
343 self.config.retry_interval_ms * (1 << (attempts - 1)),
344 );
345 tokio::time::sleep(delay).await;
346 continue;
347 }
348 return Err(e);
349 }
350 }
351 }
352 }
353
354 async fn execute_request(&self, request: &HttpRequest) -> Result<Response> {
356 let builder = match request.method {
357 HttpMethod::Get => self.client.get(&request.url),
358 HttpMethod::Post => self.client.post(&request.url),
359 HttpMethod::Put => self.client.put(&request.url),
360 HttpMethod::Delete => self.client.delete(&request.url),
361 HttpMethod::Patch => self.client.patch(&request.url),
362 };
363
364 let builder = request
366 .headers
367 .iter()
368 .fold(builder, |b, (k, v)| b.header(k, v));
369
370 let builder = if let Some(body) = &request.body {
372 builder.json(body)
373 } else {
374 builder
375 };
376
377 let response = builder.send().await?;
378 Ok(response)
379 }
380
381 pub async fn request_stream(&self, request: HttpRequest) -> Result<HttpResponseStream> {
383 let response = self.request(request).await?;
384 Ok(HttpResponseStream::new(response, self.abort_flag.clone()))
385 }
386
387 pub async fn request_sse(&self, request: HttpRequest) -> Result<SseStream> {
389 let builder = self.client.post(&request.url);
390
391 let builder = request
392 .headers
393 .iter()
394 .fold(builder, |b, (k, v)| b.header(k, v));
395
396 let builder = if let Some(body) = &request.body {
397 builder.json(body)
398 } else {
399 builder
400 };
401
402 let builder = builder.header("Accept", "text/event-stream");
403
404 let response = builder.send().await?;
405
406 let status = response.status();
407 if !status.is_success() {
408 let body = response.text().await.unwrap_or_default();
409 return Err(anyhow!(
410 "SSE request failed with status {}: {}",
411 status,
412 body
413 ));
414 }
415
416 Ok(SseStream::new(response, self.abort_flag.clone()))
417 }
418
419 fn is_retryable_status(status: StatusCode) -> bool {
421 matches!(status.as_u16(), 429 | 500 | 502 | 503 | 504)
422 }
423
424 fn is_retryable_error(error: &anyhow::Error) -> bool {
426 let msg = error.to_string().to_lowercase();
427 msg.contains("timeout")
428 || msg.contains("connection")
429 || msg.contains("network")
430 || msg.contains("429")
431 || msg.contains("overloaded")
432 }
433}
434
435impl Default for HttpAdapter {
436 fn default() -> Self {
437 Self::new()
438 }
439}
440
441pub struct SseStream {
445 response: Response,
446 parser: SseParser,
447 abort_flag: Arc<AtomicBool>,
448 done: bool,
449}
450
451impl SseStream {
452 pub fn new(response: Response, abort_flag: Arc<AtomicBool>) -> Self {
454 Self {
455 response,
456 parser: SseParser::new(),
457 abort_flag,
458 done: false,
459 }
460 }
461
462 pub async fn next_event(&mut self) -> Result<Option<SseEvent>> {
464 loop {
465 if self.abort_flag.load(Ordering::Relaxed) {
466 return Ok(None);
467 }
468
469 if self.done {
470 let remaining = self.parser.finish()?;
471 if remaining.is_empty() {
472 return Ok(None);
473 }
474 return Ok(remaining.into_iter().next());
476 }
477
478 match self.response.chunk().await? {
479 Some(chunk) => {
480 let events = self.parser.push(&chunk)?;
481 if !events.is_empty() {
482 return Ok(Some(events.into_iter().next().unwrap()));
483 }
484 }
485 None => {
486 self.done = true;
487 }
488 }
489 }
490 }
491
492 pub async fn collect_events(&mut self) -> Result<Vec<SseEvent>> {
494 let mut events = Vec::new();
495 while let Some(event) = self.next_event().await? {
496 events.push(event);
497 }
498 Ok(events)
499 }
500}
501
502#[async_trait]
506pub trait HttpAdapterTrait: Send + Sync {
507 async fn get(&self, url: &str) -> Result<String>;
509
510 async fn post(&self, url: &str, body: serde_json::Value) -> Result<String>;
512
513 async fn post_stream(&self, url: &str, body: serde_json::Value) -> Result<HttpResponseStream>;
515
516 async fn post_sse(&self, url: &str, body: serde_json::Value) -> Result<SseStream>;
518}
519
520#[async_trait]
521impl HttpAdapterTrait for HttpAdapter {
522 async fn get(&self, url: &str) -> Result<String> {
523 let request = HttpRequest::get(url);
524 let response = self.request(request).await?;
525 let text = response.text().await?;
526 Ok(text)
527 }
528
529 async fn post(&self, url: &str, body: serde_json::Value) -> Result<String> {
530 let request = HttpRequest::post(url, body);
531 let response = self.request(request).await?;
532 let text = response.text().await?;
533 Ok(text)
534 }
535
536 async fn post_stream(&self, url: &str, body: serde_json::Value) -> Result<HttpResponseStream> {
537 let request = HttpRequest::post(url, body);
538 self.request_stream(request).await
539 }
540
541 async fn post_sse(&self, url: &str, body: serde_json::Value) -> Result<SseStream> {
542 let request = HttpRequest::post(url, body).header("Accept", "text/event-stream");
543 self.request_sse(request).await
544 }
545}
546
547#[cfg(test)]
548mod tests {
549 use super::*;
550
551 #[test]
552 fn test_http_config_default() {
553 let config = HttpConfig::default();
554 assert_eq!(config.timeout_ms, 30000);
555 assert_eq!(config.connect_timeout_ms, 10000);
556 assert_eq!(config.max_retries, 3);
557 }
558
559 #[test]
560 fn test_http_request_builder() {
561 let request = HttpRequest::get("https://api.example.com")
562 .bearer_auth("token123")
563 .header("X-Custom", "value");
564
565 assert_eq!(request.url, "https://api.example.com");
566 assert_eq!(request.headers.len(), 2);
567 }
568
569 #[test]
570 fn test_http_request_post() {
571 let body = serde_json::json!({"key": "value"});
572 let request = HttpRequest::post("https://api.example.com", body.clone());
573
574 assert_eq!(request.url, "https://api.example.com");
575 assert!(matches!(request.method, HttpMethod::Post));
576 assert_eq!(request.body, Some(body));
577 }
578
579 #[test]
580 fn test_is_retryable_status() {
581 assert!(HttpAdapter::is_retryable_status(
582 StatusCode::TOO_MANY_REQUESTS
583 ));
584 assert!(HttpAdapter::is_retryable_status(
585 StatusCode::INTERNAL_SERVER_ERROR
586 ));
587 assert!(HttpAdapter::is_retryable_status(StatusCode::BAD_GATEWAY));
588 assert!(HttpAdapter::is_retryable_status(
589 StatusCode::SERVICE_UNAVAILABLE
590 ));
591 assert!(HttpAdapter::is_retryable_status(
592 StatusCode::GATEWAY_TIMEOUT
593 ));
594
595 assert!(!HttpAdapter::is_retryable_status(StatusCode::BAD_REQUEST));
596 assert!(!HttpAdapter::is_retryable_status(StatusCode::UNAUTHORIZED));
597 assert!(!HttpAdapter::is_retryable_status(StatusCode::NOT_FOUND));
598 }
599
600 #[test]
601 fn test_http_adapter_creation() {
602 let adapter = HttpAdapter::new();
603 assert!(!adapter.is_aborted());
604 }
605
606 #[test]
607 fn test_http_adapter_abort() {
608 let adapter = HttpAdapter::new();
609 adapter.abort();
610 assert!(adapter.is_aborted());
611 adapter.reset();
612 assert!(!adapter.is_aborted());
613 }
614}