use crate::api_sse;
use crate::api_v2;
use crate::api_ws;
use super::node_commands::NodeCommand;
use async_channel::Sender;
use hyper::server::conn::Http;
use reqwest::StatusCode;
use serde::Deserialize;
use serde::Serialize;
use serde_json::json;
use hanzo_messages::hanzo_utils::hanzo_logging::hanzo_log;
use hanzo_messages::hanzo_utils::hanzo_logging::HanzoLogLevel;
use hanzo_messages::hanzo_utils::hanzo_logging::HanzoLogOption;
use std::net::SocketAddr;
use std::sync::Arc;
use tokio::net::TcpListener;
use tokio_rustls::rustls::{self, ServerConfig};
use tokio_rustls::TlsAcceptor;
use utoipa::ToSchema;
use warp::filters::compression;
use warp::Filter;
#[derive(serde::Serialize, ToSchema, Debug, Clone)]
pub struct SendResponseBodyData {
pub message_id: String,
pub parent_message_id: Option<String>,
pub inbox: String,
pub scheduled_time: String,
}
#[derive(serde::Serialize, ToSchema, Debug, Clone)]
pub struct SendResponseBody {
pub status: String,
pub message: String,
pub data: Option<SendResponseBodyData>,
}
#[derive(Debug, Serialize, Deserialize, ToSchema, Clone)]
pub struct APIUseRegistrationCodeSuccessResponse {
pub message: String,
pub node_name: String,
pub encryption_public_key: String,
pub identity_public_key: String,
pub api_v2_key: String,
pub api_v2_cert: Option<String>,
}
#[derive(serde::Serialize, ToSchema, Debug, Clone)]
pub struct GetPublicKeysResponse {
pub signature_public_key: String,
pub encryption_public_key: String,
}
#[derive(Serialize, ToSchema, Debug, Clone)]
pub struct APIError {
pub code: u16,
pub error: String,
pub message: String,
}
impl APIError {
pub fn new(code: StatusCode, error: &str, message: &str) -> Self {
Self {
code: code.as_u16(),
error: error.to_string(),
message: message.to_string(),
}
}
}
impl From<&str> for APIError {
fn from(error: &str) -> Self {
APIError {
code: StatusCode::BAD_REQUEST.as_u16(),
error: "Bad Request".to_string(),
message: error.to_string(),
}
}
}
impl From<async_channel::SendError<NodeCommand>> for APIError {
fn from(error: async_channel::SendError<NodeCommand>) -> Self {
APIError {
code: StatusCode::INTERNAL_SERVER_ERROR.as_u16(),
error: "Internal Server Error".to_string(),
message: format!("Failed with error: {}", error),
}
}
}
impl From<String> for APIError {
fn from(error: String) -> Self {
APIError {
code: StatusCode::INTERNAL_SERVER_ERROR.as_u16(),
error: "Internal Server Error".to_string(),
message: error,
}
}
}
impl warp::reject::Reject for APIError {}
pub async fn run_api(
node_commands_sender: Sender<NodeCommand>,
address: SocketAddr,
https_address: SocketAddr,
ws_address: SocketAddr,
node_name: String,
private_https_certificate: Option<String>,
public_https_certificate: Option<String>,
) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
hanzo_log(
HanzoLogOption::Api,
HanzoLogLevel::Info,
&format!("Starting Node API server at: {}", &address),
);
let log = warp::log::custom(|info| {
hanzo_log(
HanzoLogOption::Api,
HanzoLogLevel::Info,
&format!(
"ip: {:?}, method: {:?}, path: {:?}, status: {:?}, elapsed: {:?}",
info.remote_addr(),
info.method(),
info.path(),
info.status(),
info.elapsed(),
),
);
});
let cors = warp::cors()
.allow_any_origin()
.allow_methods(vec!["GET", "POST", "OPTIONS", "DELETE"])
.allow_headers(vec![
"Content-Type",
"Authorization",
"x-hanzo-tool-id",
"x-hanzo-app-id",
"x-hanzo-llm-provider",
"x-hanzo-original-tool-router-key",
"ngrok-skip-browser-warning",
]);
let v2_routes = warp::path("v2").and(
api_v2::api_v2_router::v2_routes(node_commands_sender.clone(), node_name.clone())
.recover(handle_rejection)
.with(log)
.with(cors.clone())
.with(compression::gzip()),
);
let mcp_routes = warp::path("mcp").and(
api_sse::api_sse_routes::mcp_sse_routes(node_commands_sender.clone(), node_name.clone())
.recover(handle_rejection)
.with(log)
.with(cors.clone()),
);
let ws_routes = warp::path("ws").and(
api_ws::api_ws_routes::ws_routes(ws_address)
.recover(handle_rejection)
.with(log)
.with(cors.clone()),
);
let routes = v2_routes.or(mcp_routes).or(ws_routes).with(log).with(cors);
let http_server = async {
warp::serve(routes.clone()).run(address).await;
Ok::<(), Box<dyn std::error::Error + Send + Sync>>(())
};
if let (Some(cert_string), Some(key_string)) = (public_https_certificate, private_https_certificate) {
if !cert_string.is_empty() && !key_string.is_empty() {
eprintln!("cert_string: {}", cert_string);
eprintln!("key_string: {}", key_string);
let certs = rustls_pemfile::certs(&mut cert_string.as_bytes())
.map_err(|_| "Failed to parse certificates")?
.into_iter()
.map(rustls::Certificate)
.collect();
let key = rustls_pemfile::pkcs8_private_keys(&mut key_string.as_bytes())
.map_err(|_| "Failed to parse private key")?
.into_iter()
.next()
.ok_or("No private key found")
.map(rustls::PrivateKey)?;
let config = ServerConfig::builder()
.with_safe_defaults()
.with_no_client_auth()
.with_single_cert(certs, key)?;
let tls_acceptor = TlsAcceptor::from(Arc::new(config));
let listener = TcpListener::bind(&https_address).await?;
let https_server = async {
loop {
let (stream, _) = listener.accept().await?;
let acceptor = tls_acceptor.clone();
let routes = routes.clone();
tokio::spawn(async move {
match acceptor.accept(stream).await {
Ok(stream) => {
let service = warp::service(routes);
if let Err(e) = Http::new().serve_connection(stream, service).await {
hanzo_log(
HanzoLogOption::Api,
HanzoLogLevel::Error,
&format!("Error serving connection: {:?}", e),
);
}
}
Err(e) => {
hanzo_log(
HanzoLogOption::Api,
HanzoLogLevel::Error,
&format!("TLS handshake failed: {:?}", e),
);
}
}
});
}
#[allow(unreachable_code)]
Ok::<(), Box<dyn std::error::Error + Send + Sync>>(())
};
tokio::try_join!(http_server, https_server)?;
}
} else {
http_server.await?;
}
Ok(())
}
pub async fn handle_node_command<T, U, V>(
node_commands_sender: Sender<NodeCommand>,
message: V,
command: T,
) -> Result<impl warp::Reply, warp::reject::Rejection>
where
T: FnOnce(Sender<NodeCommand>, V, Sender<Result<U, APIError>>) -> NodeCommand,
U: Serialize,
V: Serialize,
{
let (res_sender, res_receiver) = async_channel::bounded(1);
node_commands_sender
.clone()
.send(command(node_commands_sender, message, res_sender))
.await
.map_err(|_| warp::reject::reject())?;
let result = res_receiver.recv().await.map_err(|_| warp::reject::reject())?;
match result {
Ok(message) => Ok(warp::reply::with_status(
warp::reply::json(&json!({"status": "success", "data": message})),
StatusCode::OK,
)),
Err(error) => Ok(warp::reply::with_status(
warp::reply::json(&json!({"status": "error", "error": error.message})),
StatusCode::from_u16(error.code).unwrap(),
)),
}
}
async fn handle_rejection(err: warp::Rejection) -> Result<impl warp::Reply, warp::Rejection> {
if let Some(api_error) = err.find::<APIError>() {
let json = warp::reply::json(api_error);
Ok(warp::reply::with_status(
json,
StatusCode::from_u16(api_error.code).unwrap(),
))
} else if err.is_not_found() {
let json = warp::reply::json(&APIError::new(
StatusCode::NOT_FOUND,
"Not Found",
"Please check your URL.",
));
Ok(warp::reply::with_status(json, StatusCode::NOT_FOUND))
} else if let Some(body_err) = err.find::<warp::filters::body::BodyDeserializeError>() {
let json = warp::reply::json(&APIError::new(
StatusCode::BAD_REQUEST,
"Invalid Body",
&format!("Deserialization error: {}", body_err),
));
Ok(warp::reply::with_status(json, StatusCode::BAD_REQUEST))
} else if err.find::<warp::reject::MethodNotAllowed>().is_some() {
let json = warp::reply::json(&APIError::new(
StatusCode::METHOD_NOT_ALLOWED,
"Method Not Allowed",
"Please check your request method.",
));
Ok(warp::reply::with_status(json, StatusCode::METHOD_NOT_ALLOWED))
} else if err.find::<warp::reject::PayloadTooLarge>().is_some() {
let json = warp::reply::json(&APIError::new(
StatusCode::PAYLOAD_TOO_LARGE,
"Payload Too Large",
"The request payload is too large.",
));
Ok(warp::reply::with_status(json, StatusCode::PAYLOAD_TOO_LARGE))
} else if err.find::<warp::reject::InvalidQuery>().is_some() {
let json = warp::reply::json(&APIError::new(
StatusCode::BAD_REQUEST,
"Invalid Query",
"The request query string is invalid.",
));
Ok(warp::reply::with_status(json, StatusCode::BAD_REQUEST))
} else {
let json = warp::reply::json(&APIError::new(
StatusCode::INTERNAL_SERVER_ERROR,
"Internal Server Error",
"An unexpected error occurred. Please try again.",
));
Ok(warp::reply::with_status(json, StatusCode::INTERNAL_SERVER_ERROR))
}
}