llmvm_backend_util/
lib.rs

1use std::error::Error;
2use std::process::exit;
3use std::sync::Arc;
4
5use clap::{Args, Subcommand};
6use llmvm_protocol::{
7    http::server::{HttpServer, HttpServerConfig},
8    service::BackendService,
9    stdio::server::{StdioServer, StdioServerConfig},
10    Backend, BackendGenerationRequest,
11};
12use tracing::error;
13
14#[derive(Args, Clone)]
15pub struct GenerateModelArgs {
16    #[arg(long)]
17    model: String,
18
19    #[arg(long)]
20    prompt: String,
21
22    #[arg(long)]
23    max_tokens: u64,
24}
25
26#[derive(Args, Clone)]
27pub struct HttpServerArgs {
28    #[arg(short, long)]
29    port: Option<u16>,
30}
31
32impl Into<BackendGenerationRequest> for GenerateModelArgs {
33    fn into(self) -> BackendGenerationRequest {
34        BackendGenerationRequest {
35            model: self.model,
36            prompt: self.prompt,
37            max_tokens: self.max_tokens,
38            ..Default::default()
39        }
40    }
41}
42
43#[derive(Subcommand)]
44pub enum BackendCommand {
45    Generate(GenerateModelArgs),
46    Http(HttpServerArgs),
47}
48
49pub async fn run_backend<B: Backend + 'static>(
50    command: Option<BackendCommand>,
51    backend: Arc<B>,
52    stdio_config: Option<StdioServerConfig>,
53    http_config: Option<HttpServerConfig>,
54) -> Result<(), Box<dyn Error + Send + Sync>> {
55    // TODO: require a command to be specified, create --stdio switch
56    // TODO: show error if stdio or http server features are not enabled
57    match command {
58        Some(command) => match command {
59            BackendCommand::Generate(args) => {
60                let result = backend
61                    .generate(BackendGenerationRequest {
62                        model: args.model,
63                        prompt: args.prompt,
64                        max_tokens: args.max_tokens,
65                        thread_messages: None,
66                        model_parameters: None,
67                    })
68                    .await;
69                match result {
70                    Ok(response) => {
71                        println!("{}", response.response);
72                    }
73                    Err(e) => {
74                        error!("Failed to process request: {}", e);
75                        exit(1);
76                    }
77                };
78            }
79            BackendCommand::Http(args) => {
80                let mut config = http_config.unwrap_or_default();
81                if let Some(port) = args.port {
82                    config.port = port;
83                }
84                HttpServer::new(BackendService::new(backend), config)
85                    .run()
86                    .await?;
87            }
88        },
89        None => {
90            let config = stdio_config.unwrap_or_default();
91            StdioServer::new(BackendService::new(backend), config)
92                .run()
93                .await?;
94        }
95    };
96    Ok(())
97}