use super::*;
use axum::{
extract::Path,
http::{HeaderValue, StatusCode},
routing::{get, post},
Extension, Json, Router,
};
use reqwest::{Client, Method};
use serde::{Deserialize, Serialize};
use serde_json::{json, Value};
use std::{collections::HashMap, sync::Arc};
use thiserror::Error;
use tower_http::cors::CorsLayer;
#[derive(Error, Debug)]
enum HttpError {
#[error("invalid request: {0}")]
InvalidRequest(String),
#[error("method not found: {0}")]
MethodNotFound(String),
}
pub trait HttpInterface:
DispatchStringDictAsync<Error = serde_json::Error, Poly = serde_json::Value>
+ DispatchStringTupleAsync<Error = serde_json::Error>
+ Send
+ Sync
+ 'static
{
}
impl<T> HttpInterface for Arc<T> where T: HttpInterface + ?Sized {}
pub fn create_http_object<T: ?Sized + HttpInterface>(x: Arc<T>) -> Arc<dyn HttpInterface> {
Arc::new(x) as Arc<dyn HttpInterface>
}
#[derive(Clone)]
struct State {
pub registered_objects: HashMap<String, Arc<dyn HttpInterface>>,
}
async fn root() -> &'static str {
"This is a serde-tc JSON RPC server. Please access to /<object-name> with POST, to use the API."
}
async fn dispatch(
Path(path): Path<String>,
Json(args): Json<RawArg>,
Extension(state): Extension<Arc<State>>,
) -> (StatusCode, Json<Value>) {
if let Some(object) = state.registered_objects.get(&path) {
match dispatch_raw(object.as_ref(), &args.method, args.params.clone()).await {
Ok(value) => (StatusCode::OK, Json(value)),
Err(err) => (
StatusCode::INTERNAL_SERVER_ERROR,
Json(json!({
"error": "invalid http request",
"error_message": err.to_string(),
"request": args,
})),
),
}
} else {
(
StatusCode::NOT_FOUND,
Json(json!({
"error": "object not found",
"obejct": &path.as_str()[1..],
})),
)
}
}
pub async fn run_server(port: u16, objects: HashMap<String, Arc<dyn HttpInterface>>) {
let app = Router::new().route("/", get(root));
let app = app.route("/:key", post(dispatch));
let app = app
.layer(Extension(Arc::new(State {
registered_objects: objects,
})))
.layer(
CorsLayer::new()
.allow_origin("*".parse::<HeaderValue>().unwrap())
.allow_headers([axum::http::header::CONTENT_TYPE])
.allow_methods([Method::POST]),
);
let addr = std::net::SocketAddr::from(([0, 0, 0, 0], port));
axum::Server::bind(&addr)
.serve(app.into_make_service())
.await
.unwrap();
}
#[derive(Serialize, Deserialize, Debug)]
#[serde(deny_unknown_fields)]
struct RawArg {
method: String,
params: serde_json::Value,
}
async fn dispatch_raw<T>(
api: &T,
method: &str,
arguments: serde_json::Value,
) -> std::result::Result<serde_json::Value, HttpError>
where
T: HttpInterface + ?Sized,
{
let result = if arguments.is_array() {
DispatchStringTupleAsync::dispatch(api, method, &arguments.to_string()).await
} else if arguments.is_object() {
DispatchStringDictAsync::dispatch(api, method, &arguments.to_string()).await
} else {
return Err(HttpError::InvalidRequest(format!(
"invalid argument type: {}",
arguments
)));
};
match result {
Ok(x) => Ok(serde_json::from_str(&x).unwrap()),
Err(Error::MethodNotFound(x)) => Err(HttpError::MethodNotFound(x)),
Err(x) => Err(HttpError::InvalidRequest(x.to_string())),
}
}
pub struct HttpClient {
client: Client,
addr: String,
}
impl HttpClient {
pub fn new(addr: String, client: Client) -> Self {
HttpClient { client, addr }
}
}
#[async_trait]
impl StubCall for HttpClient {
type Error = anyhow::Error;
async fn call(&self, method: &'static str, params: String) -> Result<String, Self::Error> {
let body = format!(
r#"{{"method": "{}",
"params": {}}}"#,
method, params
);
let response = self
.client
.request(Method::POST, &format!("http://{}", self.addr))
.header("content-type", "application/json")
.body(body)
.send()
.await?;
if response.status().as_u16() != 200 {
Err(anyhow::Error::msg(format!(
r#"HTTP request failed: "{}""#,
response.text().await?
)))
} else {
Ok(response.text().await?)
}
}
}