hanzo-api 1.1.12

Http Api for Hanzo AI platform
use async_channel::Sender;
use serde::Deserialize;
use hanzo_messages::schemas::custom_prompt::CustomPrompt;
use warp::Filter;

use crate::node_commands::NodeCommand;

use super::api_v2_router::with_sender;

pub fn prompt_routes(
    node_commands_sender: Sender<NodeCommand>,
) -> impl Filter<Extract = impl warp::Reply, Error = warp::Rejection> + Clone {
    let add_custom_prompt_route = warp::path("add_custom_prompt")
        .and(warp::post())
        .and(with_sender(node_commands_sender.clone()))
        .and(warp::header::<String>("authorization"))
        .and(warp::body::json())
        .and_then(add_custom_prompt_handler);

    let delete_custom_prompt_route = warp::path("delete_custom_prompt")
        .and(warp::post())
        .and(with_sender(node_commands_sender.clone()))
        .and(warp::header::<String>("authorization"))
        .and(warp::body::json())
        .and_then(delete_custom_prompt_handler);

    let get_all_custom_prompts_route = warp::path("get_all_custom_prompts")
        .and(warp::get())
        .and(with_sender(node_commands_sender.clone()))
        .and(warp::header::<String>("authorization"))
        .and_then(get_all_custom_prompts_handler);

    let get_custom_prompt_route = warp::path("get_custom_prompt")
        .and(warp::get())
        .and(with_sender(node_commands_sender.clone()))
        .and(warp::header::<String>("authorization"))
        .and(warp::query::<GetCustomPromptRequest>())
        .and_then(get_custom_prompt_handler);

    let search_custom_prompts_route = warp::path("search_custom_prompts")
        .and(warp::get())
        .and(with_sender(node_commands_sender.clone()))
        .and(warp::header::<String>("authorization"))
        .and(warp::query::<SearchCustomPromptsRequest>())
        .and_then(search_custom_prompts_handler);

    let update_custom_prompt_route = warp::path("update_custom_prompt")
        .and(warp::post())
        .and(with_sender(node_commands_sender.clone()))
        .and(warp::header::<String>("authorization"))
        .and(warp::body::json())
        .and_then(update_custom_prompt_handler);

    add_custom_prompt_route
        .or(delete_custom_prompt_route)
        .or(get_all_custom_prompts_route)
        .or(get_custom_prompt_route)
        .or(search_custom_prompts_route)
        .or(update_custom_prompt_route)
}

#[derive(Deserialize)]
pub struct GetCustomPromptRequest {
    pub prompt_name: String,
}

#[derive(Deserialize)]
pub struct SearchCustomPromptsRequest {
    pub query: String,
}

#[utoipa::path(
    post,
    path = "/v2/add_custom_prompt",
    request_body = CustomPrompt,
    responses(
        (status = 200, description = "Successfully added custom prompt", body = CustomPrompt),
        (status = 500, description = "Internal server error", body = APIError)
    )
)]
pub async fn add_custom_prompt_handler(
    sender: Sender<NodeCommand>,
    authorization: String,
    payload: CustomPrompt,
) -> Result<impl warp::Reply, warp::Rejection> {
    let bearer = authorization.strip_prefix("Bearer ").unwrap_or("").to_string();
    let (res_sender, res_receiver) = async_channel::bounded(1);
    sender
        .send(NodeCommand::V2ApiAddCustomPrompt {
            bearer,
            prompt: payload,
            res: res_sender,
        })
        .await
        .map_err(|_| warp::reject::reject())?;

    let result = res_receiver.recv().await.map_err(|_| warp::reject::reject())?;

    match result {
        Ok(response) => Ok(warp::reply::json(&response)),
        Err(error) => Err(warp::reject::custom(error)),
    }
}

#[utoipa::path(
    post,
    path = "/v2/delete_custom_prompt",
    request_body = GetCustomPromptRequest,
    responses(
        (status = 200, description = "Successfully deleted custom prompt", body = CustomPrompt),
        (status = 500, description = "Internal server error", body = APIError)
    )
)]
pub async fn delete_custom_prompt_handler(
    sender: Sender<NodeCommand>,
    authorization: String,
    payload: GetCustomPromptRequest,
) -> Result<impl warp::Reply, warp::Rejection> {
    let bearer = authorization.strip_prefix("Bearer ").unwrap_or("").to_string();
    let (res_sender, res_receiver) = async_channel::bounded(1);
    sender
        .send(NodeCommand::V2ApiDeleteCustomPrompt {
            bearer,
            prompt_name: payload.prompt_name,
            res: res_sender,
        })
        .await
        .map_err(|_| warp::reject::reject())?;

    let result = res_receiver.recv().await.map_err(|_| warp::reject::reject())?;

    match result {
        Ok(response) => Ok(warp::reply::json(&response)),
        Err(error) => Err(warp::reject::custom(error)),
    }
}

#[utoipa::path(
    get,
    path = "/v2/get_all_custom_prompts",
    responses(
        (status = 200, description = "Successfully retrieved all custom prompts", body = Vec<CustomPrompt>),
        (status = 500, description = "Internal server error", body = APIError)
    )
)]
pub async fn get_all_custom_prompts_handler(
    sender: Sender<NodeCommand>,
    authorization: String,
) -> Result<impl warp::Reply, warp::Rejection> {
    let bearer = authorization.strip_prefix("Bearer ").unwrap_or("").to_string();
    let (res_sender, res_receiver) = async_channel::bounded(1);
    sender
        .send(NodeCommand::V2ApiGetAllCustomPrompts {
            bearer,
            res: res_sender,
        })
        .await
        .map_err(|_| warp::reject::reject())?;

    let result = res_receiver.recv().await.map_err(|_| warp::reject::reject())?;

    match result {
        Ok(response) => Ok(warp::reply::json(&response)),
        Err(error) => Err(warp::reject::custom(error)),
    }
}

#[utoipa::path(
    get,
    path = "/v2/get_custom_prompt",
    params(
        ("prompt_name" = String, Query, description = "Name of the custom prompt to retrieve")
    ),
    responses(
        (status = 200, description = "Successfully retrieved custom prompt", body = CustomPrompt),
        (status = 500, description = "Internal server error", body = APIError)
    )
)]
pub async fn get_custom_prompt_handler(
    sender: Sender<NodeCommand>,
    authorization: String,
    query: GetCustomPromptRequest,
) -> Result<impl warp::Reply, warp::Rejection> {
    let bearer = authorization.strip_prefix("Bearer ").unwrap_or("").to_string();
    let (res_sender, res_receiver) = async_channel::bounded(1);
    sender
        .send(NodeCommand::V2ApiGetCustomPrompt {
            bearer,
            prompt_name: query.prompt_name,
            res: res_sender,
        })
        .await
        .map_err(|_| warp::reject::reject())?;

    let result = res_receiver.recv().await.map_err(|_| warp::reject::reject())?;

    match result {
        Ok(response) => Ok(warp::reply::json(&response)),
        Err(error) => Err(warp::reject::custom(error)),
    }
}

#[utoipa::path(
    get,
    path = "/v2/search_custom_prompts",
    params(
        ("query" = String, Query, description = "Search query for custom prompts")
    ),
    responses(
        (status = 200, description = "Successfully searched custom prompts", body = Vec<CustomPrompt>),
        (status = 500, description = "Internal server error", body = APIError)
    )
)]
pub async fn search_custom_prompts_handler(
    sender: Sender<NodeCommand>,
    authorization: String,
    query: SearchCustomPromptsRequest,
) -> Result<impl warp::Reply, warp::Rejection> {
    let bearer = authorization.strip_prefix("Bearer ").unwrap_or("").to_string();
    let (res_sender, res_receiver) = async_channel::bounded(1);
    sender
        .send(NodeCommand::V2ApiSearchCustomPrompts {
            bearer,
            query: query.query,
            res: res_sender,
        })
        .await
        .map_err(|_| warp::reject::reject())?;

    let result = res_receiver.recv().await.map_err(|_| warp::reject::reject())?;

    match result {
        Ok(response) => Ok(warp::reply::json(&response)),
        Err(error) => Err(warp::reject::custom(error)),
    }
}

#[utoipa::path(
    post,
    path = "/v2/update_custom_prompt",
    request_body = CustomPrompt,
    responses(
        (status = 200, description = "Successfully updated custom prompt", body = CustomPrompt),
        (status = 500, description = "Internal server error", body = APIError)
    )
)]
pub async fn update_custom_prompt_handler(
    sender: Sender<NodeCommand>,
    authorization: String,
    payload: CustomPrompt,
) -> Result<impl warp::Reply, warp::Rejection> {
    let bearer = authorization.strip_prefix("Bearer ").unwrap_or("").to_string();
    let (res_sender, res_receiver) = async_channel::bounded(1);
    sender
        .send(NodeCommand::V2ApiUpdateCustomPrompt {
            bearer,
            prompt: payload,
            res: res_sender,
        })
        .await
        .map_err(|_| warp::reject::reject())?;

    let result = res_receiver.recv().await.map_err(|_| warp::reject::reject())?;

    match result {
        Ok(response) => Ok(warp::reply::json(&response)),
        Err(error) => Err(warp::reject::custom(error)),
    }
}