use candle_transformers::models::quantized_llama::ModelWeights;
use std::time::Duration;
use tokenizers::Tokenizer;
use super::generate::{GenerationConfig, GenerationOutput, generate_tokens};
use super::template::ChatTemplate;
use crate::error::LlmError;
use crate::provider::Message;
pub(crate) const DEFAULT_INFERENCE_TIMEOUT_SECS: u64 = 120;
pub struct InferenceRequest {
pub messages: Vec<Message>,
pub reply: tokio::sync::oneshot::Sender<Result<GenerationOutput, LlmError>>,
}
pub(crate) struct WorkerConfig {
pub weights: ModelWeights,
pub tokenizer: std::sync::Arc<Tokenizer>,
pub eos_token_id: u32,
pub template: ChatTemplate,
pub generation_config: GenerationConfig,
pub device: candle_core::Device,
}
pub(crate) struct InferenceWorker {
pub tx: tokio::sync::mpsc::Sender<InferenceRequest>,
pub inference_timeout: Duration,
pub(super) _handle: Option<tokio::task::JoinHandle<()>>,
}
impl InferenceWorker {
pub fn spawn(
config: WorkerConfig,
channel_capacity: usize,
inference_timeout: Duration,
) -> Self {
let (tx, rx) = tokio::sync::mpsc::channel::<InferenceRequest>(channel_capacity);
let handle = tokio::task::spawn_blocking(move || {
worker_loop(config, rx);
});
Self {
tx,
inference_timeout,
_handle: Some(handle),
}
}
}
fn worker_loop(mut config: WorkerConfig, mut rx: tokio::sync::mpsc::Receiver<InferenceRequest>) {
while let Some(req) = rx.blocking_recv() {
let result = generate_sync(
&mut config.weights,
&config.tokenizer,
config.eos_token_id,
config.template,
&config.generation_config,
&config.device,
&req.messages,
);
let _ = req.reply.send(result);
}
tracing::debug!("candle inference worker exiting: all senders dropped");
}
fn generate_sync(
weights: &mut ModelWeights,
tokenizer: &Tokenizer,
eos_token_id: u32,
template: ChatTemplate,
generation_config: &GenerationConfig,
device: &candle_core::Device,
messages: &[Message],
) -> Result<GenerationOutput, LlmError> {
let prompt = template.format(messages);
let encoding = tokenizer
.encode(prompt.as_str(), false)
.map_err(|e| LlmError::Inference(format!("tokenizer encode failed: {e}")))?;
let input_tokens = encoding.get_ids();
let mut forward_fn =
|input: &candle_core::Tensor, pos: usize| -> Result<candle_core::Tensor, LlmError> {
weights.forward(input, pos).map_err(LlmError::Candle)
};
generate_tokens(
&mut forward_fn,
tokenizer,
input_tokens,
generation_config,
eos_token_id,
device,
)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn default_inference_timeout_is_nonzero() {
assert!(DEFAULT_INFERENCE_TIMEOUT_SECS > 0);
}
#[test]
fn inference_request_oneshot_roundtrip() {
use crate::provider::Message;
let (tx, mut rx) = tokio::sync::oneshot::channel::<Result<GenerationOutput, LlmError>>();
let req = InferenceRequest {
messages: vec![Message::from_legacy(crate::provider::Role::User, "hello")],
reply: tx,
};
let expected_text = "world";
let output = GenerationOutput {
text: expected_text.into(),
tokens_generated: 1,
};
req.reply
.send(Ok(output))
.expect("send must succeed when receiver is live");
let result = rx.try_recv().expect("reply must be immediately available");
assert!(result.is_ok());
assert_eq!(result.unwrap().text, expected_text);
}
#[tokio::test]
async fn dropped_reply_receiver_does_not_block_worker() {
let (req_tx, mut req_rx) = tokio::sync::mpsc::channel::<InferenceRequest>(1);
tokio::task::spawn_blocking(move || {
if let Some(req) = req_rx.blocking_recv() {
let result: Result<GenerationOutput, LlmError> = Ok(GenerationOutput {
text: "ok".into(),
tokens_generated: 1,
});
let _ = req.reply.send(result);
}
});
let (_reply_tx, reply_rx) = tokio::sync::oneshot::channel();
drop(reply_rx);
let req = InferenceRequest {
messages: vec![],
reply: _reply_tx,
};
req_tx
.send(req)
.await
.expect("channel must accept the request");
tokio::time::sleep(std::time::Duration::from_millis(50)).await;
}
#[test]
fn cloned_worker_has_no_handle() {
let (tx, _rx) = tokio::sync::mpsc::channel::<InferenceRequest>(1);
let worker = InferenceWorker {
tx: tx.clone(),
inference_timeout: Duration::from_secs(1),
_handle: None,
};
assert!(worker._handle.is_none());
}
}