use std::str::FromStr;
use crate::prelude2::*;
use actix_web::dev::PeerAddr;
use futures_util::StreamExt as _;
use rand::seq::IndexedRandom;
use tokio::sync::mpsc;
use tokio_stream::wrappers::UnboundedReceiverStream;
use crate::core::config::UpStreams;
pub async fn forward_reqwest(
req: HttpRequest,
mut payload: web::Payload,
peer_addr: Option<PeerAddr>,
client: web::Data<reqwest::Client>,
upstream: web::Data<Option<Vec<UpStreams>>>,
) -> Result<HttpResponse> {
let location = req.match_info().get("forward").unwrap();
let uri = req.match_info().get("tail").unwrap();
let _splash = if uri.starts_with('/') { "" } else { "/" };
let query_string = if let Some(s) = req.uri().query() {
"?".to_owned() + s
} else {
"".to_string()
};
let mut servers: Vec<String> = vec![];
let mut new_url = String::new();
if let Some(up) = &upstream.get_ref() {
for up in up {
if "/".to_owned() + location == up.location {
servers.clone_from(&up.servers);
if let Some(true) = up.rewrite {
new_url = format!(
"{}{}{}{}{}",
servers.choose(&mut rand::rng()).unwrap(),
_splash,
location,
uri,
query_string
);
} else {
new_url = format!(
"{}{}{}{}",
servers.choose(&mut rand::rng()).unwrap(),
_splash,
uri,
query_string
);
}
break;
}
}
}
let (tx, rx) = mpsc::unbounded_channel();
actix_web::rt::spawn(async move {
while let Some(chunk) = payload.next().await {
tx.send(chunk).unwrap();
}
});
let method = reqwest::Method::from_str(req.method().as_str()).unwrap();
log::info!(
"rewrite_forward: method={}, url={}, servers={:?}",
method,
new_url,
servers
);
let mut forwarded_req = client
.request(method, new_url)
.body(reqwest::Body::wrap_stream(UnboundedReceiverStream::new(rx)));
for header in req.headers().into_iter() {
forwarded_req = forwarded_req.header(
header.0.to_string(),
header.1.to_str().unwrap_or_default().to_string(),
);
}
let forwarded_req = match peer_addr {
Some(PeerAddr(addr)) => forwarded_req.header("x-forwarded-for", addr.ip().to_string()),
None => forwarded_req,
};
let res = forwarded_req.send().await.map_err(Error::run_time)?;
let mut client_resp =
HttpResponse::build(actix_http::StatusCode::from_u16(res.status().as_u16()).unwrap());
for (header_name, header_value) in res.headers().iter().filter(|(h, _)| *h != "connection") {
client_resp.insert_header((header_name.as_str(), header_value.to_str().unwrap()));
}
Ok(client_resp.streaming(res.bytes_stream()))
}