Skip to main content

mistralrs_core/engine/
mod.rs

1use crate::{
2    distributed,
3    paged_attention::block_hash::compute_block_hashes,
4    pipeline::{
5        llg::{constraint_from_llg_grammar, llg_grammar_from_constraint},
6        text_models_inputs_processor::PagedAttentionMeta,
7        CacheBackendMetadata, CacheInstruction,
8    },
9    prefix_cacher::PrefixCacheManagerV2,
10    response::CompletionChoice,
11    scheduler::{Scheduler, SchedulerOutput},
12    search::{self, rag::SearchPipeline},
13    sequence::{SeqStepType, StopReason},
14    tools, CompletionResponse, SchedulerConfig, DEBUG,
15};
16use interprocess::local_socket::{traits::Listener, ListenerOptions};
17use llguidance::ParserFactory;
18pub use logger::IntervalLogger;
19use mistralrs_quant::RingConfig;
20use rand::SeedableRng;
21use rand_isaac::Isaac64Rng;
22use serde::{Deserialize, Serialize};
23use std::{
24    collections::HashMap,
25    fmt,
26    io::{BufWriter, Write},
27    net::TcpListener,
28    ops::Deref,
29    str::FromStr,
30    sync::{
31        atomic::{AtomicBool, Ordering},
32        Arc, LazyLock,
33    },
34    time::{Duration, Instant, SystemTime, UNIX_EPOCH},
35};
36use tokio::{
37    select,
38    sync::{
39        mpsc::{error::TryRecvError, Receiver, Sender},
40        Mutex, Notify,
41    },
42    task::JoinHandle,
43};
44
45use crate::{
46    get_mut_arcmutex, handle_pipeline_forward_error,
47    pipeline::{ModelCategory, Pipeline},
48    request::Request,
49    response::{ChatCompletionResponse, Choice, ResponseMessage},
50    sequence::{SequenceRecognizer, SequenceState},
51    Constraint,
52};
53
54mod add_request;
55mod logger;
56mod search_request;
57
58pub enum EngineInstruction {
59    Terminate,
60}
61
62#[derive(Debug, Default, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
63#[serde(rename_all = "snake_case")]
64/// Embedding model used for ranking web search results internally.
65pub enum SearchEmbeddingModel {
66    #[default]
67    #[serde(rename = "embedding_gemma")]
68    EmbeddingGemma300M,
69}
70
71impl SearchEmbeddingModel {
72    pub fn hf_model_id(&self) -> &'static str {
73        match self {
74            Self::EmbeddingGemma300M => "google/embeddinggemma-300m",
75        }
76    }
77
78    pub fn variants() -> &'static [&'static str] {
79        &["embedding_gemma"]
80    }
81}
82
83impl fmt::Display for SearchEmbeddingModel {
84    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
85        match self {
86            Self::EmbeddingGemma300M => f.write_str("embedding_gemma"),
87        }
88    }
89}
90
91impl FromStr for SearchEmbeddingModel {
92    type Err = String;
93
94    fn from_str(s: &str) -> Result<Self, Self::Err> {
95        match s.trim().to_ascii_lowercase().as_str() {
96            "embedding_gemma" => Ok(Self::EmbeddingGemma300M),
97            other => Err(format!(
98                "Unknown search embedding model `{other}`. Supported values: {}",
99                Self::variants().join(", ")
100            )),
101        }
102    }
103}
104
105const SEED: u64 = 0;
106/// Terminate all sequences on the next scheduling step. Be sure to reset this.
107/// This is a global flag for terminating all engines at once (e.g., Ctrl+C).
108pub static TERMINATE_ALL_NEXT_STEP: AtomicBool = AtomicBool::new(false);
109
110/// Engine-specific termination flags, per Engine thread ID.
111static ENGINE_TERMINATE_FLAGS: LazyLock<
112    std::sync::Mutex<HashMap<std::thread::ThreadId, Arc<AtomicBool>>>,
113> = LazyLock::new(|| std::sync::Mutex::new(HashMap::new()));
114
115/// Get or create a termination flag for the current engine thread.
116pub fn get_engine_terminate_flag() -> Arc<AtomicBool> {
117    let thread_id = std::thread::current().id();
118    let mut flags = ENGINE_TERMINATE_FLAGS.lock().unwrap();
119    flags
120        .entry(thread_id)
121        .or_insert_with(|| Arc::new(AtomicBool::new(false)))
122        .clone()
123}
124
125/// Check if the current engine should terminate sequences.
126pub fn should_terminate_engine_sequences() -> bool {
127    // Check global flag first
128    if TERMINATE_ALL_NEXT_STEP.load(Ordering::SeqCst) {
129        return true;
130    }
131    // Then check engine-specific flag
132    let thread_id = std::thread::current().id();
133    if let Ok(flags) = ENGINE_TERMINATE_FLAGS.lock() {
134        if let Some(flag) = flags.get(&thread_id) {
135            return flag.load(Ordering::SeqCst);
136        }
137    }
138    false
139}
140
141/// Reset termination flags for the current engine.
142pub fn reset_engine_terminate_flag() {
143    let thread_id = std::thread::current().id();
144    if let Ok(flags) = ENGINE_TERMINATE_FLAGS.lock() {
145        if let Some(flag) = flags.get(&thread_id) {
146            flag.store(false, Ordering::SeqCst);
147        }
148    }
149}
150
151/// Engine instructions, per Engine (MistralRs) ID.
152pub static ENGINE_INSTRUCTIONS: LazyLock<
153    std::sync::Mutex<HashMap<usize, Option<EngineInstruction>>>,
154> = LazyLock::new(|| std::sync::Mutex::new(HashMap::new()));
155
156pub struct Engine {
157    tx: Sender<Request>,
158    rx: Arc<Mutex<Receiver<Request>>>,
159    pipeline: Arc<Mutex<dyn Pipeline>>,
160    search_pipeline: Arc<Mutex<Option<SearchPipeline>>>,
161    search_callback: Option<Arc<search::SearchCallback>>,
162    tool_callbacks: tools::ToolCallbacks,
163    tool_callbacks_with_tools: tools::ToolCallbacksWithTools,
164    scheduler: Arc<Mutex<dyn Scheduler>>,
165    id: Arc<Mutex<usize>>,
166    no_kv_cache: bool,
167    prefix_cacher: Arc<Mutex<PrefixCacheManagerV2>>,
168    is_debug: bool,
169    disable_eos_stop: bool,
170    throughput_logging_enabled: bool,
171    logger: Arc<IntervalLogger>,
172    handles: Arc<Mutex<Vec<JoinHandle<()>>>>,
173    pending_notify: Arc<Notify>,
174}
175
176impl Drop for Engine {
177    fn drop(&mut self) {
178        for handle in &*get_mut_arcmutex!(self.handles) {
179            handle.abort();
180        }
181    }
182}
183
184impl Engine {
185    #[allow(clippy::too_many_arguments)]
186    pub fn new(
187        tx: Sender<Request>,
188        rx: Receiver<Request>,
189        pipeline: Arc<Mutex<dyn Pipeline>>,
190        config: SchedulerConfig,
191        mut no_kv_cache: bool,
192        mut no_prefix_cache: bool,
193        prefix_cache_n: usize,
194        disable_eos_stop: bool,
195        throughput_logging_enabled: bool,
196        search_embedding_model: Option<SearchEmbeddingModel>,
197        search_callback: Option<Arc<search::SearchCallback>>,
198        tool_callbacks: tools::ToolCallbacks,
199        tool_callbacks_with_tools: tools::ToolCallbacksWithTools,
200        logger: Arc<IntervalLogger>,
201    ) -> anyhow::Result<Self> {
202        no_kv_cache |= get_mut_arcmutex!(pipeline).get_metadata().no_kv_cache;
203
204        no_prefix_cache = no_prefix_cache
205            || no_kv_cache
206            || get_mut_arcmutex!(pipeline).get_metadata().no_prefix_cache
207            || prefix_cache_n == 0;
208
209        let search_pipeline = match search_embedding_model {
210            Some(search_embedding_model) => Some(SearchPipeline::new(
211                search_embedding_model,
212                &get_mut_arcmutex!(pipeline).device(),
213            )?),
214            None => None,
215        };
216
217        let scheduler = config.into_scheduler();
218
219        // Configure prefix caching on the scheduler based on the global no_prefix_cache flag
220        // This ensures PagedAttention prefix caching respects the same setting
221        get_mut_arcmutex!(scheduler).set_prefix_caching_enabled(!no_prefix_cache);
222
223        let has_paged_attention = get_mut_arcmutex!(scheduler).kv_cache_manager().is_some();
224
225        Ok(Self {
226            tx,
227            rx: Arc::new(Mutex::new(rx)),
228            pipeline,
229            search_pipeline: Arc::new(Mutex::new(search_pipeline)),
230            search_callback,
231            tool_callbacks,
232            tool_callbacks_with_tools,
233            scheduler: scheduler.clone(),
234            id: Arc::new(Mutex::new(0)),
235            no_kv_cache,
236            prefix_cacher: Arc::new(Mutex::new(PrefixCacheManagerV2::new(
237                prefix_cache_n,
238                no_prefix_cache,
239                has_paged_attention,
240            ))),
241            is_debug: DEBUG.load(Ordering::Relaxed),
242            disable_eos_stop,
243            throughput_logging_enabled,
244            logger,
245            handles: Arc::new(Mutex::new(Vec::new())),
246            pending_notify: Arc::new(Notify::new()),
247        })
248    }
249
250    /// Returns the maximum supported sequence length for the underlying model, if applicable.
251    #[allow(dead_code)]
252    pub fn max_sequence_length(&self) -> Option<usize> {
253        let pipeline = get_mut_arcmutex!(self.pipeline);
254        let category = pipeline.category();
255
256        if matches!(category, ModelCategory::Diffusion | ModelCategory::Speech) {
257            None
258        } else {
259            Some(pipeline.get_metadata().max_seq_len)
260        }
261    }
262
263    pub async fn run(self: Arc<Self>) {
264        if self.throughput_logging_enabled {
265            self.logger.enable_logging();
266        }
267
268        let rng = Arc::new(std::sync::Mutex::new(Isaac64Rng::seed_from_u64(SEED)));
269        let mut last_completion_ids: Vec<usize> = vec![];
270        'lp: loop {
271            let should_terminate = || {
272                matches!(
273                    ENGINE_INSTRUCTIONS
274                        .lock()
275                        .expect("`ENGINE_INSTRUCTIONS` was poisoned")
276                        .get(get_mut_arcmutex!(self.id).deref()),
277                    Some(Some(EngineInstruction::Terminate))
278                )
279            };
280
281            if should_terminate() {
282                self.replicate_request_to_daemons(&Request::Terminate);
283                break 'lp;
284            }
285
286            let mut channel_disconnected = false;
287            loop {
288                let next_request = {
289                    let mut rx = self.rx.lock().await;
290                    rx.try_recv()
291                };
292
293                match next_request {
294                    Ok(request) => {
295                        self.replicate_request_to_daemons(&request);
296                        if matches!(request, Request::Terminate) {
297                            break 'lp;
298                        }
299                        self.clone().handle_request(request).await;
300                    }
301                    Err(TryRecvError::Empty) => break,
302                    Err(TryRecvError::Disconnected) => {
303                        channel_disconnected = true;
304                        break;
305                    }
306                }
307            }
308
309            if channel_disconnected {
310                break 'lp;
311            }
312
313            let (waiting_len, running_len) = {
314                let scheduler = get_mut_arcmutex!(self.scheduler);
315                (scheduler.waiting_len(), scheduler.running_len())
316            };
317            let scheduler_idle = waiting_len == 0 && running_len == 0;
318
319            if scheduler_idle {
320                if should_terminate() {
321                    self.replicate_request_to_daemons(&Request::Terminate);
322                    break 'lp;
323                }
324                enum WaitEvent {
325                    Request(Option<Request>),
326                    Wake,
327                }
328                let wait_for_request = async {
329                    let mut rx = self.rx.lock().await;
330                    rx.recv().await
331                };
332                tokio::pin!(wait_for_request);
333                let wait_for_wake = self.pending_notify.notified();
334                tokio::pin!(wait_for_wake);
335
336                let event = select! {
337                    res = &mut wait_for_request => WaitEvent::Request(res),
338                    _ = &mut wait_for_wake => WaitEvent::Wake,
339                };
340
341                match event {
342                    WaitEvent::Request(Some(request)) => {
343                        self.replicate_request_to_daemons(&request);
344                        if matches!(request, Request::Terminate) {
345                            break 'lp;
346                        }
347                        self.clone().handle_request(request).await;
348                        continue;
349                    }
350                    WaitEvent::Request(None) => break 'lp,
351                    WaitEvent::Wake => {
352                        continue;
353                    }
354                }
355            }
356
357            if TERMINATE_ALL_NEXT_STEP.load(Ordering::SeqCst) {
358                self.replicate_request_to_daemons(&Request::TerminateAllSeqsNextStep);
359            }
360
361            let run_start = Instant::now();
362            let mut scheduler = get_mut_arcmutex!(self.scheduler);
363            let scheduled = scheduler.schedule(&self.logger);
364
365            match scheduled {
366                SchedulerOutput::DefaultScheduler {
367                    output: mut scheduled,
368                } => {
369                    if !scheduled.completion.is_empty() {
370                        let current_completion_ids: Vec<usize> =
371                            scheduled.completion.iter().map(|seq| *seq.id()).collect();
372                        let res = {
373                            let mut pipeline = get_mut_arcmutex!(self.pipeline);
374                            let pre_op = if !self.no_kv_cache
375                                && last_completion_ids != current_completion_ids
376                            {
377                                CacheInstruction::In
378                            } else {
379                                CacheInstruction::Nothing
380                            };
381                            let post_op = if !self.no_kv_cache {
382                                CacheInstruction::Out
383                            } else {
384                                CacheInstruction::Reset {
385                                    load_preallocated_cache: false,
386                                    reset_non_granular: false,
387                                }
388                            };
389
390                            let return_raw_logits = scheduled.completion[0].return_raw_logits;
391                            assert!(
392                                scheduled
393                                    .completion
394                                    .iter()
395                                    .all(|seq| seq.return_raw_logits == return_raw_logits),
396                                "All sequences must either return raw logits, or not."
397                            );
398
399                            pipeline
400                                .step(
401                                    &mut scheduled.completion,
402                                    false,
403                                    return_raw_logits,
404                                    &mut *get_mut_arcmutex!(self.prefix_cacher),
405                                    self.disable_eos_stop,
406                                    rng.clone(),
407                                    CacheBackendMetadata::DefaultInstructions { pre_op, post_op },
408                                )
409                                .await
410                        };
411
412                        handle_pipeline_forward_error!(
413                            "completion step",
414                            res,
415                            &mut scheduled.completion,
416                            self.pipeline,
417                            'lp,
418                            self.prefix_cacher
419                        );
420
421                        self.logger.add_tokens_processed(scheduled.completion.len());
422
423                        last_completion_ids = current_completion_ids;
424                    }
425
426                    if !scheduled.prompt.is_empty() {
427                        let prompt_exec_time = {
428                            let mut pipeline = get_mut_arcmutex!(self.pipeline);
429
430                            // Run the prompt seqs
431                            let post_op = if !self.no_kv_cache {
432                                CacheInstruction::Out
433                            } else {
434                                CacheInstruction::Reset {
435                                    load_preallocated_cache: false,
436                                    reset_non_granular: false,
437                                }
438                            };
439
440                            let return_raw_logits = scheduled.prompt[0].return_raw_logits;
441                            assert!(
442                                scheduled
443                                    .prompt
444                                    .iter()
445                                    .all(|seq| seq.return_raw_logits == return_raw_logits),
446                                "All sequences must either return raw logits, or not."
447                            );
448
449                            // This comes from prefix caching
450                            // The invariant where all token offsets are the same is handled by the scheduler
451                            let pre_op = if scheduled.prompt[0].token_offset() != 0 {
452                                CacheInstruction::In
453                            } else {
454                                CacheInstruction::Reset {
455                                    load_preallocated_cache: true,
456                                    reset_non_granular: false,
457                                }
458                            };
459
460                            pipeline
461                                .step(
462                                    &mut scheduled.prompt,
463                                    true,
464                                    return_raw_logits,
465                                    &mut *get_mut_arcmutex!(self.prefix_cacher),
466                                    self.disable_eos_stop,
467                                    rng.clone(),
468                                    CacheBackendMetadata::DefaultInstructions { pre_op, post_op },
469                                )
470                                .await
471                        };
472
473                        let prompt_exec_time = handle_pipeline_forward_error!(
474                            "prompt step",
475                            prompt_exec_time,
476                            &mut scheduled.prompt,
477                            self.pipeline,
478                            'lp,
479                            self.prefix_cacher
480                        );
481
482                        let total_processed_tokens: usize = scheduled
483                            .prompt
484                            .iter()
485                            .map(|seq| seq.get_toks().len())
486                            .sum();
487                        self.logger.add_tokens_processed(total_processed_tokens);
488
489                        for seq in scheduled.prompt.iter_mut() {
490                            match seq.sequence_stepping_type() {
491                                SeqStepType::OneShot => {
492                                    seq.set_state(SequenceState::Done(StopReason::GeneratedImage))
493                                }
494                                SeqStepType::PromptAndDecode => {
495                                    seq.set_state(SequenceState::RunningCompletion)
496                                }
497                            }
498                            let now = SystemTime::now()
499                                .duration_since(UNIX_EPOCH)
500                                .expect("Time travel has occurred!")
501                                .as_millis();
502                            #[allow(clippy::cast_precision_loss)]
503                            let prompt_tok_per_sec =
504                                seq.len() as f32 / prompt_exec_time.as_secs_f32();
505                            seq.prompt_tok_per_sec = prompt_tok_per_sec;
506                            seq.prompt_timestamp = Some(now);
507                            seq.total_prompt_time = Some(prompt_exec_time.as_millis());
508                            seq.step_start_instant = None;
509                        }
510                        last_completion_ids = vec![];
511                    }
512
513                    if self.is_debug {
514                        let ms_from_last_run = run_start.elapsed().as_secs_f64();
515                        let total_len = scheduled.prompt.len() + scheduled.completion.len();
516                        if total_len > 0 {
517                            let prompt_lengths = scheduled
518                                .prompt
519                                .iter()
520                                .map(|seq| seq.len().to_string())
521                                .collect::<Vec<_>>()
522                                .join(", ");
523
524                            let completion_lengths = scheduled
525                                .completion
526                                .iter()
527                                .map(|seq| seq.len().to_string())
528                                .collect::<Vec<_>>()
529                                .join(", ");
530
531                            tracing::info!(
532                                "Prompt[{}] Completion[{}] - {}ms",
533                                prompt_lengths,
534                                completion_lengths,
535                                ms_from_last_run * 1000.,
536                            );
537                        }
538                    }
539                }
540                SchedulerOutput::PagedAttention { mut output } => {
541                    if !output.scheduled.is_empty() {
542                        let is_prompt = get_mut_arcmutex!(output.scheduled[0]).is_prompt();
543
544                        // Record prompt timing BEFORE step() so it's available if response is sent inside step()
545                        if is_prompt {
546                            let now = SystemTime::now()
547                                .duration_since(UNIX_EPOCH)
548                                .expect("Time travel has occurred!")
549                                .as_millis();
550                            for seq in output.scheduled.iter() {
551                                let mut seq_guard = get_mut_arcmutex!(seq);
552                                seq_guard.prompt_timestamp = Some(now);
553                                // Start the timer using Instant for accurate duration measurement
554                                seq_guard.set_step_start_instant();
555                            }
556                        }
557
558                        let mut guards = output
559                            .scheduled
560                            .iter_mut()
561                            .map(|seq| seq.lock().unwrap())
562                            .collect::<Vec<_>>();
563
564                        let mut guards_mut =
565                            guards.iter_mut().map(|seq| &mut **seq).collect::<Vec<_>>();
566
567                        let res = {
568                            let mut pipeline = get_mut_arcmutex!(self.pipeline);
569
570                            let block_size = scheduler.block_size().unwrap();
571
572                            // For hybrid models under paged attention, restore recurrent state
573                            // from block-hash keyed prefix snapshots before prompt prefill.
574                            if is_prompt && pipeline.cache().is_hybrid() {
575                                let mut hybrid_cache = pipeline.cache().hybrid();
576                                let mut prefix_cacher = get_mut_arcmutex!(self.prefix_cacher);
577                                let kv_cache_manager = scheduler.kv_cache_manager().unwrap();
578
579                                for seq in guards_mut.iter_mut() {
580                                    let cached_prefix_len = seq.prefix_cache_len();
581                                    if cached_prefix_len == 0 {
582                                        continue;
583                                    }
584
585                                    let mut fallback_to_full_prompt = false;
586
587                                    let slot_idx = match seq.recurrent_state_idx() {
588                                        Some(idx) => idx,
589                                        None => {
590                                            tracing::warn!("Sequence {} has paged prefix hit but no recurrent_state_idx; recomputing full prompt.", seq.id());
591                                            fallback_to_full_prompt = true;
592                                            // Dummy value, unused in fallback path.
593                                            0usize
594                                        }
595                                    };
596
597                                    if !fallback_to_full_prompt {
598                                        if cached_prefix_len % block_size != 0 {
599                                            tracing::warn!(
600                                                "Sequence {} has non-aligned paged prefix len {}; recomputing full prompt.",
601                                                seq.id(),
602                                                cached_prefix_len
603                                            );
604                                            fallback_to_full_prompt = true;
605                                        } else {
606                                            let num_prefix_blocks = cached_prefix_len / block_size;
607                                            let block_hashes = compute_block_hashes(
608                                                seq.get_toks(),
609                                                block_size,
610                                                seq.mm_features(),
611                                                &[],
612                                            );
613                                            if block_hashes.len() < num_prefix_blocks {
614                                                fallback_to_full_prompt = true;
615                                            } else if let Some(snapshots) = prefix_cacher
616                                                .get_paged_recurrent_prefix(
617                                                    &block_hashes[..num_prefix_blocks],
618                                                )
619                                            {
620                                                if let Err(e) = hybrid_cache
621                                                    .restore_recurrent_state(slot_idx, &snapshots)
622                                                {
623                                                    tracing::warn!(
624                                                        "Failed restoring paged recurrent prefix state for sequence {}: {e}",
625                                                        seq.id()
626                                                    );
627                                                    fallback_to_full_prompt = true;
628                                                }
629                                            } else {
630                                                tracing::warn!(
631                                                    "No recurrent prefix snapshot for sequence {} at cached prefix length {}; recomputing full prompt.",
632                                                    seq.id(),
633                                                    cached_prefix_len
634                                                );
635                                                fallback_to_full_prompt = true;
636                                            }
637                                        }
638                                    }
639
640                                    if fallback_to_full_prompt {
641                                        let seq_id = *seq.id();
642                                        let num_tokens = seq.get_toks().len();
643                                        let mut kv_mgr = get_mut_arcmutex!(kv_cache_manager);
644                                        kv_mgr.free(seq_id);
645                                        let realloc_ok = kv_mgr
646                                            .allocate_slots(seq_id, num_tokens, &[])
647                                            .is_some();
648                                        drop(kv_mgr);
649
650                                        if !realloc_ok {
651                                            tracing::warn!(
652                                                "Failed to reallocate fresh paged KV blocks for sequence {} after recurrent-prefix fallback.",
653                                                seq_id
654                                            );
655                                            seq.set_state(SequenceState::FinishedIgnored);
656                                        }
657                                        seq.set_prefix_cache_len(0);
658                                    }
659                                }
660
661                                // Drop sequences that were canceled due fallback allocation failures.
662                                guards_mut.retain(|seq| !seq.is_finished_paged_attn());
663                            }
664
665                            if guards_mut.is_empty() {
666                                Ok(Duration::ZERO)
667                            } else {
668                                let metadata = PagedAttentionMeta {
669                                    block_size,
670                                    sliding_window: pipeline.get_metadata().sliding_window,
671                                    kv_cache_manager: scheduler.kv_cache_manager().unwrap(),
672                                };
673
674                                let return_raw_logits = guards_mut[0].return_raw_logits;
675                                assert!(
676                                    guards_mut
677                                        .iter()
678                                        .all(|seq| seq.return_raw_logits == return_raw_logits),
679                                    "All sequences must either return raw logits, or not."
680                                );
681
682                                pipeline
683                                    .step(
684                                        &mut guards_mut,
685                                        is_prompt,
686                                        return_raw_logits,
687                                        &mut *get_mut_arcmutex!(self.prefix_cacher),
688                                        self.disable_eos_stop,
689                                        rng.clone(),
690                                        CacheBackendMetadata::PagedAttention { metadata },
691                                    )
692                                    .await
693                            }
694                        };
695
696                        handle_pipeline_forward_error!(
697                            "step",
698                            res,
699                            &mut guards_mut,
700                            self.pipeline,
701                            'lp,
702                            self.prefix_cacher
703                        );
704
705                        let total_processed_tokens: usize = guards_mut
706                            .iter()
707                            .map(|seq| {
708                                if seq.is_prompt() {
709                                    seq.get_toks().len()
710                                } else {
711                                    1
712                                }
713                            })
714                            .sum();
715                        self.logger.add_tokens_processed(total_processed_tokens);
716
717                        // Capture recurrent states at full-block boundaries so hybrid models can
718                        // reuse recurrent prefix state when paged prefix caching hits.
719                        {
720                            let pipeline = get_mut_arcmutex!(self.pipeline);
721                            if pipeline.cache().is_hybrid() {
722                                let block_size = scheduler.block_size().unwrap();
723                                let hybrid_cache = pipeline.cache().hybrid();
724                                let mut prefix_cacher = get_mut_arcmutex!(self.prefix_cacher);
725
726                                for seq in guards_mut.iter() {
727                                    let seq_len = seq.get_toks().len();
728                                    if seq_len == 0 || seq_len % block_size != 0 {
729                                        continue;
730                                    }
731
732                                    let Some(slot_idx) = seq.recurrent_state_idx() else {
733                                        continue;
734                                    };
735
736                                    let snapshots = match hybrid_cache
737                                        .snapshot_recurrent_state(slot_idx)
738                                    {
739                                        Ok(snapshots) => snapshots,
740                                        Err(e) => {
741                                            tracing::warn!(
742                                                    "Failed snapshotting recurrent state for sequence {}: {e}",
743                                                    seq.id()
744                                                );
745                                            continue;
746                                        }
747                                    };
748                                    if snapshots.is_empty() {
749                                        continue;
750                                    }
751
752                                    let num_blocks = seq_len / block_size;
753                                    let block_hashes = compute_block_hashes(
754                                        seq.get_toks(),
755                                        block_size,
756                                        seq.mm_features(),
757                                        &[],
758                                    );
759                                    if block_hashes.len() < num_blocks {
760                                        continue;
761                                    }
762                                    prefix_cacher.add_paged_recurrent_prefix(
763                                        block_hashes[..num_blocks].to_vec(),
764                                        snapshots,
765                                    );
766                                }
767                            }
768                        }
769
770                        if self.is_debug {
771                            let ms_from_last_run = run_start.elapsed().as_secs_f64();
772                            let total_len = guards.len();
773                            if total_len > 0 {
774                                let lengths = guards
775                                    .iter()
776                                    .map(|seq| seq.len().to_string())
777                                    .collect::<Vec<_>>()
778                                    .join(", ");
779
780                                let (prompt_lengths, completion_lengths) = if is_prompt {
781                                    (lengths, "".to_string())
782                                } else {
783                                    ("".to_string(), lengths)
784                                };
785
786                                tracing::info!(
787                                    "Prompt[{}] Completion[{}] - {}ms",
788                                    prompt_lengths,
789                                    completion_lengths,
790                                    ms_from_last_run * 1000.,
791                                );
792                            }
793                        }
794
795                        if is_prompt {
796                            #[allow(clippy::cast_precision_loss)]
797                            for mut seq in guards {
798                                // Use Instant duration for accurate prompt timing
799                                if let Some(start) = seq.step_start_instant {
800                                    let duration = start.elapsed();
801                                    seq.prompt_tok_per_sec =
802                                        seq.len() as f32 / duration.as_secs_f32();
803                                    seq.total_prompt_time = Some(duration.as_millis());
804                                    seq.step_start_instant = None;
805                                }
806                                let now = SystemTime::now()
807                                    .duration_since(UNIX_EPOCH)
808                                    .expect("Time travel has occurred!")
809                                    .as_millis();
810                                seq.prompt_timestamp = Some(now);
811                            }
812                        }
813                    }
814                }
815            }
816
817            // Free recurrent state pool slots for finished sequences (hybrid models)
818            {
819                let pipeline = get_mut_arcmutex!(self.pipeline);
820                if !pipeline.get_metadata().no_kv_cache && pipeline.cache().is_hybrid() {
821                    let recurrent_indices = scheduler.get_finished_recurrent_indices();
822                    if !recurrent_indices.is_empty() {
823                        let mut hybrid_cache = pipeline.cache().hybrid();
824                        for idx in recurrent_indices {
825                            hybrid_cache.free_seq(idx);
826                        }
827                    }
828                }
829            }
830            scheduler.free_finished_sequence_groups();
831        }
832    }
833
834    fn build_sequence_recognizer(
835        factory: &Option<Arc<ParserFactory>>,
836        constraint: &Constraint,
837    ) -> anyhow::Result<SequenceRecognizer> {
838        if let Some(grm) = llg_grammar_from_constraint(constraint)? {
839            let factory = factory
840                .as_ref()
841                .ok_or_else(|| anyhow::anyhow!("No token environment (llg_factory) found."))?;
842            let llg = constraint_from_llg_grammar(factory, grm)?;
843            Ok(SequenceRecognizer::Llguidance(Box::new(llg)))
844        } else {
845            Ok(SequenceRecognizer::None)
846        }
847    }
848
849    fn replicate_request_to_daemons(&self, request: &Request) {
850        if !distributed::is_daemon() && mistralrs_quant::distributed::use_nccl() {
851            let name = distributed::ipc_name().unwrap();
852            let num_workers =
853                mistralrs_quant::distributed::get_global_tp_size_from_devices().unwrap() - 1;
854            let listener = ListenerOptions::new().name(name).create_sync().unwrap();
855
856            for _ in 0..num_workers {
857                let stream = listener.accept().unwrap();
858                let mut writer = BufWriter::new(stream);
859                let req = format!("{}\n", serde_json::to_string(&request).unwrap());
860                writer.write_all(req.as_bytes()).unwrap();
861            }
862        } else if !distributed::is_daemon() && cfg!(feature = "ring") {
863            let num_workers =
864                mistralrs_quant::distributed::get_global_tp_size_from_devices().unwrap() - 1;
865            let master_port = RingConfig::load().master_port;
866            let listener =
867                TcpListener::bind(format!("0.0.0.0:{master_port}")).expect("bind replicator");
868
869            for _ in 0..num_workers {
870                let (stream, _) = listener.accept().unwrap();
871                let mut writer = BufWriter::new(stream);
872                let req = format!("{}\n", serde_json::to_string(&request).unwrap());
873                writer.write_all(req.as_bytes()).unwrap();
874            }
875        }
876    }
877}