use crate::embeddings::EmbeddingsBuilder;
use crate::engine::{build_backend, EngineBackend, TokenizerAdapter};
use crate::generation::{GenerationBuilder, GenerationConfig};
use crate::model::{ModelArtifacts, ModelManager};
use crate::rerank::RerankBuilder;
use crate::types::{ClientConfig, Result};
use log::warn;
pub struct Client {
engine: EngineBackend,
tokenizer: TokenizerAdapter,
#[allow(dead_code)]
artifacts: ModelArtifacts,
}
impl Client {
pub(crate) fn create(model: &str, config: ClientConfig) -> Result<Self> {
let manager = ModelManager::new(config.clone());
let artifacts = manager.prepare(model)?;
let engine = build_backend(&artifacts.info, &artifacts.model_dir, &config.device)?;
if let Err(err) = manager.validate_model_files(&artifacts.model_dir) {
warn!("Skipping model file validation: {err}");
}
Ok(Self {
engine,
tokenizer: artifacts.tokenizer.clone(),
artifacts,
})
}
pub fn embeddings<I, S>(&self, input: I) -> EmbeddingsBuilder<'_>
where
I: IntoIterator<Item = S>,
S: AsRef<str>,
{
let inputs = input.into_iter().map(|s| s.as_ref().to_string()).collect();
EmbeddingsBuilder {
engine: &self.engine,
tokenizer: &self.tokenizer,
inputs,
graph_inputs: Vec::new(),
}
}
pub fn rerank<I, S>(&self, query: &str, documents: I) -> RerankBuilder<'_>
where
I: IntoIterator<Item = S>,
S: AsRef<str>,
{
let docs = documents
.into_iter()
.map(|s| s.as_ref().to_string())
.collect();
RerankBuilder {
engine: &self.engine,
tokenizer: &self.tokenizer,
query: query.to_string(),
documents: docs,
top_n: None,
return_documents: false,
}
}
pub fn generate(&self, prompt: &str) -> GenerationBuilder<'_> {
GenerationBuilder {
engine: &self.engine,
tokenizer: &self.tokenizer,
prompt: prompt.to_string(),
config: GenerationConfig::default(),
}
}
}
impl Client {
pub fn cleanup(&mut self) {
log::debug!("Client GPU cleanup requested");
for attempt in 1..=3 {
log::debug!("Client cleanup attempt {}/3", attempt);
std::thread::sleep(std::time::Duration::from_millis(50));
}
log::debug!("Client cleanup completed");
}
}
impl Drop for Client {
fn drop(&mut self) {
self.cleanup();
}
}
#[cfg(not(feature = "tokio"))]
impl Client {
pub fn new(model: &str) -> Result<Self> {
Self::with_config(model, ClientConfig::default())
}
pub fn with_config(model: &str, config: ClientConfig) -> Result<Self> {
Self::create(model, config)
}
}
#[cfg(feature = "tokio")]
impl Client {
pub async fn new(model: &str) -> Result<Self> {
Self::with_config(model, ClientConfig::default()).await
}
pub async fn with_config(model: &str, config: ClientConfig) -> Result<Self> {
tokio::task::block_in_place(|| Self::create(model, config))
}
}