use axum::extract::State;
use serde_json::{json, Value};
use sqlx::any::AnyRow;
use uuid::Uuid;
use chrono::Utc;
use reqwest::Client;
use std::path::Path;
use k8s_openapi::api::core::v1::Service;
use kube::{Api, Client as KubeClient};
use tokio::net::TcpListener;
use axum::{
extract::Json,
http::StatusCode,
response::IntoResponse,
};
use sqlx::Row;
use crate::models::{ChatCompletionRequest, ChatCompletionResponse, Choice, Message};
use crate::db::DbPool;
pub async fn chat_completions(
State(pool): State<DbPool>,
Json(payload): Json<ChatCompletionRequest>,
) -> impl IntoResponse {
let model_result = sqlx::query("SELECT * FROM model_instances WHERE model_name = ?")
.bind(&payload.model)
.fetch_one(&pool)
.await;
let model_row = match model_result {
Ok(row) => row,
Err(_) => {
let error_response = json!({
"error": {
"message": format!("The model '{}' does not exist.", payload.model),
"type": "invalid_request_error",
"param": "model",
"code": null
}
});
return (StatusCode::BAD_REQUEST, Json(error_response));
}
};
match call_model_backend(&model_row, &payload).await {
Ok(response_message) => {
let id = Uuid::new_v4().to_string();
let created_at = Utc::now().timestamp() as u64;
let prompt_text = serde_json::to_string(&payload.messages).unwrap_or_default();
let response_text = response_message.content.clone();
sqlx::query(
"INSERT INTO chat_histories (id, owner_id, prompt, response, created_at) VALUES (?, ?, ?, ?, ?)",
)
.bind(&id)
.bind("default_owner")
.bind(&prompt_text)
.bind(&response_text)
.bind(created_at as i64)
.execute(&pool)
.await
.ok();
let response = ChatCompletionResponse {
id: format!("chatcmpl-{}", id),
object: "chat.completion".to_string(),
created: created_at,
model: payload.model.clone(),
choices: vec![Choice {
index: 0,
message: response_message,
finish_reason: Some("stop".to_string()),
}],
usage: None, };
let response_value = serde_json::to_value(response)
.unwrap_or_else(|_| json!({"error": "Failed to serialize response"}));
(StatusCode::OK, Json(response_value))
}
Err(error_message) => {
let error_response = json!({
"error": {
"message": error_message,
"type": "internal_server_error",
"param": null,
"code": null
}
});
(StatusCode::INTERNAL_SERVER_ERROR, Json(error_response))
}
}
}
async fn call_model_backend(
model_row: &AnyRow,
payload: &ChatCompletionRequest,
) -> Result<Message, String> {
let backend_url: String = model_row.try_get("backend_url").unwrap_or_default();
let resource_name: String = model_row.try_get("resource_name").unwrap_or_default();
let resource_namespace: String = model_row.try_get("resource_namespace").unwrap_or_else(|_| "default".to_string());
if backend_url.is_empty() {
return Err("Backend URL not specified for the model.".to_string());
}
let in_cluster = Path::new("/var/run/secrets/kubernetes.io/serviceaccount/token").exists();
let backend_endpoint = if in_cluster {
format!("http://{}/chat/completions", backend_url)
} else {
match setup_port_forwarding(&resource_name, &resource_namespace).await {
Ok(port) => format!("http://localhost:{}/chat/completions", port),
Err(e) => return Err(format!("Failed to set up port forwarding: {}", e)),
}
};
let backend_payload = json!({
"messages": payload.messages,
"temperature": payload.temperature,
"top_p": payload.top_p,
"n": payload.n,
"stop": payload.stop,
"max_tokens": payload.max_tokens,
"presence_penalty": payload.presence_penalty,
"frequency_penalty": payload.frequency_penalty,
});
let client = Client::new();
let response = client
.post(&backend_endpoint)
.json(&backend_payload)
.send()
.await
.map_err(|e| format!("Failed to send request to model backend: {}", e))?;
if !response.status().is_success() {
return Err(format!(
"Model backend returned error status: {}",
response.status()
));
}
let response_json: Value = response
.json()
.await
.map_err(|e| format!("Failed to parse backend response: {}", e))?;
let content = response_json["choices"][0]["message"]["content"]
.as_str()
.unwrap_or_default()
.to_string();
Ok(Message {
role: "assistant".to_string(),
content,
})
}
async fn setup_port_forwarding(
service_name: &str,
namespace: &str,
) -> Result<u16, String> {
let kube_client = KubeClient::try_default()
.await
.map_err(|e| format!("Failed to create Kubernetes client: {}", e))?;
let services: Api<Service> = Api::namespaced(kube_client.clone(), namespace);
let service = services
.get(service_name)
.await
.map_err(|e| format!("Failed to get service: {}", e))?;
let target_port = service
.spec
.as_ref()
.and_then(|spec| spec.ports.as_ref())
.and_then(|ports| ports.get(0))
.and_then(|port| Some(port.port))
.ok_or_else(|| "Service has no ports defined".to_string())?;
let listener = TcpListener::bind("127.0.0.1:0")
.await
.map_err(|e| format!("Failed to bind local port: {}", e))?;
let local_port = listener.local_addr().unwrap().port();
let service_name_owned = service_name.to_string();
let namespace_owned = namespace.to_string();
let kube_client_owned = kube_client.clone();
let listener_owned = listener;
let target_port_owned = target_port as u16;
tokio::spawn(async move {
if let Err(e) = port_forward(
kube_client_owned,
namespace_owned,
service_name_owned,
listener_owned,
target_port_owned,
)
.await
{
eprintln!("Port forwarding error: {}", e);
}
});
Ok(local_port)
}
async fn port_forward(
kube_client: KubeClient,
namespace: String,
service_name: String,
listener: TcpListener,
target_port: u16,
) -> Result<(), String> {
let pods_api: Api<k8s_openapi::api::core::v1::Pod> =
Api::namespaced(kube_client.clone(), &namespace);
let service_api: Api<Service> = Api::namespaced(kube_client.clone(), &namespace);
let service = service_api
.get(&service_name)
.await
.map_err(|e| format!("Failed to get service: {}", e))?;
let selector = service
.spec
.as_ref()
.and_then(|spec| spec.selector.clone())
.ok_or_else(|| "Service has no selector".to_string())?;
let pods = pods_api
.list(&kube::api::ListParams::default().labels(&selector_to_string(&selector)))
.await
.map_err(|e| format!("Failed to list pods: {}", e))?;
let pod_name = pods
.items
.get(0)
.and_then(|pod| pod.metadata.name.clone())
.ok_or_else(|| "No pods found for service".to_string())?;
let mut pf = pods_api
.portforward(&pod_name, &[target_port])
.await
.map_err(|e| format!("Failed to create portforwarder: {}", e))?;
let mut stream = pf
.take_stream(target_port)
.ok_or_else(|| "Failed to get portforward stream".to_string())?;
let mut local_stream = listener.accept().await.unwrap().0;
tokio::spawn(async move {
if let Err(e) = tokio::io::copy_bidirectional(&mut local_stream, &mut stream).await {
eprintln!("Port forward error: {}", e);
}
});
Ok(())
}
fn selector_to_string(selector: &std::collections::BTreeMap<String, String>) -> String {
selector
.iter()
.map(|(k, v)| format!("{}={}", k, v))
.collect::<Vec<_>>()
.join(",")
}