llm 1.3.8

A Rust library unifying multiple LLM backends.
Documentation
use axum::{extract::State, http::HeaderMap, Json};
use uuid::Uuid;

use super::chain::handle_chain_request;
use super::helpers::{bad_request, internal_error, parse_model, unauthorized, ApiResult};
use crate::api::types::{ChatRequest, ChatResponse, Choice, Message};
use crate::api::ServerState;
use crate::chat::{ChatMessage, ChatRole, MessageType};

pub async fn handle_chat(
    State(state): State<ServerState>,
    headers: HeaderMap,
    Json(req): Json<ChatRequest>,
) -> ApiResult<Json<ChatResponse>> {
    validate_auth(&state, &headers)?;
    if !req.steps.is_empty() {
        return handle_chain_request(state, req).await;
    }

    let messages = build_messages(req.messages);
    let model = req.model.ok_or_else(|| bad_request("Model is required"))?;
    let (provider_id, model_name) = parse_model(&model)?;
    let provider = state
        .llms
        .get(&provider_id)
        .ok_or_else(|| bad_request(format!("Unknown provider: {provider_id}")))?;

    let response = provider
        .chat(&messages)
        .await
        .map_err(|e| internal_error(e.to_string()))?;

    Ok(Json(build_response(
        model_name,
        response.text().unwrap_or_default(),
    )))
}

fn validate_auth(state: &ServerState, headers: &HeaderMap) -> ApiResult<()> {
    let Some(key) = &state.auth_key else {
        return Ok(());
    };

    let auth_header = headers
        .get("Authorization")
        .ok_or_else(|| unauthorized("Missing authorization"))?;
    let auth_str = auth_header
        .to_str()
        .map_err(|_| unauthorized("Invalid authorization header"))?;

    if !auth_str.starts_with("Bearer ") || &auth_str[7..] != key {
        return Err(unauthorized("Invalid API key"));
    }

    Ok(())
}

fn build_messages(messages: Option<Vec<Message>>) -> Vec<ChatMessage> {
    messages
        .unwrap_or_default()
        .into_iter()
        .map(|msg| ChatMessage {
            role: parse_role(&msg.role),
            message_type: MessageType::Text,
            content: msg.content,
        })
        .collect()
}

fn parse_role(role: &str) -> ChatRole {
    match role {
        "user" => ChatRole::User,
        "assistant" => ChatRole::Assistant,
        _ => ChatRole::User,
    }
}

fn build_response(model: String, content: String) -> ChatResponse {
    ChatResponse {
        id: format!("chatcmpl-{}", Uuid::new_v4()),
        object: "chat.completion".to_string(),
        created: std::time::SystemTime::now()
            .duration_since(std::time::UNIX_EPOCH)
            .unwrap_or_default()
            .as_secs(),
        model,
        choices: vec![Choice {
            index: 0,
            message: Message {
                role: "assistant".to_string(),
                content,
            },
            finish_reason: "stop".to_string(),
        }],
    }
}