use std::sync::Arc;
use std::sync::atomic::{AtomicBool, Ordering};
use std::thread;
use rig::client::CompletionClient;
use rig::completion::{
CompletionError, CompletionModel, CompletionRequest, CompletionResponse, Usage,
};
use rig::streaming::StreamingCompletionResponse;
use tokio::sync::{mpsc, oneshot};
use tokio_stream::wrappers::UnboundedReceiverStream;
use crate::error::LoadError;
use crate::request::prepare_request;
use crate::types::{
CheckpointParams, FitParams, InferenceCommand, InferenceParams, InferenceRequest,
KvCacheParams, RawResponse, ReloadRequest, ResponseChannel, SamplingParams, StreamChunk,
};
use crate::worker::{WorkerInit, inference_worker};
const DEFAULT_N_CTX: u32 = 4096;
const COMMAND_CHANNEL_CAPACITY: usize = 8;
#[must_use]
pub struct ClientBuilder {
model_path: String,
#[cfg(feature = "mtmd")]
mmproj_path: Option<String>,
n_ctx: u32,
sampling: SamplingParams,
fit: FitParams,
kv_cache: KvCacheParams,
checkpoint: CheckpointParams,
}
impl ClientBuilder {
fn new(model_path: impl Into<String>) -> Self {
Self {
model_path: model_path.into(),
#[cfg(feature = "mtmd")]
mmproj_path: None,
n_ctx: DEFAULT_N_CTX,
sampling: SamplingParams::default(),
fit: FitParams::default(),
kv_cache: KvCacheParams::default(),
checkpoint: CheckpointParams::default(),
}
}
pub fn n_ctx(mut self, n_ctx: u32) -> Self {
self.n_ctx = n_ctx;
self
}
pub fn sampling(mut self, sampling: SamplingParams) -> Self {
self.sampling = sampling;
self
}
pub fn fit(mut self, fit: FitParams) -> Self {
self.fit = fit;
self
}
pub fn kv_cache(mut self, kv_cache: KvCacheParams) -> Self {
self.kv_cache = kv_cache;
self
}
pub fn checkpoints(mut self, checkpoint: CheckpointParams) -> Self {
self.checkpoint = checkpoint;
self
}
#[cfg(feature = "mtmd")]
pub fn mmproj(mut self, mmproj_path: impl Into<String>) -> Self {
self.mmproj_path = Some(mmproj_path.into());
self
}
pub fn build(self) -> Result<Client, LoadError> {
#[cfg(feature = "mtmd")]
let mmproj_path = self.mmproj_path;
#[cfg(not(feature = "mtmd"))]
let mmproj_path: Option<String> = None;
Client::spawn(
self.model_path,
mmproj_path,
self.n_ctx,
self.sampling,
self.fit,
self.kv_cache,
self.checkpoint,
)
}
}
pub struct Client {
request_tx: mpsc::Sender<InferenceCommand>,
cancel: Arc<AtomicBool>,
sampling_params: std::sync::RwLock<SamplingParams>,
worker_handle: Option<thread::JoinHandle<()>>,
}
impl Client {
pub fn builder(model_path: impl Into<String>) -> ClientBuilder {
ClientBuilder::new(model_path)
}
pub fn from_gguf(
model_path: impl Into<String>,
n_ctx: u32,
sampling_params: SamplingParams,
fit_params: FitParams,
kv_cache_params: KvCacheParams,
checkpoint_params: CheckpointParams,
) -> Result<Self, LoadError> {
Self::spawn(
model_path.into(),
None,
n_ctx,
sampling_params,
fit_params,
kv_cache_params,
checkpoint_params,
)
}
#[cfg(feature = "mtmd")]
pub fn from_gguf_with_mmproj(
model_path: impl Into<String>,
mmproj_path: impl Into<String>,
n_ctx: u32,
sampling_params: SamplingParams,
fit_params: FitParams,
kv_cache_params: KvCacheParams,
checkpoint_params: CheckpointParams,
) -> Result<Self, LoadError> {
Self::spawn(
model_path.into(),
Some(mmproj_path.into()),
n_ctx,
sampling_params,
fit_params,
kv_cache_params,
checkpoint_params,
)
}
fn spawn(
model_path: String,
mmproj_path: Option<String>,
n_ctx: u32,
sampling_params: SamplingParams,
fit_params: FitParams,
kv_cache_params: KvCacheParams,
checkpoint_params: CheckpointParams,
) -> Result<Self, LoadError> {
let (request_tx, mut request_rx) =
mpsc::channel::<InferenceCommand>(COMMAND_CHANNEL_CAPACITY);
let (init_tx, init_rx) = std::sync::mpsc::channel::<Result<(), LoadError>>();
let cancel = Arc::new(AtomicBool::new(false));
let worker_cancel = Arc::clone(&cancel);
let worker_handle = thread::spawn(move || {
let init = WorkerInit {
model_path: &model_path,
mmproj_path: mmproj_path.as_deref(),
n_ctx,
fit_params: &fit_params,
kv_cache_params: &kv_cache_params,
checkpoint_params,
cancel: worker_cancel,
};
inference_worker(init, init_tx, &mut request_rx);
});
init_rx
.recv()
.map_err(|_| LoadError::WorkerInitDisconnected)??;
Ok(Self {
request_tx,
cancel,
sampling_params: std::sync::RwLock::new(sampling_params),
worker_handle: Some(worker_handle),
})
}
#[allow(clippy::too_many_arguments)]
pub fn reload(
&self,
model_path: String,
mmproj_path: Option<String>,
n_ctx: u32,
sampling: SamplingParams,
fit_params: FitParams,
kv_cache_params: KvCacheParams,
checkpoint_params: CheckpointParams,
) -> Result<(), LoadError> {
let (result_tx, result_rx) = std::sync::mpsc::channel();
self.request_tx
.blocking_send(InferenceCommand::Reload(ReloadRequest {
model_path,
mmproj_path,
n_ctx,
fit_params,
kv_cache_params,
checkpoint_params,
result_tx,
}))
.map_err(|_| LoadError::WorkerNotRunning)?;
let result = result_rx
.recv()
.map_err(|_| LoadError::WorkerInitDisconnected)?;
if result.is_ok() {
let mut guard = self
.sampling_params
.write()
.unwrap_or_else(|p| p.into_inner());
*guard = sampling;
}
result
}
}
impl Drop for Client {
fn drop(&mut self) {
self.cancel.store(true, Ordering::Relaxed);
let _ = self.request_tx.try_send(InferenceCommand::Shutdown);
if let Some(worker_handle) = self.worker_handle.take() {
let _ = worker_handle.join();
}
}
}
impl CompletionClient for Client {
type CompletionModel = Model;
}
#[derive(Clone)]
pub struct Model {
request_tx: mpsc::Sender<InferenceCommand>,
sampling_params: SamplingParams,
#[allow(dead_code)]
model_id: String,
}
impl CompletionModel for Model {
type Response = RawResponse;
type StreamingResponse = StreamChunk;
type Client = Client;
fn make(client: &Client, model: impl Into<String>) -> Self {
let sampling_params = *client
.sampling_params
.read()
.unwrap_or_else(|p| p.into_inner());
Self {
request_tx: client.request_tx.clone(),
sampling_params,
model_id: model.into(),
}
}
async fn completion(
&self,
request: CompletionRequest,
) -> Result<CompletionResponse<Self::Response>, CompletionError> {
let prepared_request = prepare_request(&request).map_err(CompletionError::ProviderError)?;
let max_tokens = request.max_tokens.unwrap_or(512) as u32;
let temperature = request.temperature.unwrap_or(0.7) as f32;
let (response_tx, response_rx) = oneshot::channel();
self.request_tx
.send(InferenceCommand::Request(InferenceRequest {
params: InferenceParams {
prepared_request,
max_tokens,
temperature,
top_p: self.sampling_params.top_p,
top_k: self.sampling_params.top_k,
min_p: self.sampling_params.min_p,
presence_penalty: self.sampling_params.presence_penalty,
repetition_penalty: self.sampling_params.repetition_penalty,
},
response_channel: ResponseChannel::Completion(response_tx),
}))
.await
.map_err(|_| CompletionError::ProviderError("Inference thread shut down".into()))?;
let result = response_rx
.await
.map_err(|_| CompletionError::ProviderError("Response channel closed".into()))?
.map_err(CompletionError::ProviderError)?;
Ok(CompletionResponse {
choice: result.choice,
usage: Usage {
input_tokens: result.prompt_tokens,
output_tokens: result.completion_tokens,
total_tokens: result.prompt_tokens + result.completion_tokens,
cached_input_tokens: result.cached_input_tokens,
cache_creation_input_tokens: 0,
},
raw_response: RawResponse { text: result.text },
message_id: None,
})
}
async fn stream(
&self,
request: CompletionRequest,
) -> Result<StreamingCompletionResponse<Self::StreamingResponse>, CompletionError> {
let prepared_request = prepare_request(&request).map_err(CompletionError::ProviderError)?;
let max_tokens = request.max_tokens.unwrap_or(512) as u32;
let temperature = request.temperature.unwrap_or(0.7) as f32;
let (stream_tx, stream_rx) = mpsc::unbounded_channel();
self.request_tx
.send(InferenceCommand::Request(InferenceRequest {
params: InferenceParams {
prepared_request,
max_tokens,
temperature,
top_p: self.sampling_params.top_p,
top_k: self.sampling_params.top_k,
min_p: self.sampling_params.min_p,
presence_penalty: self.sampling_params.presence_penalty,
repetition_penalty: self.sampling_params.repetition_penalty,
},
response_channel: ResponseChannel::Streaming(stream_tx),
}))
.await
.map_err(|_| CompletionError::ProviderError("Inference thread shut down".into()))?;
Ok(StreamingCompletionResponse::stream(Box::pin(
UnboundedReceiverStream::new(stream_rx),
)))
}
}