use std::collections::HashMap;
use std::path::PathBuf;
use std::str::FromStr;
use aimo_client::types::providers::{RegisterProviderRequest, RegisterProviderResponse};
use aimo_core::utils::id::create_keypair_from_file;
use aimo_core::{
keys::SecretKeyV1,
transport::{Request, Response},
};
use anyhow::{anyhow, Context, Result};
use canonical_json;
use futures_util::{stream::StreamExt, SinkExt};
use reqwest::{Client, Method};
use serde_json;
use solana_sdk::{pubkey::Pubkey, signature::Keypair, signer::Signer};
use tokio::sync::mpsc::{self, UnboundedSender};
use tokio_tungstenite::{
connect_async,
tungstenite::{self, Message},
};
use tracing::{debug, error, info, warn};
use url::Url;
use crate::config::ProxyConfig;
pub async fn serve_websocket(
node_url: String,
secret_key: String,
endpoint_url: String,
api_key: Option<String>,
) -> anyhow::Result<()> {
info!("Starting proxy service...");
info!("Node URL: {}", node_url);
info!("Endpoint URL: {}", endpoint_url);
let ws_url = build_websocket_url(&node_url, &secret_key)?;
info!("Connecting to WebSocket: {}", ws_url);
let url = url::Url::parse(&ws_url)?;
let request = tungstenite::http::Request::builder()
.method("GET")
.uri(ws_url.as_str())
.header("Host", url.host_str().unwrap_or("localhost"))
.header("Upgrade", "websocket")
.header("Connection", "Upgrade")
.header(
"Sec-WebSocket-Key",
tungstenite::handshake::client::generate_key(),
)
.header("Sec-WebSocket-Version", "13")
.header("Authorization", format!("Bearer {}", secret_key))
.body(())?;
let (ws_stream, _) = connect_async(request)
.await
.map_err(|e| anyhow!("Failed to connect to WebSocket: {}", e))?;
info!("WebSocket connection established");
let (mut ws_sender, mut ws_receiver) = ws_stream.split();
let http_client = Client::new();
let (response_tx, mut response_rx) = mpsc::unbounded_channel::<Message>();
let sender_task = tokio::spawn(async move {
while let Some(message) = response_rx.recv().await {
if let Err(e) = ws_sender.send(message).await {
error!("Failed to send message: {}", e);
break;
}
}
});
while let Some(message) = ws_receiver.next().await {
match message {
Ok(Message::Text(text)) => {
debug!("Received message: {}", text);
let request: Request = match serde_json::from_str::<Request>(&text) {
Ok(req) => {
info!(
"Parsed request successfully - ID: {}, Method: {}, Type: {}",
req.request_id, req.method, req.request_type
);
req
}
Err(e) => {
warn!("Failed to parse request: {}", e);
warn!("Raw message was: {}", text);
continue;
}
}; let endpoint_url = endpoint_url.clone();
let api_key = api_key.clone();
let client = http_client.clone();
let response_sender = response_tx.clone();
tokio::spawn(async move {
if let Err(e) =
handle_request(client, request, endpoint_url, api_key, response_sender)
.await
{
error!("Error handling request: {}", e);
}
});
}
Ok(Message::Close(_)) => {
info!("WebSocket connection closed by server");
break;
}
Ok(_) => {
debug!("Received non-text message, ignoring");
}
Err(e) => {
error!("WebSocket error: {}", e);
break;
}
}
}
sender_task.abort();
info!("Proxy service stopped");
Ok(())
}
fn build_websocket_url(node_url: &str, _secret_key: &str) -> Result<String> {
let mut url = Url::parse(node_url)?;
match url.scheme() {
"http" => url
.set_scheme("ws")
.map_err(|_| anyhow!("Invalid scheme"))?,
"https" => url
.set_scheme("wss")
.map_err(|_| anyhow!("Invalid scheme"))?,
"ws" | "wss" => {} _ => return Err(anyhow!("Unsupported URL scheme: {}", url.scheme())),
}
url.set_path("/api/v1/providers/subscribe");
Ok(url.to_string())
}
async fn handle_request(
client: Client,
request: Request,
endpoint_url: String,
api_key: Option<String>,
response_sender: UnboundedSender<Message>,
) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
debug!("Handling request ID: {}", request.request_id);
let method = match request.method.to_uppercase().as_str() {
"GET" => Method::GET,
"POST" => Method::POST,
"PUT" => Method::PUT,
"DELETE" => Method::DELETE,
"PATCH" => Method::PATCH,
"HEAD" => Method::HEAD,
"OPTIONS" => Method::OPTIONS,
_ => {
warn!("Unsupported HTTP method: {}", request.method);
send_error_response(
&response_sender,
&request.request_id,
400,
"Unsupported HTTP method",
)?;
return Ok(());
}
};
let target_url = if let Some(endpoint) = &request.endpoint {
format!(
"{}/{}",
endpoint_url.trim_end_matches('/'),
endpoint.trim_start_matches('/')
)
} else {
endpoint_url.clone()
};
debug!("Forwarding {} request to: {}", method, target_url);
debug!("Request headers: {:?}", request.headers);
debug!("Request payload length: {} bytes", request.payload.len());
if !request.payload.is_empty() {
debug!(
"Request payload preview: {}",
if request.payload.len() > 200 {
format!("{}...", &request.payload[..200])
} else {
request.payload.clone()
}
);
}
let mut http_request = client.request(method, &target_url);
for (key, value) in &request.headers {
debug!("Adding header: {}: {}", key, value);
http_request = http_request.header(key, value);
}
if let Some(api_key) = &api_key {
debug!("Adding Authorization header with API key");
http_request = http_request.header("Authorization", format!("Bearer {}", api_key));
}
if !request.payload.is_empty() {
let content_type = request
.headers
.get("content-type")
.or_else(|| request.headers.get("Content-Type"))
.map(|s| s.as_str())
.unwrap_or("application/json");
debug!("Setting content-type to: {}", content_type);
http_request = http_request
.header("Content-Type", content_type)
.body(request.payload);
}
debug!("Sending HTTP request...");
match http_request.send().await {
Ok(response) => {
debug!("Received HTTP response with status: {}", response.status());
let headers: HashMap<String, String> = response
.headers()
.iter()
.map(|(k, v)| (k.to_string(), v.to_str().unwrap_or("").to_string()))
.collect();
debug!("Response headers: {:?}", headers);
let content_type = headers
.get("content-type")
.or_else(|| headers.get("Content-Type"))
.unwrap_or(&"text/plain".to_string())
.clone();
if content_type.contains("text/event-stream") || content_type.contains("text/stream") {
debug!("Handling SSE stream for request {}", request.request_id);
handle_sse_stream(
response,
&response_sender,
&request.request_id,
&content_type,
headers,
)
.await?;
} else {
debug!(
"Handling regular HTTP response for request {}",
request.request_id
);
handle_regular_response(
response,
&response_sender,
&request.request_id,
&content_type,
headers,
)
.await?;
}
}
Err(e) => {
error!("HTTP request failed: {}", e);
error!("Error details: {:?}", e);
if e.is_connect() {
error!("Connection error - unable to connect to {}", target_url);
} else if e.is_timeout() {
error!("Request timeout");
} else if e.is_request() {
error!("Request construction error");
} else {
error!("Other HTTP error type");
}
send_error_response(
&response_sender,
&request.request_id,
500,
&format!("HTTP request failed: {}", e),
)?;
}
}
Ok(())
}
async fn handle_regular_response(
response: reqwest::Response,
response_sender: &UnboundedSender<Message>,
request_id: &str,
content_type: &str,
headers: HashMap<String, String>,
) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
let status_code = response.status().as_u16();
let body = response.text().await.unwrap_or_else(|_| "".to_string());
info!(
"Endpoint response - Status: {}, Content-Type: {}",
status_code, content_type
);
info!("Response body length: {} bytes", body.len());
debug!(
"Response body: {}",
if body.len() > 1000 {
format!("{}...", &body[..1000])
} else {
body.clone()
}
);
let response = Response {
request_id: request_id.to_string(),
status_code,
content_type: content_type.to_string(),
payload: body,
headers,
is_stream_chunk: false,
stream_done: true,
};
let message = Message::text(serde_json::to_string(&response)?);
response_sender.send(message)?;
debug!("Sent regular response for request {}", request_id);
Ok(())
}
async fn handle_sse_stream(
response: reqwest::Response,
response_sender: &UnboundedSender<Message>,
request_id: &str,
content_type: &str,
headers: HashMap<String, String>,
) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
let status_code = response.status().as_u16();
info!(
"Starting SSE stream - Status: {}, Content-Type: {}",
status_code, content_type
);
let mut stream = response.bytes_stream();
while let Some(chunk_result) = stream.next().await {
match chunk_result {
Ok(chunk) => {
let chunk_data = String::from_utf8_lossy(&chunk).to_string();
debug!(
"Received stream chunk ({} bytes): {}",
chunk.len(),
if chunk_data.len() > 200 {
format!("{}...", &chunk_data[..200])
} else {
chunk_data.clone()
}
);
let response = Response {
request_id: request_id.to_string(),
status_code,
content_type: content_type.to_string(),
payload: chunk_data,
headers: headers.clone(),
is_stream_chunk: true,
stream_done: false,
};
let message = Message::text(serde_json::to_string(&response)?);
if response_sender.send(message).is_err() {
error!("Failed to send stream chunk, connection closed");
break;
}
debug!("Sent stream chunk for request {}", request_id);
}
Err(e) => {
error!("Error reading stream chunk: {}", e);
break;
}
}
}
info!("SSE stream completed for request {}", request_id);
let final_response = Response {
request_id: request_id.to_string(),
status_code,
content_type: content_type.to_string(),
payload: "".to_string(),
headers,
is_stream_chunk: true,
stream_done: true,
};
let message = Message::text(serde_json::to_string(&final_response)?);
response_sender.send(message)?;
debug!("Stream completed for request {}", request_id);
Ok(())
}
fn send_error_response(
response_sender: &UnboundedSender<Message>,
request_id: &str,
status_code: u16,
error_message: &str,
) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
let response = Response {
request_id: request_id.to_string(),
status_code,
content_type: "text/plain".to_string(),
payload: error_message.to_string(),
headers: HashMap::new(),
is_stream_chunk: false,
stream_done: true,
};
let message = Message::text(serde_json::to_string(&response)?);
response_sender.send(message)?;
Ok(())
}
pub async fn serve_websocket_with_config(config_path: PathBuf, id: Option<PathBuf>) -> Result<()> {
info!(
"Loading proxy configuration from: {}",
config_path.display()
);
let config = ProxyConfig::from_file(&config_path)?;
let keypair = create_keypair_from_file(id)?;
info!("Registering provider before starting websocket proxy");
register_provider(&config, &keypair).await?;
info!("Provider registered successfully, starting websocket proxy");
serve_websocket(
format!("{}/api/v1/providers/subscribe", config.node_url()),
config.secret_key().to_string(),
config.endpoint_url().to_string(),
Some(config.endpoint_api_key().to_string()),
)
.await
}
async fn register_provider(config: &ProxyConfig, keypair: &Keypair) -> Result<()> {
let (_scope, secret_key_v1) =
SecretKeyV1::decode(config.secret_key()).context("Failed to decode secret key")?;
let _signer_pubkey =
Pubkey::from_str(&secret_key_v1.signer).context("Failed to parse signer pubkey")?;
let metadata_json =
serde_json::to_value(&config.metadata).context("Failed to serialize provider metadata")?;
let canonical_metadata =
canonical_json::to_string(&metadata_json).context("Failed to create canonical JSON")?;
let signature = keypair.sign_message(canonical_metadata.as_bytes());
let request = RegisterProviderRequest {
metadata: config.metadata.clone(),
signature: signature.to_string(),
};
let base_url = config.router.url.clone();
let register_url = format!("{}/api/v1/providers/register", base_url);
let client = reqwest::Client::new();
let response = client
.post(®ister_url)
.header("Authorization", format!("Bearer {}", config.secret_key()))
.header("Content-Type", "application/json")
.json(&request)
.send()
.await
.context("Failed to send registration request")?;
if response.status().is_success() {
let result: RegisterProviderResponse = response
.json()
.await
.context("Failed to parse registration response")?;
info!("Provider registration: {}", result.message);
Ok(())
} else {
let status = response.status();
let error_text = response
.text()
.await
.unwrap_or_else(|_| "Failed to read error response".to_string());
Err(anyhow!(
"Provider registration failed with status {}: {}",
status,
error_text
))
}
}