use std::collections::HashMap;
use std::net::SocketAddr;
use std::sync::atomic::{AtomicUsize, Ordering};
use std::sync::Arc;
use std::time::{Duration, Instant};
use axum::{
extract::State,
http::{HeaderMap, StatusCode},
response::{IntoResponse, Response},
routing::{get, post},
Json, Router,
};
use reqwest::Client;
use serde::{Deserialize, Serialize};
use serde_json::{json, Value};
use tokio::net::TcpListener;
use tokio::sync::Mutex;
use crate::errors::{IicpError, Result};
const DEFAULT_DIRECTORY: &str = "https://iicp.network/api";
const HEARTBEAT_INTERVAL_SECS: u64 = 30;
const NONCE_TTL_SECS: u64 = 300;
#[derive(Debug, Clone)]
pub struct NodeConfig {
pub node_id: String,
pub endpoint: String,
pub intent: String,
pub model: Option<String>,
pub region: Option<String>,
pub capabilities: Vec<String>,
pub directory_url: String,
pub timeout_ms: u64,
pub max_concurrent: usize,
}
impl NodeConfig {
pub fn new(
node_id: impl Into<String>,
endpoint: impl Into<String>,
intent: impl Into<String>,
) -> Self {
Self {
node_id: node_id.into(),
endpoint: endpoint.into(),
intent: intent.into(),
model: None,
region: None,
capabilities: vec![],
directory_url: DEFAULT_DIRECTORY.into(),
timeout_ms: 5_000,
max_concurrent: 4,
}
}
}
#[derive(Debug, Deserialize)]
pub struct TaskRequest {
pub task_id: String,
pub intent: String,
pub payload: Value,
pub constraints: Option<Value>,
pub auth: Option<Value>,
pub nonce: Option<String>,
#[serde(skip_deserializing)]
pub _trace: Option<Value>,
}
#[derive(Debug, Serialize)]
pub struct TaskResponse {
pub task_id: String,
pub status: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub result: Option<Value>,
#[serde(skip_serializing_if = "Option::is_none")]
pub error: Option<Value>,
}
pub type TaskHandlerFn = Arc<
dyn Fn(
TaskRequest,
) -> std::pin::Pin<Box<dyn std::future::Future<Output = Result<Value>> + Send>>
+ Send
+ Sync,
>;
struct AppState {
handler: TaskHandlerFn,
node_id: String,
region: String,
intent: String,
model: String,
active_jobs: Arc<AtomicUsize>,
max_concurrent: usize,
nonce_cache: Arc<Mutex<HashMap<String, Instant>>>,
}
async fn health_endpoint(State(state): State<Arc<AppState>>) -> impl IntoResponse {
let active = state.active_jobs.load(Ordering::Relaxed);
Json(json!({
"status": "ok",
"node_id": state.node_id,
"region": state.region,
"load": (active as f64 / state.max_concurrent.max(1) as f64),
"active_jobs": active,
"max_concurrent": state.max_concurrent,
"available": active < state.max_concurrent,
"model": state.model,
"intent": state.intent,
}))
}
async fn metrics_endpoint() -> Response {
#[cfg(feature = "metrics")]
{
use prometheus::{Encoder, TextEncoder};
let encoder = TextEncoder::new();
let mf = prometheus::gather();
let mut buf = Vec::new();
if encoder.encode(&mf, &mut buf).is_ok() {
return (
StatusCode::OK,
[(
axum::http::header::CONTENT_TYPE,
"text/plain; version=0.0.4",
)],
buf,
)
.into_response();
}
}
(
StatusCode::SERVICE_UNAVAILABLE,
"metrics feature not enabled",
)
.into_response()
}
async fn task_endpoint(
State(state): State<Arc<AppState>>,
headers: HeaderMap,
Json(mut req): Json<TaskRequest>,
) -> Response {
let prev = state.active_jobs.fetch_add(1, Ordering::Relaxed);
if prev >= state.max_concurrent {
state.active_jobs.fetch_sub(1, Ordering::Relaxed);
return (
StatusCode::TOO_MANY_REQUESTS,
[("Retry-After", "2"), ("Content-Type", "application/json")],
Json(json!({
"error": {
"code": "IICP-E021",
"message": "capacity_exceeded",
"retry_after_ms": 2000,
}
})),
)
.into_response();
}
if let Some(ref nonce) = req.nonce {
let mut cache = state.nonce_cache.lock().await;
cache.retain(|_, inserted_at| inserted_at.elapsed().as_secs() < NONCE_TTL_SECS);
if cache.contains_key(nonce) {
state.active_jobs.fetch_sub(1, Ordering::Relaxed);
return (
StatusCode::CONFLICT,
Json(json!({
"error": { "code": "IICP-E011", "message": "replay_detected" }
})),
)
.into_response();
}
cache.insert(nonce.clone(), Instant::now());
}
if let Some(tp) = headers.get("traceparent").and_then(|v| v.to_str().ok()) {
req._trace = Some(json!({ "traceparent": tp }));
}
let task_id = req.task_id.clone();
let result = (state.handler)(req).await;
state.active_jobs.fetch_sub(1, Ordering::Relaxed);
match result {
Ok(value) => Json(TaskResponse {
task_id,
status: "completed".into(),
result: Some(value),
error: None,
})
.into_response(),
Err(e) => (
StatusCode::INTERNAL_SERVER_ERROR,
Json(TaskResponse {
task_id,
status: "error".into(),
result: None,
error: Some(json!({ "message": e.to_string() })),
}),
)
.into_response(),
}
}
pub struct IicpNode {
cfg: NodeConfig,
http: Client,
}
impl IicpNode {
pub fn new(cfg: NodeConfig) -> Self {
let http = Client::builder()
.timeout(Duration::from_millis(cfg.timeout_ms + 2_000))
.use_rustls_tls()
.build()
.expect("failed to build HTTP client");
Self { cfg, http }
}
pub async fn register(&self) -> Result<String> {
let mut payload = json!({
"node_id": self.cfg.node_id,
"endpoint": self.cfg.endpoint,
"intent": self.cfg.intent,
});
if let Some(model) = &self.cfg.model {
payload["model"] = json!(model);
}
if let Some(region) = &self.cfg.region {
payload["region"] = json!(region);
}
if !self.cfg.capabilities.is_empty() {
payload["capabilities"] = json!(self.cfg.capabilities);
}
let resp = self
.http
.post(format!(
"{}/v1/register",
self.cfg.directory_url.trim_end_matches('/')
))
.json(&payload)
.send()
.await
.map_err(|e| IicpError::Node(e.to_string()))?;
if !resp.status().is_success() {
return Err(IicpError::Node(format!(
"register failed: {}",
resp.status()
)));
}
let data: Value = resp
.json()
.await
.map_err(|e| IicpError::Node(e.to_string()))?;
let token = data["node_token"]
.as_str()
.or_else(|| data["token"].as_str())
.ok_or_else(|| IicpError::Node(format!("no node_token in response: {data}")))?;
Ok(token.to_string())
}
pub async fn heartbeat(&self, node_token: &str) -> Result<()> {
let resp = self
.http
.post(format!(
"{}/api/v1/heartbeat",
self.cfg.directory_url.trim_end_matches('/')
))
.json(&json!({
"node_id": self.cfg.node_id,
"node_token": node_token,
"status": "available",
}))
.send()
.await
.map_err(|e| IicpError::Node(e.to_string()))?;
if !resp.status().is_success() {
return Err(IicpError::Node(format!(
"heartbeat failed: {}",
resp.status()
)));
}
Ok(())
}
pub async fn serve<F, Fut>(
&self,
handler: F,
addr: &str,
node_token: Option<String>,
) -> Result<()>
where
F: Fn(TaskRequest) -> Fut + Send + Sync + 'static,
Fut: std::future::Future<Output = Result<Value>> + Send + 'static,
{
let handler: TaskHandlerFn = Arc::new(move |req| Box::pin(handler(req)));
let active_jobs = Arc::new(AtomicUsize::new(0));
let nonce_cache = Arc::new(Mutex::new(HashMap::new()));
let state = Arc::new(AppState {
handler,
node_id: self.cfg.node_id.clone(),
region: self.cfg.region.clone().unwrap_or_else(|| "unknown".into()),
intent: self.cfg.intent.clone(),
model: self.cfg.model.clone().unwrap_or_default(),
active_jobs,
max_concurrent: self.cfg.max_concurrent,
nonce_cache,
});
let app = Router::new()
.route("/v1/task", post(task_endpoint))
.route("/iicp/health", get(health_endpoint))
.route("/metrics", get(metrics_endpoint))
.with_state(state);
let addr: SocketAddr = addr
.parse()
.map_err(|e| IicpError::Node(format!("invalid addr: {e}")))?;
let listener = TcpListener::bind(addr)
.await
.map_err(|e| IicpError::Node(e.to_string()))?;
tracing::info!("IICP node {} listening on {}", self.cfg.node_id, addr);
if let Some(token) = node_token {
let node_id = self.cfg.node_id.clone();
let dir = self.cfg.directory_url.clone();
let http = self.http.clone();
tokio::spawn(async move {
loop {
tokio::time::sleep(Duration::from_secs(HEARTBEAT_INTERVAL_SECS)).await;
if let Err(e) = http
.post(format!("{}/api/v1/heartbeat", dir.trim_end_matches('/')))
.json(&json!({
"node_id": &node_id,
"node_token": &token,
"status": "available",
}))
.send()
.await
{
tracing::warn!("heartbeat failed: {e}");
}
}
});
}
axum::serve(listener, app)
.await
.map_err(|e| IicpError::Node(e.to_string()))
}
}