1use crate::{
2 distributed,
3 paged_attention::block_hash::compute_block_hashes,
4 pipeline::{
5 llg::{constraint_from_llg_grammar, llg_grammar_from_constraint},
6 text_models_inputs_processor::PagedAttentionMeta,
7 CacheBackendMetadata, CacheInstruction,
8 },
9 prefix_cacher::PrefixCacheManagerV2,
10 response::CompletionChoice,
11 scheduler::{Scheduler, SchedulerOutput},
12 search::{self, rag::SearchPipeline},
13 sequence::{SeqStepType, StopReason},
14 tools, CompletionResponse, SchedulerConfig, DEBUG,
15};
16use interprocess::local_socket::{traits::Listener, ListenerOptions};
17use llguidance::ParserFactory;
18pub use logger::IntervalLogger;
19use mistralrs_quant::RingConfig;
20use rand::SeedableRng;
21use rand_isaac::Isaac64Rng;
22use serde::{Deserialize, Serialize};
23use std::{
24 collections::HashMap,
25 fmt,
26 io::{BufWriter, Write},
27 net::TcpListener,
28 ops::Deref,
29 str::FromStr,
30 sync::{
31 atomic::{AtomicBool, Ordering},
32 Arc, LazyLock,
33 },
34 time::{Duration, Instant, SystemTime, UNIX_EPOCH},
35};
36use tokio::{
37 select,
38 sync::{
39 mpsc::{error::TryRecvError, Receiver, Sender},
40 Mutex, Notify,
41 },
42 task::JoinHandle,
43};
44
45use crate::{
46 get_mut_arcmutex, handle_pipeline_forward_error,
47 pipeline::{ModelCategory, Pipeline},
48 request::Request,
49 response::{ChatCompletionResponse, Choice, ResponseMessage},
50 sequence::{SequenceRecognizer, SequenceState},
51 Constraint,
52};
53
54mod add_request;
55mod logger;
56mod search_request;
57
58pub enum EngineInstruction {
59 Terminate,
60}
61
62#[derive(Debug, Default, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
63#[serde(rename_all = "snake_case")]
64pub enum SearchEmbeddingModel {
66 #[default]
67 #[serde(rename = "embedding_gemma")]
68 EmbeddingGemma300M,
69}
70
71impl SearchEmbeddingModel {
72 pub fn hf_model_id(&self) -> &'static str {
73 match self {
74 Self::EmbeddingGemma300M => "google/embeddinggemma-300m",
75 }
76 }
77
78 pub fn variants() -> &'static [&'static str] {
79 &["embedding_gemma"]
80 }
81}
82
83impl fmt::Display for SearchEmbeddingModel {
84 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
85 match self {
86 Self::EmbeddingGemma300M => f.write_str("embedding_gemma"),
87 }
88 }
89}
90
91impl FromStr for SearchEmbeddingModel {
92 type Err = String;
93
94 fn from_str(s: &str) -> Result<Self, Self::Err> {
95 match s.trim().to_ascii_lowercase().as_str() {
96 "embedding_gemma" => Ok(Self::EmbeddingGemma300M),
97 other => Err(format!(
98 "Unknown search embedding model `{other}`. Supported values: {}",
99 Self::variants().join(", ")
100 )),
101 }
102 }
103}
104
105const SEED: u64 = 0;
106pub static TERMINATE_ALL_NEXT_STEP: AtomicBool = AtomicBool::new(false);
109
110static ENGINE_TERMINATE_FLAGS: LazyLock<
112 std::sync::Mutex<HashMap<std::thread::ThreadId, Arc<AtomicBool>>>,
113> = LazyLock::new(|| std::sync::Mutex::new(HashMap::new()));
114
115pub fn get_engine_terminate_flag() -> Arc<AtomicBool> {
117 let thread_id = std::thread::current().id();
118 let mut flags = ENGINE_TERMINATE_FLAGS.lock().unwrap();
119 flags
120 .entry(thread_id)
121 .or_insert_with(|| Arc::new(AtomicBool::new(false)))
122 .clone()
123}
124
125pub fn should_terminate_engine_sequences() -> bool {
127 if TERMINATE_ALL_NEXT_STEP.load(Ordering::SeqCst) {
129 return true;
130 }
131 let thread_id = std::thread::current().id();
133 if let Ok(flags) = ENGINE_TERMINATE_FLAGS.lock() {
134 if let Some(flag) = flags.get(&thread_id) {
135 return flag.load(Ordering::SeqCst);
136 }
137 }
138 false
139}
140
141pub fn reset_engine_terminate_flag() {
143 let thread_id = std::thread::current().id();
144 if let Ok(flags) = ENGINE_TERMINATE_FLAGS.lock() {
145 if let Some(flag) = flags.get(&thread_id) {
146 flag.store(false, Ordering::SeqCst);
147 }
148 }
149}
150
151pub static ENGINE_INSTRUCTIONS: LazyLock<
153 std::sync::Mutex<HashMap<usize, Option<EngineInstruction>>>,
154> = LazyLock::new(|| std::sync::Mutex::new(HashMap::new()));
155
156pub struct Engine {
157 tx: Sender<Request>,
158 rx: Arc<Mutex<Receiver<Request>>>,
159 pipeline: Arc<Mutex<dyn Pipeline>>,
160 search_pipeline: Arc<Mutex<Option<SearchPipeline>>>,
161 search_callback: Option<Arc<search::SearchCallback>>,
162 tool_callbacks: tools::ToolCallbacks,
163 tool_callbacks_with_tools: tools::ToolCallbacksWithTools,
164 scheduler: Arc<Mutex<dyn Scheduler>>,
165 id: Arc<Mutex<usize>>,
166 no_kv_cache: bool,
167 prefix_cacher: Arc<Mutex<PrefixCacheManagerV2>>,
168 is_debug: bool,
169 disable_eos_stop: bool,
170 throughput_logging_enabled: bool,
171 logger: Arc<IntervalLogger>,
172 handles: Arc<Mutex<Vec<JoinHandle<()>>>>,
173 pending_notify: Arc<Notify>,
174}
175
176impl Drop for Engine {
177 fn drop(&mut self) {
178 for handle in &*get_mut_arcmutex!(self.handles) {
179 handle.abort();
180 }
181 }
182}
183
184impl Engine {
185 #[allow(clippy::too_many_arguments)]
186 pub fn new(
187 tx: Sender<Request>,
188 rx: Receiver<Request>,
189 pipeline: Arc<Mutex<dyn Pipeline>>,
190 config: SchedulerConfig,
191 mut no_kv_cache: bool,
192 mut no_prefix_cache: bool,
193 prefix_cache_n: usize,
194 disable_eos_stop: bool,
195 throughput_logging_enabled: bool,
196 search_embedding_model: Option<SearchEmbeddingModel>,
197 search_callback: Option<Arc<search::SearchCallback>>,
198 tool_callbacks: tools::ToolCallbacks,
199 tool_callbacks_with_tools: tools::ToolCallbacksWithTools,
200 logger: Arc<IntervalLogger>,
201 ) -> anyhow::Result<Self> {
202 no_kv_cache |= get_mut_arcmutex!(pipeline).get_metadata().no_kv_cache;
203
204 no_prefix_cache = no_prefix_cache
205 || no_kv_cache
206 || get_mut_arcmutex!(pipeline).get_metadata().no_prefix_cache
207 || prefix_cache_n == 0;
208
209 let search_pipeline = match search_embedding_model {
210 Some(search_embedding_model) => Some(SearchPipeline::new(
211 search_embedding_model,
212 &get_mut_arcmutex!(pipeline).device(),
213 )?),
214 None => None,
215 };
216
217 let scheduler = config.into_scheduler();
218
219 get_mut_arcmutex!(scheduler).set_prefix_caching_enabled(!no_prefix_cache);
222
223 let has_paged_attention = get_mut_arcmutex!(scheduler).kv_cache_manager().is_some();
224
225 Ok(Self {
226 tx,
227 rx: Arc::new(Mutex::new(rx)),
228 pipeline,
229 search_pipeline: Arc::new(Mutex::new(search_pipeline)),
230 search_callback,
231 tool_callbacks,
232 tool_callbacks_with_tools,
233 scheduler: scheduler.clone(),
234 id: Arc::new(Mutex::new(0)),
235 no_kv_cache,
236 prefix_cacher: Arc::new(Mutex::new(PrefixCacheManagerV2::new(
237 prefix_cache_n,
238 no_prefix_cache,
239 has_paged_attention,
240 ))),
241 is_debug: DEBUG.load(Ordering::Relaxed),
242 disable_eos_stop,
243 throughput_logging_enabled,
244 logger,
245 handles: Arc::new(Mutex::new(Vec::new())),
246 pending_notify: Arc::new(Notify::new()),
247 })
248 }
249
250 #[allow(dead_code)]
252 pub fn max_sequence_length(&self) -> Option<usize> {
253 let pipeline = get_mut_arcmutex!(self.pipeline);
254 let category = pipeline.category();
255
256 if matches!(category, ModelCategory::Diffusion | ModelCategory::Speech) {
257 None
258 } else {
259 Some(pipeline.get_metadata().max_seq_len)
260 }
261 }
262
263 pub async fn run(self: Arc<Self>) {
264 if self.throughput_logging_enabled {
265 self.logger.enable_logging();
266 }
267
268 let rng = Arc::new(std::sync::Mutex::new(Isaac64Rng::seed_from_u64(SEED)));
269 let mut last_completion_ids: Vec<usize> = vec![];
270 'lp: loop {
271 let should_terminate = || {
272 matches!(
273 ENGINE_INSTRUCTIONS
274 .lock()
275 .expect("`ENGINE_INSTRUCTIONS` was poisoned")
276 .get(get_mut_arcmutex!(self.id).deref()),
277 Some(Some(EngineInstruction::Terminate))
278 )
279 };
280
281 if should_terminate() {
282 self.replicate_request_to_daemons(&Request::Terminate);
283 break 'lp;
284 }
285
286 let mut channel_disconnected = false;
287 loop {
288 let next_request = {
289 let mut rx = self.rx.lock().await;
290 rx.try_recv()
291 };
292
293 match next_request {
294 Ok(request) => {
295 self.replicate_request_to_daemons(&request);
296 if matches!(request, Request::Terminate) {
297 break 'lp;
298 }
299 self.clone().handle_request(request).await;
300 }
301 Err(TryRecvError::Empty) => break,
302 Err(TryRecvError::Disconnected) => {
303 channel_disconnected = true;
304 break;
305 }
306 }
307 }
308
309 if channel_disconnected {
310 break 'lp;
311 }
312
313 let (waiting_len, running_len) = {
314 let scheduler = get_mut_arcmutex!(self.scheduler);
315 (scheduler.waiting_len(), scheduler.running_len())
316 };
317 let scheduler_idle = waiting_len == 0 && running_len == 0;
318
319 if scheduler_idle {
320 if should_terminate() {
321 self.replicate_request_to_daemons(&Request::Terminate);
322 break 'lp;
323 }
324 enum WaitEvent {
325 Request(Option<Request>),
326 Wake,
327 }
328 let wait_for_request = async {
329 let mut rx = self.rx.lock().await;
330 rx.recv().await
331 };
332 tokio::pin!(wait_for_request);
333 let wait_for_wake = self.pending_notify.notified();
334 tokio::pin!(wait_for_wake);
335
336 let event = select! {
337 res = &mut wait_for_request => WaitEvent::Request(res),
338 _ = &mut wait_for_wake => WaitEvent::Wake,
339 };
340
341 match event {
342 WaitEvent::Request(Some(request)) => {
343 self.replicate_request_to_daemons(&request);
344 if matches!(request, Request::Terminate) {
345 break 'lp;
346 }
347 self.clone().handle_request(request).await;
348 continue;
349 }
350 WaitEvent::Request(None) => break 'lp,
351 WaitEvent::Wake => {
352 continue;
353 }
354 }
355 }
356
357 if TERMINATE_ALL_NEXT_STEP.load(Ordering::SeqCst) {
358 self.replicate_request_to_daemons(&Request::TerminateAllSeqsNextStep);
359 }
360
361 let run_start = Instant::now();
362 let mut scheduler = get_mut_arcmutex!(self.scheduler);
363 let scheduled = scheduler.schedule(&self.logger);
364
365 match scheduled {
366 SchedulerOutput::DefaultScheduler {
367 output: mut scheduled,
368 } => {
369 if !scheduled.completion.is_empty() {
370 let current_completion_ids: Vec<usize> =
371 scheduled.completion.iter().map(|seq| *seq.id()).collect();
372 let res = {
373 let mut pipeline = get_mut_arcmutex!(self.pipeline);
374 let pre_op = if !self.no_kv_cache
375 && last_completion_ids != current_completion_ids
376 {
377 CacheInstruction::In
378 } else {
379 CacheInstruction::Nothing
380 };
381 let post_op = if !self.no_kv_cache {
382 CacheInstruction::Out
383 } else {
384 CacheInstruction::Reset {
385 load_preallocated_cache: false,
386 reset_non_granular: false,
387 }
388 };
389
390 let return_raw_logits = scheduled.completion[0].return_raw_logits;
391 assert!(
392 scheduled
393 .completion
394 .iter()
395 .all(|seq| seq.return_raw_logits == return_raw_logits),
396 "All sequences must either return raw logits, or not."
397 );
398
399 pipeline
400 .step(
401 &mut scheduled.completion,
402 false,
403 return_raw_logits,
404 &mut *get_mut_arcmutex!(self.prefix_cacher),
405 self.disable_eos_stop,
406 rng.clone(),
407 CacheBackendMetadata::DefaultInstructions { pre_op, post_op },
408 )
409 .await
410 };
411
412 handle_pipeline_forward_error!(
413 "completion step",
414 res,
415 &mut scheduled.completion,
416 self.pipeline,
417 'lp,
418 self.prefix_cacher
419 );
420
421 self.logger.add_tokens_processed(scheduled.completion.len());
422
423 last_completion_ids = current_completion_ids;
424 }
425
426 if !scheduled.prompt.is_empty() {
427 let prompt_exec_time = {
428 let mut pipeline = get_mut_arcmutex!(self.pipeline);
429
430 let post_op = if !self.no_kv_cache {
432 CacheInstruction::Out
433 } else {
434 CacheInstruction::Reset {
435 load_preallocated_cache: false,
436 reset_non_granular: false,
437 }
438 };
439
440 let return_raw_logits = scheduled.prompt[0].return_raw_logits;
441 assert!(
442 scheduled
443 .prompt
444 .iter()
445 .all(|seq| seq.return_raw_logits == return_raw_logits),
446 "All sequences must either return raw logits, or not."
447 );
448
449 let pre_op = if scheduled.prompt[0].token_offset() != 0 {
452 CacheInstruction::In
453 } else {
454 CacheInstruction::Reset {
455 load_preallocated_cache: true,
456 reset_non_granular: false,
457 }
458 };
459
460 pipeline
461 .step(
462 &mut scheduled.prompt,
463 true,
464 return_raw_logits,
465 &mut *get_mut_arcmutex!(self.prefix_cacher),
466 self.disable_eos_stop,
467 rng.clone(),
468 CacheBackendMetadata::DefaultInstructions { pre_op, post_op },
469 )
470 .await
471 };
472
473 let prompt_exec_time = handle_pipeline_forward_error!(
474 "prompt step",
475 prompt_exec_time,
476 &mut scheduled.prompt,
477 self.pipeline,
478 'lp,
479 self.prefix_cacher
480 );
481
482 let total_processed_tokens: usize = scheduled
483 .prompt
484 .iter()
485 .map(|seq| seq.get_toks().len())
486 .sum();
487 self.logger.add_tokens_processed(total_processed_tokens);
488
489 for seq in scheduled.prompt.iter_mut() {
490 match seq.sequence_stepping_type() {
491 SeqStepType::OneShot => {
492 seq.set_state(SequenceState::Done(StopReason::GeneratedImage))
493 }
494 SeqStepType::PromptAndDecode => {
495 seq.set_state(SequenceState::RunningCompletion)
496 }
497 }
498 let now = SystemTime::now()
499 .duration_since(UNIX_EPOCH)
500 .expect("Time travel has occurred!")
501 .as_millis();
502 #[allow(clippy::cast_precision_loss)]
503 let prompt_tok_per_sec =
504 seq.len() as f32 / prompt_exec_time.as_secs_f32();
505 seq.prompt_tok_per_sec = prompt_tok_per_sec;
506 seq.prompt_timestamp = Some(now);
507 seq.total_prompt_time = Some(prompt_exec_time.as_millis());
508 seq.step_start_instant = None;
509 }
510 last_completion_ids = vec![];
511 }
512
513 if self.is_debug {
514 let ms_from_last_run = run_start.elapsed().as_secs_f64();
515 let total_len = scheduled.prompt.len() + scheduled.completion.len();
516 if total_len > 0 {
517 let prompt_lengths = scheduled
518 .prompt
519 .iter()
520 .map(|seq| seq.len().to_string())
521 .collect::<Vec<_>>()
522 .join(", ");
523
524 let completion_lengths = scheduled
525 .completion
526 .iter()
527 .map(|seq| seq.len().to_string())
528 .collect::<Vec<_>>()
529 .join(", ");
530
531 tracing::info!(
532 "Prompt[{}] Completion[{}] - {}ms",
533 prompt_lengths,
534 completion_lengths,
535 ms_from_last_run * 1000.,
536 );
537 }
538 }
539 }
540 SchedulerOutput::PagedAttention { mut output } => {
541 if !output.scheduled.is_empty() {
542 let is_prompt = get_mut_arcmutex!(output.scheduled[0]).is_prompt();
543
544 if is_prompt {
546 let now = SystemTime::now()
547 .duration_since(UNIX_EPOCH)
548 .expect("Time travel has occurred!")
549 .as_millis();
550 for seq in output.scheduled.iter() {
551 let mut seq_guard = get_mut_arcmutex!(seq);
552 seq_guard.prompt_timestamp = Some(now);
553 seq_guard.set_step_start_instant();
555 }
556 }
557
558 let mut guards = output
559 .scheduled
560 .iter_mut()
561 .map(|seq| seq.lock().unwrap())
562 .collect::<Vec<_>>();
563
564 let mut guards_mut =
565 guards.iter_mut().map(|seq| &mut **seq).collect::<Vec<_>>();
566
567 let res = {
568 let mut pipeline = get_mut_arcmutex!(self.pipeline);
569
570 let block_size = scheduler.block_size().unwrap();
571
572 if is_prompt && pipeline.cache().is_hybrid() {
575 let mut hybrid_cache = pipeline.cache().hybrid();
576 let mut prefix_cacher = get_mut_arcmutex!(self.prefix_cacher);
577 let kv_cache_manager = scheduler.kv_cache_manager().unwrap();
578
579 for seq in guards_mut.iter_mut() {
580 let cached_prefix_len = seq.prefix_cache_len();
581 if cached_prefix_len == 0 {
582 continue;
583 }
584
585 let mut fallback_to_full_prompt = false;
586
587 let slot_idx = match seq.recurrent_state_idx() {
588 Some(idx) => idx,
589 None => {
590 tracing::warn!("Sequence {} has paged prefix hit but no recurrent_state_idx; recomputing full prompt.", seq.id());
591 fallback_to_full_prompt = true;
592 0usize
594 }
595 };
596
597 if !fallback_to_full_prompt {
598 if cached_prefix_len % block_size != 0 {
599 tracing::warn!(
600 "Sequence {} has non-aligned paged prefix len {}; recomputing full prompt.",
601 seq.id(),
602 cached_prefix_len
603 );
604 fallback_to_full_prompt = true;
605 } else {
606 let num_prefix_blocks = cached_prefix_len / block_size;
607 let block_hashes = compute_block_hashes(
608 seq.get_toks(),
609 block_size,
610 seq.mm_features(),
611 &[],
612 );
613 if block_hashes.len() < num_prefix_blocks {
614 fallback_to_full_prompt = true;
615 } else if let Some(snapshots) = prefix_cacher
616 .get_paged_recurrent_prefix(
617 &block_hashes[..num_prefix_blocks],
618 )
619 {
620 if let Err(e) = hybrid_cache
621 .restore_recurrent_state(slot_idx, &snapshots)
622 {
623 tracing::warn!(
624 "Failed restoring paged recurrent prefix state for sequence {}: {e}",
625 seq.id()
626 );
627 fallback_to_full_prompt = true;
628 }
629 } else {
630 tracing::warn!(
631 "No recurrent prefix snapshot for sequence {} at cached prefix length {}; recomputing full prompt.",
632 seq.id(),
633 cached_prefix_len
634 );
635 fallback_to_full_prompt = true;
636 }
637 }
638 }
639
640 if fallback_to_full_prompt {
641 let seq_id = *seq.id();
642 let num_tokens = seq.get_toks().len();
643 let mut kv_mgr = get_mut_arcmutex!(kv_cache_manager);
644 kv_mgr.free(seq_id);
645 let realloc_ok = kv_mgr
646 .allocate_slots(seq_id, num_tokens, &[])
647 .is_some();
648 drop(kv_mgr);
649
650 if !realloc_ok {
651 tracing::warn!(
652 "Failed to reallocate fresh paged KV blocks for sequence {} after recurrent-prefix fallback.",
653 seq_id
654 );
655 seq.set_state(SequenceState::FinishedIgnored);
656 }
657 seq.set_prefix_cache_len(0);
658 }
659 }
660
661 guards_mut.retain(|seq| !seq.is_finished_paged_attn());
663 }
664
665 if guards_mut.is_empty() {
666 Ok(Duration::ZERO)
667 } else {
668 let metadata = PagedAttentionMeta {
669 block_size,
670 sliding_window: pipeline.get_metadata().sliding_window,
671 kv_cache_manager: scheduler.kv_cache_manager().unwrap(),
672 };
673
674 let return_raw_logits = guards_mut[0].return_raw_logits;
675 assert!(
676 guards_mut
677 .iter()
678 .all(|seq| seq.return_raw_logits == return_raw_logits),
679 "All sequences must either return raw logits, or not."
680 );
681
682 pipeline
683 .step(
684 &mut guards_mut,
685 is_prompt,
686 return_raw_logits,
687 &mut *get_mut_arcmutex!(self.prefix_cacher),
688 self.disable_eos_stop,
689 rng.clone(),
690 CacheBackendMetadata::PagedAttention { metadata },
691 )
692 .await
693 }
694 };
695
696 handle_pipeline_forward_error!(
697 "step",
698 res,
699 &mut guards_mut,
700 self.pipeline,
701 'lp,
702 self.prefix_cacher
703 );
704
705 let total_processed_tokens: usize = guards_mut
706 .iter()
707 .map(|seq| {
708 if seq.is_prompt() {
709 seq.get_toks().len()
710 } else {
711 1
712 }
713 })
714 .sum();
715 self.logger.add_tokens_processed(total_processed_tokens);
716
717 {
720 let pipeline = get_mut_arcmutex!(self.pipeline);
721 if pipeline.cache().is_hybrid() {
722 let block_size = scheduler.block_size().unwrap();
723 let hybrid_cache = pipeline.cache().hybrid();
724 let mut prefix_cacher = get_mut_arcmutex!(self.prefix_cacher);
725
726 for seq in guards_mut.iter() {
727 let seq_len = seq.get_toks().len();
728 if seq_len == 0 || seq_len % block_size != 0 {
729 continue;
730 }
731
732 let Some(slot_idx) = seq.recurrent_state_idx() else {
733 continue;
734 };
735
736 let snapshots = match hybrid_cache
737 .snapshot_recurrent_state(slot_idx)
738 {
739 Ok(snapshots) => snapshots,
740 Err(e) => {
741 tracing::warn!(
742 "Failed snapshotting recurrent state for sequence {}: {e}",
743 seq.id()
744 );
745 continue;
746 }
747 };
748 if snapshots.is_empty() {
749 continue;
750 }
751
752 let num_blocks = seq_len / block_size;
753 let block_hashes = compute_block_hashes(
754 seq.get_toks(),
755 block_size,
756 seq.mm_features(),
757 &[],
758 );
759 if block_hashes.len() < num_blocks {
760 continue;
761 }
762 prefix_cacher.add_paged_recurrent_prefix(
763 block_hashes[..num_blocks].to_vec(),
764 snapshots,
765 );
766 }
767 }
768 }
769
770 if self.is_debug {
771 let ms_from_last_run = run_start.elapsed().as_secs_f64();
772 let total_len = guards.len();
773 if total_len > 0 {
774 let lengths = guards
775 .iter()
776 .map(|seq| seq.len().to_string())
777 .collect::<Vec<_>>()
778 .join(", ");
779
780 let (prompt_lengths, completion_lengths) = if is_prompt {
781 (lengths, "".to_string())
782 } else {
783 ("".to_string(), lengths)
784 };
785
786 tracing::info!(
787 "Prompt[{}] Completion[{}] - {}ms",
788 prompt_lengths,
789 completion_lengths,
790 ms_from_last_run * 1000.,
791 );
792 }
793 }
794
795 if is_prompt {
796 #[allow(clippy::cast_precision_loss)]
797 for mut seq in guards {
798 if let Some(start) = seq.step_start_instant {
800 let duration = start.elapsed();
801 seq.prompt_tok_per_sec =
802 seq.len() as f32 / duration.as_secs_f32();
803 seq.total_prompt_time = Some(duration.as_millis());
804 seq.step_start_instant = None;
805 }
806 let now = SystemTime::now()
807 .duration_since(UNIX_EPOCH)
808 .expect("Time travel has occurred!")
809 .as_millis();
810 seq.prompt_timestamp = Some(now);
811 }
812 }
813 }
814 }
815 }
816
817 {
819 let pipeline = get_mut_arcmutex!(self.pipeline);
820 if !pipeline.get_metadata().no_kv_cache && pipeline.cache().is_hybrid() {
821 let recurrent_indices = scheduler.get_finished_recurrent_indices();
822 if !recurrent_indices.is_empty() {
823 let mut hybrid_cache = pipeline.cache().hybrid();
824 for idx in recurrent_indices {
825 hybrid_cache.free_seq(idx);
826 }
827 }
828 }
829 }
830 scheduler.free_finished_sequence_groups();
831 }
832 }
833
834 fn build_sequence_recognizer(
835 factory: &Option<Arc<ParserFactory>>,
836 constraint: &Constraint,
837 ) -> anyhow::Result<SequenceRecognizer> {
838 if let Some(grm) = llg_grammar_from_constraint(constraint)? {
839 let factory = factory
840 .as_ref()
841 .ok_or_else(|| anyhow::anyhow!("No token environment (llg_factory) found."))?;
842 let llg = constraint_from_llg_grammar(factory, grm)?;
843 Ok(SequenceRecognizer::Llguidance(Box::new(llg)))
844 } else {
845 Ok(SequenceRecognizer::None)
846 }
847 }
848
849 fn replicate_request_to_daemons(&self, request: &Request) {
850 if !distributed::is_daemon() && mistralrs_quant::distributed::use_nccl() {
851 let name = distributed::ipc_name().unwrap();
852 let num_workers =
853 mistralrs_quant::distributed::get_global_tp_size_from_devices().unwrap() - 1;
854 let listener = ListenerOptions::new().name(name).create_sync().unwrap();
855
856 for _ in 0..num_workers {
857 let stream = listener.accept().unwrap();
858 let mut writer = BufWriter::new(stream);
859 let req = format!("{}\n", serde_json::to_string(&request).unwrap());
860 writer.write_all(req.as_bytes()).unwrap();
861 }
862 } else if !distributed::is_daemon() && cfg!(feature = "ring") {
863 let num_workers =
864 mistralrs_quant::distributed::get_global_tp_size_from_devices().unwrap() - 1;
865 let master_port = RingConfig::load().master_port;
866 let listener =
867 TcpListener::bind(format!("0.0.0.0:{master_port}")).expect("bind replicator");
868
869 for _ in 0..num_workers {
870 let (stream, _) = listener.accept().unwrap();
871 let mut writer = BufWriter::new(stream);
872 let req = format!("{}\n", serde_json::to_string(&request).unwrap());
873 writer.write_all(req.as_bytes()).unwrap();
874 }
875 }
876 }
877}