llmvm_outsource_lib/
lib.rs

1//! [llmvm](https://github.com/djandries/llmvm) backend which forwards generation requests
2//! to known hosted providers.
3//!
4//! Currently supported providers:
5//! - OpenAI (text and chat interface)
6//! - Hugging Face (text interface)
7//! - Anthropic (chat interface)
8//! - Ollama (text interface)
9
10mod anthropic;
11mod huggingface;
12mod ollama;
13mod openai;
14mod util;
15
16use std::str::FromStr;
17
18use futures::{stream::once, StreamExt};
19use llmvm_protocol::{
20    async_trait, error::ProtocolErrorType, Backend, BackendGenerationRequest,
21    BackendGenerationResponse, ConfigExampleSnippet, ModelDescription, NotificationStream,
22    ProtocolError,
23};
24use reqwest::StatusCode;
25use serde::Deserialize;
26use strum_macros::{Display, EnumString};
27use thiserror::Error;
28use util::{get_api_key, get_openai_api_key};
29
30pub type Result<T> = std::result::Result<T, OutsourceError>;
31
32/// Error enum containing all possible backend errors.
33#[derive(Debug, Error)]
34pub enum OutsourceError {
35    #[error("provider for model not found, assumed provider name is '{0}'")]
36    ProviderNotFound(String),
37    #[error("api key not defined")]
38    APIKeyNotDefined,
39    #[error("could not parse api host as url")]
40    HostURLParse,
41    #[error("http request error: {0}")]
42    HttpRequestError(#[from] reqwest::Error),
43    #[error("bad http status code: {status} body: {body}")]
44    BadHttpStatusCode { status: StatusCode, body: String },
45    #[error("json serialization error: {0}")]
46    Serialization(#[from] serde_json::Error),
47    #[error("no text in response")]
48    NoTextInResponse,
49    #[error("failed to parse model name")]
50    ModelDescriptionParse,
51    #[error("model parameters should be object")]
52    ModelParamsNotObject,
53}
54
55#[derive(Display, EnumString)]
56#[strum(ascii_case_insensitive)]
57enum Provider {
58    #[strum(serialize = "openai-text")]
59    OpenAIText,
60    #[strum(serialize = "openai-chat")]
61    OpenAIChat,
62    #[strum(serialize = "huggingface-text")]
63    HuggingFaceText,
64    #[strum(serialize = "ollama-text")]
65    OllamaText,
66    #[strum(serialize = "anthropic-chat")]
67    AnthropicChat,
68}
69
70impl Into<ProtocolError> for OutsourceError {
71    fn into(self) -> ProtocolError {
72        let error_type = match &self {
73            OutsourceError::ProviderNotFound(_) => ProtocolErrorType::BadRequest,
74            OutsourceError::APIKeyNotDefined => ProtocolErrorType::BadRequest,
75            OutsourceError::HostURLParse => ProtocolErrorType::BadRequest,
76            OutsourceError::HttpRequestError(_) => ProtocolErrorType::Internal,
77            OutsourceError::BadHttpStatusCode { .. } => ProtocolErrorType::Internal,
78            OutsourceError::Serialization(_) => ProtocolErrorType::Internal,
79            OutsourceError::NoTextInResponse => ProtocolErrorType::Internal,
80            OutsourceError::ModelDescriptionParse => ProtocolErrorType::BadRequest,
81            OutsourceError::ModelParamsNotObject => ProtocolErrorType::BadRequest,
82        };
83        ProtocolError {
84            error_type,
85            error: Box::new(self),
86        }
87    }
88}
89
90/// Configuration structure for the backend.
91#[derive(Deserialize)]
92pub struct OutsourceConfig {
93    pub openai_api_key: Option<String>,
94    pub huggingface_api_key: Option<String>,
95    pub ollama_endpoint: Option<String>,
96    pub openai_endpoint: Option<String>,
97    pub anthropic_api_key: Option<String>,
98}
99
100impl ConfigExampleSnippet for OutsourceConfig {
101    fn config_example_snippet() -> String {
102        r#"# API key for OpenAI
103# openai_api_key = ""
104
105# API key for Hugging Face
106# huggingface_api_key = ""
107
108# API key for Anthropic
109# anthropic_api_key = ""
110
111# Endpoint for ollama (defaults to http://127.0.0.1:11434/api/generate)
112# ollama_endpoint = ""
113
114# Endpoint for OpenAI, only specify if using a custom OpenAI-compatible
115# server (i.e. fastchat)
116# openai_endpoint = """#
117            .into()
118    }
119}
120
121/// An llmvm backend that forwards requests to known hosted providers.
122pub struct OutsourceBackend {
123    config: OutsourceConfig,
124}
125
126impl OutsourceBackend {
127    pub fn new(config: OutsourceConfig) -> Self {
128        Self { config }
129    }
130
131    fn get_model_description_and_provider(
132        request: &BackendGenerationRequest,
133    ) -> Result<(ModelDescription, Provider)> {
134        let model_description = ModelDescription::from_str(&request.model)
135            .map_err(|_| OutsourceError::ModelDescriptionParse)?;
136        let provider = Provider::try_from(model_description.provider.as_str()).map_err(|_| {
137            OutsourceError::ProviderNotFound(model_description.provider.to_string())
138        })?;
139        Ok((model_description, provider))
140    }
141}
142
143#[async_trait]
144impl Backend for OutsourceBackend {
145    async fn generate(
146        &self,
147        request: BackendGenerationRequest,
148    ) -> std::result::Result<BackendGenerationResponse, ProtocolError> {
149        async {
150            let (model_description, provider) = Self::get_model_description_and_provider(&request)?;
151            match provider {
152                Provider::OpenAIText | Provider::OpenAIChat => {
153                    let api_key = get_openai_api_key(
154                        self.config.openai_api_key.as_deref(),
155                        self.config.openai_endpoint.is_some(),
156                        &model_description,
157                    )?;
158
159                    openai::generate(
160                        request,
161                        model_description,
162                        self.config.openai_endpoint.as_deref(),
163                        api_key,
164                    )
165                    .await
166                }
167                Provider::HuggingFaceText => {
168                    huggingface::generate(
169                        request,
170                        model_description,
171                        get_api_key(self.config.huggingface_api_key.as_deref())?,
172                    )
173                    .await
174                }
175                Provider::OllamaText => {
176                    ollama::generate(
177                        request,
178                        model_description,
179                        self.config.ollama_endpoint.as_ref(),
180                    )
181                    .await
182                }
183                Provider::AnthropicChat => {
184                    anthropic::generate(
185                        request,
186                        model_description,
187                        get_api_key(self.config.anthropic_api_key.as_deref())?,
188                    )
189                    .await
190                }
191            }
192        }
193        .await
194        .map_err(|e| e.into())
195    }
196
197    async fn generate_stream(
198        &self,
199        request: BackendGenerationRequest,
200    ) -> std::result::Result<NotificationStream<BackendGenerationResponse>, ProtocolError> {
201        async {
202            let (model_description, provider) = Self::get_model_description_and_provider(&request)?;
203            match provider {
204                Provider::OpenAIText | Provider::OpenAIChat => {
205                    let api_key = get_openai_api_key(
206                        self.config.openai_api_key.as_deref(),
207                        self.config.openai_endpoint.is_some(),
208                        &model_description,
209                    )?;
210
211                    openai::generate_stream(
212                        request,
213                        model_description,
214                        self.config.openai_endpoint.as_deref(),
215                        api_key,
216                    )
217                    .await
218                }
219                Provider::HuggingFaceText => {
220                    let api_key =
221                        get_api_key(self.config.huggingface_api_key.as_deref())?.to_string();
222                    Ok(once(async move {
223                        huggingface::generate(request, model_description, &api_key)
224                            .await
225                            .map_err(|e| e.into())
226                    })
227                    .boxed())
228                }
229                Provider::OllamaText => {
230                    ollama::generate_stream(
231                        request,
232                        model_description,
233                        self.config.ollama_endpoint.as_ref(),
234                    )
235                    .await
236                }
237                Provider::AnthropicChat => {
238                    anthropic::generate_stream(
239                        request,
240                        model_description,
241                        get_api_key(self.config.anthropic_api_key.as_deref())?,
242                    )
243                    .await
244                }
245            }
246        }
247        .await
248        .map_err(|e| e.into())
249    }
250}