llm 1.3.8

A Rust library unifying multiple LLM backends.
Documentation
use axum::{http::StatusCode, Json};
use uuid::Uuid;

use super::helpers::{bad_request, internal_error, parse_model, transform_response, ApiResult};
use crate::api::types::{ChainStepRequest, ChatRequest, ChatResponse, Choice, Message};
use crate::api::ServerState;
use crate::chain::{MultiChainStep, MultiChainStepBuilder, MultiChainStepMode, MultiPromptChain};

const DEFAULT_TEMPERATURE: f32 = 0.7;
const DEFAULT_MAX_TOKENS: u32 = 1000;

pub async fn handle_chain_request(
    state: ServerState,
    req: ChatRequest,
) -> ApiResult<Json<ChatResponse>> {
    let last_step_id = resolve_last_step_id(&req)?;
    let mut provider_ids = Vec::new();
    let mut chain = MultiPromptChain::new(&state.llms);

    if let Some(model) = &req.model {
        chain = add_initial_step(chain, &mut provider_ids, &req, model)?;
    }

    let steps = build_steps(req.steps, &mut provider_ids)?;
    chain = chain.chain(steps);

    let chain_result = chain
        .run()
        .await
        .map_err(|e| internal_error(e.to_string()))?;

    let final_response = chain_result.get(&last_step_id).ok_or_else(|| {
        (
            StatusCode::INTERNAL_SERVER_ERROR,
            format!("No response found for step {last_step_id}"),
        )
    })?;

    Ok(Json(build_response(
        provider_ids.join(","),
        final_response.to_string(),
    )))
}

fn resolve_last_step_id(req: &ChatRequest) -> ApiResult<String> {
    if let Some(last_step) = req.steps.last() {
        return Ok(last_step.id.clone());
    }
    if req.model.is_some() {
        return Ok("initial".to_string());
    }
    Err(bad_request("No steps provided"))
}

fn add_initial_step<'a>(
    chain: MultiPromptChain<'a>,
    provider_ids: &mut Vec<String>,
    req: &ChatRequest,
    model: &str,
) -> ApiResult<MultiPromptChain<'a>> {
    let (provider_id, _) = parse_model(model)?;
    provider_ids.push(provider_id.clone());

    let prompt = last_message(req)
        .ok_or_else(|| bad_request("Initial model requires at least one message"))?;

    let transform = req.response_transform.clone().unwrap_or_default();
    let step = MultiChainStepBuilder::new(MultiChainStepMode::Chat)
        .provider_id(provider_id)
        .id("initial")
        .template(prompt)
        .max_tokens(req.max_tokens.unwrap_or(DEFAULT_MAX_TOKENS))
        .temperature(req.temperature.unwrap_or(DEFAULT_TEMPERATURE))
        .response_transform(move |resp| transform_response(resp, &transform))
        .build()
        .map_err(|e| bad_request(e.to_string()))?;

    Ok(chain.step(step))
}

fn build_steps(
    steps: Vec<ChainStepRequest>,
    provider_ids: &mut Vec<String>,
) -> ApiResult<Vec<MultiChainStep>> {
    steps
        .into_iter()
        .map(|step| build_step(step, provider_ids))
        .collect()
}

fn build_step(step: ChainStepRequest, provider_ids: &mut Vec<String>) -> ApiResult<MultiChainStep> {
    provider_ids.push(step.provider_id.clone());
    let transform = step.response_transform.unwrap_or_default();

    MultiChainStepBuilder::new(MultiChainStepMode::Chat)
        .provider_id(step.provider_id)
        .id(step.id)
        .template(step.template)
        .temperature(step.temperature.unwrap_or(DEFAULT_TEMPERATURE))
        .max_tokens(step.max_tokens.unwrap_or(DEFAULT_MAX_TOKENS))
        .response_transform(move |resp| transform_response(resp, &transform))
        .build()
        .map_err(|e| bad_request(e.to_string()))
}

fn last_message(req: &ChatRequest) -> Option<String> {
    req.messages
        .as_ref()
        .and_then(|messages| messages.last())
        .map(|message| message.content.clone())
}

fn build_response(model: String, content: String) -> ChatResponse {
    ChatResponse {
        id: format!("chatcmpl-{}", Uuid::new_v4()),
        object: "chat.completion".to_string(),
        created: std::time::SystemTime::now()
            .duration_since(std::time::UNIX_EPOCH)
            .unwrap_or_default()
            .as_secs(),
        model,
        choices: vec![Choice {
            index: 0,
            message: Message {
                role: "assistant".to_string(),
                content,
            },
            finish_reason: "stop".to_string(),
        }],
    }
}