use crate::config::ForgeConfig;
use crate::error::Result;
use crate::transport;
use crate::types::{
ChatCompletion, ChatCompletionChunk, ChatCompletionRequest, EmbeddingRequest,
EmbeddingResponse, Message, ModelList,
};
use futures::Stream;
use reqwest::Method;
use std::pin::Pin;
#[derive(Debug)]
pub struct ForgeClient {
inner: AsyncForgeClient,
runtime: tokio::runtime::Runtime,
}
impl ForgeClient {
pub fn new() -> Self {
Self::with_config(ForgeConfig::from_env())
}
pub fn with_config(config: ForgeConfig) -> Self {
let inner = AsyncForgeClient::with_config(config);
let runtime = tokio::runtime::Runtime::new().expect("Failed to create Tokio runtime");
Self { inner, runtime }
}
pub fn builder() -> ForgeClientBuilder {
ForgeClientBuilder::default()
}
pub fn model(&self) -> &str {
self.inner.model()
}
pub fn base_url(&self) -> &str {
self.inner.base_url()
}
pub fn complete(&self, messages: Vec<Message>) -> Result<ChatCompletion> {
self.runtime.block_on(self.inner.complete(messages))
}
pub fn complete_with_model(
&self,
model: &str,
messages: Vec<Message>,
) -> Result<ChatCompletion> {
self.runtime
.block_on(self.inner.complete_with_model(model, messages))
}
pub fn chat_completions(&self, request: ChatCompletionRequest) -> Result<ChatCompletion> {
self.runtime.block_on(self.inner.chat_completions(request))
}
pub fn list_models(&self) -> Result<ModelList> {
self.runtime.block_on(self.inner.list_models())
}
pub fn embed(&self, text: impl Into<String>) -> Result<EmbeddingResponse> {
self.runtime.block_on(self.inner.embed(text))
}
pub fn embed_batch(&self, texts: Vec<String>) -> Result<EmbeddingResponse> {
self.runtime.block_on(self.inner.embed_batch(texts))
}
pub fn embeddings(&self, request: EmbeddingRequest) -> Result<EmbeddingResponse> {
self.runtime.block_on(self.inner.embeddings(request))
}
}
impl Default for ForgeClient {
fn default() -> Self {
Self::new()
}
}
#[derive(Debug, Clone)]
pub struct AsyncForgeClient {
config: ForgeConfig,
http: reqwest::Client,
}
impl AsyncForgeClient {
pub fn new() -> Self {
Self::with_config(ForgeConfig::from_env())
}
pub fn with_config(config: ForgeConfig) -> Self {
let mut builder = reqwest::Client::builder().timeout(config.timeout);
if let Ok(path) = std::env::var("LITEFORGE_EXTRA_CA_FILE") {
match std::fs::read(&path) {
Ok(pem) => match reqwest::Certificate::from_pem_bundle(&pem) {
Ok(certs) => {
for c in certs {
builder = builder.add_root_certificate(c);
}
}
Err(e) => {
tracing::warn!(
"LITEFORGE_EXTRA_CA_FILE={} could not be parsed as a PEM bundle: {}",
path,
e
);
}
},
Err(e) => {
tracing::warn!(
"LITEFORGE_EXTRA_CA_FILE={} could not be read: {}",
path,
e
);
}
}
}
let http = builder.build().expect("Failed to build HTTP client");
Self { config, http }
}
pub fn model(&self) -> &str {
&self.config.default_model
}
pub fn base_url(&self) -> &str {
&self.config.base_url
}
pub async fn complete(&self, messages: Vec<Message>) -> Result<ChatCompletion> {
self.complete_with_model(&self.config.default_model, messages)
.await
}
pub async fn complete_with_model(
&self,
model: &str,
messages: Vec<Message>,
) -> Result<ChatCompletion> {
let request = ChatCompletionRequest::new(model, messages);
self.chat_completions(request).await
}
pub async fn chat_completions(&self, request: ChatCompletionRequest) -> Result<ChatCompletion> {
transport::request_with_body(
&self.http,
&self.config,
Method::POST,
"/chat/completions",
&request,
)
.await
}
pub async fn chat_completions_stream(
&self,
request: ChatCompletionRequest,
) -> Result<Pin<Box<dyn Stream<Item = Result<ChatCompletionChunk>> + Send>>> {
let mut request = request;
request.stream = Some(true);
transport::request_stream(&self.http, &self.config, "/chat/completions", &request).await
}
pub async fn complete_stream(
&self,
messages: Vec<Message>,
) -> Result<Pin<Box<dyn Stream<Item = Result<ChatCompletionChunk>> + Send>>> {
let request = ChatCompletionRequest::new(&self.config.default_model, messages);
self.chat_completions_stream(request).await
}
pub async fn list_models(&self) -> Result<ModelList> {
transport::request_no_body(&self.http, &self.config, Method::GET, "/models").await
}
pub async fn embed(&self, text: impl Into<String>) -> Result<EmbeddingResponse> {
let request = EmbeddingRequest::new(&self.config.default_model, text);
self.embeddings(request).await
}
pub async fn embed_batch(&self, texts: Vec<String>) -> Result<EmbeddingResponse> {
let request = EmbeddingRequest::batch(&self.config.default_model, texts);
self.embeddings(request).await
}
pub async fn embeddings(&self, request: EmbeddingRequest) -> Result<EmbeddingResponse> {
transport::request_with_body(
&self.http,
&self.config,
Method::POST,
"/embeddings",
&request,
)
.await
}
pub async fn post<T, R>(&self, path: &str, body: &T) -> Result<R>
where
T: serde::Serialize,
R: serde::de::DeserializeOwned,
{
let path = if path.starts_with('/') {
path.to_string()
} else {
format!("/{}", path)
};
transport::request_with_body(&self.http, &self.config, Method::POST, &path, body).await
}
}
impl Default for AsyncForgeClient {
fn default() -> Self {
Self::new()
}
}
#[derive(Debug, Default)]
pub struct ForgeClientBuilder {
config_builder: crate::config::ForgeConfigBuilder,
}
impl ForgeClientBuilder {
pub fn api_key(mut self, api_key: impl Into<String>) -> Self {
self.config_builder = self.config_builder.api_key(api_key);
self
}
pub fn default_model(mut self, model: impl Into<String>) -> Self {
self.config_builder = self.config_builder.default_model(model);
self
}
pub fn base_url(mut self, url: impl Into<String>) -> Self {
self.config_builder = self.config_builder.base_url(url);
self
}
pub fn timeout_secs(mut self, secs: u64) -> Self {
self.config_builder = self.config_builder.timeout_secs(secs);
self
}
pub fn build(self) -> ForgeClient {
ForgeClient::with_config(self.config_builder.build())
}
pub fn build_async(self) -> AsyncForgeClient {
AsyncForgeClient::with_config(self.config_builder.build())
}
}