Skip to main content

cake_core/cake/api/
mod.rs

1mod image;
2pub mod text;
3mod ui;
4
5use std::sync::Arc;
6
7use actix_web::web;
8use actix_web::App;
9use actix_web::HttpResponse;
10use actix_web::HttpServer;
11use serde::Serialize;
12use tokio::sync::RwLock;
13
14use crate::models::{ImageGenerator, TextGenerator};
15
16use image::*;
17use text::*;
18
19use super::Master;
20
21#[derive(Serialize)]
22struct ModelObject {
23    id: String,
24    object: String,
25    owned_by: String,
26}
27
28#[derive(Serialize)]
29struct ModelsResponse {
30    object: String,
31    data: Vec<ModelObject>,
32}
33
34pub async fn list_models<TG, IG>(
35    _state: web::Data<Arc<RwLock<Master<TG, IG>>>>,
36) -> HttpResponse
37where
38    TG: TextGenerator + Send + Sync + 'static,
39    IG: ImageGenerator + Send + Sync + 'static,
40{
41    let response = ModelsResponse {
42        object: "list".to_string(),
43        data: vec![ModelObject {
44            id: TG::MODEL_NAME.to_string(),
45            object: "model".to_string(),
46            owned_by: "cake".to_string(),
47        }],
48    };
49    HttpResponse::Ok().json(response)
50}
51
52async fn not_found() -> actix_web::Result<HttpResponse> {
53    Ok(HttpResponse::NotFound().body("nope"))
54}
55
56pub(crate) async fn start<TG, IG>(master: Master<TG, IG>) -> anyhow::Result<()>
57where
58    TG: TextGenerator + Send + Sync + 'static,
59    IG: ImageGenerator + Send + Sync + 'static,
60{
61    let address = master.ctx.args.api.as_ref().unwrap().to_string();
62
63    log::info!("starting api on http://{} ...", &address);
64
65    let state = Arc::new(RwLock::new(master));
66
67    HttpServer::new(
68        move || {
69            App::new()
70                .app_data(web::Data::new(state.clone()))
71                .route(
72                    "/v1/chat/completions",
73                    web::post().to(generate_text::<TG, IG>),
74                )
75                .route(
76                    "/api/v1/chat/completions",
77                    web::post().to(generate_text::<TG, IG>),
78                )
79                .route("/v1/models", web::get().to(list_models::<TG, IG>))
80                .route("/api/v1/image", web::post().to(generate_image::<TG, IG>))
81                .route("/api/v1/topology", web::get().to(ui::topology::<TG, IG>))
82                .route("/", web::get().to(ui::index::<TG, IG>))
83                .default_service(web::route().to(not_found))
84        }, //.wrap(actix_web::middleware::Logger::default()))
85    )
86    .bind(&address)
87    .map_err(|e| anyhow!(e))?
88    .run()
89    .await
90    .map_err(|e| anyhow!(e))
91}