use json::{object, JsonValue};
use native_tls::{TlsConnector, TlsStream};
use rand::distr::Alphanumeric;
use rand::Rng;
use std::io::{Error, ErrorKind, Read, Write};
use std::net::TcpStream;
use std::path::PathBuf;
use std::sync::{Arc, Mutex};
use std::time::Duration;
use std::{fs, io};
use url::Url;
#[derive(Debug, Clone)]
pub struct Client {
debug: bool,
url: Url,
method: Method,
pub header: JsonValue,
version: Version,
params: Vec<u8>,
retry: usize,
range_size: usize,
}
impl Default for Client {
fn default() -> Self {
Self::new()
}
}
impl Client {
pub fn new() -> Self {
Self {
debug: false,
method: Method::None,
url: Url::parse("http://127.0.0.1").unwrap(),
header: object! {},
version: Version::Http11,
params: vec![],
retry: 0,
range_size: 0,
}
}
pub fn debug(&mut self) -> &mut Self {
self.debug = true;
self
}
pub fn version(&mut self) -> String {
self.version.as_str().to_string()
}
pub fn url(&mut self, url: &str) -> &mut Self {
self.url = Url::parse(url).unwrap();
self
}
pub fn method(&mut self, method: &str) -> &mut Self {
self.method = Method::from_str(method);
self
}
pub fn head(&mut self, url: &str) -> &mut Self {
self.method = Method::Head;
self.url = Url::parse(url).unwrap();
self
}
pub fn get(&mut self, url: &str) -> &mut Self {
self.method = Method::Get;
self.url = Url::parse(url).unwrap();
self
}
pub fn post(&mut self, url: &str) -> &mut Self {
self.method = Method::Post;
self.url = Url::parse(url).unwrap();
self
}
pub fn put(&mut self, url: &str) -> &mut Self {
self.method = Method::Put;
self.url = Url::parse(url).unwrap();
self
}
pub fn delete(&mut self, url: &str) -> &mut Self {
self.method = Method::Delete;
self.url = Url::parse(url).unwrap();
self
}
pub fn patch(&mut self, url: &str) -> &mut Self {
self.method = Method::Patch;
self.url = Url::parse(url).unwrap();
self
}
pub fn options(&mut self, url: &str) -> &mut Self {
self.method = Method::Options;
self.url = Url::parse(url).unwrap();
self
}
pub fn retry(&mut self, count: usize) -> &mut Self {
self.retry = count;
self
}
pub fn range(&mut self, range_size: usize) -> &mut Self {
self.range_size = range_size;
self.make_range_header(Some(0), Some(range_size - 1));
self
}
fn make_range_header(&mut self, start: Option<usize>, end: Option<usize>) {
let res = match (start, end) {
(Some(s), Some(e)) => Some(format!("bytes={}-{}", s, e)),
(Some(s), None) => Some(format!("bytes={}-", s)),
(None, Some(n)) => Some(format!("bytes=-{}", n)),
(None, None) => None,
};
if let Some(e)=res {
self.header("Range", e.as_str());
self.header("Accept-Encoding", "identity");
self.header("connection", "keep-alive");
}
}
pub fn query(&mut self, data: JsonValue) -> &mut Self {
let mut query = vec![];
for (key, value) in data.entries() {
query.push(format!("{}={}", key, value));
}
for (key, value) in self.url.query_pairs() {
query.push(format!("{}={}", key, value));
}
self.url.set_query(Some(&query.join("&")));
self
}
pub fn raw_json(&mut self, data: JsonValue) -> &mut Self {
let _ = self.header.insert("Content-Type", "application/json");
self.params = data.to_string().into_bytes();
let _ = self.header.insert("Content-Length", self.params.len());
self
}
pub fn body(&mut self, data: Vec<u8>) -> &mut Self {
self.params = data;
let _ = self.header.insert("Content-Length", self.params.len());
self
}
pub fn form_data(&mut self, data: JsonValue) -> &mut Self {
let rand_str: String = rand::rng()
.sample_iter(&Alphanumeric)
.take(30) .map(char::from)
.collect();
let boundary = format!("----RustBoundary{}", rand_str);
let _ = self.header.insert(
"Content-Type",
format!("multipart/form-data; boundary={boundary}"),
);
let mut params = vec![];
for (key, value) in data.entries() {
let res = PathBuf::from(value.to_string().as_str());
if res.is_file() {
let filename = res.file_name().unwrap().to_string_lossy();
let value_b = fs::read(value.to_string()).unwrap();
params.extend(format!("--{boundary}\r\n").as_bytes());
params.extend(
format!(
r#"Content-Disposition: form-data; name="{key}"; filename="{filename}""#
)
.as_bytes(),
);
params.extend("\r\nContent-Type: application/octet-stream\r\n\r\n".as_bytes());
params.extend(value_b);
params.extend("\r\n".as_bytes());
} else {
params.extend(format!("--{boundary}\r\n").as_bytes());
params
.extend(format!(r#"Content-Disposition: form-data; name="{key}""#).as_bytes());
params.extend(b"Content-Type: text/plain; charset=utf-8");
params.extend(format!("\r\n\r\n{value}\r\n").as_bytes());
}
}
params.extend(format!("--{boundary}--\r\n").bytes());
self.params = params.to_vec();
let _ = self.header.insert("Content-Length", self.params.len());
self
}
pub fn form_urlencoded(&mut self, data: JsonValue) -> &mut Self {
let _ = self
.header
.insert("Content-Type", "application/x-www-form-urlencoded");
let mut params = vec![];
for (key, value) in data.entries() {
params.push(format!("{}={}", key, value));
}
let params = params.join("&");
self.params = params.as_bytes().to_vec();
let _ = self.header.insert("Content-Length", self.params.len());
self
}
pub fn header(&mut self, key: &str, value: &str) -> &mut Self {
self.header.insert(key, value).expect("TODO: panic message");
self
}
fn stream(&mut self) -> Result<HttpStream, Box<dyn std::error::Error>> {
let port = self.url.port().unwrap_or_else(|| {
if self.url.scheme() == "https" {
443
} else {
80
}
});
let host = self.url.host().unwrap().to_string();
let mut stream = if port == 443 {
let tcp = TcpStream::connect((host.clone(), port))?;
let connector = TlsConnector::new()?;
let stream = connector.connect(&host, tcp)?;
HttpStream::Https(Arc::new(Mutex::new(stream)))
} else {
let tcp = TcpStream::connect((host.clone(), port))?;
HttpStream::Http(Arc::new(Mutex::new(tcp)))
};
stream.set_read_timeout(Duration::from_secs(30))?;
stream.set_write_timeout(Duration::from_secs(30))?;
stream.set_nonblocking(false)?;
stream.set_nodelay(true)?;
Ok(stream)
}
fn request_txt(&mut self) -> io::Result<Vec<u8>> {
let port = self.url.port().unwrap_or_else(|| {
if self.url.scheme() == "https" {
443
} else {
80
}
});
let host = self.url.host().unwrap().to_string();
let uri = if self.url.query().is_some() {
format!("{}?{}", self.url.path(), self.url.query().unwrap())
} else {
self.url.path().to_string()
};
let mut header = vec![];
header.push(format!(
"{} {} {}",
self.method.as_str(),
uri,
self.version.as_str()
));
for (k, v) in self.header.entries() {
header.push(format!("{k}: {v}"));
}
if !self.header.has_key("host") {
let host_header = if (self.url.scheme() == "https" && port != 443)
|| (self.url.scheme() == "http" && port != 80)
{
format!("{}:{}", host, port)
} else {
host.clone()
};
header.push(format!("Host: {host_header}"));
}
header.push("\r\n".to_string());
let request = header.join("\r\n");
if self.debug {
println!("================请求内容==============\r\n{}", request);
}
Ok(request.as_bytes().to_vec())
}
pub fn send_relay(
&mut self,
mut client: impl Write,
) -> Result<Response, Box<dyn std::error::Error>> {
let mut retry = 0;
let mut index = 0;
let mut stream = self.stream()?;
loop {
stream.write_all(self.request_txt()?.as_slice())?;
stream.flush()?;
return match Response::new(stream.clone(), self.clone()) {
Ok(mut e) => {
match e.code {
206 => {
if index == 0 {
let mut res = Response::new_protocol(
Version::format(e.version.as_str()),
200,
"200 OK",
e.header.clone(),
);
res.header["content-type"] = e.content_type.clone().into();
res.header["content-length"] = e.ranges_len.into();
res.header["connection"] = "keep-alive".into();
res.header["keep-alive"] = "timeout=120, max=1000".into();
res.header["accept-ranges"] = "bytes".into();
res.header["accept-encoding"] = "identity".into();
res.header.remove("content-range");
let res = res.generate_response_protocol();
client.write_all(res.as_slice())?;
index += 1;
}
match client.write_all(&e.body()) {
Ok(()) => {
println!("发送成功: {}", e.body().len());
}
Err(err) => {
println!("发送失败: {} {}", e.body().len(), err);
return Err(Box::new(err));
}
};
if e.ranges_end + 1 < e.ranges_len {
let next_start = e.ranges_end + 1;
let mut next_end =
next_start.saturating_add(self.range_size.saturating_sub(1));
if next_end >= e.ranges_len {
next_end = e.ranges_len - 1;
}
self.make_range_header(Some(next_start), Some(next_end));
continue;
}
e.status = "OK".into();
e.code = 200;
Ok(e)
}
_ => Ok(e),
}
}
Err(e) => {
if self.retry > retry {
retry += 1;
println!("响应解析错误重试1: {}", e);
continue;
}
println!("响应解析错误: {}", e);
Err(Box::from(e.to_string()))
}
};
}
}
pub fn send(&mut self) -> Result<Response, Box<dyn std::error::Error>> {
let mut retry = 0;
self.request_txt()?;
let mut stream = self.stream()?;
loop {
let mut request = self.request_txt()?;
if !self.params.is_empty() {
request.extend(self.params.clone());
}
stream.write_all(request.as_slice())?;
return match Response::new(stream.clone(), self.clone()) {
Ok(mut e) => {
Ok(e)
},
Err(e) => {
if self.retry > retry {
retry += 1;
println!("响应解析错误重试2: {}", e);
continue;
}
println!("响应解析错误: {}", e);
Err(Box::from(e.to_string()))
}
};
}
}
}
#[derive(Debug, Clone)]
enum Method {
Head,
Get,
Post,
Put,
Patch,
Delete,
Options,
None,
}
impl Method {
fn as_str(&self) -> &'static str {
match self {
Method::Head => "HEAD",
Method::Get => "GET",
Method::Post => "POST",
Method::Put => "PUT",
Method::Patch => "PATCH",
Method::Delete => "DELETE",
Method::Options => "OPTIONS",
Method::None => "",
}
}
fn from_str(method: &str) -> Method {
match method.to_uppercase().as_str() {
"HEAD" => Method::Head,
"GET" => Method::Get,
"POST" => Method::Post,
"PUT" => Method::Put,
"PATCH" => Method::Patch,
"DELETE" => Method::Delete,
"OPTIONS" => Method::Options,
"NONE" => Method::None,
_ => Method::None,
}
}
}
#[derive(Debug, Clone)]
pub enum HttpStream {
Http(Arc<Mutex<TcpStream>>),
Https(Arc<Mutex<TlsStream<TcpStream>>>),
None,
}
impl HttpStream {
pub fn set_nonblocking(&mut self, nonblocking: bool) -> Result<(), Error> {
match self {
HttpStream::Http(e) => e.lock().unwrap().set_nonblocking(nonblocking),
HttpStream::Https(e) => e.lock().unwrap().get_mut().set_nonblocking(nonblocking),
HttpStream::None => Err(Error::new(ErrorKind::TimedOut, "未知")),
}
}
pub fn set_nodelay(&mut self, nodelay: bool) -> Result<(), Error> {
match self {
HttpStream::Http(e) => e.lock().unwrap().set_nodelay(nodelay),
HttpStream::Https(e) => e.lock().unwrap().get_mut().set_nodelay(nodelay),
HttpStream::None => Err(Error::new(ErrorKind::TimedOut, "未知")),
}
}
pub fn write_all(&mut self, buf: &[u8]) -> Result<(), Error> {
match self {
HttpStream::Http(e) => e.lock().unwrap().write_all(buf),
HttpStream::Https(e) => e.lock().unwrap().write_all(buf),
HttpStream::None => Err(Error::new(ErrorKind::TimedOut, "未知")),
}
}
pub fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
match self {
HttpStream::Http(e) => e.lock().unwrap().read(buf),
HttpStream::Https(e) => e.lock().unwrap().read(buf),
HttpStream::None => Err(Error::new(ErrorKind::TimedOut, "未知")),
}
}
pub fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
match self {
HttpStream::Http(e) => e.lock().unwrap().write(buf),
HttpStream::Https(e) => e.lock().unwrap().write(buf),
HttpStream::None => Err(Error::new(ErrorKind::TimedOut, "未知")),
}
}
pub fn flush(&mut self) -> io::Result<()> {
match self {
HttpStream::Http(e) => e.lock().unwrap().flush(),
HttpStream::Https(e) => e.lock().unwrap().flush(),
HttpStream::None => Err(Error::new(ErrorKind::TimedOut, "未知")),
}
}
pub fn set_read_timeout(&mut self, timeout: Duration) -> io::Result<()> {
match self {
HttpStream::Http(e) => e.lock().unwrap().set_read_timeout(Some(timeout)),
HttpStream::Https(e) => e.lock().unwrap().get_mut().set_read_timeout(Some(timeout)),
HttpStream::None => Err(Error::new(ErrorKind::TimedOut, "未知")),
}
}
pub fn set_write_timeout(&mut self, timeout: Duration) -> io::Result<()> {
match self {
HttpStream::Http(e) => e.lock().unwrap().set_write_timeout(Some(timeout)),
HttpStream::Https(e) => e.lock().unwrap().get_mut().set_write_timeout(Some(timeout)),
HttpStream::None => Err(Error::new(ErrorKind::TimedOut, "未知")),
}
}
}
#[derive(Debug)]
pub struct Response {
version: String,
header: JsonValue,
body: Vec<u8>,
stream: HttpStream,
header_data: Vec<u8>,
pub code: u16,
pub status: String,
pub content_type: String,
pub ranges_len: usize,
pub ranges_end: usize,
}
impl Response {
pub fn new(stream: HttpStream, request: Client) -> Result<Response, Error> {
let mut response = Response {
version: "".to_string(),
body: vec![],
stream,
header_data: vec![],
header: object! {},
code: 0,
status: "".to_string(),
content_type: "".to_string(),
ranges_end: 0,
ranges_len: 0,
};
loop {
let mut buf = [0; 1024];
match response.stream.read(&mut buf) {
Ok(0) => {
if response.header_data.is_empty() {
return Err(Error::other("无请求头数据"));
}
break;
}
Ok(n) => {
response.header_data.extend(&buf[..n]);
if let Some(pos) = response
.header_data
.windows(4)
.position(|w| w == [13, 10, 13, 10])
{
response.body = response.header_data[pos + 4..].to_vec();
response.header_data = response.header_data[..pos].to_vec();
response.handle_header()?;
break;
}
}
Err(e) => return Err(e),
}
}
if request.debug {
println!(
"================响应内容==============\r\n{}\r\n",
String::from_utf8_lossy(&response.header_data)
);
}
if let Method::Head = request.method {
return Ok(response);
}
if let Ok(e) = response.get_header("content-length") {
let len = e.parse::<usize>().unwrap();
if len > 0 {
if response.body.len() == len {
return Ok(response);
}
loop {
let mut buf = [0; 1024 * 1024];
match response.stream.read(&mut buf) {
Ok(0) => {
if len > response.body.len() {
continue;
};
if response.body.len() == len {
return Ok(response);
}
}
Ok(n) => {
response.body.extend(&buf[..n]);
if response.body.len() == len {
break;
}
}
Err(e) => return Err(e),
}
}
}
}
Ok(response)
}
pub fn new_protocol(version: Version, code: u16, status: &str, header: JsonValue) -> Response {
Self {
version: version.as_str().to_string(),
code,
status: status.to_string(),
content_type: "".to_string(),
ranges_len: 0,
header,
body: vec![],
stream: HttpStream::None,
header_data: vec![],
ranges_end: 0,
}
}
pub fn generate_response_protocol(&self) -> Vec<u8> {
let mut res = vec![];
res.push(format!("{} {}", self.version, self.status));
for (key, value) in self.header.entries() {
res.push(format!("{key}: {value}"));
}
res.push("\r\n".to_string());
let res = res.join("\r\n").as_bytes().to_vec();
res
}
pub fn new_header(stream: HttpStream) -> Result<Response, Error> {
let mut response = Response {
version: "".to_string(),
body: vec![],
stream,
header_data: vec![],
header: object! {},
code: 0,
status: "".to_string(),
content_type: "".to_string(),
ranges_end: 0,
ranges_len: 0,
};
let mut buf = [0; 1024];
loop {
match response.stream.read(&mut buf) {
Ok(0) => {
if response.header_data.is_empty() {
return Err(Error::other("无请求头数据"));
}
break;
}
Ok(n) => {
response.header_data.extend(&buf[..n]);
if let Some(pos) = response
.header_data
.windows(4)
.position(|w| w == [13, 10, 13, 10])
{
response.body = response.header_data[pos + 4..].to_vec();
response.header_data = response.header_data[..pos].to_vec();
response.handle_header()?;
break;
}
}
Err(e) => return Err(e),
}
}
Ok(response)
}
pub fn handle_header(&mut self) -> Result<(), Error> {
let res = match std::str::from_utf8(self.header_data.as_slice()) {
Ok(e) => e,
Err(e) => {
return Err(Error::other(e.to_string()));
}
};
let request_line = res.lines().next().unwrap();
let mut parts = request_line.split_whitespace();
self.version = parts.next().ok_or("缺少版本").unwrap_or("").to_string();
self.code = parts
.next()
.ok_or("缺少状态码")
.unwrap_or("")
.parse::<u16>()
.unwrap();
self.status = parts.clone().collect::<Vec<&str>>().join(" ").clone();
for line in res.lines().skip(1) {
match line.find(":") {
None => {}
Some(e) => {
let key = &line[..e];
let value = &line[e + 1..].trim();
self.header[key.to_lowercase().as_str()] = value.trim().into();
}
}
}
if self.get_header("content-type").is_ok() {
let content = self.get_header("content-type").unwrap().to_string();
match content.find(";") {
None => {
self.content_type = content;
}
Some(e) => {
self.content_type = content[..e].to_string();
}
}
}
if self.get_header("content-range").is_ok() {
let content = self.get_header("content-range").unwrap().to_string();
match content.find("/") {
None => {}
Some(e) => {
self.ranges_len = content[e + 1..].parse::<usize>().unwrap();
let bytes = content[..e].to_string();
match bytes.find("-") {
None => {}
Some(e) => {
self.ranges_end = bytes[e + 1..].parse::<usize>().unwrap();
}
}
}
}
}
Ok(())
}
pub fn get_header(&self, key: &str) -> Result<String, String> {
if self.header[key].is_null() {
Err("请求头不存在".to_string())
} else {
Ok(self.header[key].to_string())
}
}
pub fn version(&self) -> String {
self.version.clone()
}
pub fn status(&self) -> String {
format!("{} {}", self.code, self.status)
}
pub fn headers(&self) -> JsonValue {
self.header.clone()
}
pub fn content_type(&self) -> String {
self.content_type.clone()
}
pub fn json(&self) -> Result<JsonValue, Error> {
match String::from_utf8(self.body.clone()) {
Ok(e) => match json::parse(&e) {
Ok(e) => Ok(e),
Err(e) => Err(Error::other(e.to_string())),
},
Err(e) => Err(Error::other(e.to_string())),
}
}
pub fn txt(&self) -> Result<String, Error> {
match String::from_utf8(self.body.clone()) {
Ok(e) => Ok(e),
Err(e) => Err(Error::other(e.to_string())),
}
}
pub fn body(&self) -> Vec<u8> {
self.body.clone()
}
}
#[derive(Debug, Clone)]
pub enum Version {
#[allow(dead_code)]
Http10,
#[allow(dead_code)]
Http11,
#[allow(dead_code)]
Http2,
#[allow(dead_code)]
Http3,
#[allow(dead_code)]
None,
}
impl Version {
fn as_str(&self) -> &'static str {
match self {
Version::Http10 => "HTTP/1.0",
Version::Http11 => "HTTP/1.1",
Version::Http2 => "HTTP/2",
Version::Http3 => "HTTP/3",
Version::None => "",
}
}
pub fn format(version: &str) -> Version {
match version {
"HTTP/1.0" => Version::Http10,
"HTTP/1.1" => Version::Http11,
"HTTP/2" => Version::Http2,
"HTTP/3" => Version::Http3,
_ => Version::None,
}
}
}