use std::net::SocketAddr;
use std::sync::Arc;
use async_trait::async_trait;
use atomr_core::actor::{Actor, ActorRef, Context};
use axum::extract::State;
use axum::http::StatusCode;
use axum::response::{IntoResponse, Response};
use axum::routing::{get, post};
use axum::{Json, Router};
use serde::{Deserialize, Serialize};
use tokio::net::TcpListener;
use tokio::sync::oneshot;
use atomr_infer_core::batch::{ExecuteBatch, Message, MessageContent, Role, SamplingParams};
use crate::dp_coordinator::DpCoordinatorMsg;
#[derive(Clone)]
pub struct GatewayConfig {
pub bind: SocketAddr,
}
impl Default for GatewayConfig {
fn default() -> Self {
Self {
bind: SocketAddr::from(([127, 0, 0, 1], 8080)),
}
}
}
pub enum ApiGatewayMsg {
Stop,
}
#[derive(Clone)]
struct AppState {
coordinator: ActorRef<DpCoordinatorMsg>,
}
#[derive(Debug, Deserialize)]
struct ChatRequest {
model: String,
messages: Vec<ChatMessage>,
#[serde(default)]
stream: bool,
#[serde(default)]
temperature: Option<f32>,
#[serde(default)]
max_tokens: Option<u32>,
}
#[derive(Debug, Deserialize)]
struct ChatMessage {
role: String,
content: String,
}
#[derive(Debug, Serialize)]
struct ChatErrorResponse {
error: ChatError,
}
#[derive(Debug, Serialize)]
struct ChatError {
message: String,
#[serde(rename = "type")]
kind: String,
}
pub struct ApiGatewayActor {
config: GatewayConfig,
coordinator: ActorRef<DpCoordinatorMsg>,
shutdown_tx: Option<oneshot::Sender<()>>,
}
impl ApiGatewayActor {
pub fn new(config: GatewayConfig, coordinator: ActorRef<DpCoordinatorMsg>) -> Self {
Self {
config,
coordinator,
shutdown_tx: None,
}
}
}
#[async_trait]
impl Actor for ApiGatewayActor {
type Msg = ApiGatewayMsg;
async fn pre_start(&mut self, _ctx: &mut Context<Self>) {
let bind = self.config.bind;
let state = AppState {
coordinator: self.coordinator.clone(),
};
let app = Router::new()
.route("/healthz", get(|| async { "ok" }))
.route("/v1/chat/completions", post(chat_completions))
.with_state(state);
let listener = match TcpListener::bind(bind).await {
Ok(l) => l,
Err(e) => {
tracing::error!(?e, "gateway bind failed");
return;
}
};
let (tx, rx) = oneshot::channel();
self.shutdown_tx = Some(tx);
tokio::spawn(async move {
tracing::info!(%bind, "gateway listening");
let server = axum::serve(listener, app);
let _ = tokio::select! {
r = server => r,
_ = rx => Ok(()),
};
});
}
async fn handle(&mut self, ctx: &mut Context<Self>, msg: Self::Msg) {
match msg {
ApiGatewayMsg::Stop => {
if let Some(tx) = self.shutdown_tx.take() {
let _ = tx.send(());
}
ctx.stop_self();
}
}
}
async fn post_stop(&mut self, _ctx: &mut Context<Self>) {
if let Some(tx) = self.shutdown_tx.take() {
let _ = tx.send(());
}
}
}
pub fn spawn_gateway(
sys: &atomr_core::actor::ActorSystem,
config: GatewayConfig,
coordinator: ActorRef<DpCoordinatorMsg>,
) -> Result<ActorRef<ApiGatewayMsg>, atomr_core::actor::ActorSystemError> {
use atomr_core::actor::Props;
let coord = Arc::new(coordinator);
let cfg = Arc::new(config);
let props = Props::create(move || ApiGatewayActor::new((*cfg).clone(), (*coord).clone()));
sys.actor_of(props, "gateway")
}
async fn chat_completions(State(state): State<AppState>, Json(req): Json<ChatRequest>) -> Response {
let messages = req
.messages
.into_iter()
.map(|m| Message {
role: parse_role(&m.role),
content: MessageContent::Text(m.content),
})
.collect();
let batch = ExecuteBatch {
request_id: format!("req-{}", chrono::Utc::now().timestamp_nanos_opt().unwrap_or(0)),
model: req.model.clone(),
messages,
sampling: SamplingParams {
temperature: req.temperature,
max_tokens: req.max_tokens,
..Default::default()
},
stream: req.stream,
estimated_tokens: 256,
};
let route = state
.coordinator
.ask_with(
|reply| DpCoordinatorMsg::RouteTo {
deployment: req.model.clone(),
reply,
},
std::time::Duration::from_secs(2),
)
.await;
match route {
Ok(Ok(_target)) => {
let body = serde_json::json!({
"id": batch.request_id,
"model": batch.model,
"object": "chat.completion",
"choices": [{
"index": 0,
"message": {"role": "assistant", "content": ""},
"finish_reason": "stop"
}],
});
(StatusCode::OK, Json(body)).into_response()
}
Ok(Err(e)) => bad_request(e.to_string(), "no_route"),
Err(_) => bad_request("coordinator timeout".into(), "internal_error"),
}
}
fn bad_request(msg: String, kind: &str) -> Response {
(
StatusCode::BAD_REQUEST,
Json(ChatErrorResponse {
error: ChatError {
message: msg,
kind: kind.into(),
},
}),
)
.into_response()
}
fn parse_role(s: &str) -> Role {
match s {
"system" => Role::System,
"assistant" => Role::Assistant,
"tool" => Role::Tool,
_ => Role::User,
}
}