use std::collections::HashMap;
use std::path::PathBuf;
use std::str::FromStr;
use aimo_client::types::providers::{RegisterProviderRequest, RegisterProviderResponse};
use aimo_core::utils::id::create_keypair_from_file;
use aimo_core::{
keys::SecretKeyV1,
transport::{Request, Response},
};
use anyhow::{Context, Result, anyhow};
use canonical_json;
use futures_util::{SinkExt, stream::StreamExt};
use reqwest::{Client, Method};
use serde::Deserialize;
use serde_json;
use solana_sdk::{pubkey::Pubkey, signature::Keypair, signer::Signer};
use tokio::sync::mpsc::{self, UnboundedSender};
use tokio_tungstenite::{
connect_async,
tungstenite::{self, Message},
};
use tracing::{debug, error, info, warn};
use url::Url;
use crate::config::ProxyConfig;
pub async fn serve_websocket(
node_url: String,
secret_key: String,
endpoint_url: String,
api_key: Option<String>,
) -> anyhow::Result<()> {
info!("Starting proxy service...");
info!("Node URL: {}", node_url);
info!("Endpoint URL: {}", endpoint_url);
let ws_url = build_websocket_url(&node_url, &secret_key)?;
info!("Connecting to WebSocket: {}", ws_url);
let url = url::Url::parse(&ws_url)?;
debug!("=== WEBSOCKET CONNECTION SETUP ===");
debug!("Target URL: {}", ws_url);
debug!("Host: {}", url.host_str().unwrap_or("localhost"));
debug!(
"Secret key prefix: {}...",
&secret_key[..std::cmp::min(10, secret_key.len())]
);
let websocket_key = tungstenite::handshake::client::generate_key();
debug!("Generated WebSocket key: {}", websocket_key);
let request = tungstenite::http::Request::builder()
.method("GET")
.uri(ws_url.as_str())
.header("Host", url.host_str().unwrap_or("localhost"))
.header("Upgrade", "websocket")
.header("Connection", "Upgrade")
.header("Sec-WebSocket-Key", &websocket_key)
.header("Sec-WebSocket-Version", "13")
.header("Authorization", format!("Bearer {}", secret_key))
.body(())?;
debug!("=== WEBSOCKET REQUEST HEADERS ===");
for (name, value) in request.headers() {
if name == "authorization" {
debug!(
"{}: Bearer {}...",
name,
&secret_key[..std::cmp::min(10, secret_key.len())]
);
} else {
debug!("{}: {:?}", name, value);
}
}
debug!("Attempting WebSocket connection...");
let connection_result = connect_async(request).await;
match connection_result {
Ok((ws_stream, response)) => {
info!("WebSocket connection established successfully");
debug!("=== WEBSOCKET RESPONSE ===");
debug!("Status: {}", response.status());
debug!("Response headers:");
for (name, value) in response.headers() {
debug!(" {}: {:?}", name, value);
}
let (mut ws_sender, mut ws_receiver) = ws_stream.split();
let http_client = Client::new();
let (response_tx, mut response_rx) = mpsc::unbounded_channel::<Message>();
let sender_task = tokio::spawn(async move {
while let Some(message) = response_rx.recv().await {
if let Err(e) = ws_sender.send(message).await {
error!("Failed to send message through WebSocket: {}", e);
break;
}
}
debug!("WebSocket sender task terminated");
});
loop {
match ws_receiver.next().await {
Some(Ok(Message::Text(text))) => {
debug!("Received message: {}", text);
let request: Request = match serde_json::from_str::<Request>(&text) {
Ok(req) => {
info!(
"Parsed request successfully - ID: {}, Method: {}, Type: {}",
req.request_id, req.method, req.request_type
);
req
}
Err(e) => {
warn!("Failed to parse request: {}", e);
warn!("Raw message was: {}", text);
continue;
}
}; let endpoint_url = endpoint_url.clone();
let api_key = api_key.clone();
let client = http_client.clone();
let response_sender = response_tx.clone();
tokio::spawn(async move {
if let Err(e) = handle_request(
client,
request,
endpoint_url,
api_key,
response_sender,
)
.await
{
error!("Error handling request: {}", e);
}
});
}
Some(Ok(Message::Close(_))) => {
info!("WebSocket connection closed by server");
break;
}
Some(Ok(_)) => {
debug!("Received non-text message, ignoring");
}
Some(Err(e)) => {
error!("WebSocket error: {}", e);
break;
}
None => {
info!("WebSocket stream ended");
break;
}
}
}
sender_task.abort();
info!("WebSocket connection terminated, proxy service stopped");
Err(anyhow!("WebSocket connection lost"))
}
Err(e) => {
error!("=== WEBSOCKET CONNECTION FAILED ===");
error!("Connection error: {}", e);
match &e {
tungstenite::Error::Http(http_response) => {
error!("HTTP error during WebSocket handshake:");
error!("Status: {}", http_response.status());
error!("Response headers:");
for (name, value) in http_response.headers() {
error!(" {}: {:?}", name, value);
}
if let Some(body) = http_response.body() {
if !body.is_empty() {
error!(
"Response body: {:?}",
std::str::from_utf8(body).unwrap_or("Invalid UTF-8")
);
}
}
}
tungstenite::Error::Url(url_error) => {
error!("URL error: {}", url_error);
}
tungstenite::Error::Tls(tls_error) => {
error!("TLS error: {}", tls_error);
}
tungstenite::Error::Io(io_error) => {
error!("IO error: {}", io_error);
}
_ => {
error!("Other WebSocket error: {}", e);
}
}
Err(anyhow!("Failed to connect to WebSocket: {}", e))
}
}
}
fn build_websocket_url(node_url: &str, _secret_key: &str) -> Result<String> {
let mut url = Url::parse(node_url)?;
match url.scheme() {
"http" => url
.set_scheme("ws")
.map_err(|_| anyhow!("Invalid scheme"))?,
"https" => url
.set_scheme("wss")
.map_err(|_| anyhow!("Invalid scheme"))?,
"ws" | "wss" => {} _ => return Err(anyhow!("Unsupported URL scheme: {}", url.scheme())),
}
url.set_path("/api/v1/providers/subscribe");
Ok(url.to_string())
}
async fn handle_request(
client: Client,
request: Request,
endpoint_url: String,
api_key: Option<String>,
response_sender: UnboundedSender<Message>,
) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
debug!("Handling request ID: {}", request.request_id);
let method_str = request.method.to_uppercase();
let method = match method_str.as_str() {
"GET" => Method::GET,
"POST" => Method::POST,
"PUT" => Method::PUT,
"DELETE" => Method::DELETE,
"PATCH" => Method::PATCH,
"HEAD" => Method::HEAD,
"OPTIONS" => Method::OPTIONS,
_ => {
warn!("Unsupported HTTP method: {}", request.method);
send_error_response(
&response_sender,
&request.request_id,
400,
"Unsupported HTTP method",
)?;
return Ok(());
}
};
let target_url = if let Some(endpoint) = &request.endpoint {
format!(
"{}/{}",
endpoint_url.trim_end_matches('/'),
endpoint.trim_start_matches('/')
)
} else {
endpoint_url.clone()
};
debug!("=== REQUEST DETAILS ===");
debug!("Method: {}", method_str);
debug!("Target URL: {}", target_url);
debug!("Original endpoint: {:?}", request.endpoint);
debug!("Request type: {}", request.request_type);
debug!("Request headers ({} total):", request.headers.len());
for (key, value) in &request.headers {
debug!(" {}: {}", key, value);
}
debug!("Request payload length: {} bytes", request.payload.len());
if !request.payload.is_empty() {
debug!("Request payload (full): {}", request.payload);
if let Ok(json_value) = serde_json::from_str::<serde_json::Value>(&request.payload) {
debug!(
"Request payload (formatted JSON): {}",
serde_json::to_string_pretty(&json_value)
.unwrap_or_else(|_| request.payload.clone())
);
}
}
let mut http_request = client.request(method, &target_url);
debug!("=== HTTP REQUEST CONSTRUCTION ===");
for (key, value) in &request.headers {
debug!("Adding header: {}: {}", key, value);
http_request = http_request.header(key, value);
}
if let Some(api_key) = &api_key {
debug!(
"Adding Authorization header with API key: Bearer {}",
api_key.clone()
);
http_request = http_request.header("Authorization", format!("Bearer {}", api_key));
}
if !request.payload.is_empty() {
let content_type = request
.headers
.get("content-type")
.or_else(|| request.headers.get("Content-Type"))
.map(|s| s.as_str())
.unwrap_or("application/json");
debug!("Setting content-type to: {}", content_type);
debug!("Setting request body: {}", request.payload);
http_request = http_request
.header("Content-Type", content_type)
.body(request.payload.clone());
}
debug!("Sending HTTP request...");
debug!("=== SENDING HTTP REQUEST ===");
debug!("About to send HTTP request...");
match http_request.send().await {
Ok(response) => {
let status = response.status();
debug!("=== HTTP RESPONSE RECEIVED ===");
debug!(
"Response status: {} {}",
status.as_u16(),
status.canonical_reason().unwrap_or("")
);
let headers: HashMap<String, String> = response
.headers()
.iter()
.map(|(k, v)| (k.to_string(), v.to_str().unwrap_or("").to_string()))
.collect();
debug!("Response headers ({} total):", headers.len());
for (key, value) in &headers {
debug!(" {}: {}", key, value);
}
if status.is_client_error() || status.is_server_error() {
error!("=== HTTP ERROR RESPONSE ===");
error!(
"Status: {} {}",
status.as_u16(),
status.canonical_reason().unwrap_or("")
);
error!("Request ID: {}", request.request_id);
error!("Target URL: {}", target_url);
error!("Method: {}", method_str);
if status.as_u16() == 400 {
error!("=== 400 BAD REQUEST DETAILS ===");
let error_body = response
.text()
.await
.unwrap_or_else(|_| "Failed to read error body".to_string());
error!("Error response body: {}", error_body);
let error_response = Response {
request_id: request.request_id.to_string(),
status_code: 400,
content_type: "text/plain".to_string(),
payload: error_body,
headers,
is_stream_chunk: false,
stream_done: true,
};
let message = Message::text(serde_json::to_string(&error_response)?);
if let Err(e) = response_sender.send(message) {
warn!(
"Failed to send error response, WebSocket connection likely closed: {}",
e
);
}
return Ok(());
}
}
let content_type = headers
.get("content-type")
.or_else(|| headers.get("Content-Type"))
.unwrap_or(&"text/plain".to_string())
.clone();
if content_type.contains("text/event-stream") || content_type.contains("text/stream") {
debug!("Handling SSE stream for request {}", request.request_id);
handle_sse_stream(
response,
&response_sender,
&request.request_id,
&content_type,
headers,
)
.await?;
} else {
debug!(
"Handling regular HTTP response for request {}",
request.request_id
);
handle_regular_response(
response,
&response_sender,
&request.request_id,
&content_type,
headers,
)
.await?;
}
}
Err(e) => {
error!("HTTP request failed: {}", e);
error!("Error details: {:?}", e);
if e.is_connect() {
error!("Connection error - unable to connect to {}", target_url);
} else if e.is_timeout() {
error!("Request timeout");
} else if e.is_request() {
error!("Request construction error");
} else {
error!("Other HTTP error type");
}
send_error_response(
&response_sender,
&request.request_id,
500,
&format!("HTTP request failed: {}", e),
)?;
}
}
Ok(())
}
async fn handle_regular_response(
response: reqwest::Response,
response_sender: &UnboundedSender<Message>,
request_id: &str,
content_type: &str,
headers: HashMap<String, String>,
) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
let status_code = response.status().as_u16();
let body = response.text().await.unwrap_or_else(|_| "".to_string());
info!(
"Endpoint response - Status: {}, Content-Type: {}",
status_code, content_type
);
info!("Response body length: {} bytes", body.len());
debug!(
"Response body: {}",
if body.len() > 1000 {
format!("{}...", &body[..1000])
} else {
body.clone()
}
);
let response = Response {
request_id: request_id.to_string(),
status_code,
content_type: content_type.to_string(),
payload: body,
headers,
is_stream_chunk: false,
stream_done: true,
};
let message = Message::text(serde_json::to_string(&response)?);
if let Err(e) = response_sender.send(message) {
warn!(
"Failed to send regular response, WebSocket connection likely closed: {}",
e
);
return Err(e.into());
}
debug!("Sent regular response for request {}", request_id);
Ok(())
}
async fn handle_sse_stream(
response: reqwest::Response,
response_sender: &UnboundedSender<Message>,
request_id: &str,
content_type: &str,
headers: HashMap<String, String>,
) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
let status_code = response.status().as_u16();
info!(
"Starting SSE stream - Status: {}, Content-Type: {}",
status_code, content_type
);
let mut stream = response.bytes_stream();
let mut line_buffer = String::new();
let mut current_event_data = String::new();
#[derive(Deserialize, Clone)]
struct ChunkSchema {
choices: Vec<ChunkChoice>,
}
#[derive(Deserialize, Clone)]
struct ChunkChoice {
delta: ChunkDelta,
}
#[derive(Deserialize, Clone)]
struct ChunkDelta {
content: String,
}
let send_payload = |payload: String| -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
let chunk_delta = serde_json::from_str::<ChunkSchema>(&payload)
.ok()
.and_then(|data| data.choices.first().cloned())
.map(|first| first.delta.content.clone())
.unwrap_or("[null]".to_string());
tracing::debug!("delta: {chunk_delta}");
let response = Response {
request_id: request_id.to_string(),
status_code,
content_type: content_type.to_string(),
payload,
headers: headers.clone(),
is_stream_chunk: true,
stream_done: false,
};
let message = Message::text(serde_json::to_string(&response)?);
if response_sender.send(message).is_err() {
warn!("Failed to send stream chunk, WebSocket connection likely closed");
return Err("WebSocket connection likely closed".into());
}
debug!("Sent parsed stream chunk for request {}", request_id);
Ok(())
};
'outer: while let Some(chunk_result) = stream.next().await {
match chunk_result {
Ok(chunk) => {
let chunk_data = String::from_utf8_lossy(&chunk).to_string();
line_buffer.push_str(&chunk_data);
loop {
if let Some(pos) = line_buffer.find('\n') {
let line = line_buffer[..pos].trim_end_matches('\r').to_string();
line_buffer.drain(..=pos);
let trimmed = line.trim();
if trimmed.is_empty() {
if !current_event_data.is_empty() {
let payload = current_event_data.clone();
current_event_data.clear();
if let Err(e) = send_payload(payload) {
return Err(e);
}
}
continue;
}
if let Some(rest) = line.strip_prefix("data: ") {
if rest == "[DONE]" {
if !current_event_data.is_empty() {
let payload = current_event_data.clone();
current_event_data.clear();
if let Err(e) = send_payload(payload) {
return Err(e);
}
}
if let Err(e) = send_payload("[DONE]".to_string()) {
return Err(e);
}
break 'outer;
} else {
if !current_event_data.is_empty() {
current_event_data.push('\n');
}
current_event_data.push_str(rest);
}
}
} else {
break;
}
}
}
Err(e) => {
error!("Error reading stream chunk: {}", e);
break;
}
}
}
if !current_event_data.is_empty() {
let payload = std::mem::take(&mut current_event_data);
if let Err(e) = send_payload(payload) {
return Err(e);
}
}
info!("SSE stream completed for request {}", request_id);
let final_response = Response {
request_id: request_id.to_string(),
status_code,
content_type: content_type.to_string(),
payload: "".to_string(),
headers,
is_stream_chunk: true,
stream_done: true,
};
let message = Message::text(serde_json::to_string(&final_response)?);
if let Err(e) = response_sender.send(message) {
warn!(
"Failed to send final stream response, WebSocket connection likely closed: {}",
e
);
return Err(e.into());
}
debug!("Stream completed for request {}", request_id);
Ok(())
}
#[allow(dead_code)]
fn parse_sse_chunk(chunk_data: &str) -> Option<String> {
for line in chunk_data.lines() {
let line = line.trim();
if line.starts_with("data: ") {
let json_str = &line[6..]; if json_str == "[DONE]" {
return Some("[DONE]".to_string());
}
if !json_str.is_empty() {
return Some(json_str.to_string());
}
}
}
None
}
fn send_error_response(
response_sender: &UnboundedSender<Message>,
request_id: &str,
status_code: u16,
error_message: &str,
) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
let response = Response {
request_id: request_id.to_string(),
status_code,
content_type: "text/plain".to_string(),
payload: error_message.to_string(),
headers: HashMap::new(),
is_stream_chunk: false,
stream_done: true,
};
let message = Message::text(serde_json::to_string(&response)?);
if let Err(e) = response_sender.send(message) {
warn!(
"Failed to send error response, WebSocket connection likely closed: {}",
e
);
}
Ok(())
}
pub async fn serve_websocket_with_config(config_path: PathBuf, id: Option<PathBuf>) -> Result<()> {
info!(
"Loading proxy configuration from: {}",
config_path.display()
);
let config = ProxyConfig::from_file(&config_path)?;
let keypair = create_keypair_from_file(id)?;
info!("Registering provider before starting websocket proxy");
register_provider(&config, &keypair).await?;
info!("Provider registered successfully, starting websocket proxy with retry logic");
let mut retry_count = 0;
loop {
retry_count += 1;
if retry_count > 1 {
info!("Reconnection attempt #{}", retry_count - 1);
}
match serve_websocket(
format!("{}/api/v1/providers/subscribe", config.node_url()),
config.secret_key().to_string(),
config.endpoint_url().to_string(),
Some(config.endpoint_api_key().to_string()),
)
.await
{
Ok(_) => {
info!("WebSocket connection closed normally, reconnecting in 1 second...");
}
Err(e) => {
error!(
"WebSocket connection failed: {}, reconnecting in 1 second...",
e
);
}
}
tokio::time::sleep(tokio::time::Duration::from_secs(1)).await;
info!("Attempting to reconnect to WebSocket...");
}
}
async fn register_provider(config: &ProxyConfig, keypair: &Keypair) -> Result<()> {
debug!("=== PROVIDER REGISTRATION DEBUG ===");
debug!("Verifying secret key validity...");
let (_scope, secret_key_v1) =
SecretKeyV1::decode(config.secret_key()).context("Failed to decode secret key")?;
match secret_key_v1.verify_signature() {
Ok(_) => debug!("Secret key signature is VALID"),
Err(e) => {
error!("Secret key signature is INVALID: {}", e);
return Err(anyhow!("Secret key has invalid signature: {}", e));
}
}
let signer_pubkey =
Pubkey::from_str(&secret_key_v1.signer).context("Failed to parse signer pubkey")?;
debug!("Secret key signer pubkey: {}", signer_pubkey);
debug!("Keypair pubkey: {}", keypair.pubkey());
debug!(
"Keypair matches secret key: {}",
keypair.pubkey() == signer_pubkey
);
let metadata_json =
serde_json::to_value(&config.metadata).context("Failed to serialize provider metadata")?;
debug!(
"Metadata JSON: {}",
serde_json::to_string_pretty(&metadata_json)?
);
let canonical_metadata =
canonical_json::to_string(&metadata_json).context("Failed to create canonical JSON")?;
debug!("Canonical metadata string: {}", canonical_metadata);
debug!(
"Canonical metadata bytes: {:?}",
canonical_metadata.as_bytes()
);
let signature = keypair.sign_message(canonical_metadata.as_bytes());
debug!("Generated signature: {}", signature.to_string());
let request = RegisterProviderRequest {
metadata: config.metadata.clone(),
signature: signature.to_string(),
};
debug!(
"Registration request: {}",
serde_json::to_string_pretty(&request)?
);
let base_url = config.router.url.clone();
let register_url = format!("{}/api/v1/providers/register", base_url);
debug!("Registration URL: {}", register_url);
debug!(
"Authorization header: Bearer {}...",
&config.secret_key()[..std::cmp::min(10, config.secret_key().len())]
);
let client = reqwest::Client::new();
let response = client
.post(®ister_url)
.header("Authorization", format!("Bearer {}", config.secret_key()))
.header("Content-Type", "application/json")
.json(&request)
.send()
.await
.context("Failed to send registration request")?;
debug!(
"Received registration response with status: {}",
response.status()
);
debug!("Response headers:");
for (name, value) in response.headers() {
debug!(" {}: {:?}", name, value);
}
if response.status().is_success() {
let result: RegisterProviderResponse = response
.json()
.await
.context("Failed to parse registration response")?;
info!("Provider registration: {}", result.message);
Ok(())
} else {
let status = response.status();
let error_text = response
.text()
.await
.unwrap_or_else(|_| "Failed to read error response".to_string());
error!("=== REGISTRATION FAILED ===");
error!("Status: {}", status);
error!("Error response body: {}", error_text);
Err(anyhow!(
"Provider registration failed with status {}: {}",
status,
error_text
))
}
}