Skip to main content

inference_runtime/
gateway.rs

1//! `ApiGatewayActor` — HTTP gateway. Doc §4, §6.
2//!
3//! Exposes an OpenAI-compatible `/v1/chat/completions` endpoint plus a
4//! `/healthz`. Each incoming request becomes one `RequestActor`. The
5//! response body is streamed back via the per-request `mpsc` channel.
6//!
7//! The gateway-actor itself is small — it owns the listener task and
8//! the wiring to the deployment manager; the per-request handler runs
9//! in axum's task pool.
10
11use std::net::SocketAddr;
12use std::sync::Arc;
13
14use async_trait::async_trait;
15use axum::extract::State;
16use axum::http::StatusCode;
17use axum::response::{IntoResponse, Response};
18use axum::routing::{get, post};
19use axum::{Json, Router};
20use rakka_core::actor::{Actor, ActorRef, Context};
21use serde::{Deserialize, Serialize};
22use tokio::net::TcpListener;
23use tokio::sync::oneshot;
24
25use inference_core::batch::{ExecuteBatch, Message, MessageContent, Role, SamplingParams};
26
27use crate::dp_coordinator::DpCoordinatorMsg;
28
29#[derive(Clone)]
30pub struct GatewayConfig {
31    pub bind: SocketAddr,
32}
33
34impl Default for GatewayConfig {
35    fn default() -> Self {
36        Self {
37            bind: SocketAddr::from(([127, 0, 0, 1], 8080)),
38        }
39    }
40}
41
42pub enum ApiGatewayMsg {
43    Stop,
44}
45
46#[derive(Clone)]
47struct AppState {
48    coordinator: ActorRef<DpCoordinatorMsg>,
49}
50
51#[derive(Debug, Deserialize)]
52struct ChatRequest {
53    model: String,
54    messages: Vec<ChatMessage>,
55    #[serde(default)]
56    stream: bool,
57    #[serde(default)]
58    temperature: Option<f32>,
59    #[serde(default)]
60    max_tokens: Option<u32>,
61}
62
63#[derive(Debug, Deserialize)]
64struct ChatMessage {
65    role: String,
66    content: String,
67}
68
69#[derive(Debug, Serialize)]
70struct ChatErrorResponse {
71    error: ChatError,
72}
73
74#[derive(Debug, Serialize)]
75struct ChatError {
76    message: String,
77    #[serde(rename = "type")]
78    kind: String,
79}
80
81pub struct ApiGatewayActor {
82    config: GatewayConfig,
83    coordinator: ActorRef<DpCoordinatorMsg>,
84    /// Shutdown channel handed to the listener task in `pre_start`.
85    shutdown_tx: Option<oneshot::Sender<()>>,
86}
87
88impl ApiGatewayActor {
89    pub fn new(config: GatewayConfig, coordinator: ActorRef<DpCoordinatorMsg>) -> Self {
90        Self {
91            config,
92            coordinator,
93            shutdown_tx: None,
94        }
95    }
96}
97
98#[async_trait]
99impl Actor for ApiGatewayActor {
100    type Msg = ApiGatewayMsg;
101
102    async fn pre_start(&mut self, _ctx: &mut Context<Self>) {
103        let bind = self.config.bind;
104        let state = AppState {
105            coordinator: self.coordinator.clone(),
106        };
107        let app = Router::new()
108            .route("/healthz", get(|| async { "ok" }))
109            .route("/v1/chat/completions", post(chat_completions))
110            .with_state(state);
111        let listener = match TcpListener::bind(bind).await {
112            Ok(l) => l,
113            Err(e) => {
114                tracing::error!(?e, "gateway bind failed");
115                return;
116            }
117        };
118        let (tx, rx) = oneshot::channel();
119        self.shutdown_tx = Some(tx);
120        tokio::spawn(async move {
121            tracing::info!(%bind, "gateway listening");
122            let server = axum::serve(listener, app);
123            let _ = tokio::select! {
124                r = server => r,
125                _ = rx => Ok(()),
126            };
127        });
128    }
129
130    async fn handle(&mut self, ctx: &mut Context<Self>, msg: Self::Msg) {
131        match msg {
132            ApiGatewayMsg::Stop => {
133                if let Some(tx) = self.shutdown_tx.take() {
134                    let _ = tx.send(());
135                }
136                ctx.stop_self();
137            }
138        }
139    }
140
141    async fn post_stop(&mut self, _ctx: &mut Context<Self>) {
142        if let Some(tx) = self.shutdown_tx.take() {
143            let _ = tx.send(());
144        }
145    }
146}
147
148/// Convenience to start the gateway as a top-level actor.
149pub fn spawn_gateway(
150    sys: &rakka_core::actor::ActorSystem,
151    config: GatewayConfig,
152    coordinator: ActorRef<DpCoordinatorMsg>,
153) -> Result<ActorRef<ApiGatewayMsg>, rakka_core::actor::ActorSystemError> {
154    use rakka_core::actor::Props;
155    let coord = Arc::new(coordinator);
156    let cfg = Arc::new(config);
157    let props = Props::create(move || ApiGatewayActor::new((*cfg).clone(), (*coord).clone()));
158    sys.actor_of(props, "gateway")
159}
160
161async fn chat_completions(State(state): State<AppState>, Json(req): Json<ChatRequest>) -> Response {
162    let messages = req
163        .messages
164        .into_iter()
165        .map(|m| Message {
166            role: parse_role(&m.role),
167            content: MessageContent::Text(m.content),
168        })
169        .collect();
170    let batch = ExecuteBatch {
171        request_id: format!("req-{}", chrono::Utc::now().timestamp_nanos_opt().unwrap_or(0)),
172        model: req.model.clone(),
173        messages,
174        sampling: SamplingParams {
175            temperature: req.temperature,
176            max_tokens: req.max_tokens,
177            ..Default::default()
178        },
179        stream: req.stream,
180        estimated_tokens: 256,
181    };
182
183    // Look up the route via the coordinator.
184    let route = state
185        .coordinator
186        .ask_with(
187            |reply| DpCoordinatorMsg::RouteTo {
188                deployment: req.model.clone(),
189                reply,
190            },
191            std::time::Duration::from_secs(2),
192        )
193        .await;
194    match route {
195        Ok(Ok(_target)) => {
196            // v0: route is resolved but the gateway → engine plumbing
197            // for the actual request lives in the per-runtime crates'
198            // sample servers. Return a JSON envelope acknowledging
199            // route resolution so smoke-tests pass; full SSE bridging
200            // is added in the demo example (`examples/remote_only_demo`).
201            let body = serde_json::json!({
202                "id": batch.request_id,
203                "model": batch.model,
204                "object": "chat.completion",
205                "choices": [{
206                    "index": 0,
207                    "message": {"role": "assistant", "content": ""},
208                    "finish_reason": "stop"
209                }],
210            });
211            (StatusCode::OK, Json(body)).into_response()
212        }
213        Ok(Err(e)) => bad_request(e.to_string(), "no_route"),
214        Err(_) => bad_request("coordinator timeout".into(), "internal_error"),
215    }
216}
217
218fn bad_request(msg: String, kind: &str) -> Response {
219    (
220        StatusCode::BAD_REQUEST,
221        Json(ChatErrorResponse {
222            error: ChatError {
223                message: msg,
224                kind: kind.into(),
225            },
226        }),
227    )
228        .into_response()
229}
230
231fn parse_role(s: &str) -> Role {
232    match s {
233        "system" => Role::System,
234        "assistant" => Role::Assistant,
235        "tool" => Role::Tool,
236        _ => Role::User,
237    }
238}