pub use super::server::{HttpResponse, WsMessageType};
use crate::{get_blob, LazyLoadBlob as KiBlob, Message, Request as KiRequest};
use http::Method;
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::str::FromStr;
use thiserror::Error;
#[derive(Clone, Debug, Serialize, Deserialize)]
pub enum HttpClientAction {
Http(OutgoingHttpRequest),
WebSocketOpen {
url: String,
headers: HashMap<String, String>,
channel_id: u32,
},
WebSocketPush {
channel_id: u32,
message_type: WsMessageType,
},
WebSocketClose {
channel_id: u32,
},
}
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct OutgoingHttpRequest {
pub method: String,
pub version: Option<String>,
pub url: String,
pub headers: HashMap<String, String>,
}
#[derive(Clone, Copy, Debug, Serialize, Deserialize)]
pub enum HttpClientRequest {
WebSocketPush {
channel_id: u32,
message_type: WsMessageType,
},
WebSocketClose {
channel_id: u32,
},
}
#[derive(Clone, Debug, Serialize, Deserialize)]
pub enum HttpClientResponse {
Http(HttpResponse),
WebSocketAck,
}
#[derive(Clone, Debug, Error, Serialize, Deserialize)]
pub enum HttpClientError {
#[error("request could not be deserialized to valid HttpClientRequest")]
MalformedRequest,
#[error("http method not supported: {method}")]
BadMethod { method: String },
#[error("url could not be parsed: {url}")]
BadUrl { url: String },
#[error("http version not supported: {version}")]
BadVersion { version: String },
#[error("client failed to build request: {0}")]
BuildRequestFailed(String),
#[error("client failed to execute request: {0}")]
ExecuteRequestFailed(String),
#[error("could not open connection to {url}")]
WsOpenFailed { url: String },
#[error("sent WebSocket push to unknown channel {channel_id}")]
WsPushUnknownChannel { channel_id: u32 },
#[error("WebSocket push failed because message had no blob attached")]
WsPushNoBlob,
#[error("WebSocket push failed because message type was Text, but blob was not a valid UTF-8 string")]
WsPushBadText,
#[error("failed to close connection {channel_id} because it was not open")]
WsCloseFailed { channel_id: u32 },
}
pub fn send_request(
method: Method,
url: url::Url,
headers: Option<HashMap<String, String>>,
timeout: Option<u64>,
body: Vec<u8>,
) {
let req = KiRequest::to(("our", "http-client", "distro", "sys"))
.body(
serde_json::to_vec(&HttpClientAction::Http(OutgoingHttpRequest {
method: method.to_string(),
version: None,
url: url.to_string(),
headers: headers.unwrap_or_default(),
}))
.unwrap(),
)
.blob_bytes(body);
if let Some(timeout) = timeout {
req.expects_response(timeout).send().unwrap()
} else {
req.send().unwrap()
}
}
pub fn send_request_await_response(
method: Method,
url: url::Url,
headers: Option<HashMap<String, String>>,
timeout: u64,
body: Vec<u8>,
) -> std::result::Result<http::Response<Vec<u8>>, HttpClientError> {
let res = KiRequest::to(("our", "http-client", "distro", "sys"))
.body(
serde_json::to_vec(&HttpClientAction::Http(OutgoingHttpRequest {
method: method.to_string(),
version: None,
url: url.to_string(),
headers: headers.unwrap_or_default(),
}))
.map_err(|_| HttpClientError::MalformedRequest)?,
)
.blob_bytes(body)
.send_and_await_response(timeout)
.unwrap();
let Ok(Message::Response { body, .. }) = res else {
return Err(HttpClientError::ExecuteRequestFailed(
"http-client timed out".to_string(),
));
};
let resp = match serde_json::from_slice::<
std::result::Result<HttpClientResponse, HttpClientError>,
>(&body)
{
Ok(Ok(HttpClientResponse::Http(resp))) => resp,
Ok(Ok(HttpClientResponse::WebSocketAck)) => {
return Err(HttpClientError::ExecuteRequestFailed(
"http-client gave unexpected response".to_string(),
))
}
Ok(Err(e)) => return Err(e),
Err(e) => {
return Err(HttpClientError::ExecuteRequestFailed(format!(
"http-client gave invalid response: {e:?}"
)))
}
};
let mut http_response = http::Response::builder()
.status(http::StatusCode::from_u16(resp.status).unwrap_or_default());
let headers = http_response.headers_mut().unwrap();
for (key, value) in &resp.headers {
let Ok(key) = http::header::HeaderName::from_str(key) else {
continue;
};
let Ok(value) = http::header::HeaderValue::from_str(value) else {
continue;
};
headers.insert(key, value);
}
Ok(http_response
.body(get_blob().unwrap_or_default().bytes)
.unwrap())
}
pub fn open_ws_connection(
url: String,
headers: Option<HashMap<String, String>>,
channel_id: u32,
) -> std::result::Result<(), HttpClientError> {
let Ok(Ok(Message::Response { body, .. })) =
KiRequest::to(("our", "http-client", "distro", "sys"))
.body(
serde_json::to_vec(&HttpClientAction::WebSocketOpen {
url: url.clone(),
headers: headers.unwrap_or(HashMap::new()),
channel_id,
})
.unwrap(),
)
.send_and_await_response(5)
else {
return Err(HttpClientError::WsOpenFailed { url });
};
match serde_json::from_slice(&body) {
Ok(Ok(HttpClientResponse::WebSocketAck)) => Ok(()),
Ok(Err(e)) => Err(e),
_ => Err(HttpClientError::WsOpenFailed { url }),
}
}
pub fn send_ws_client_push(channel_id: u32, message_type: WsMessageType, blob: KiBlob) {
KiRequest::to(("our", "http-client", "distro", "sys"))
.body(
serde_json::to_vec(&HttpClientAction::WebSocketPush {
channel_id,
message_type,
})
.unwrap(),
)
.blob(blob)
.send()
.unwrap()
}
pub fn close_ws_connection(channel_id: u32) -> std::result::Result<(), HttpClientError> {
let Ok(Ok(Message::Response { body, .. })) =
KiRequest::to(("our", "http-client", "distro", "sys"))
.body(
serde_json::json!(HttpClientAction::WebSocketClose { channel_id })
.to_string()
.as_bytes()
.to_vec(),
)
.send_and_await_response(5)
else {
return Err(HttpClientError::WsCloseFailed { channel_id });
};
match serde_json::from_slice(&body) {
Ok(Ok(HttpClientResponse::WebSocketAck)) => Ok(()),
Ok(Err(e)) => Err(e),
_ => Err(HttpClientError::WsCloseFailed { channel_id }),
}
}