aha 0.2.5

aha model inference library, now supports Qwen(2.5VL/3/3VL/3.5/ASR/3Embedding/3Reranker), MiniCPM4, VoxCPM/1.5, DeepSeek-OCR/2, Hunyuan-OCR, PaddleOCR-VL/1.5, RMBG2.0, GLM(ASR-Nano-2512/OCR), Fun-ASR-Nano-2512, LFM(2/2.5/2VL/2.5VL)
Documentation
use rocket::{http::Status, post, serde::json::Json};
use serde_json::Value;

use crate::{
    params::rerank::{RerankRequest, RerankResponse, RerankResult},
    server::api::MODEL,
};

fn validate_rerank_input(query: &str, documents: &[String]) -> anyhow::Result<()> {
    if query.trim().is_empty() {
        return Err(anyhow::anyhow!("rerank query cannot be empty"));
    }
    if documents.is_empty() {
        return Err(anyhow::anyhow!("rerank documents cannot be empty"));
    }
    if documents.iter().any(|doc| doc.trim().is_empty()) {
        return Err(anyhow::anyhow!(
            "rerank documents cannot contain empty strings"
        ));
    }
    Ok(())
}

#[post("/rerank", data = "<req>")]
pub(crate) async fn rerank(req: Json<RerankRequest>) -> (Status, Json<Value>) {
    let req = req.into_inner();
    if let Err(e) = validate_rerank_input(&req.query, &req.documents) {
        return (
            Status::BadRequest,
            Json(serde_json::json!({ "error": e.to_string() })),
        );
    }

    let model_ref = match MODEL.get().cloned() {
        Some(v) => v,
        None => {
            return (
                Status::ServiceUnavailable,
                Json(serde_json::json!({ "error": "model not init" })),
            );
        }
    };

    let mut guard = model_ref.write().await;
    let scores = match guard.instance.rerank(&req.query, &req.documents) {
        Ok(v) => v,
        Err(e) => {
            return (
                Status::BadRequest,
                Json(serde_json::json!({ "error": e.to_string() })),
            );
        }
    };

    let mut results = scores
        .into_iter()
        .enumerate()
        .map(|(index, relevance_score)| RerankResult {
            index,
            relevance_score,
            document: req.documents[index].clone(),
        })
        .collect::<Vec<_>>();
    results.sort_by(|a, b| b.relevance_score.total_cmp(&a.relevance_score));
    if let Some(top_n) = req.top_n {
        results.truncate(top_n.min(results.len()));
    }

    let response = RerankResponse {
        object: "list".to_string(),
        model: guard.which_model.as_string(),
        results,
    };
    (Status::Ok, Json(serde_json::to_value(response).unwrap()))
}