pub mod completion;
pub mod embedding;
pub mod error;
pub mod image;
pub mod types;
use crate::core::types::{GenerateOptions, GenerateResult, Prompt, StreamPart};
use crate::openai::OpenAIModel;
use async_trait::async_trait;
use futures::stream::BoxStream;
use reqwest::Client;
pub struct OpenAICompatibleModel {
pub inner: OpenAIModel,
}
impl OpenAICompatibleModel {
#[must_use]
pub fn new(api_key: String, base_url: String) -> Self {
Self {
inner: OpenAIModel {
api_key,
base_url,
client: Client::new(),
},
}
}
}
#[async_trait]
impl crate::core::LanguageModel for OpenAICompatibleModel {
#[tracing::instrument(skip(self, prompt), fields(model = options.model_id))]
async fn generate(
&self,
prompt: Prompt,
options: GenerateOptions,
) -> crate::core::Result<GenerateResult> {
self.inner.generate(prompt, options).await
}
async fn generate_stream(
&self,
prompt: Prompt,
options: GenerateOptions,
) -> crate::core::Result<BoxStream<'static, StreamPart>> {
self.inner.generate_stream(prompt, options).await
}
}
#[derive(Debug, Clone)]
pub struct OpenAICompatibleProviderSettings {
pub base_url: String,
pub name: String,
pub api_key: Option<String>,
pub headers: Option<std::collections::HashMap<String, String>>,
}
pub struct OpenAICompatibleProvider {
settings: OpenAICompatibleProviderSettings,
}
impl OpenAICompatibleProvider {
#[must_use]
pub fn chat(&self, _model_id: &str) -> OpenAICompatibleModel {
let api_key = self.settings.api_key.clone().unwrap_or_default();
OpenAICompatibleModel::new(api_key, self.settings.base_url.clone())
}
#[must_use]
pub fn language_model(&self, model_id: &str) -> OpenAICompatibleModel {
self.chat(model_id)
}
#[must_use]
pub fn embedding(&self, _model_id: &str) -> embedding::OpenAICompatibleEmbeddingModel {
let api_key = self.settings.api_key.clone().unwrap_or_default();
embedding::OpenAICompatibleEmbeddingModel::new(api_key, self.settings.base_url.clone())
}
#[must_use]
pub fn image(&self, _model_id: &str) -> image::OpenAICompatibleImageModel {
let api_key = self.settings.api_key.clone().unwrap_or_default();
image::OpenAICompatibleImageModel::new(api_key, self.settings.base_url.clone())
}
#[must_use]
pub fn completion(&self, _model_id: &str) -> completion::OpenAICompatibleCompletionModel {
let api_key = self.settings.api_key.clone().unwrap_or_default();
completion::OpenAICompatibleCompletionModel::new(api_key, self.settings.base_url.clone())
}
}
#[must_use]
pub fn create_openai_compatible(
settings: OpenAICompatibleProviderSettings,
) -> OpenAICompatibleProvider {
OpenAICompatibleProvider { settings }
}