1use std::net::SocketAddr;
12use std::sync::Arc;
13
14use async_trait::async_trait;
15use axum::extract::State;
16use axum::http::StatusCode;
17use axum::response::{IntoResponse, Response};
18use axum::routing::{get, post};
19use axum::{Json, Router};
20use rakka_core::actor::{Actor, ActorRef, Context};
21use serde::{Deserialize, Serialize};
22use tokio::net::TcpListener;
23use tokio::sync::oneshot;
24
25use inference_core::batch::{ExecuteBatch, Message, MessageContent, Role, SamplingParams};
26
27use crate::dp_coordinator::DpCoordinatorMsg;
28
29#[derive(Clone)]
30pub struct GatewayConfig {
31 pub bind: SocketAddr,
32}
33
34impl Default for GatewayConfig {
35 fn default() -> Self {
36 Self {
37 bind: SocketAddr::from(([127, 0, 0, 1], 8080)),
38 }
39 }
40}
41
42pub enum ApiGatewayMsg {
43 Stop,
44}
45
46#[derive(Clone)]
47struct AppState {
48 coordinator: ActorRef<DpCoordinatorMsg>,
49}
50
51#[derive(Debug, Deserialize)]
52struct ChatRequest {
53 model: String,
54 messages: Vec<ChatMessage>,
55 #[serde(default)]
56 stream: bool,
57 #[serde(default)]
58 temperature: Option<f32>,
59 #[serde(default)]
60 max_tokens: Option<u32>,
61}
62
63#[derive(Debug, Deserialize)]
64struct ChatMessage {
65 role: String,
66 content: String,
67}
68
69#[derive(Debug, Serialize)]
70struct ChatErrorResponse {
71 error: ChatError,
72}
73
74#[derive(Debug, Serialize)]
75struct ChatError {
76 message: String,
77 #[serde(rename = "type")]
78 kind: String,
79}
80
81pub struct ApiGatewayActor {
82 config: GatewayConfig,
83 coordinator: ActorRef<DpCoordinatorMsg>,
84 shutdown_tx: Option<oneshot::Sender<()>>,
86}
87
88impl ApiGatewayActor {
89 pub fn new(config: GatewayConfig, coordinator: ActorRef<DpCoordinatorMsg>) -> Self {
90 Self {
91 config,
92 coordinator,
93 shutdown_tx: None,
94 }
95 }
96}
97
98#[async_trait]
99impl Actor for ApiGatewayActor {
100 type Msg = ApiGatewayMsg;
101
102 async fn pre_start(&mut self, _ctx: &mut Context<Self>) {
103 let bind = self.config.bind;
104 let state = AppState {
105 coordinator: self.coordinator.clone(),
106 };
107 let app = Router::new()
108 .route("/healthz", get(|| async { "ok" }))
109 .route("/v1/chat/completions", post(chat_completions))
110 .with_state(state);
111 let listener = match TcpListener::bind(bind).await {
112 Ok(l) => l,
113 Err(e) => {
114 tracing::error!(?e, "gateway bind failed");
115 return;
116 }
117 };
118 let (tx, rx) = oneshot::channel();
119 self.shutdown_tx = Some(tx);
120 tokio::spawn(async move {
121 tracing::info!(%bind, "gateway listening");
122 let server = axum::serve(listener, app);
123 let _ = tokio::select! {
124 r = server => r,
125 _ = rx => Ok(()),
126 };
127 });
128 }
129
130 async fn handle(&mut self, ctx: &mut Context<Self>, msg: Self::Msg) {
131 match msg {
132 ApiGatewayMsg::Stop => {
133 if let Some(tx) = self.shutdown_tx.take() {
134 let _ = tx.send(());
135 }
136 ctx.stop_self();
137 }
138 }
139 }
140
141 async fn post_stop(&mut self, _ctx: &mut Context<Self>) {
142 if let Some(tx) = self.shutdown_tx.take() {
143 let _ = tx.send(());
144 }
145 }
146}
147
148pub fn spawn_gateway(
150 sys: &rakka_core::actor::ActorSystem,
151 config: GatewayConfig,
152 coordinator: ActorRef<DpCoordinatorMsg>,
153) -> Result<ActorRef<ApiGatewayMsg>, rakka_core::actor::ActorSystemError> {
154 use rakka_core::actor::Props;
155 let coord = Arc::new(coordinator);
156 let cfg = Arc::new(config);
157 let props = Props::create(move || ApiGatewayActor::new((*cfg).clone(), (*coord).clone()));
158 sys.actor_of(props, "gateway")
159}
160
161async fn chat_completions(State(state): State<AppState>, Json(req): Json<ChatRequest>) -> Response {
162 let messages = req
163 .messages
164 .into_iter()
165 .map(|m| Message {
166 role: parse_role(&m.role),
167 content: MessageContent::Text(m.content),
168 })
169 .collect();
170 let batch = ExecuteBatch {
171 request_id: format!("req-{}", chrono::Utc::now().timestamp_nanos_opt().unwrap_or(0)),
172 model: req.model.clone(),
173 messages,
174 sampling: SamplingParams {
175 temperature: req.temperature,
176 max_tokens: req.max_tokens,
177 ..Default::default()
178 },
179 stream: req.stream,
180 estimated_tokens: 256,
181 };
182
183 let route = state
185 .coordinator
186 .ask_with(
187 |reply| DpCoordinatorMsg::RouteTo {
188 deployment: req.model.clone(),
189 reply,
190 },
191 std::time::Duration::from_secs(2),
192 )
193 .await;
194 match route {
195 Ok(Ok(_target)) => {
196 let body = serde_json::json!({
202 "id": batch.request_id,
203 "model": batch.model,
204 "object": "chat.completion",
205 "choices": [{
206 "index": 0,
207 "message": {"role": "assistant", "content": ""},
208 "finish_reason": "stop"
209 }],
210 });
211 (StatusCode::OK, Json(body)).into_response()
212 }
213 Ok(Err(e)) => bad_request(e.to_string(), "no_route"),
214 Err(_) => bad_request("coordinator timeout".into(), "internal_error"),
215 }
216}
217
218fn bad_request(msg: String, kind: &str) -> Response {
219 (
220 StatusCode::BAD_REQUEST,
221 Json(ChatErrorResponse {
222 error: ChatError {
223 message: msg,
224 kind: kind.into(),
225 },
226 }),
227 )
228 .into_response()
229}
230
231fn parse_role(s: &str) -> Role {
232 match s {
233 "system" => Role::System,
234 "assistant" => Role::Assistant,
235 "tool" => Role::Tool,
236 _ => Role::User,
237 }
238}