1use anyhow::Result;
9use async_stream::try_stream;
10use futures_core::Stream;
11use futures_util::StreamExt;
12use reqwest::{
13 Client, Method,
14 header::{self, HeaderMap, HeaderName, HeaderValue},
15};
16use serde::Serialize;
17use wcore::model::{Response, StreamChunk};
18
19#[derive(Clone)]
24pub struct HttpProvider {
25 client: Client,
26 headers: HeaderMap,
27 endpoint: String,
28}
29
30impl HttpProvider {
31 pub fn bearer(client: Client, key: &str, endpoint: &str) -> Result<Self> {
33 let mut headers = HeaderMap::new();
34 headers.insert(
35 header::CONTENT_TYPE,
36 HeaderValue::from_static("application/json"),
37 );
38 headers.insert(header::ACCEPT, HeaderValue::from_static("application/json"));
39 headers.insert(header::AUTHORIZATION, format!("Bearer {key}").parse()?);
40 Ok(Self {
41 client,
42 headers,
43 endpoint: endpoint.to_owned(),
44 })
45 }
46
47 pub fn no_auth(client: Client, endpoint: &str) -> Self {
49 let mut headers = HeaderMap::new();
50 headers.insert(
51 header::CONTENT_TYPE,
52 HeaderValue::from_static("application/json"),
53 );
54 headers.insert(header::ACCEPT, HeaderValue::from_static("application/json"));
55 Self {
56 client,
57 headers,
58 endpoint: endpoint.to_owned(),
59 }
60 }
61
62 pub fn custom_header(
67 client: Client,
68 header_name: &str,
69 header_value: &str,
70 endpoint: &str,
71 ) -> Result<Self> {
72 let mut headers = HeaderMap::new();
73 headers.insert(
74 header::CONTENT_TYPE,
75 HeaderValue::from_static("application/json"),
76 );
77 headers.insert(header::ACCEPT, HeaderValue::from_static("application/json"));
78 headers.insert(
79 header_name.parse::<HeaderName>()?,
80 header_value.parse::<HeaderValue>()?,
81 );
82 Ok(Self {
83 client,
84 headers,
85 endpoint: endpoint.to_owned(),
86 })
87 }
88
89 pub async fn send(&self, body: &impl Serialize) -> Result<Response> {
91 tracing::trace!("request: {}", serde_json::to_string(body)?);
92 let text = self
93 .client
94 .request(Method::POST, &self.endpoint)
95 .headers(self.headers.clone())
96 .json(body)
97 .send()
98 .await?
99 .text()
100 .await?;
101
102 serde_json::from_str(&text).map_err(Into::into)
103 }
104
105 pub fn stream_sse(
110 &self,
111 body: &impl Serialize,
112 ) -> impl Stream<Item = Result<StreamChunk>> + Send {
113 if let Ok(body) = serde_json::to_string(body) {
114 tracing::trace!("request: {}", body);
115 }
116 let request = self
117 .client
118 .request(Method::POST, &self.endpoint)
119 .headers(self.headers.clone())
120 .json(body);
121
122 try_stream! {
123 let response = request.send().await?;
124 let mut stream = response.bytes_stream();
125 while let Some(next) = stream.next().await {
126 let bytes = next?;
127 let text = String::from_utf8_lossy(&bytes);
128 tracing::trace!("chunk: {}", text);
129 for data in text.split("data: ").skip(1).filter(|s| !s.starts_with("[DONE]")) {
130 let trimmed = data.trim();
131 if trimmed.is_empty() {
132 continue;
133 }
134 match serde_json::from_str::<StreamChunk>(trimmed) {
135 Ok(chunk) => yield chunk,
136 Err(e) => tracing::warn!("failed to parse chunk: {e}, data: {trimmed}"),
137 }
138 }
139 }
140 }
141 }
142
143 pub fn endpoint(&self) -> &str {
145 &self.endpoint
146 }
147
148 pub fn headers(&self) -> &HeaderMap {
150 &self.headers
151 }
152}