Skip to main content

hanzo_engine/engine/
mod.rs

1use crate::{
2    distributed,
3    paged_attention::block_hash::compute_block_hashes,
4    pipeline::{
5        llg::{constraint_from_llg_grammar, llg_grammar_from_constraint},
6        text_models_inputs_processor::PagedAttentionMeta,
7        CacheBackendMetadata, CacheInstruction,
8    },
9    prefix_cacher::PrefixCacheManagerV2,
10    response::CompletionChoice,
11    scheduler::{Scheduler, SchedulerOutput},
12    search::{self, rag::SearchPipeline},
13    sequence::{SeqStepType, StopReason},
14    tools, CompletionResponse, SchedulerConfig, DEBUG,
15};
16use hanzo_quant::RingConfig;
17use interprocess::local_socket::{traits::Listener, ListenerOptions};
18use llguidance::ParserFactory;
19pub use logger::IntervalLogger;
20use rand::SeedableRng;
21use rand_isaac::Isaac64Rng;
22use serde::{Deserialize, Serialize};
23use std::{
24    collections::HashMap,
25    fmt,
26    io::{BufWriter, Write},
27    net::TcpListener,
28    ops::Deref,
29    str::FromStr,
30    sync::{
31        atomic::{AtomicBool, Ordering},
32        Arc, LazyLock,
33    },
34    time::{Duration, Instant, SystemTime, UNIX_EPOCH},
35};
36use tokio::{
37    select,
38    sync::{
39        mpsc::{error::TryRecvError, Receiver, Sender},
40        Mutex, Notify,
41    },
42    task::JoinHandle,
43};
44
45use crate::{
46    get_mut_arcmutex, handle_pipeline_forward_error,
47    pipeline::{ModelCategory, Pipeline},
48    request::Request,
49    response::{ChatCompletionResponse, Choice, ResponseMessage},
50    sequence::{SequenceRecognizer, SequenceState},
51    Constraint,
52};
53
54mod add_request;
55pub(crate) mod agentic_loop;
56pub use agentic_loop::DEFAULT_MAX_TOOL_ROUNDS;
57pub(crate) mod agentic_session;
58mod file_tools;
59mod logger;
60mod tool_dispatch;
61
62pub enum EngineInstruction {
63    Terminate,
64}
65
66#[derive(Debug, Default, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
67#[serde(rename_all = "snake_case")]
68/// Embedding model used for ranking web search results internally.
69pub enum SearchEmbeddingModel {
70    #[default]
71    #[serde(rename = "embedding_gemma")]
72    EmbeddingGemma300M,
73}
74
75impl SearchEmbeddingModel {
76    pub fn hf_model_id(&self) -> &'static str {
77        match self {
78            Self::EmbeddingGemma300M => "google/embeddinggemma-300m",
79        }
80    }
81
82    pub fn variants() -> &'static [&'static str] {
83        &["embedding_gemma"]
84    }
85}
86
87impl fmt::Display for SearchEmbeddingModel {
88    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
89        match self {
90            Self::EmbeddingGemma300M => f.write_str("embedding_gemma"),
91        }
92    }
93}
94
95impl FromStr for SearchEmbeddingModel {
96    type Err = String;
97
98    fn from_str(s: &str) -> Result<Self, Self::Err> {
99        match s.trim().to_ascii_lowercase().as_str() {
100            "embedding_gemma" => Ok(Self::EmbeddingGemma300M),
101            other => Err(format!(
102                "Unknown search embedding model `{other}`. Supported values: {}",
103                Self::variants().join(", ")
104            )),
105        }
106    }
107}
108
109const SEED: u64 = 0;
110/// Terminate all sequences on the next scheduling step. Be sure to reset this.
111/// This is a global flag for terminating all engines at once (e.g., Ctrl+C).
112pub static TERMINATE_ALL_NEXT_STEP: AtomicBool = AtomicBool::new(false);
113
114/// Engine-specific termination flags, per Engine thread ID.
115static ENGINE_TERMINATE_FLAGS: LazyLock<
116    std::sync::Mutex<HashMap<std::thread::ThreadId, Arc<AtomicBool>>>,
117> = LazyLock::new(|| std::sync::Mutex::new(HashMap::new()));
118
119/// Get or create a termination flag for the current engine thread.
120pub fn get_engine_terminate_flag() -> Arc<AtomicBool> {
121    let thread_id = std::thread::current().id();
122    let mut flags = ENGINE_TERMINATE_FLAGS.lock().unwrap();
123    flags
124        .entry(thread_id)
125        .or_insert_with(|| Arc::new(AtomicBool::new(false)))
126        .clone()
127}
128
129/// Check if the current engine should terminate sequences.
130pub fn should_terminate_engine_sequences() -> bool {
131    // Check global flag first
132    if TERMINATE_ALL_NEXT_STEP.load(Ordering::SeqCst) {
133        return true;
134    }
135    // Then check engine-specific flag
136    let thread_id = std::thread::current().id();
137    if let Ok(flags) = ENGINE_TERMINATE_FLAGS.lock() {
138        if let Some(flag) = flags.get(&thread_id) {
139            return flag.load(Ordering::SeqCst);
140        }
141    }
142    false
143}
144
145/// Reset termination flags for the current engine.
146pub fn reset_engine_terminate_flag() {
147    let thread_id = std::thread::current().id();
148    if let Ok(flags) = ENGINE_TERMINATE_FLAGS.lock() {
149        if let Some(flag) = flags.get(&thread_id) {
150            flag.store(false, Ordering::SeqCst);
151        }
152    }
153}
154
155/// Engine instructions, per Engine (Hanzo) ID.
156pub static ENGINE_INSTRUCTIONS: LazyLock<
157    std::sync::Mutex<HashMap<usize, Option<EngineInstruction>>>,
158> = LazyLock::new(|| std::sync::Mutex::new(HashMap::new()));
159
160pub struct Engine {
161    tx: Sender<Request>,
162    rx: Arc<Mutex<Receiver<Request>>>,
163    pipeline: Arc<Mutex<dyn Pipeline>>,
164    search_pipeline: Arc<Mutex<Option<SearchPipeline>>>,
165    search_callback: Option<Arc<search::SearchCallback>>,
166    tool_callbacks: tools::ToolCallbacksWithTools,
167    scheduler: Arc<Mutex<dyn Scheduler>>,
168    id: Arc<Mutex<usize>>,
169    no_kv_cache: bool,
170    prefix_cacher: Arc<Mutex<PrefixCacheManagerV2>>,
171    is_debug: bool,
172    disable_eos_stop: bool,
173    throughput_logging_enabled: bool,
174    logger: Arc<IntervalLogger>,
175    handles: Arc<Mutex<Vec<JoinHandle<()>>>>,
176    pending_notify: Arc<Notify>,
177    pub(crate) session_store: Arc<std::sync::Mutex<agentic_session::AgenticSessionStore>>,
178    pub(crate) file_store: crate::files::FileStore,
179}
180
181impl Drop for Engine {
182    fn drop(&mut self) {
183        for handle in &*get_mut_arcmutex!(self.handles) {
184            handle.abort();
185        }
186    }
187}
188
189impl Engine {
190    #[allow(clippy::too_many_arguments)]
191    pub fn new(
192        tx: Sender<Request>,
193        rx: Receiver<Request>,
194        pipeline: Arc<Mutex<dyn Pipeline>>,
195        config: SchedulerConfig,
196        mut no_kv_cache: bool,
197        mut no_prefix_cache: bool,
198        prefix_cache_n: usize,
199        disable_eos_stop: bool,
200        throughput_logging_enabled: bool,
201        search_embedding_model: Option<SearchEmbeddingModel>,
202        search_callback: Option<Arc<search::SearchCallback>>,
203        tool_callbacks: tools::ToolCallbacksWithTools,
204        logger: Arc<IntervalLogger>,
205        session_store: Arc<std::sync::Mutex<agentic_session::AgenticSessionStore>>,
206        file_store: crate::files::FileStore,
207    ) -> anyhow::Result<Self> {
208        no_kv_cache |= get_mut_arcmutex!(pipeline).get_metadata().no_kv_cache;
209
210        no_prefix_cache = no_prefix_cache
211            || no_kv_cache
212            || get_mut_arcmutex!(pipeline).get_metadata().no_prefix_cache
213            || prefix_cache_n == 0;
214
215        let search_pipeline = match search_embedding_model {
216            Some(search_embedding_model) => Some(SearchPipeline::new(
217                search_embedding_model,
218                &get_mut_arcmutex!(pipeline).device(),
219            )?),
220            None => None,
221        };
222
223        let scheduler = config.into_scheduler();
224
225        // Configure prefix caching on the scheduler based on the global no_prefix_cache flag
226        // This ensures PagedAttention prefix caching respects the same setting
227        get_mut_arcmutex!(scheduler).set_prefix_caching_enabled(!no_prefix_cache);
228
229        let has_paged_attention = get_mut_arcmutex!(scheduler).kv_cache_manager().is_some();
230
231        Ok(Self {
232            tx,
233            rx: Arc::new(Mutex::new(rx)),
234            pipeline,
235            search_pipeline: Arc::new(Mutex::new(search_pipeline)),
236            search_callback,
237            tool_callbacks,
238            scheduler: scheduler.clone(),
239            id: Arc::new(Mutex::new(0)),
240            no_kv_cache,
241            prefix_cacher: Arc::new(Mutex::new(PrefixCacheManagerV2::new(
242                prefix_cache_n,
243                no_prefix_cache,
244                has_paged_attention,
245            ))),
246            is_debug: DEBUG.load(Ordering::Relaxed),
247            disable_eos_stop,
248            throughput_logging_enabled,
249            logger,
250            handles: Arc::new(Mutex::new(Vec::new())),
251            pending_notify: Arc::new(Notify::new()),
252            session_store,
253            file_store,
254        })
255    }
256
257    /// Returns the maximum supported sequence length for the underlying model, if applicable.
258    #[allow(dead_code)]
259    pub fn max_sequence_length(&self) -> Option<usize> {
260        let pipeline = get_mut_arcmutex!(self.pipeline);
261        let category = pipeline.category();
262
263        if matches!(category, ModelCategory::Diffusion | ModelCategory::Speech) {
264            None
265        } else {
266            Some(pipeline.get_metadata().max_seq_len)
267        }
268    }
269
270    pub async fn run(self: Arc<Self>) {
271        if self.throughput_logging_enabled {
272            self.logger.enable_logging();
273        }
274
275        let rng = Arc::new(std::sync::Mutex::new(Isaac64Rng::seed_from_u64(SEED)));
276        let mut last_completion_ids: Vec<usize> = vec![];
277        'lp: loop {
278            let should_terminate = || {
279                matches!(
280                    ENGINE_INSTRUCTIONS
281                        .lock()
282                        .expect("`ENGINE_INSTRUCTIONS` was poisoned")
283                        .get(get_mut_arcmutex!(self.id).deref()),
284                    Some(Some(EngineInstruction::Terminate))
285                )
286            };
287
288            if should_terminate() {
289                self.replicate_request_to_daemons(&Request::Terminate);
290                break 'lp;
291            }
292
293            let mut channel_disconnected = false;
294            loop {
295                let next_request = {
296                    let mut rx = self.rx.lock().await;
297                    rx.try_recv()
298                };
299
300                match next_request {
301                    Ok(request) => {
302                        self.replicate_request_to_daemons(&request);
303                        if matches!(request, Request::Terminate) {
304                            break 'lp;
305                        }
306                        self.clone().handle_request(request).await;
307                    }
308                    Err(TryRecvError::Empty) => break,
309                    Err(TryRecvError::Disconnected) => {
310                        channel_disconnected = true;
311                        break;
312                    }
313                }
314            }
315
316            if channel_disconnected {
317                break 'lp;
318            }
319
320            let (waiting_len, running_len) = {
321                let scheduler = get_mut_arcmutex!(self.scheduler);
322                (scheduler.waiting_len(), scheduler.running_len())
323            };
324            let scheduler_idle = waiting_len == 0 && running_len == 0;
325
326            if scheduler_idle {
327                if should_terminate() {
328                    self.replicate_request_to_daemons(&Request::Terminate);
329                    break 'lp;
330                }
331                enum WaitEvent {
332                    Request(Option<Request>),
333                    Wake,
334                }
335                let wait_for_request = async {
336                    let mut rx = self.rx.lock().await;
337                    rx.recv().await
338                };
339                tokio::pin!(wait_for_request);
340                let wait_for_wake = self.pending_notify.notified();
341                tokio::pin!(wait_for_wake);
342
343                let event = select! {
344                    res = &mut wait_for_request => WaitEvent::Request(res),
345                    _ = &mut wait_for_wake => WaitEvent::Wake,
346                };
347
348                match event {
349                    WaitEvent::Request(Some(request)) => {
350                        self.replicate_request_to_daemons(&request);
351                        if matches!(request, Request::Terminate) {
352                            break 'lp;
353                        }
354                        self.clone().handle_request(request).await;
355                        continue;
356                    }
357                    WaitEvent::Request(None) => break 'lp,
358                    WaitEvent::Wake => {
359                        continue;
360                    }
361                }
362            }
363
364            if TERMINATE_ALL_NEXT_STEP.load(Ordering::SeqCst) {
365                self.replicate_request_to_daemons(&Request::TerminateAllSeqsNextStep);
366            }
367
368            let run_start = Instant::now();
369            let mut scheduler = get_mut_arcmutex!(self.scheduler);
370            let scheduled = scheduler.schedule(&self.logger);
371
372            match scheduled {
373                SchedulerOutput::DefaultScheduler {
374                    output: mut scheduled,
375                } => {
376                    if !scheduled.completion.is_empty() {
377                        let current_completion_ids: Vec<usize> =
378                            scheduled.completion.iter().map(|seq| *seq.id()).collect();
379                        let res = {
380                            let mut pipeline = get_mut_arcmutex!(self.pipeline);
381                            let pre_op = if !self.no_kv_cache
382                                && last_completion_ids != current_completion_ids
383                            {
384                                CacheInstruction::In
385                            } else {
386                                CacheInstruction::Nothing
387                            };
388                            let post_op = if !self.no_kv_cache {
389                                CacheInstruction::Out
390                            } else {
391                                CacheInstruction::Reset {
392                                    load_preallocated_cache: false,
393                                    reset_non_granular: false,
394                                }
395                            };
396
397                            let return_raw_logits = scheduled.completion[0].return_raw_logits;
398                            assert!(
399                                scheduled
400                                    .completion
401                                    .iter()
402                                    .all(|seq| seq.return_raw_logits == return_raw_logits),
403                                "All sequences must either return raw logits, or not."
404                            );
405
406                            pipeline
407                                .step(
408                                    &mut scheduled.completion,
409                                    false,
410                                    return_raw_logits,
411                                    &mut *get_mut_arcmutex!(self.prefix_cacher),
412                                    self.disable_eos_stop,
413                                    rng.clone(),
414                                    CacheBackendMetadata::DefaultInstructions { pre_op, post_op },
415                                )
416                                .await
417                        };
418
419                        handle_pipeline_forward_error!(
420                            "completion step",
421                            res,
422                            &mut scheduled.completion,
423                            self.pipeline,
424                            'lp,
425                            self.prefix_cacher
426                        );
427
428                        self.logger.add_tokens_processed(scheduled.completion.len());
429
430                        last_completion_ids = current_completion_ids;
431                    }
432
433                    if !scheduled.prompt.is_empty() {
434                        // Mirror the paged-attn arm: prime timing fields before step()
435                        // so update_time_info called from inside sampling sees them.
436                        let pre_step_now = SystemTime::now()
437                            .duration_since(UNIX_EPOCH)
438                            .expect("Time travel has occurred!")
439                            .as_millis();
440                        for seq in scheduled.prompt.iter_mut() {
441                            seq.prompt_timestamp = Some(pre_step_now);
442                            seq.set_step_start_instant();
443                        }
444
445                        let prompt_exec_time = {
446                            let mut pipeline = get_mut_arcmutex!(self.pipeline);
447
448                            // Run the prompt seqs
449                            let post_op = if !self.no_kv_cache {
450                                CacheInstruction::Out
451                            } else {
452                                CacheInstruction::Reset {
453                                    load_preallocated_cache: false,
454                                    reset_non_granular: false,
455                                }
456                            };
457
458                            let return_raw_logits = scheduled.prompt[0].return_raw_logits;
459                            assert!(
460                                scheduled
461                                    .prompt
462                                    .iter()
463                                    .all(|seq| seq.return_raw_logits == return_raw_logits),
464                                "All sequences must either return raw logits, or not."
465                            );
466
467                            // This comes from prefix caching
468                            // The invariant where all token offsets are the same is handled by the scheduler
469                            let pre_op = if scheduled.prompt[0].token_offset() != 0 {
470                                CacheInstruction::In
471                            } else {
472                                CacheInstruction::Reset {
473                                    load_preallocated_cache: true,
474                                    reset_non_granular: false,
475                                }
476                            };
477
478                            pipeline
479                                .step(
480                                    &mut scheduled.prompt,
481                                    true,
482                                    return_raw_logits,
483                                    &mut *get_mut_arcmutex!(self.prefix_cacher),
484                                    self.disable_eos_stop,
485                                    rng.clone(),
486                                    CacheBackendMetadata::DefaultInstructions { pre_op, post_op },
487                                )
488                                .await
489                        };
490
491                        let prompt_exec_time = handle_pipeline_forward_error!(
492                            "prompt step",
493                            prompt_exec_time,
494                            &mut scheduled.prompt,
495                            self.pipeline,
496                            'lp,
497                            self.prefix_cacher
498                        );
499
500                        let total_processed_tokens: usize = scheduled
501                            .prompt
502                            .iter()
503                            .map(|seq| seq.get_toks().len())
504                            .sum();
505                        self.logger.add_tokens_processed(total_processed_tokens);
506
507                        for seq in scheduled.prompt.iter_mut() {
508                            match seq.sequence_stepping_type() {
509                                SeqStepType::OneShot => {
510                                    seq.set_state(SequenceState::Done(StopReason::GeneratedImage))
511                                }
512                                SeqStepType::PromptAndDecode => {
513                                    seq.set_state(SequenceState::RunningCompletion)
514                                }
515                            }
516                            let now = SystemTime::now()
517                                .duration_since(UNIX_EPOCH)
518                                .expect("Time travel has occurred!")
519                                .as_millis();
520                            #[allow(clippy::cast_precision_loss)]
521                            let prompt_tok_per_sec =
522                                seq.len() as f32 / prompt_exec_time.as_secs_f32();
523                            seq.prompt_tok_per_sec = prompt_tok_per_sec;
524                            seq.prompt_timestamp = Some(now);
525                            seq.total_prompt_time = Some(prompt_exec_time.as_millis());
526                            seq.step_start_instant = None;
527                        }
528                        last_completion_ids = vec![];
529                    }
530
531                    if self.is_debug {
532                        let ms_from_last_run = run_start.elapsed().as_secs_f64();
533                        let total_len = scheduled.prompt.len() + scheduled.completion.len();
534                        if total_len > 0 {
535                            let prompt_lengths = scheduled
536                                .prompt
537                                .iter()
538                                .map(|seq| seq.len().to_string())
539                                .collect::<Vec<_>>()
540                                .join(", ");
541
542                            let completion_lengths = scheduled
543                                .completion
544                                .iter()
545                                .map(|seq| seq.len().to_string())
546                                .collect::<Vec<_>>()
547                                .join(", ");
548
549                            tracing::info!(
550                                "Prompt[{}] Completion[{}] - {}ms",
551                                prompt_lengths,
552                                completion_lengths,
553                                ms_from_last_run * 1000.,
554                            );
555                        }
556                    }
557                }
558                SchedulerOutput::PagedAttention { mut output } => {
559                    if !output.scheduled.is_empty() {
560                        let is_prompt = get_mut_arcmutex!(output.scheduled[0]).is_prompt();
561
562                        // Record prompt timing BEFORE step() so it's available if response is sent inside step()
563                        if is_prompt {
564                            let now = SystemTime::now()
565                                .duration_since(UNIX_EPOCH)
566                                .expect("Time travel has occurred!")
567                                .as_millis();
568                            for seq in output.scheduled.iter() {
569                                let mut seq_guard = get_mut_arcmutex!(seq);
570                                seq_guard.prompt_timestamp = Some(now);
571                                // Start the timer using Instant for accurate duration measurement
572                                seq_guard.set_step_start_instant();
573                            }
574                        }
575
576                        let mut guards = output
577                            .scheduled
578                            .iter_mut()
579                            .map(|seq| seq.lock().unwrap())
580                            .collect::<Vec<_>>();
581
582                        let mut guards_mut =
583                            guards.iter_mut().map(|seq| &mut **seq).collect::<Vec<_>>();
584
585                        let res = {
586                            let mut pipeline = get_mut_arcmutex!(self.pipeline);
587
588                            let block_size = scheduler.block_size().unwrap();
589
590                            // For hybrid models under paged attention, restore recurrent state
591                            // from block-hash keyed prefix snapshots before prompt prefill.
592                            if is_prompt && pipeline.cache().is_hybrid() {
593                                let mut hybrid_cache = pipeline.cache().hybrid();
594                                let mut prefix_cacher = get_mut_arcmutex!(self.prefix_cacher);
595                                let kv_cache_manager = scheduler.kv_cache_manager().unwrap();
596
597                                for seq in guards_mut.iter_mut() {
598                                    let cached_prefix_len = seq.prefix_cache_len();
599                                    if cached_prefix_len == 0 {
600                                        continue;
601                                    }
602
603                                    let mut fallback_to_full_prompt = false;
604
605                                    let slot_idx = match seq.recurrent_state_idx() {
606                                        Some(idx) => idx,
607                                        None => {
608                                            tracing::warn!("Sequence {} has paged prefix hit but no recurrent_state_idx; recomputing full prompt.", seq.id());
609                                            fallback_to_full_prompt = true;
610                                            // Dummy value, unused in fallback path.
611                                            0usize
612                                        }
613                                    };
614
615                                    if !fallback_to_full_prompt {
616                                        if cached_prefix_len % block_size != 0 {
617                                            tracing::warn!(
618                                                "Sequence {} has non-aligned paged prefix len {}; recomputing full prompt.",
619                                                seq.id(),
620                                                cached_prefix_len
621                                            );
622                                            fallback_to_full_prompt = true;
623                                        } else {
624                                            let num_prefix_blocks = cached_prefix_len / block_size;
625                                            let block_hashes = compute_block_hashes(
626                                                seq.get_toks(),
627                                                block_size,
628                                                seq.mm_features(),
629                                                &[],
630                                            );
631                                            if block_hashes.len() < num_prefix_blocks {
632                                                fallback_to_full_prompt = true;
633                                            } else if let Some(snapshots) = prefix_cacher
634                                                .get_paged_recurrent_prefix(
635                                                    &block_hashes[..num_prefix_blocks],
636                                                )
637                                            {
638                                                if let Err(e) = hybrid_cache
639                                                    .restore_recurrent_state(slot_idx, &snapshots)
640                                                {
641                                                    tracing::warn!(
642                                                        "Failed restoring paged recurrent prefix state for sequence {}: {e}",
643                                                        seq.id()
644                                                    );
645                                                    fallback_to_full_prompt = true;
646                                                }
647                                            } else {
648                                                tracing::warn!(
649                                                    "No recurrent prefix snapshot for sequence {} at cached prefix length {}; recomputing full prompt.",
650                                                    seq.id(),
651                                                    cached_prefix_len
652                                                );
653                                                fallback_to_full_prompt = true;
654                                            }
655                                        }
656                                    }
657
658                                    if fallback_to_full_prompt {
659                                        let seq_id = *seq.id();
660                                        let num_tokens = seq.get_toks().len();
661                                        let mut kv_mgr = get_mut_arcmutex!(kv_cache_manager);
662                                        kv_mgr.free(seq_id);
663                                        let realloc_ok = kv_mgr
664                                            .allocate_slots(seq_id, num_tokens, &[])
665                                            .is_some();
666                                        drop(kv_mgr);
667
668                                        if !realloc_ok {
669                                            tracing::warn!(
670                                                "Failed to reallocate fresh paged KV blocks for sequence {} after recurrent-prefix fallback.",
671                                                seq_id
672                                            );
673                                            seq.set_state(SequenceState::FinishedIgnored);
674                                        }
675                                        seq.set_prefix_cache_len(0);
676                                    }
677                                }
678
679                                // Drop sequences that were canceled due fallback allocation failures.
680                                guards_mut.retain(|seq| !seq.is_finished_paged_attn());
681                            }
682
683                            if guards_mut.is_empty() {
684                                Ok(Duration::ZERO)
685                            } else {
686                                let metadata = PagedAttentionMeta {
687                                    block_size,
688                                    sliding_window: pipeline.get_metadata().sliding_window,
689                                    kv_cache_manager: scheduler.kv_cache_manager().unwrap(),
690                                };
691
692                                let return_raw_logits = guards_mut[0].return_raw_logits;
693                                assert!(
694                                    guards_mut
695                                        .iter()
696                                        .all(|seq| seq.return_raw_logits == return_raw_logits),
697                                    "All sequences must either return raw logits, or not."
698                                );
699
700                                pipeline
701                                    .step(
702                                        &mut guards_mut,
703                                        is_prompt,
704                                        return_raw_logits,
705                                        &mut *get_mut_arcmutex!(self.prefix_cacher),
706                                        self.disable_eos_stop,
707                                        rng.clone(),
708                                        CacheBackendMetadata::PagedAttention { metadata },
709                                    )
710                                    .await
711                            }
712                        };
713
714                        handle_pipeline_forward_error!(
715                            "step",
716                            res,
717                            &mut guards_mut,
718                            self.pipeline,
719                            'lp,
720                            self.prefix_cacher
721                        );
722
723                        let total_processed_tokens: usize = guards_mut
724                            .iter()
725                            .map(|seq| {
726                                if seq.is_prompt() {
727                                    seq.get_toks().len()
728                                } else {
729                                    1
730                                }
731                            })
732                            .sum();
733                        self.logger.add_tokens_processed(total_processed_tokens);
734
735                        // Capture recurrent states at full-block boundaries so hybrid models can
736                        // reuse recurrent prefix state when paged prefix caching hits.
737                        {
738                            let pipeline = get_mut_arcmutex!(self.pipeline);
739                            if pipeline.cache().is_hybrid() {
740                                let block_size = scheduler.block_size().unwrap();
741                                let hybrid_cache = pipeline.cache().hybrid();
742                                let mut prefix_cacher = get_mut_arcmutex!(self.prefix_cacher);
743
744                                for seq in guards_mut.iter() {
745                                    let seq_len = seq.get_toks().len();
746                                    if seq_len == 0 || seq_len % block_size != 0 {
747                                        continue;
748                                    }
749
750                                    let Some(slot_idx) = seq.recurrent_state_idx() else {
751                                        continue;
752                                    };
753
754                                    let snapshots = match hybrid_cache
755                                        .snapshot_recurrent_state(slot_idx)
756                                    {
757                                        Ok(snapshots) => snapshots,
758                                        Err(e) => {
759                                            tracing::warn!(
760                                                    "Failed snapshotting recurrent state for sequence {}: {e}",
761                                                    seq.id()
762                                                );
763                                            continue;
764                                        }
765                                    };
766                                    if snapshots.is_empty() {
767                                        continue;
768                                    }
769
770                                    let num_blocks = seq_len / block_size;
771                                    let block_hashes = compute_block_hashes(
772                                        seq.get_toks(),
773                                        block_size,
774                                        seq.mm_features(),
775                                        &[],
776                                    );
777                                    if block_hashes.len() < num_blocks {
778                                        continue;
779                                    }
780                                    prefix_cacher.add_paged_recurrent_prefix(
781                                        block_hashes[..num_blocks].to_vec(),
782                                        snapshots,
783                                    );
784                                }
785                            }
786                        }
787
788                        if self.is_debug {
789                            let ms_from_last_run = run_start.elapsed().as_secs_f64();
790                            let total_len = guards.len();
791                            if total_len > 0 {
792                                let lengths = guards
793                                    .iter()
794                                    .map(|seq| seq.len().to_string())
795                                    .collect::<Vec<_>>()
796                                    .join(", ");
797
798                                let (prompt_lengths, completion_lengths) = if is_prompt {
799                                    (lengths, "".to_string())
800                                } else {
801                                    ("".to_string(), lengths)
802                                };
803
804                                tracing::info!(
805                                    "Prompt[{}] Completion[{}] - {}ms",
806                                    prompt_lengths,
807                                    completion_lengths,
808                                    ms_from_last_run * 1000.,
809                                );
810                            }
811                        }
812
813                        if is_prompt {
814                            #[allow(clippy::cast_precision_loss)]
815                            for mut seq in guards {
816                                // Use Instant duration for accurate prompt timing
817                                if let Some(start) = seq.step_start_instant {
818                                    let duration = start.elapsed();
819                                    seq.prompt_tok_per_sec =
820                                        seq.len() as f32 / duration.as_secs_f32();
821                                    seq.total_prompt_time = Some(duration.as_millis());
822                                    seq.step_start_instant = None;
823                                }
824                                let now = SystemTime::now()
825                                    .duration_since(UNIX_EPOCH)
826                                    .expect("Time travel has occurred!")
827                                    .as_millis();
828                                seq.prompt_timestamp = Some(now);
829                            }
830                        }
831                    }
832                }
833            }
834
835            // Free recurrent state pool slots for finished sequences (hybrid models)
836            {
837                let pipeline = get_mut_arcmutex!(self.pipeline);
838                if !pipeline.get_metadata().no_kv_cache && pipeline.cache().is_hybrid() {
839                    let recurrent_indices = scheduler.get_finished_recurrent_indices();
840                    if !recurrent_indices.is_empty() {
841                        let mut hybrid_cache = pipeline.cache().hybrid();
842                        for idx in recurrent_indices {
843                            hybrid_cache.free_seq(idx);
844                        }
845                    }
846                }
847            }
848            scheduler.free_finished_sequence_groups();
849        }
850    }
851
852    fn build_sequence_recognizer(
853        factory: &Option<Arc<ParserFactory>>,
854        constraint: &Constraint,
855    ) -> anyhow::Result<SequenceRecognizer> {
856        if let Some(grm) = llg_grammar_from_constraint(constraint)? {
857            let factory = factory
858                .as_ref()
859                .ok_or_else(|| anyhow::anyhow!("No token environment (llg_factory) found."))?;
860            let llg = constraint_from_llg_grammar(factory, grm)?;
861            Ok(SequenceRecognizer::Llguidance(Box::new(llg)))
862        } else {
863            Ok(SequenceRecognizer::None)
864        }
865    }
866
867    fn replicate_request_to_daemons(&self, request: &Request) {
868        if !distributed::is_daemon() && hanzo_quant::distributed::use_nccl() {
869            let name = distributed::ipc_name().unwrap();
870            let num_workers =
871                hanzo_quant::distributed::get_global_tp_size_from_devices().unwrap() - 1;
872            let listener = ListenerOptions::new().name(name).create_sync().unwrap();
873
874            for _ in 0..num_workers {
875                let stream = listener.accept().unwrap();
876                let mut writer = BufWriter::new(stream);
877                let req = format!("{}\n", serde_json::to_string(&request).unwrap());
878                writer.write_all(req.as_bytes()).unwrap();
879            }
880        } else if !distributed::is_daemon() && cfg!(feature = "ring") {
881            let num_workers =
882                hanzo_quant::distributed::get_global_tp_size_from_devices().unwrap() - 1;
883            let master_port = RingConfig::load().master_port;
884            let listener =
885                TcpListener::bind(format!("0.0.0.0:{master_port}")).expect("bind replicator");
886
887            for _ in 0..num_workers {
888                let (stream, _) = listener.accept().unwrap();
889                let mut writer = BufWriter::new(stream);
890                let req = format!("{}\n", serde_json::to_string(&request).unwrap());
891                writer.write_all(req.as_bytes()).unwrap();
892            }
893        }
894    }
895}