1mod anthropic;
11mod huggingface;
12mod ollama;
13mod openai;
14mod util;
15
16use std::str::FromStr;
17
18use futures::{stream::once, StreamExt};
19use llmvm_protocol::{
20 async_trait, error::ProtocolErrorType, Backend, BackendGenerationRequest,
21 BackendGenerationResponse, ConfigExampleSnippet, ModelDescription, NotificationStream,
22 ProtocolError,
23};
24use reqwest::StatusCode;
25use serde::Deserialize;
26use strum_macros::{Display, EnumString};
27use thiserror::Error;
28use util::{get_api_key, get_openai_api_key};
29
30pub type Result<T> = std::result::Result<T, OutsourceError>;
31
32#[derive(Debug, Error)]
34pub enum OutsourceError {
35 #[error("provider for model not found, assumed provider name is '{0}'")]
36 ProviderNotFound(String),
37 #[error("api key not defined")]
38 APIKeyNotDefined,
39 #[error("could not parse api host as url")]
40 HostURLParse,
41 #[error("http request error: {0}")]
42 HttpRequestError(#[from] reqwest::Error),
43 #[error("bad http status code: {status} body: {body}")]
44 BadHttpStatusCode { status: StatusCode, body: String },
45 #[error("json serialization error: {0}")]
46 Serialization(#[from] serde_json::Error),
47 #[error("no text in response")]
48 NoTextInResponse,
49 #[error("failed to parse model name")]
50 ModelDescriptionParse,
51 #[error("model parameters should be object")]
52 ModelParamsNotObject,
53}
54
55#[derive(Display, EnumString)]
56#[strum(ascii_case_insensitive)]
57enum Provider {
58 #[strum(serialize = "openai-text")]
59 OpenAIText,
60 #[strum(serialize = "openai-chat")]
61 OpenAIChat,
62 #[strum(serialize = "huggingface-text")]
63 HuggingFaceText,
64 #[strum(serialize = "ollama-text")]
65 OllamaText,
66 #[strum(serialize = "anthropic-chat")]
67 AnthropicChat,
68}
69
70impl Into<ProtocolError> for OutsourceError {
71 fn into(self) -> ProtocolError {
72 let error_type = match &self {
73 OutsourceError::ProviderNotFound(_) => ProtocolErrorType::BadRequest,
74 OutsourceError::APIKeyNotDefined => ProtocolErrorType::BadRequest,
75 OutsourceError::HostURLParse => ProtocolErrorType::BadRequest,
76 OutsourceError::HttpRequestError(_) => ProtocolErrorType::Internal,
77 OutsourceError::BadHttpStatusCode { .. } => ProtocolErrorType::Internal,
78 OutsourceError::Serialization(_) => ProtocolErrorType::Internal,
79 OutsourceError::NoTextInResponse => ProtocolErrorType::Internal,
80 OutsourceError::ModelDescriptionParse => ProtocolErrorType::BadRequest,
81 OutsourceError::ModelParamsNotObject => ProtocolErrorType::BadRequest,
82 };
83 ProtocolError {
84 error_type,
85 error: Box::new(self),
86 }
87 }
88}
89
90#[derive(Deserialize)]
92pub struct OutsourceConfig {
93 pub openai_api_key: Option<String>,
94 pub huggingface_api_key: Option<String>,
95 pub ollama_endpoint: Option<String>,
96 pub openai_endpoint: Option<String>,
97 pub anthropic_api_key: Option<String>,
98}
99
100impl ConfigExampleSnippet for OutsourceConfig {
101 fn config_example_snippet() -> String {
102 r#"# API key for OpenAI
103# openai_api_key = ""
104
105# API key for Hugging Face
106# huggingface_api_key = ""
107
108# API key for Anthropic
109# anthropic_api_key = ""
110
111# Endpoint for ollama (defaults to http://127.0.0.1:11434/api/generate)
112# ollama_endpoint = ""
113
114# Endpoint for OpenAI, only specify if using a custom OpenAI-compatible
115# server (i.e. fastchat)
116# openai_endpoint = """#
117 .into()
118 }
119}
120
121pub struct OutsourceBackend {
123 config: OutsourceConfig,
124}
125
126impl OutsourceBackend {
127 pub fn new(config: OutsourceConfig) -> Self {
128 Self { config }
129 }
130
131 fn get_model_description_and_provider(
132 request: &BackendGenerationRequest,
133 ) -> Result<(ModelDescription, Provider)> {
134 let model_description = ModelDescription::from_str(&request.model)
135 .map_err(|_| OutsourceError::ModelDescriptionParse)?;
136 let provider = Provider::try_from(model_description.provider.as_str()).map_err(|_| {
137 OutsourceError::ProviderNotFound(model_description.provider.to_string())
138 })?;
139 Ok((model_description, provider))
140 }
141}
142
143#[async_trait]
144impl Backend for OutsourceBackend {
145 async fn generate(
146 &self,
147 request: BackendGenerationRequest,
148 ) -> std::result::Result<BackendGenerationResponse, ProtocolError> {
149 async {
150 let (model_description, provider) = Self::get_model_description_and_provider(&request)?;
151 match provider {
152 Provider::OpenAIText | Provider::OpenAIChat => {
153 let api_key = get_openai_api_key(
154 self.config.openai_api_key.as_deref(),
155 self.config.openai_endpoint.is_some(),
156 &model_description,
157 )?;
158
159 openai::generate(
160 request,
161 model_description,
162 self.config.openai_endpoint.as_deref(),
163 api_key,
164 )
165 .await
166 }
167 Provider::HuggingFaceText => {
168 huggingface::generate(
169 request,
170 model_description,
171 get_api_key(self.config.huggingface_api_key.as_deref())?,
172 )
173 .await
174 }
175 Provider::OllamaText => {
176 ollama::generate(
177 request,
178 model_description,
179 self.config.ollama_endpoint.as_ref(),
180 )
181 .await
182 }
183 Provider::AnthropicChat => {
184 anthropic::generate(
185 request,
186 model_description,
187 get_api_key(self.config.anthropic_api_key.as_deref())?,
188 )
189 .await
190 }
191 }
192 }
193 .await
194 .map_err(|e| e.into())
195 }
196
197 async fn generate_stream(
198 &self,
199 request: BackendGenerationRequest,
200 ) -> std::result::Result<NotificationStream<BackendGenerationResponse>, ProtocolError> {
201 async {
202 let (model_description, provider) = Self::get_model_description_and_provider(&request)?;
203 match provider {
204 Provider::OpenAIText | Provider::OpenAIChat => {
205 let api_key = get_openai_api_key(
206 self.config.openai_api_key.as_deref(),
207 self.config.openai_endpoint.is_some(),
208 &model_description,
209 )?;
210
211 openai::generate_stream(
212 request,
213 model_description,
214 self.config.openai_endpoint.as_deref(),
215 api_key,
216 )
217 .await
218 }
219 Provider::HuggingFaceText => {
220 let api_key =
221 get_api_key(self.config.huggingface_api_key.as_deref())?.to_string();
222 Ok(once(async move {
223 huggingface::generate(request, model_description, &api_key)
224 .await
225 .map_err(|e| e.into())
226 })
227 .boxed())
228 }
229 Provider::OllamaText => {
230 ollama::generate_stream(
231 request,
232 model_description,
233 self.config.ollama_endpoint.as_ref(),
234 )
235 .await
236 }
237 Provider::AnthropicChat => {
238 anthropic::generate_stream(
239 request,
240 model_description,
241 get_api_key(self.config.anthropic_api_key.as_deref())?,
242 )
243 .await
244 }
245 }
246 }
247 .await
248 .map_err(|e| e.into())
249 }
250}