orign 0.2.3

A globally distributed container orchestrator
Documentation
// src/handlers/chat_new.rs

use axum::extract::State;
use serde_json::{json, Value};
use sqlx::any::AnyRow;
use uuid::Uuid;
use chrono::Utc;
use reqwest::Client;
use std::path::Path;
use k8s_openapi::api::core::v1::Service;
use kube::{Api, Client as KubeClient};
use tokio::net::TcpListener;
use axum::{
    extract::Json,
    http::StatusCode,
    response::IntoResponse,
};
use sqlx::Row;
use crate::models::{ChatCompletionRequest, ChatCompletionResponse, Choice, Message};
use crate::db::DbPool;

pub async fn chat_completions(
    State(pool): State<DbPool>,
    Json(payload): Json<ChatCompletionRequest>,
) -> impl IntoResponse {
    // Look up the model in the database
    let model_result = sqlx::query("SELECT * FROM model_instances WHERE model_name = ?")
        .bind(&payload.model)
        .fetch_one(&pool)
        .await;

    let model_row = match model_result {
        Ok(row) => row,
        Err(_) => {
            // Model not found
            let error_response = json!({
                "error": {
                    "message": format!("The model '{}' does not exist.", payload.model),
                    "type": "invalid_request_error",
                    "param": "model",
                    "code": null
                }
            });
            return (StatusCode::BAD_REQUEST, Json(error_response));
        }
    };

    // Route the request to the model's backend
    match call_model_backend(&model_row, &payload).await {
        Ok(response_message) => {
            // Record the prompt and response
            let id = Uuid::new_v4().to_string();
            let created_at = Utc::now().timestamp() as u64;

            // Serialize messages
            let prompt_text = serde_json::to_string(&payload.messages).unwrap_or_default();
            let response_text = response_message.content.clone();

            sqlx::query(
                "INSERT INTO chat_histories (id, owner_id, prompt, response, created_at) VALUES (?, ?, ?, ?, ?)",
            )
            .bind(&id)
            .bind("default_owner")
            .bind(&prompt_text)
            .bind(&response_text)
            .bind(created_at as i64)
            .execute(&pool)
            .await
            .ok();

            // Construct the response
            let response = ChatCompletionResponse {
                id: format!("chatcmpl-{}", id),
                object: "chat.completion".to_string(),
                created: created_at,
                model: payload.model.clone(),
                choices: vec![Choice {
                    index: 0,
                    message: response_message,
                    finish_reason: Some("stop".to_string()),
                }],
                usage: None, // Fill in if you have token usage data
            };
            let response_value = serde_json::to_value(response)
            .unwrap_or_else(|_| json!({"error": "Failed to serialize response"}));

            (StatusCode::OK, Json(response_value))
        }
        Err(error_message) => {
            // Handle backend error
            let error_response = json!({
                "error": {
                    "message": error_message,
                    "type": "internal_server_error",
                    "param": null,
                    "code": null
                }
            });
            (StatusCode::INTERNAL_SERVER_ERROR, Json(error_response))
        }
    }
}

async fn call_model_backend(
    model_row: &AnyRow,
    payload: &ChatCompletionRequest,
) -> Result<Message, String> {
    // Extract backend URL or connection details from the model_row
    let backend_url: String = model_row.try_get("backend_url").unwrap_or_default();
    let resource_name: String = model_row.try_get("resource_name").unwrap_or_default();
    let resource_namespace: String = model_row.try_get("resource_namespace").unwrap_or_else(|_| "default".to_string());

    if backend_url.is_empty() {
        return Err("Backend URL not specified for the model.".to_string());
    }

    // Determine if running inside the cluster
    let in_cluster = Path::new("/var/run/secrets/kubernetes.io/serviceaccount/token").exists();

    // Prepare the backend URL based on the environment
    let backend_endpoint = if in_cluster {
        // Inside the cluster, use the backend_url directly
        format!("http://{}/chat/completions", backend_url)
    } else {
        // Outside the cluster, set up port forwarding
        match setup_port_forwarding(&resource_name, &resource_namespace).await {
            Ok(port) => format!("http://localhost:{}/chat/completions", port),
            Err(e) => return Err(format!("Failed to set up port forwarding: {}", e)),
        }
    };

    // Prepare the request payload for the backend
    let backend_payload = json!({
        "messages": payload.messages,
        "temperature": payload.temperature,
        "top_p": payload.top_p,
        "n": payload.n,
        "stop": payload.stop,
        "max_tokens": payload.max_tokens,
        "presence_penalty": payload.presence_penalty,
        "frequency_penalty": payload.frequency_penalty,
        // Add other parameters if needed
    });

    // Send the request to the model's backend
    let client = Client::new();
    let response = client
        .post(&backend_endpoint)
        .json(&backend_payload)
        .send()
        .await
        .map_err(|e| format!("Failed to send request to model backend: {}", e))?;

    if !response.status().is_success() {
        return Err(format!(
            "Model backend returned error status: {}",
            response.status()
        ));
    }

    // Parse the response
    let response_json: Value = response
        .json()
        .await
        .map_err(|e| format!("Failed to parse backend response: {}", e))?;

    // Extract the assistant's message
    let content = response_json["choices"][0]["message"]["content"]
        .as_str()
        .unwrap_or_default()
        .to_string();

    Ok(Message {
        role: "assistant".to_string(),
        content,
    })
}

async fn setup_port_forwarding(
    service_name: &str,
    namespace: &str,
) -> Result<u16, String> {
    // Create a Kubernetes client
    let kube_client = KubeClient::try_default()
        .await
        .map_err(|e| format!("Failed to create Kubernetes client: {}", e))?;

    // Get the service
    let services: Api<Service> = Api::namespaced(kube_client.clone(), namespace);
    let service = services
        .get(service_name)
        .await
        .map_err(|e| format!("Failed to get service: {}", e))?;

    // Get the target port
    let target_port = service
        .spec
        .as_ref()
        .and_then(|spec| spec.ports.as_ref())
        .and_then(|ports| ports.get(0))
        .and_then(|port| Some(port.port))
        .ok_or_else(|| "Service has no ports defined".to_string())?;

    // Set up a local TCP listener
    let listener = TcpListener::bind("127.0.0.1:0")
        .await
        .map_err(|e| format!("Failed to bind local port: {}", e))?;
    let local_port = listener.local_addr().unwrap().port();

    // **Create owned variables**
    let service_name_owned = service_name.to_string();
    let namespace_owned = namespace.to_string();
    let kube_client_owned = kube_client.clone();
    let listener_owned = listener;
    let target_port_owned = target_port as u16;

    // Port-forward to the service
    tokio::spawn(async move {
        if let Err(e) = port_forward(
            kube_client_owned,
            namespace_owned,
            service_name_owned,
            listener_owned,
            target_port_owned,
        )
        .await
        {
            eprintln!("Port forwarding error: {}", e);
        }
    });

    Ok(local_port)
}

async fn port_forward(
    kube_client: KubeClient,
    namespace: String,
    service_name: String,
    listener: TcpListener,
    target_port: u16,
) -> Result<(), String> {

    // Get the pods associated with the service
    let pods_api: Api<k8s_openapi::api::core::v1::Pod> =
        Api::namespaced(kube_client.clone(), &namespace);

    // List pods with the same labels as the service selector
    let service_api: Api<Service> = Api::namespaced(kube_client.clone(), &namespace);
    let service = service_api
        .get(&service_name)
        .await
        .map_err(|e| format!("Failed to get service: {}", e))?;
    let selector = service
        .spec
        .as_ref()
        .and_then(|spec| spec.selector.clone())
        .ok_or_else(|| "Service has no selector".to_string())?;

    let pods = pods_api
        .list(&kube::api::ListParams::default().labels(&selector_to_string(&selector)))
        .await
        .map_err(|e| format!("Failed to list pods: {}", e))?;

    let pod_name = pods
        .items
        .get(0)
        .and_then(|pod| pod.metadata.name.clone())
        .ok_or_else(|| "No pods found for service".to_string())?;

    // Set up port forwarding using the `portforward` method
    let mut pf = pods_api
        .portforward(&pod_name, &[target_port])
        .await
        .map_err(|e| format!("Failed to create portforwarder: {}", e))?;

    let mut stream = pf
        .take_stream(target_port)
        .ok_or_else(|| "Failed to get portforward stream".to_string())?;

    // Accept the incoming local connection
    let mut local_stream = listener.accept().await.unwrap().0;

    // Forward data between local connection and pod
    tokio::spawn(async move {
        if let Err(e) = tokio::io::copy_bidirectional(&mut local_stream, &mut stream).await {
            eprintln!("Port forward error: {}", e);
        }
    });
    Ok(())
}



fn selector_to_string(selector: &std::collections::BTreeMap<String, String>) -> String {
    selector
        .iter()
        .map(|(k, v)| format!("{}={}", k, v))
        .collect::<Vec<_>>()
        .join(",")
}