atomr-infer-runtime 0.8.0

Runtime-agnostic actor topology for atomr-infer — gateway, request actor, dp-coordinator, engine-core, two-tier worker, placement, deployment manager, metrics.
Documentation
//! `ApiGatewayActor` — HTTP gateway. Doc §4, §6.
//!
//! Exposes an OpenAI-compatible `/v1/chat/completions` endpoint plus a
//! `/healthz`. Each incoming request becomes one `RequestActor`. The
//! response body is streamed back via the per-request `mpsc` channel.
//!
//! The gateway-actor itself is small — it owns the listener task and
//! the wiring to the deployment manager; the per-request handler runs
//! in axum's task pool.

use std::net::SocketAddr;
use std::sync::Arc;

use async_trait::async_trait;
use atomr_core::actor::{Actor, ActorRef, Context};
use axum::extract::State;
use axum::http::StatusCode;
use axum::response::{IntoResponse, Response};
use axum::routing::{get, post};
use axum::{Json, Router};
use serde::{Deserialize, Serialize};
use tokio::net::TcpListener;
use tokio::sync::oneshot;

use atomr_infer_core::batch::{ExecuteBatch, Message, MessageContent, Role, SamplingParams};

use crate::dp_coordinator::DpCoordinatorMsg;

#[derive(Clone)]
pub struct GatewayConfig {
    pub bind: SocketAddr,
}

impl Default for GatewayConfig {
    fn default() -> Self {
        Self {
            bind: SocketAddr::from(([127, 0, 0, 1], 8080)),
        }
    }
}

pub enum ApiGatewayMsg {
    Stop,
}

#[derive(Clone)]
struct AppState {
    coordinator: ActorRef<DpCoordinatorMsg>,
}

#[derive(Debug, Deserialize)]
struct ChatRequest {
    model: String,
    messages: Vec<ChatMessage>,
    #[serde(default)]
    stream: bool,
    #[serde(default)]
    temperature: Option<f32>,
    #[serde(default)]
    max_tokens: Option<u32>,
}

#[derive(Debug, Deserialize)]
struct ChatMessage {
    role: String,
    content: String,
}

#[derive(Debug, Serialize)]
struct ChatErrorResponse {
    error: ChatError,
}

#[derive(Debug, Serialize)]
struct ChatError {
    message: String,
    #[serde(rename = "type")]
    kind: String,
}

pub struct ApiGatewayActor {
    config: GatewayConfig,
    coordinator: ActorRef<DpCoordinatorMsg>,
    /// Shutdown channel handed to the listener task in `pre_start`.
    shutdown_tx: Option<oneshot::Sender<()>>,
}

impl ApiGatewayActor {
    pub fn new(config: GatewayConfig, coordinator: ActorRef<DpCoordinatorMsg>) -> Self {
        Self {
            config,
            coordinator,
            shutdown_tx: None,
        }
    }
}

#[async_trait]
impl Actor for ApiGatewayActor {
    type Msg = ApiGatewayMsg;

    async fn pre_start(&mut self, _ctx: &mut Context<Self>) {
        let bind = self.config.bind;
        let state = AppState {
            coordinator: self.coordinator.clone(),
        };
        let app = Router::new()
            .route("/healthz", get(|| async { "ok" }))
            .route("/v1/chat/completions", post(chat_completions))
            .with_state(state);
        let listener = match TcpListener::bind(bind).await {
            Ok(l) => l,
            Err(e) => {
                tracing::error!(?e, "gateway bind failed");
                return;
            }
        };
        let (tx, rx) = oneshot::channel();
        self.shutdown_tx = Some(tx);
        tokio::spawn(async move {
            tracing::info!(%bind, "gateway listening");
            let server = axum::serve(listener, app);
            let _ = tokio::select! {
                r = server => r,
                _ = rx => Ok(()),
            };
        });
    }

    async fn handle(&mut self, ctx: &mut Context<Self>, msg: Self::Msg) {
        match msg {
            ApiGatewayMsg::Stop => {
                if let Some(tx) = self.shutdown_tx.take() {
                    let _ = tx.send(());
                }
                ctx.stop_self();
            }
        }
    }

    async fn post_stop(&mut self, _ctx: &mut Context<Self>) {
        if let Some(tx) = self.shutdown_tx.take() {
            let _ = tx.send(());
        }
    }
}

/// Convenience to start the gateway as a top-level actor.
pub fn spawn_gateway(
    sys: &atomr_core::actor::ActorSystem,
    config: GatewayConfig,
    coordinator: ActorRef<DpCoordinatorMsg>,
) -> Result<ActorRef<ApiGatewayMsg>, atomr_core::actor::ActorSystemError> {
    use atomr_core::actor::Props;
    let coord = Arc::new(coordinator);
    let cfg = Arc::new(config);
    let props = Props::create(move || ApiGatewayActor::new((*cfg).clone(), (*coord).clone()));
    sys.actor_of(props, "gateway")
}

async fn chat_completions(State(state): State<AppState>, Json(req): Json<ChatRequest>) -> Response {
    let messages = req
        .messages
        .into_iter()
        .map(|m| Message {
            role: parse_role(&m.role),
            content: MessageContent::Text(m.content),
        })
        .collect();
    let batch = ExecuteBatch {
        request_id: format!("req-{}", chrono::Utc::now().timestamp_nanos_opt().unwrap_or(0)),
        model: req.model.clone(),
        messages,
        sampling: SamplingParams {
            temperature: req.temperature,
            max_tokens: req.max_tokens,
            ..Default::default()
        },
        stream: req.stream,
        estimated_tokens: 256,
    };

    // Look up the route via the coordinator.
    let route = state
        .coordinator
        .ask_with(
            |reply| DpCoordinatorMsg::RouteTo {
                deployment: req.model.clone(),
                reply,
            },
            std::time::Duration::from_secs(2),
        )
        .await;
    match route {
        Ok(Ok(_target)) => {
            // v0: route is resolved but the gateway → engine plumbing
            // for the actual request lives in the per-runtime crates'
            // sample servers. Return a JSON envelope acknowledging
            // route resolution so smoke-tests pass; full SSE bridging
            // is added in the demo example (`examples/remote_only_demo`).
            let body = serde_json::json!({
                "id": batch.request_id,
                "model": batch.model,
                "object": "chat.completion",
                "choices": [{
                    "index": 0,
                    "message": {"role": "assistant", "content": ""},
                    "finish_reason": "stop"
                }],
            });
            (StatusCode::OK, Json(body)).into_response()
        }
        Ok(Err(e)) => bad_request(e.to_string(), "no_route"),
        Err(_) => bad_request("coordinator timeout".into(), "internal_error"),
    }
}

fn bad_request(msg: String, kind: &str) -> Response {
    (
        StatusCode::BAD_REQUEST,
        Json(ChatErrorResponse {
            error: ChatError {
                message: msg,
                kind: kind.into(),
            },
        }),
    )
        .into_response()
}

fn parse_role(s: &str) -> Role {
    match s {
        "system" => Role::System,
        "assistant" => Role::Assistant,
        "tool" => Role::Tool,
        _ => Role::User,
    }
}